当前位置 : 主页 > 编程语言 > java >

MLlib javaAPI-朴素贝叶斯分类demo

来源:互联网 收集:自由互联 发布时间:2021-06-28
gistfile1.txt package @@@@@@@@@@@@@@import java.util.LinkedList;import java.util.List;import org.apache.spark.SparkConf;import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.JavaSparkContext;import org.apache.spark.mllib
gistfile1.txt
package @@@@@@@@@@@@@@

import java.util.LinkedList;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;

public class NaiveBayesModel_test {
	@SuppressWarnings("unchecked")
	public static void main(String[] args) {
		/**
		 * 链接并且初始化
		 */
		SparkConf conf = new SparkConf().setAppName("myApp").setMaster("local[*]");
		@SuppressWarnings("resource")
		JavaSparkContext jsc = new JavaSparkContext(conf);

		// 训练集生成,有标签的
		LabeledPoint lp1 = new LabeledPoint(1.0, Vectors.dense(2.0, 3.0, 3.0));
		LabeledPoint lp2 = new LabeledPoint(0.0, Vectors.dense(1.0, 6.0, 7.0));

		// 规定数据结构为LabeledPoint,1.0为类别标号,Vectors.dense(2.0, 3.0, 3.0)为特征向量
		// LabeledPoint neg = new LabeledPoint(0.0,
		// Vectors.sparse(3, new int[] { 2, 1, 1 }, new double[] { 1.0, 1.0, 1.0 }));//
		// 特征值稀疏时,利用sparse构建

		@SuppressWarnings("rawtypes")
		List list = new LinkedList();// 利用List存放训练样本

		// 添加训练集
		list.add(lp1);
		list.add(lp2);
		list.add(new LabeledPoint(1.0, Vectors.dense(2.0, 2.0, 3.0)));
		list.add(new LabeledPoint(1.0, Vectors.dense(2.0, 1.0, 3.0)));

		// RDD化,泛化类型为LabeledPoint 而不是List
		JavaRDD
 
   training = jsc.parallelize(list);
		final NaiveBayesModel nb_model = NaiveBayes.train(training.rdd());

		// 测试集生成
		double[] d = { 1, 1, 2 };
		Vector v = Vectors.dense(d);// 测试对象为单个vector,或者是RDD化后的vector

		// 朴素贝叶斯
		System.out.println(nb_model.predict(v));// 分类结果
		System.out.println(nb_model.predictProbabilities(v)); // 计算概率值
	}
}
 
网友评论