IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    [原]liblinear简单使用说明

    linger2012liu发表于 2015-09-22 19:52:52
    love 0

     liblinear简单使用说明


    liblinear适合解决大规模数据和高维稀疏特征的分类和回归问题。

     

    特征文件格式:跟libsvm的一致,每一行都是

    label index1:value1 index2:value2

    的稀疏向量的格式。

     

    离线的训练和测试阶段,为了方便,我是通过命令行来做的,不需要再写代码。

    其中liblinear封装了一个train和predict命令(java和C都有),我们只需要调整参数即可方便调用。

    模型训练好和测试符合我们要求之后,为了方便hive调用,需要java程序加载模型来做预测,这一步需要在java版的liblinear之上再做一次封装或者改动。

     

    训练阶段:

    train命令的使用

    1 最简单的使用方式,不调整任何参数,直接使用默认参数

    train train.txt model.txt

     

    2 常见参数调整

      -s 表示模型的类型,liblinear里面不止实现一种模型,里面还分为大的小的模型类别。

      值得注意的是,svm只能输出分类的label,lr可以输出分类的概率。

      目前我常使用 –s 0



     -wi 针对不同类别设置不同的惩罚因子


     通过调整此参数,可以调整预测类别的分布。

     比如男女分类中测试集的真实分布是55:45,男label为1,女label为2。

     某次模型训练后,对上面测试集的预测后的类别分布是3:7,说明模型偏向于女性,

     为了让预测分布跟真实分布一致,为了调整模型更偏向于男性,需要再次训练模型

     调大w1的值(男的label为1,所以是w1),故可以尝试 –w1 2来训练看看效果。

     如此不断尝试。

     

     

    测试阶段:

    1 默认参数使用方式

    predict test.txt model.txt predict.txt

    表示使用模型model.txt来预测测试集test.txt,结果保存在predict.txt

    2 常用参数

      -b 取值0和1。默认是0。如设置为1,表示输出预测分类的概率。

     

     

     

    java程序使用模型预测:

    为了方便hive调用,需要根据liblinear中的predict.java重写一个预测函数。

    predict.java原本是输入一个特征文件,输出一个预测文件。

    我们改写的预测函数,是输入一个字符串表示的特征向量,输出一个字符串表示的类别预测概率。

    package de.bwaldvogel.liblinear;
    
    import static de.bwaldvogel.liblinear.Linear.atof;
    import static de.bwaldvogel.liblinear.Linear.atoi;
    
    import java.io.File;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.List;
    import java.util.StringTokenizer;
    import java.util.regex.Pattern;
    
    public class Predictor {
    
        private static boolean       flag_predict_probability = true;
    
        private static final Pattern COLON                    = Pattern.compile(":");
    
        /**
         * <p><b>Note: The streams are NOT closed</b></p>
         */
        static public String doPredict(Model model,String line) throws IOException {
    
    
            int nr_class = model.getNrClass();
            double[] prob_estimates = null;
            int n;
            int nr_feature = model.getNrFeature();
            if (model.bias >= 0)
                n = nr_feature + 1;
            else
                n = nr_feature;
    
            if (flag_predict_probability && !model.isProbabilityModel()) {
                throw new IllegalArgumentException("probability output is only supported for logistic regression");
            }
    
            if (flag_predict_probability) {
                prob_estimates = new double[nr_class];
            }
    
                List<Feature> x = new ArrayList<Feature>();
                StringTokenizer st = new StringTokenizer(line, " \t\n");
          
    
                while (st.hasMoreTokens()) {
                    String[] split = COLON.split(st.nextToken(), 2);
                    if (split == null || split.length < 2) {
                        throw new RuntimeException("Wrong input format at line "+line);
                    }
    
                    try {
                        int idx = atoi(split[0]);
                        double val = atof(split[1]);
    
                        // feature indices larger than those in training are not used
                        if (idx <= nr_feature) {
                            Feature node = new FeatureNode(idx, val);
                            x.add(node);
                        }
                    } catch (NumberFormatException e) {
                        throw new RuntimeException("Wrong input format at line " + line, e);
                    }
                }
    
                if (model.bias >= 0) {
                    Feature node = new FeatureNode(n, model.bias);
                    x.add(node);
                }
    
                Feature[] nodes = new Feature[x.size()];
                nodes = x.toArray(nodes);
    
                double predict_label;
                String res="";
                if (flag_predict_probability) {
                	int[] labels = model.getLabels();
                    assert prob_estimates != null;
                    predict_label = Linear.predictProbability(model, nodes, prob_estimates);
    
                   // System.out.printf("%g", predict_label);
                    for (int j = 0; j < model.nr_class; j++)
                    {
                    	res =res+ labels[j]+":"+prob_estimates[j]+";";
                    	//System.out.printf(" %g", prob_estimates[j]);
                    }
              
                    
                } else {
                    predict_label = Linear.predict(model, nodes);
                }
    
               // System.out.println(res);
                return res;
    
        }
    
    
        public static void main(String[] argv) throws IOException {
            flag_predict_probability = true;
            try {
              
               String line = "438:1.0 4659:1.0 4661:1.0 5026:1.0 5067:1.0 5914:1.0 6020:1.0 9924:1.0 13845:1.0 17295:1.0 19792:1.0 21466:1.0 22054:1.0 22095:1.0 22425:1.0 26541:1.0";
               
               
              String model_path="/model/AgePredicotr4LiblinearAdmaster.model";
    		InputStream inputStream = this.getClass().getResourceAsStream(model_path);
    		BufferedReader br=new BufferedReader(new InputStreamReader(inputStream));	
    		model = Linear.loadModel(br);	
    		br.close();
                
                String res = doPredict( model,line);
                System.out.println(res);
            }
            finally {
    
            }
        }
    }
    


    参考资料

    http://www.csie.ntu.edu.tw/~cjlin/liblinear/

    http://www.csie.ntu.edu.tw/~cjlin/papers/liblinear.pdf

    https://github.com/bwaldvogel/liblinear-java  java版linlinear


    本文作者:linger

    本文链接:http://blog.csdn.net/lingerlanlan/article/details/48659803






沪ICP备19023445号-2号
友情链接