结合源码分析Spark中的Accuracy(准确率), Precision(精确率), 和F1

结合源码分析Spark中的Accuracy(准确率), Precision(精确率), 和F1







由于我知道男生远多于女生,所以我完全无视特征,直接预测所有人都是男生 我预测所的人都是男生,而实际有90个男生,所以 预测正确的数量 = 90 需要预测的总数 = 100 Accuracy = 90 / 100 = 90%


在男女比例严重不均匀的情况下,我只要预测全是男生,就能获得极高的Accuracy。 所以在正负样本严重不均匀的情况下,Accuracy指标失效

Precision(精确率), Recall(召回率) .实际为真实际为假预测为真TPFP预测为假FNTN # 前面的T和F,代表预测是否正确 # 后面的P和N,代表预测是真还是假 TP:预测为真,正确了 FP:预测为真,结果错了 TN:预测为假,正确了 FN:预测为假,结果错了





如果没有预测为真的情况,计算时分母会为0,所以做了调整,也容易比较Accuracy和Precision, Recall的区别 .实际为真实际为假预测为真10预测为假1089

Accuracy = (1 + 89)/ (1 + 0 + 10 + 89) = 90 / 100 = 0.9 Precision = 1 / 1 + 0 = 1 Recall = 1 / 1 + 10 = 0.09090909














F1 = (2 * 1 * 0.09090909) / 1 + 0.09090909 = 0.1666666 F1 = (2 * 0.9 * 0.19) / 0.9 + 0.19 = 0.3137


调整Precision, Recall的权重



Spark源码分析 Spark中API计算Precision,Recall,F1

用Spark API计算出上面我们手工计算出的值

