在mahout的官网上面,有讲诉如何在命令行之中使用Logistic Regression对自带的donut.csv进行训练的例子。
现在我们要做的,是自己在java代码之中对iris的数据使用LR进行分析。
首先,我们要熟悉一下,使用LR需要哪些参数以及他们的作用。我们从《mahout实战》上面给出的命令行例子来了解一下:
$ bin/mahout trainlogistic --input donut.csv \ --output ./model \ --target color --categories 2 \ --predictors x y --types numeric \ --features 20 --passes 100 --rate 50
简单说明一下:
--input: 输入的文件
--output: 输出的模型存放的文件
--target: 目标变量名
--categories: 有几个分类
--predictors: 使用哪些属性进行预测。在上面的命令行之中只使用了x跟y两个属性
--type: 预测变量的类型,除了numeric, 还有word,text.
--passes: 对于小样本数据,可以多循环几次,对于大型数据样本,1次即可
--rate: 学习率
--features: 不知道中文如何描述,我对LR的理解还不够深入。英文描述:Sets the size of the internal feature vector to use in building the model. A larger value here can be helpful, especially with text-like input data
命令trainlogistic 对应着org.apache.mahout.classifier.sgd.TrainLogistic.java. 这是训练模型的代码。相应的,还有运行模型的代码:org.apache.mahout.classifier.sgd.RunLogistic.java
在大概了解之后,我们开始针对iris的数据进行实际操作一把:
package org.apache.mahout.classifier.sgd; import java.io.File; import java.io.IOException; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.util.List; import java.util.Locale; import org.apache.commons.io.FileUtils; import org.apache.mahout.classifier.evaluation.Auc; import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.SequentialAccessSparseVector; import org.apache.mahout.math.Vector; import com.google.common.base.Charsets; import com.google.common.collect.Lists; public class IrisLRTest { private static LogisticModelParameters lmp; private static PrintWriter output; public static void main(String[] args) throws IOException { // 1: new lmp = new LogisticModelParameters(); output = new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true); // 2: init params lmp.setLambda(0.001); lmp.setLearningRate(50); lmp.setMaxTargetCategories(3); //总共有3种iris lmp.setNumFeatures(4); //看起来除了class只有4种属性,先设定为4 List<String> targetCategories = Lists.newArrayList("Iris-setosa", "Iris-versicolor", "Iris-versicolor"); //这里使用的是guava里面的api lmp.setTargetCategories(targetCategories); lmp.setTargetVariable("class"); // 需要进行预测的是class属性 List<String> typeList = Lists.newArrayList("numeric", "numeric", "numeric", "numeric"); List<String> predictorList = Lists.newArrayList("sepallength", "sepalwidth", "petallength", "petalwidth"); lmp.setTypeMap(predictorList, typeList); // 3. load data List<String> raw = FileUtils.readLines(new File( "E:\\DataSet\\R\\iris.csv")); //使用common-io进行文件读取 String header = raw.get(0); List<String> content = raw.subList(1, raw.size()); // parse data CsvRecordFactory csv = lmp.getCsvRecordFactory(); csv.firstLine(header); // !!!Note: this is a initialize step, do not // skip this step // 4. begin to train OnlineLogisticRegression lr = lmp.createRegression(); for(int i = 0; i < 100; i++) { //对于小数据集我们多运行几次 for (String line : content) { Vector input = new RandomAccessSparseVector(lmp.getNumFeatures()); int targetValue = csv.processLine(line, input); lr.train(targetValue, input); // 核心的一句!!! } } // 5. show model performance: show classify score double correctRate = 0; double sampleCount = content.size(); for (String line : content) { Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); int target = csv.processLine(line, v); int score = lr.classifyFull(v).maxValueIndex(); // 分类核心语句!!! System.out.println("Target:" + target + "\tReal:" + score); if(score == target) { correctRate++; } } output.printf(Locale.ENGLISH, "Rate = %.2f%n", correctRate / sampleCount); } }
运行结果:Rate = 0.90
在上面的代码中,要注意的是:
1. 注意所有必需的参数一定要都设定好并设定正确
2. 在必要的参数初始化之后,才能正确的getCsvRecordFactor 跟 createRegression. 否则会遇到空指针异常
为了对模型进行调优,我们可以做如下事情:
1. 设定更大的numFeatures. 当前是4,我们设定为5、10、20 。。。
2. 设定更大的循环次数,当前是100, 我们可以设定为200、300 ==
最终,我设定的参数:
numFeature = 5
passes = 40
结果: Rate = 0.98
相关推荐
NULL 博文链接:https://rangerwolf.iteye.com/blog/2093940
mahout0.11版本,源码,可修改源码并自己编译,使用java语言编写,maven编译
mahout-examples-0.11.1 mahout-examples-0.11.1-job mahout-h2o_2.10-0.11.1 mahout-h2o_2.10-0.11.1-dependency-reduced mahout-hdfs-0.11.1 mahout-integration-0.11.1 mahout-math-0.11.1 mahout-math-0.11.1 ...
mahout-core-0.9.jar+mahout-core-0.8.jar+mahout-core-0.1.jar
mahout-integration-0.7mahout-integration-0.7mahout-integration-0.7mahout-integration-0.7
mahout是用来做大数据推荐系统和机器学习使用的框架,这个工具包官网下载非常慢,下载了一夜终于下载到了,刚好够上传的
主要是基于豆瓣电影的数据,进行分析,所以首先要爬取相关的电影数据,对应的源代码在DouBan_Spider目录下,主要是采用Python + BeautifulSoup + urllib进行数据采集 2:ETL预处理 3:数据分析 4:可视化 代码封装...
maven_mahout_template-mahout-0.8
官方下载的mahout-distribution-0.9.tar.gz 因为下载速度实在太慢,所以分享出来,方便大家下载使用。mahout-distribution-0.9.tar.gz
mahout-distribution-0.5-src.zip mahout 源码包
mahout-distribution-0.9-src.zip
重新编译mahout-examples-0.9-job.jar,增加分类指标:最小最大精度、召回率。详情见http://blog.csdn.net/u012948976/article/details/50203249
spring-mahout-demo-----一个简单的spring-mahout结合的例子,是很好的学习开发思路的例子。
mahout实战 源码 mahout实战 配套 mahout-distribution-0.5.tar.gz 版本
mahout测试数据 mahout测试数据 mahout测试数据 mahout测试数据 mahout测试数据 mahout测试数据
教你成功运行mahout的taste webapp例子,网上的很多资料说的不清楚,或者版本冲突。正确的版本是jdk1.6 maven3.0.5 mahout0.5 。 摸索良久,亲测有效!
mahout-math-0.8.jar mahout-math-0.8.jar
mahout0.8版的源代码~ 包括 core example等
mahout-distribution-0.10.0-src.tar.gz
mahout-0.9-cdh5.5.0.tar.gz