Spark随机森林预测企业破产

1、随机森林简介

[Machine Learning & Algorithm] 随机森林(Random Forest)

决策树与随机森林

随机森林原理介绍与适用情况(综述篇)

2、Qualitative_Bankruptcy数据集

数据集介绍:Qualitative_Bankruptcy Data Set

数据集下载链接:下载

该数据集包含6个特征和一个Label,均为定性描述

  • 1. Industrial Risk: {P,A,N}
    2. Management Risk: {P,A,N}
    3. Financial Flexibility: {P,A,N}
    4. Credibility: {P,A,N}
    5. Competitiveness: {P,A,N}
    6. Operating Risk: {P,A,N}
    7. Class: {B,NB} 

(P=Positive,A-Average,N-negative,B-Bankruptcy,NB-Non-Bankruptcy)

通过建立前6个特征与标签的随机森林模型,用于预测是否破产

下载下来的数据样本:

  • P,P,A,A,A,P,NB
    N,N,A,A,A,N,NB
    A,A,A,A,A,A,NB
    P,P,P,P,P,P,NB
    N,N,P,P,P,N,NB

为了便于spark直接生成DataFrame,我们给数据集添加头部:

  • irisk,mrisk,fina,cred,comp,orisk,label
    P,P,A,A,A,P,NB
    N,N,A,A,A,N,NB
    A,A,A,A,A,A,NB
    P,P,P,P,P,P,NB
    N,N,P,P,P,N,NB

3、Spark RandomForestClassifier

使用Spark随机森林分类器包括以下几个步骤:

数据读取-->数据预处理-->特征转换-->特征选择-->模型训练-->模型效果评估

完整示例代码:


import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.sql.SparkSession

object RandomForestClassifierTest {
  
  def main(args: Array[String]): Unit = {

    //创建SparkSession
    val spark = SparkSession
      .builder()
      .appName("RandomForestClassifierTest")
      .master("local[4]")
      .getOrCreate()


    //读取csv格式数据,设置header=true表示包含表头
    val oriData = spark.read.option("header", "true")
      .csv("~/Qualitative_Bankruptcy.data.txt")

    //打印数据的Schema:
    //    root
    //    |-- irisk: string (nullable = true)
    //    |-- mrisk: string (nullable = true)
    //    |-- fina: string (nullable = true)
    //    |-- cred: string (nullable = true)
    //    |-- comp: string (nullable = true)
    //    |-- orisk: string (nullable = true)
    //    |-- label: string (nullable = true)
    oriData.printSchema()

    //显示前5条样本数据:
    //    +-----+-----+----+----+----+-----+-----+
    //    |irisk|mrisk|fina|cred|comp|orisk|label|
    //    +-----+-----+----+----+----+-----+-----+
    //    |P    |P    |A   |A   |A   |P    |NB   |
    //    |N    |N    |A   |A   |A   |N    |NB   |
    //    |A    |A    |A   |A   |A   |A    |NB   |
    //    |P    |P    |P   |P   |P   |P    |NB   |
    //    |N    |N    |P   |P   |P   |N    |NB   |
    //    +-----+-----+----+----+----+-----+-----+
    oriData.show(5, false)


    //特征预处理:将字符串类型转换成数字类型
    val indexedIrisk = new StringIndexer().setInputCol("irisk").setOutputCol("indexedIrisk").fit(oriData)
    val indexedMrisk = new StringIndexer().setInputCol("mrisk").setOutputCol("indexedMrisk").fit(oriData)
    val indexedFina = new StringIndexer().setInputCol("fina").setOutputCol("indexedFina").fit(oriData)
    val indexedCred = new StringIndexer().setInputCol("cred").setOutputCol("indexedCred").fit(oriData)
    val indexedComp = new StringIndexer().setInputCol("comp").setOutputCol("indexedComp").fit(oriData)
    val indexedOrisk = new StringIndexer().setInputCol("orisk").setOutputCol("indexedOrisk").fit(oriData)


    //将所有的特征组合成一个Vector
    val featuresAssembler = new VectorAssembler().setInputCols(Array("indexedIrisk", "indexedMrisk", "indexedFina",
      "indexedCred", "indexedComp",
      "indexedOrisk")).setOutputCol("features")

    //标签转换:将字符串型类别转换成数字类别
    val indexedLabel = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(oriData)


    //将各种数据预处理连接成管道,依次顺序执行
    val pipeline = new Pipeline().setStages(Array(indexedIrisk, indexedMrisk, indexedFina, indexedCred,
      indexedComp, indexedOrisk, featuresAssembler, indexedLabel))

    val preModel = pipeline.fit(oriData)

    val preData = preModel.transform(oriData)


    //将预处理的数据进行切分:70%用于训练随机森林模型,30%用于测试训练的模型效果
    val Array(trainData, testData) = preData.randomSplit(Array(0.7, 0.3))



    //创建随机森林分类器
    val rf = new RandomForestClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("features")
      .setNumTrees(10)

    //训练随机森林模型
    val rfModel = rf.fit(trainData)

    //使用训练的模型,对测试样本数据进行预测
    val predictions = rfModel.transform(testData)


    //创建分类效果评估器,此处我们只关注“准确率”评估指标
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")

    //对预测数据进行评估:
    //    accuracy=1.0
    val accuracy = evaluator.evaluate(predictions)
    println("accuracy=" + accuracy)


    //打印随机森林分类器
    println(rfModel.toDebugString)


    spark.stop()
  }
}


