IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    [原]Spark MLlib之协同过滤

    liuzhoulong发表于 2017-03-23 14:16:20
    love 0

    Spark MLlib之协同过滤实例:


    import org.apache.spark.SparkConf;
    import org.apache.spark.api.java.JavaDoubleRDD;
    import org.apache.spark.api.java.JavaPairRDD;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.mllib.recommendation.ALS;
    import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
    import org.apache.spark.mllib.recommendation.Rating;
    
    import scala.Tuple2;
    
    public class SparkMLlibColbFilter {
    
    	public static void main(String[] args) {
    
    		SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example");
    		JavaSparkContext sc = new JavaSparkContext(conf);
    
    		// Load and parse the data
    		String path = "file:///data/hadoop/spark-2.0.0-bin-hadoop2.7/data/mllib/als/test.data";
    		JavaRDD<String> data = sc.textFile(path);
    		JavaRDD<Rating> ratings = data.map(new Function<String, Rating>() {
    
    			@Override
    			public Rating call(String s) throws Exception {
    				String[] sarray = s.split(",");
    				return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), Double.parseDouble(sarray[2]));
    			}
    		});
    
    		// Build the recommendation model using ALS
    		int rank = 10;
    		int numIterations = 10;
    		MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01);
    		
    		JavaRDD<Tuple2<Object, Object>> userProducts  =  ratings.map(new Function<Rating, Tuple2<Object, Object>>() {
    			@Override
    			public Tuple2<Object, Object> call(Rating r) throws Exception {
    				return new Tuple2<Object, Object>(r.user(), r.product());
    			}
    			
    		});
    		
    		JavaPairRDD<Tuple2<Integer, Integer>, Double>  predictions = JavaPairRDD.fromJavaRDD(
    				model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map(
    						new Function<Rating, Tuple2<Tuple2<Integer, Integer>, Double>>() {
    
    							@Override
    							public Tuple2<Tuple2<Integer, Integer>, Double> call(
    									Rating r) throws Exception {
    								return new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating());
    							}
    				}));
    		
    		JavaRDD<Tuple2<Double, Double>> ratesAndPreds = JavaPairRDD.fromJavaRDD(ratings.map(
    				new Function<Rating, Tuple2<Tuple2<Integer, Integer>, Double>>() {
    
    					@Override
    					public Tuple2<Tuple2<Integer, Integer>, Double> call(
    							Rating r) throws Exception {
    						return new Tuple2<>(new Tuple2<>(r.user(), r.product()), r.rating());
    					}
    				})).join(predictions).values();
    		
    		
    		double MSE =  JavaDoubleRDD.fromRDD(ratesAndPreds.map(new Function<Tuple2<Double, Double>, Object>() {
    			@Override
    			public Object call(Tuple2<Double, Double> pair) throws Exception {
    				return  Math.pow((pair._1()  -  pair._2()),2);
    			}
    		}).rdd()).mean();
    		
    		System.out.println("Mean Squared Error = " + MSE);
    		
    		// Save and load model
    		model.save(sc.sc(), "target/tmp/myCollaborativeFilter");
    		MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(sc.sc(),
    		  "target/tmp/myCollaborativeFilter");
    	}
    
    }




沪ICP备19023445号-2号
友情链接