IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    [原]Spark加载PMML进行预测

    fansy1990发表于 2016-11-22 23:17:24
    love 0

    软件版本:

    CDH:5.8.0 , CDH-hadoop :2.6.0 ; CDH-spark :1.6.0 

    目标:

    使用Spark 加载PMML文件到模型,并使用Spark平台进行预测(这里测试使用的是Spark on YARN的方式)。

    具体小目标:

    1. 参考https://github.com/jpmml/jpmml-spark 实现,能运行简单例子;

    2. 直接读取HDFS上面的输入数据文件,使用PMML生成的模型进行预测;

    (第1点和第2点的不一样的地方体现在输入数据的构造上,可以参看下面的代码)

    具体步骤:

    1. 准备原始数据,原始数据包括PMML文件,以及测试数据;分别如下:

    <?xml version="1.0" encoding="UTF-8" standalone="yes"?>
    <PMML version="4.2" xmlns="http://www.dmg.org/PMML-4_2">
        <Header description="linear SVM">
            <Application name="Apache Spark MLlib"/>
            <Timestamp>2016-11-16T22:17:47</Timestamp>
        </Header>
        <DataDictionary numberOfFields="4">
            <DataField name="field_0" optype="continuous" dataType="double"/>
            <DataField name="field_1" optype="continuous" dataType="double"/>
            <DataField name="field_2" optype="continuous" dataType="double"/>
            <DataField name="target" optype="categorical" dataType="string"/>
        </DataDictionary>
        <RegressionModel modelName="linear SVM" functionName="classification" normalizationMethod="none">
            <MiningSchema>
                <MiningField name="field_0" usageType="active"/>
                <MiningField name="field_1" usageType="active"/>
                <MiningField name="field_2" usageType="active"/>
                <MiningField name="target" usageType="target"/>
            </MiningSchema>
            <RegressionTable intercept="0.0" targetCategory="1">
                <NumericPredictor name="field_0" coefficient="-0.36682158807862086"/>
                <NumericPredictor name="field_1" coefficient="3.8787681305811765"/>
                <NumericPredictor name="field_2" coefficient="-1.6134308474471166"/>
            </RegressionTable>
            <RegressionTable intercept="0.0" targetCategory="0"/>
        </RegressionModel>
    </PMML>
    
    以上pmml文件是由一个svm模型构建的,其输入有三个字段,有一个目标输出,代表类别;

    输入测试数据,如下:

    field_0,field_1,field_2
    98,97,96
    1,2,7
    这个数据由列名和数据组成,这里需要注意,列名需要和pmml里面的列名对应;

    2. 把https://github.com/jpmml/jpmml-spark工程下载到本地,并添加如下代码:

    package org.jpmml.spark;
    
    import org.apache.hadoop.conf.Configuration;
    import org.apache.hadoop.fs.FileSystem;
    import org.apache.hadoop.fs.Path;
    import org.apache.spark.SparkConf;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.ml.Transformer;
    import org.apache.spark.sql.*;
    import org.jpmml.evaluator.Evaluator;
    
    public class SVMEvaluationSparkExample {
    
    	static
    	public void main(String... args) throws Exception {
    
    		if(args.length != 3){
    			System.err.println("Usage: java " + SVMEvaluationSparkExample.class.getName() + " <PMML file> <Input file> <Output directory>");
    
    			System.exit(-1);
    		}
            /**
             * 根据pmml文件,构建模型
             */
            FileSystem fs = FileSystem.get(new Configuration());
            Evaluator evaluator = EvaluatorUtil.createEvaluator(fs.open(new Path(args[0])));
    
            TransformerBuilder modelBuilder = new TransformerBuilder(evaluator)
                    .withTargetCols()
                    .withOutputCols()
                    .exploded(true);
    
            Transformer transformer = modelBuilder.build();
    
            /**
             * 利用DataFrameReader从原始数据中构造 DataFrame对象
             * 需要原始数据包含列名
             */
            SparkConf conf = new SparkConf();
            try(JavaSparkContext sparkContext = new JavaSparkContext(conf)){
    
                SQLContext sqlContext = new SQLContext(sparkContext);
    
                DataFrameReader reader = sqlContext.read()
                        .format("com.databricks.spark.csv")
                        .option("header", "true")
                        .option("inferSchema", "true");
                DataFrame dataFrame = reader.load(args[1]);// 输入数据需要包含列名
    
                /**
                 * 使用模型进行预测
                 */
                dataFrame = transformer.transform(dataFrame);
    
                /**
                 * 写入数据
                 */
                DataFrameWriter writer = dataFrame.write()
                        .format("com.databricks.spark.csv")
                        .option("header", "true");
    
                writer.save(args[2]);
            }
    	}
    }
    这个代码主要实现的是小目标1,即参考jpmml-spark工程给的示例,编写代码;代码有四个部分,第一部分读取HDFS上面的PMML文件,然后构建模型;第二部分使用DataFrameReader根据输入数据构建DataFrame数据结构;第三部分,使用模型对构造的DataFrame数据进行预测;第四部分,把预测的结果写入HDFS。

    注意里面在构造数据的时候.option("header","true")是一定要加的,原因如下:1)原始数据中确实有列名;2)如果这里不加,那么将读取不到列名的相关信息,将不能和模型中的列名对应;(当然,下面有其他方法处理这种情况)。

    3. 上传测试数据以及pmml文件到HDFS,进行测试,代码如下:

    spark-submit --master yarn --class org.jpmml.spark.SVMEvaluationSparkExample /opt/tmp/example-1.0-SNAPSHOT.jar hdfs://quickstart.cloudera:8020/tmp/svm/part-00000 sample_test_data.txt sample_out00
    其中,example-1.0-SNAPSHOT.jar 是编译后的jar包;/tmp/svm/part-00000时svm模型的pmml文件;sample_test_data.txt 是测试数据;sample_out00是输出目录;

    查看结果:

    根据输出的结果,也可以看出预测结果是对的。

    4. 如何实现小目标2呢?

    编写代码:

    /*
     * Copyright (c) 2015 Villu Ruusmann
     *
     * This file is part of JPMML-Spark
     *
     * JPMML-Spark is free software: you can redistribute it and/or modify
     * it under the terms of the GNU Affero General Public License as published by
     * the Free Software Foundation, either version 3 of the License, or
     * (at your option) any later version.
     *
     * JPMML-Spark is distributed in the hope that it will be useful,
     * but WITHOUT ANY WARRANTY; without even the implied warranty of
     * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     * GNU Affero General Public License for more details.
     *
     * You should have received a copy of the GNU Affero General Public License
     * along with JPMML-Spark.  If not, see <http://www.gnu.org/licenses/>.
     */
    package org.jpmml.spark;
    
    import org.apache.hadoop.conf.Configuration;
    import org.apache.hadoop.fs.FileSystem;
    import org.apache.hadoop.fs.Path;
    import org.apache.spark.SparkConf;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.ml.Transformer;
    import org.apache.spark.sql.*;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    import org.dmg.pmml.FieldName;
    import org.jpmml.evaluator.Evaluator;
    
    import java.util.ArrayList;
    import java.util.List;
    
    //import org.jpmml.evaluator.FieldValue;
    
    public class EvaluationSparkExample {
    
    	static
    	public void main(String... args) throws Exception {
    
    		if(args.length != 3){
    			System.err.println("Usage: java " + EvaluationSparkExample.class.getName() + " <PMML file> <Input file> <Output directory>");
    
    			System.exit(-1);
    		}
    
            /**
             * 构造模型
            */
            FileSystem fs = FileSystem.get(new Configuration());
            Evaluator evaluator = EvaluatorUtil.createEvaluator(fs.open(new Path(args[0])));
    
            TransformerBuilder modelBuilder = new TransformerBuilder(evaluator)
                    .withTargetCols()
                    .withOutputCols()
                    .exploded(true);
            Transformer transformer = modelBuilder.build();
    
            /**
             * 构造列名,schema
             */
            List<StructField> fields = new ArrayList<>();
            for (FieldName fieldName: evaluator.getActiveFields()) {
                fields.add(DataTypes.createStructField(fieldName.getValue(), DataTypes.StringType, true));
            }
            StructType schema = DataTypes.createStructType(fields);
    
            /**
             * 原始数据构造成DataFrame
             */
            SparkConf conf = new SparkConf();
            final String splitter = ",";
            try(JavaSparkContext sparkContext = new JavaSparkContext(conf)){
                JavaRDD<Row> data = sparkContext.textFile(args[1]).map(new Function<String, Row>() {
                    @Override
                    public Row call(String line) throws Exception {
                        String[] lineArr = line.split(splitter,-1);
                        return  RowFactory.create(lineArr);
                    }
                });
    
                SQLContext sqlContext = new SQLContext(sparkContext);
                DataFrame dataFrame = sqlContext.createDataFrame(data, schema);
    
                /**
                 * 预测,并生成新的DataFrame
                 */
                dataFrame = transformer.transform(dataFrame);
    
                /**
                 * 把评估后的数据写入HDFS,不要写入列名
                 */
                DataFrameWriter writer = dataFrame.write()
                        .format("com.databricks.spark.csv");
                writer.save(args[2]);
    
            }
    	}
    }
    这个代码和上一个代码的不同之处只是从原始测试数据中构造DataFrame不同,这里使用的PMML模型中的列名信息,代码参考:http://spark.apache.org/docs/1.6.0/sql-programming-guide.html#interoperating-with-rdds;同时,这时,原始测试数据就不需要再添加列名信息了。由于在代码中,在输出的时候也把列名信息给去掉了,所以只输出数据。运行后,其结果如下所示:


    其调用代码如下所示:

    spark-submit --master yarn --class org.jpmml.spark.EvaluationSparkExample /opt/tmp/example-1.0-SNAPSHOT.jar hdfs://quickstart.cloudera:8020/tmp/svm/part-00000 sample_test_data1.txt sample_out02
    其中,sample_test_data1.txt是没有列名的数据。


    分享,成长,快乐

    转载请注明blog地址:http://blog.csdn.net/fansy1990




沪ICP备19023445号-2号
友情链接