最终构建的随机森林模型:

  • RandomForestClassificationModel (uid=rfc_254bd6468fa2) with 10 trees
      Tree 0 (weight 1.0):
        If (feature 4 in {1.0,2.0})
         If (feature 3 in {1.0,2.0})
          Predict: 0.0
         Else (feature 3 not in {1.0,2.0})
          If (feature 2 in {1.0,2.0})
           Predict: 0.0
          Else (feature 2 not in {1.0,2.0})
           Predict: 1.0
        Else (feature 4 not in {1.0,2.0})
         Predict: 1.0
      Tree 1 (weight 1.0):
        If (feature 4 in {1.0,2.0})
         If (feature 3 in {1.0,2.0})
          Predict: 0.0
         Else (feature 3 not in {1.0,2.0})
          If (feature 5 in {0.0,1.0})
           Predict: 0.0
          Else (feature 5 not in {0.0,1.0})
           Predict: 1.0
        Else (feature 4 not in {1.0,2.0})
         Predict: 1.0
      Tree 2 (weight 1.0):
        If (feature 4 in {1.0,2.0})
         If (feature 5 in {0.0,1.0})
          Predict: 0.0
         Else (feature 5 not in {0.0,1.0})
          If (feature 0 in {0.0,1.0})
           Predict: 0.0
          Else (feature 0 not in {0.0,1.0})
           If (feature 1 in {1.0})
            Predict: 0.0
           Else (feature 1 not in {1.0})
            If (feature 4 in {1.0})
             Predict: 0.0
            Else (feature 4 not in {1.0})
             Predict: 1.0
        Else (feature 4 not in {1.0,2.0})
         Predict: 1.0
      Tree 3 (weight 1.0):
        If (feature 4 in {1.0,2.0})
         If (feature 2 in {1.0,2.0})
          Predict: 0.0
         Else (feature 2 not in {1.0,2.0})
          If (feature 3 in {1.0,2.0})
           Predict: 0.0
          Else (feature 3 not in {1.0,2.0})
           Predict: 1.0
        Else (feature 4 not in {1.0,2.0})
         Predict: 1.0
      Tree 4 (weight 1.0):
        If (feature 3 in {1.0,2.0})
         If (feature 4 in {1.0,2.0})
          Predict: 0.0
         Else (feature 4 not in {1.0,2.0})
          Predict: 1.0
        Else (feature 3 not in {1.0,2.0})
         If (feature 1 in {1.0,2.0})
          If (feature 4 in {1.0})
           Predict: 0.0
          Else (feature 4 not in {1.0})
           Predict: 1.0
         Else (feature 1 not in {1.0,2.0})
          If (feature 4 in {2.0})
           If (feature 2 in {1.0})
            Predict: 0.0
           Else (feature 2 not in {1.0})
            Predict: 1.0
          Else (feature 4 not in {2.0})
           Predict: 1.0
      Tree 5 (weight 1.0):
        If (feature 3 in {1.0,2.0})
         If (feature 2 in {1.0,2.0})
          Predict: 0.0
         Else (feature 2 not in {1.0,2.0})
          If (feature 4 in {1.0,2.0})
           Predict: 0.0
          Else (feature 4 not in {1.0,2.0})
           Predict: 1.0
        Else (feature 3 not in {1.0,2.0})
         If (feature 0 in {1.0,2.0})
          If (feature 4 in {1.0})
           Predict: 0.0
          Else (feature 4 not in {1.0})
           Predict: 1.0
         Else (feature 0 not in {1.0,2.0})
          Predict: 1.0
      Tree 6 (weight 1.0):
        If (feature 4 in {1.0,2.0})
         If (feature 3 in {1.0,2.0})
          Predict: 0.0
         Else (feature 3 not in {1.0,2.0})
          If (feature 4 in {1.0})
           Predict: 0.0
          Else (feature 4 not in {1.0})
           If (feature 5 in {1.0})
            Predict: 0.0
           Else (feature 5 not in {1.0})
            Predict: 1.0
        Else (feature 4 not in {1.0,2.0})
         Predict: 1.0
      Tree 7 (weight 1.0):
        If (feature 2 in {1.0,2.0})
         If (feature 0 in {1.0,2.0})
          Predict: 0.0
         Else (feature 0 not in {1.0,2.0})
          If (feature 4 in {1.0,2.0})
           Predict: 0.0
          Else (feature 4 not in {1.0,2.0})
           Predict: 1.0
        Else (feature 2 not in {1.0,2.0})
         If (feature 4 in {1.0})
          Predict: 0.0
         Else (feature 4 not in {1.0})
          Predict: 1.0
      Tree 8 (weight 1.0):
        If (feature 2 in {1.0,2.0})
         If (feature 0 in {1.0,2.0})
          Predict: 0.0
         Else (feature 0 not in {1.0,2.0})
          If (feature 1 in {1.0,2.0})
           Predict: 0.0
          Else (feature 1 not in {1.0,2.0})
           If (feature 4 in {1.0,2.0})
            Predict: 0.0
           Else (feature 4 not in {1.0,2.0})
            Predict: 1.0
        Else (feature 2 not in {1.0,2.0})
         If (feature 3 in {1.0})
          If (feature 0 in {1.0,2.0})
           Predict: 0.0
          Else (feature 0 not in {1.0,2.0})
           Predict: 1.0
         Else (feature 3 not in {1.0})
          If (feature 3 in {2.0})
           If (feature 4 in {1.0})
            Predict: 0.0
           Else (feature 4 not in {1.0})
            Predict: 1.0
          Else (feature 3 not in {2.0})
           Predict: 1.0
      Tree 9 (weight 1.0):
        If (feature 2 in {1.0,2.0})
         Predict: 0.0
        Else (feature 2 not in {1.0,2.0})
         If (feature 4 in {1.0,2.0})
          If (feature 3 in {1.0,2.0})
           Predict: 0.0
          Else (feature 3 not in {1.0,2.0})
           Predict: 1.0
         Else (feature 4 not in {1.0,2.0})
          Predict: 1.0













个人资料
时海
等级:8
文章:272篇
访问:16.0w
排名: 2
上一篇: 图计算好文章
下一篇:2018京东数字科技全球探索者大赛,总奖金220w
猜你感兴趣的圈子:
Spark与机器学习
标签: 1.0、feature、predict、2.0、0.0、面试题
隐藏