import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.{SparkConf, SparkContext} object Test { def main(args: Array[String]) { val conf = new SparkConf().setAppName("test").setMaster("local") // 调试的时候一定不要用local[*] val sc = new SparkContext(conf) sc.setLogLevel("ERROR") // 我们先构造一个与上文一样的数据 /** * 实际为真 实际为假 * 预测为真 1 0 * 预测为假 10 89 */ // 左边是预测为真的概率,右边是真实值 val TP = Array((1.0, 1.0)) // 预测为真,实际为真 val TN = new Array[(Double, Double)](89) // 预测为假, 实际为假 for (i def main(args: Array[String]) { val conf = new SparkConf().setAppName("test").setMaster("local") // 调试的时候一定不要用local[*] val sc = new SparkContext(conf) sc.setLogLevel("ERROR") val TP = Array((1.0, 1.0)) val TN = new Array[(Double, Double)](89) for (i c += label, mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2 ).sortByKey(ascending = false) binnedCounts.collect().foreach(println) println("-- agg --") // agg是一个数组,collect返回一个数组 // 前面设置了Partition为2,所以这里会有两条数据 // 计算每个Partition中numPos, numNeg的总和 /** * {numPos: 6, numNeg: 0} * {numPos: 5, numNeg: 89} */ val agg = binnedCounts.values.mapPartitions { iter => val agg = new BinaryLabelCounter() iter.foreach(agg += _) Iterator(agg) }.collect() agg.foreach(println) // partitionwiseCumulativeCounts的长度是Partition数量加1 // partitionwiseCumulativeCounts的每一行是每个Partition的初始numPos, numNeg数量; 这点很重要, 后面会用到 /** * {numPos: 0, numNeg: 0} // 第一个Partition的初始, 都是0, * {numPos: 6, numNeg: 0} // 第一个Partition累加后, 等于第二个Partition的初始值;同样可以表明第一个Partition中有6个是Positive * {numPos: 11, numNeg: 89} // 最后一个位置,就是正负样本的总数; 一共只有两个Partition,都累加起来自然就是总和。 */ println("-- partitionwiseCumulativeCounts --") val partitionwiseCumulativeCounts = // 创建一个新的BinaryLabelCounter,然后把agg中的值,从左往右,加一遍 agg.scanLeft(new BinaryLabelCounter())( (agg: BinaryLabelCounter, c: BinaryLabelCounter) => agg.clone() += c) partitionwiseCumulativeCounts.foreach(println) // 打印正负样本总数 val totalCount = partitionwiseCumulativeCounts.last println(s"Total counts: $totalCount") // 打印Partition的数量 println("getNumPartitions = " + binnedCounts.getNumPartitions) // binnedCounts // binnedCounts经过mapPartitionsWithIndex后就变成了cumulativeCounts // 先看cumulativeCounts是怎么算出来, 跟下面那组cumulativeCounts数据的结合起来看 /** * (1.0,{numPos: 1, numNeg: 0}) // 第一行是一样的 * (0.6,{numPos: 5, numNeg: 0}) // 第一行加上第二上,就是cumulativeCounts的第二行 * (0.4,{numPos: 5, numNeg: 0}) // 前三行相加,就是cumulativeCounts的第三行 * (0.0,{numPos: 0, numNeg: 89}) // 以此类推,前四行相加,就是cumulativeCounts的第四行 */ // cumulativeCounts // 那cumulativeCounts的这些数是什么意思呢? /** * (1.0,{numPos: 1, numNeg: 0}) // 当取Threshold为1时,有一个样本,我预测为真 * (0.6,{numPos: 6, numNeg: 0}) // 当取Threshold为0.6时,有6个样本,我预测为真 * (0.4,{numPos: 11, numNeg: 0}) // 以此类推 * (0.0,{numPos: 11, numNeg: 89}) */ println("-- cumulativeCounts --") // 代码是怎么实现的, 数据可是在RDD上 // 首先binnedCounts是sortByKey排过序的,每个Partitions中是有序的 // 再加上Partition的Index, 和之前的计算的partitionwiseCumulativeCounts, 就能够计算出来 /** * partitionwiseCumulativeCounts * {numPos: 0, numNeg: 0} index为0的Partition, 刚开始时, numPos和numNeg都是0 * {numPos: 6, numNeg: 0} 经过index为0的Partition累加后, index为1的Partition, 刚开始时, numPos为6 * {numPos: 11, numNeg: 89} */ val cumulativeCounts = binnedCounts.mapPartitionsWithIndex( (index: Int, iter: Iterator[(Double, BinaryLabelCounter)]) => { val cumCount = partitionwiseCumulativeCounts(index) iter.map { case (score, c) => // index为0时, cumCount为{numPos: 0, numNeg: 0}; 也就是第一个Partition, 刚开始时, numPos和numNeg都是0 // 第一个过来的是, (1.0,{numPos: 1, numNeg: 0}), 经过cumCount += c, 变成了(1.0,{numPos: 1, numNeg: 0}) // 第二个过来的是, (0.6,{numPos: 5, numNeg: 0}), 经过cumCount += c, (0.6,{numPos: 6, numNeg: 0}) // index为1时, cumCount为{numPos: 6, numNeg: 0}; 也就是第二个Partition, 刚开始时, numPos为6 // 第一个过来的是, (0.4,{numPos: 5, numNeg: 0}), 经过cumCount += c, 变成了(0.4,{numPos: 11, numNeg: 0}) // 第二个过来的是, (0.0,{numPos: 0, numNeg: 89}), 经过cumCount += c, 变成了(0.0,{numPos: 11, numNeg: 89}) cumCount += c (score, cumCount.clone()) } // preservesPartitioning = true, mapPartitionsWithIndex算子计算过程中,不能修改key }, preservesPartitioning = true) cumulativeCounts.collect().foreach(println) /** * BinaryConfusionMatrixImpl({numPos: 1, numNeg: 0},{numPos: 11, numNeg: 89}) * 这个矩阵应该转换成下面这种形式来看 * * 实际为真 实际为假 * 预测为真 1 0 * 预测为假 11-1 89-0 * * 所以当Threshold不断变化时,矩阵也在不断变化,因此在precision在不断变化 * * (1.0,BinaryConfusionMatrixImpl({numPos: 1, numNeg: 0},{numPos: 11, numNeg: 89})) * (0.6,BinaryConfusionMatrixImpl({numPos: 6, numNeg: 0},{numPos: 11, numNeg: 89})) * (0.4,BinaryConfusionMatrixImpl({numPos: 11, numNeg: 0},{numPos: 11, numNeg: 89})) * (0.0,BinaryConfusionMatrixImpl({numPos: 11, numNeg: 89},{numPos: 11, numNeg: 89})) */ println("-- confusions --") val confusions = cumulativeCounts.map { case (score, cumCount) => (score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix]) } confusions.collect().foreach(println) println("-- precision --") def createCurve(y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = { confusions.map { case (s, c) => (s, y(c)) } } createCurve(Precision).collect().foreach(println) sc.stop() } object Precision extends BinaryClassificationMetricComputer { override def apply(c: BinaryConfusionMatrix): Double = { val totalPositives = c.numTruePositives + c.numFalsePositives if (totalPositives == 0) { 1.0 } else { c.numTruePositives.toDouble / totalPositives } } } trait BinaryClassificationMetricComputer extends Serializable { def apply(c: BinaryConfusionMatrix): Double } class BinaryLabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable { /** Processes a label. */ def +=(label: Double): BinaryLabelCounter = { // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle // -1.0 for negative as well. if (label > 0.5) numPositives += 1L else numNegatives += 1L this } /** Merges another counter. */ def +=(other: BinaryLabelCounter): BinaryLabelCounter = { numPositives += other.numPositives numNegatives += other.numNegatives this } override def clone: BinaryLabelCounter = { new BinaryLabelCounter(numPositives, numNegatives) } override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}" } private case class BinaryConfusionMatrixImpl(count: BinaryLabelCounter, totalCount: BinaryLabelCounter) extends BinaryConfusionMatrix { /** number of true positives */ override def numTruePositives: Long = count.numPositives /** number of false positives */ override def numFalsePositives: Long = count.numNegatives /** number of false negatives */ override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives /** number of true negatives */ override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives /** number of positives */ override def numPositives: Long = totalCount.numPositives /** number of negatives */ override def numNegatives: Long = totalCount.numNegatives } private trait BinaryConfusionMatrix { /** number of true positives */ def numTruePositives: Long /** number of false positives */ def numFalsePositives: Long /** number of false negatives */ def numFalseNegatives: Long /** number of true negatives */ def numTrueNegatives: Long /** number of positives */ def numPositives: Long = numTruePositives + numFalseNegatives /** number of negatives */ def numNegatives: Long = numFalsePositives + numTrueNegatives } }

到此分析完了Precision的计算过程. 那么对于Threshold和为什么返回RDD, 我们应该怎么理解呢? precisionByThreshold能够让我了解, 随着Threshold的变化, precision是如何变化的

选择Threshold import com.leo.tianchi.test.Run.BinaryLabelCounter import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.functions.max import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.{SparkConf, SparkContext} object Test { def main(args: Array[String]) { val conf = new SparkConf().setAppName("test").setMaster("local") // 调试的时候一定不要用local[*] val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) sc.setLogLevel("ERROR") import sqlContext.implicits._ // 改成自己的Spark家目录 val training = sqlContext.read.format("libsvm").load("/usr/local/spark/spark-1.6.1-bin-hadoop2.6/data/mllib/sample_libsvm_data.txt") val lr = new LogisticRegression() .setMaxIter(100) .setRegParam(0.3) .setElasticNetParam(0.8) val lrModel = lr.fit(training) val binarySummary = lrModel.summary.asInstanceOf[BinaryLogisticRegressionSummary] val scoreAndLabels = binarySummary.predictions.select("probability", "label").map { case Row(score: Vector, label: Double) => (score(1), label) } scoreAndLabels.collect().foreach(println) println("-- binnedCounts --") /** * 下面抽取的部分数据做分析 * * 左边一列是Threshold, 由大到小排列 * 通过观察发现, 刚刚开始时, numPos总是大于0的, 而numNeg总是等于0的; 也就是说当预测为真的概率很高时, 真实值也是真 * 到了中间, 我们预测的概率变化不是很大, 但是真实值却摇摆不定; 这很容易理解, 当我们只有50%的把握时, 比如扔硬币, 就是会一会儿正一会儿反 * 最后, 就都是numNeg大于0, numPos等于零 * (0.7858977614108025,{numPos: 1, numNeg: 0}) * (0.6647454962187126,{numPos: 1, numNeg: 0}) * (0.5408778070820107,{numPos: 1, numNeg: 0}) * ...省略中数据 * (0.3975829487342493,{numPos: 0, numNeg: 1}) * (0.35639781721605096,{numPos: 1, numNeg: 0}) * (0.33923223159640786,{numPos: 0, numNeg: 1}) * ...省略中数据 * (0.32419460909076375,{numPos: 0, numNeg: 3}) * (0.31989741144442924,{numPos: 0, numNeg: 1}) * (0.3170955715164504,{numPos: 0, numNeg: 1}) */ val binnedCounts = scoreAndLabels.combineByKey( createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label, mergeValue = (c: BinaryLabelCounter, label: Double) => c += label, mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2 ).sortByKey(ascending = false) binnedCounts.collect.foreach(println) binarySummary.precisionByThreshold.show(100000) binarySummary.recallByThreshold.show(100000) val fMeasure = binarySummary.fMeasureByThreshold fMeasure.show(100000) /** * 如果要选择Threshold, 这三个指标中, 自然F1最为合适 * 求出最大的F1, 对应的threshold就是最佳的threshold */ val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure) .select("threshold").head().getDouble(0) println(bestThreshold) } } 参考

准确率(Accuracy), 精确率(Precision), 召回率(Recall)和F1-Measure






