gistfile1.txt package @@@@@@@@@@@@@@import java.util.LinkedList;import java.util.List;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.mllib
package @@@@@@@@@@@@@@
import java.util.LinkedList;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
public class NaiveBayesModel_test {
@SuppressWarnings("unchecked")
public static void main(String[] args) {
/**
* 链接并且初始化
*/
SparkConf conf = new SparkConf().setAppName("myApp").setMaster("local[*]");
@SuppressWarnings("resource")
JavaSparkContext jsc = new JavaSparkContext(conf);
// 训练集生成,有标签的
LabeledPoint lp1 = new LabeledPoint(1.0, Vectors.dense(2.0, 3.0, 3.0));
LabeledPoint lp2 = new LabeledPoint(0.0, Vectors.dense(1.0, 6.0, 7.0));
// 规定数据结构为LabeledPoint,1.0为类别标号,Vectors.dense(2.0, 3.0, 3.0)为特征向量
// LabeledPoint neg = new LabeledPoint(0.0,
// Vectors.sparse(3, new int[] { 2, 1, 1 }, new double[] { 1.0, 1.0, 1.0 }));//
// 特征值稀疏时,利用sparse构建
@SuppressWarnings("rawtypes")
List list = new LinkedList();// 利用List存放训练样本
// 添加训练集
list.add(lp1);
list.add(lp2);
list.add(new LabeledPoint(1.0, Vectors.dense(2.0, 2.0, 3.0)));
list.add(new LabeledPoint(1.0, Vectors.dense(2.0, 1.0, 3.0)));
// RDD化,泛化类型为LabeledPoint 而不是List
JavaRDD
training = jsc.parallelize(list);
final NaiveBayesModel nb_model = NaiveBayes.train(training.rdd());
// 测试集生成
double[] d = { 1, 1, 2 };
Vector v = Vectors.dense(d);// 测试对象为单个vector,或者是RDD化后的vector
// 朴素贝叶斯
System.out.println(nb_model.predict(v));// 分类结果
System.out.println(nb_model.predictProbabilities(v)); // 计算概率值
}
}
