gistfile1.txt package @@@@@@@@@@@@@@import java.util.LinkedList;import java.util.List;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.mllib
package @@@@@@@@@@@@@@ import java.util.LinkedList; import java.util.List; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.classification.NaiveBayes; import org.apache.spark.mllib.classification.NaiveBayesModel; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; public class NaiveBayesModel_test { @SuppressWarnings("unchecked") public static void main(String[] args) { /** * 链接并且初始化 */ SparkConf conf = new SparkConf().setAppName("myApp").setMaster("local[*]"); @SuppressWarnings("resource") JavaSparkContext jsc = new JavaSparkContext(conf); // 训练集生成,有标签的 LabeledPoint lp1 = new LabeledPoint(1.0, Vectors.dense(2.0, 3.0, 3.0)); LabeledPoint lp2 = new LabeledPoint(0.0, Vectors.dense(1.0, 6.0, 7.0)); // 规定数据结构为LabeledPoint,1.0为类别标号,Vectors.dense(2.0, 3.0, 3.0)为特征向量 // LabeledPoint neg = new LabeledPoint(0.0, // Vectors.sparse(3, new int[] { 2, 1, 1 }, new double[] { 1.0, 1.0, 1.0 }));// // 特征值稀疏时,利用sparse构建 @SuppressWarnings("rawtypes") List list = new LinkedList();// 利用List存放训练样本 // 添加训练集 list.add(lp1); list.add(lp2); list.add(new LabeledPoint(1.0, Vectors.dense(2.0, 2.0, 3.0))); list.add(new LabeledPoint(1.0, Vectors.dense(2.0, 1.0, 3.0))); // RDD化,泛化类型为LabeledPoint 而不是List JavaRDDtraining = jsc.parallelize(list); final NaiveBayesModel nb_model = NaiveBayes.train(training.rdd()); // 测试集生成 double[] d = { 1, 1, 2 }; Vector v = Vectors.dense(d);// 测试对象为单个vector,或者是RDD化后的vector // 朴素贝叶斯 System.out.println(nb_model.predict(v));// 分类结果 System.out.println(nb_model.predictProbabilities(v)); // 计算概率值 } }