IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    [原]Spark DecisionTreeModel print

    fansy1990发表于 2017-04-26 16:17:37
    love 0

    软件版本:

       Spark:1.6.1 ; 

    问题1:

    在进行Spark DecisionTree建模时(做分类),可以打印决策树。当然,使用该模型的toDebugString 可以打印类似下面的字符串,例如:
    DecisionTreeModel classifier of depth 7 with 45 nodes
      If (feature 22 <= 114.2)
       If (feature 27 <= 0.1108)
        If (feature 13 <= 45.19)
         If (feature 21 <= 32.84)
          Predict: 0.0
         Else (feature 21 > 32.84)
          If (feature 1 <= 22.61)
           If (feature 0 <= 11.49)
            Predict: 0.0
           Else (feature 0 > 11.49)
            Predict: 1.0
          Else (feature 1 > 22.61)
           Predict: 0.0
        Else (feature 13 > 45.19)
         If (feature 21 <= 22.13)
          Predict: 0.0
         Else (feature 21 > 22.13)
          If (feature 14 <= 0.004571)
           Predict: 0.0
          Else (feature 14 > 0.004571)
           Predict: 1.0
       Else (feature 27 > 0.1108)
        If (feature 21 <= 25.72)
         If (feature 24 <= 0.1786)
          If (feature 23 <= 809.7)
           Predict: 0.0
          Else (feature 23 > 809.7)
           If (feature 0 <= 14.02)
            Predict: 1.0
           Else (feature 0 > 14.02)
            Predict: 0.0
         Else (feature 24 > 0.1786)
          Predict: 1.0
        Else (feature 21 > 25.72)
         If (feature 7 <= 0.05266)
          If (feature 20 <= 15.5)
           Predict: 0.0
          Else (feature 20 > 15.5)
           If (feature 4 <= 0.09073)
            If (feature 10 <= 0.2406)
             Predict: 1.0
            Else (feature 10 > 0.2406)
             Predict: 0.0
           Else (feature 4 > 0.09073)
            Predict: 1.0
         Else (feature 7 > 0.05266)
          If (feature 12 <= 1.539)
           Predict: 0.0
          Else (feature 12 > 1.539)
           Predict: 1.0
      Else (feature 22 > 114.2)
       If (feature 27 <= 0.1397)
        If (feature 1 <= 14.96)
         Predict: 0.0
        Else (feature 1 > 14.96)
         If (feature 20 <= 18.79)
          If (feature 17 <= 0.009753)
           Predict: 1.0
          Else (feature 17 > 0.009753)
           If (feature 0 <= 17.3)
            Predict: 0.0
           Else (feature 0 > 17.3)
            Predict: 1.0
         Else (feature 20 > 18.79)
          Predict: 1.0
       Else (feature 27 > 0.1397)
        Predict: 1.0

    但是,这样子看的其实不是很清楚,能否使用树形结构呢?

    解决方法:

    使用代码:
    import org.apache.spark.mllib.tree.model.DecisionTreeModel
    import org.apache.spark.mllib.tree.model.Node
    /**
     * 打印工具(可视化工具).
     */
    object PrintUtils {
    
      /**
       * 打印决策树
       * @param model
       * @return
       */
      def printDecisionTree(model : DecisionTreeModel):String = {
        model.toString() + "\n" +
        printTree(model.topNode)
      }
    
      def printTree(root : Node) :String =  {
        val right:String = if (root.rightNode  != None) {
          printTree(root.rightNode.get, true, "")
        }else {
          ""
        }
    
        val rootStr = printNodeValue(root)
          val left :String= if (root.leftNode != None) {
          printTree(root.leftNode.get, false, "")
        }else {
            ""
          }
        right + rootStr + left
      }
      def printNodeValue(root :Node) :String= {
        val rootStr :String = if (root.split  == None) {
          if(root.isLeaf){
            root.predict.toString()
          }else{
            ""
          }
        } else {
          "Feature:"+root.split.get.feature+" > "+root.split.get.threshold
        }
        rootStr + "\n"
      }
    
     def printTree(root : Node,  isRight:Boolean ,  indent:String):String= {
        val right:String = if (root.rightNode != None) {
          printTree(root.rightNode.get, true, indent + (if(isRight)  "        " else " |      "))
        } else {
          ""
        }
    //    indent
        val right2 = if (isRight) {
          " /"
        } else {
          " \\"
        }
        val tmp = "----- "
        val rootStr = printNodeValue(root)
          val left:String =     if (root.leftNode != None) {
          printTree(root.leftNode.get, false, indent + (if(isRight)  " |      " else "        "))
        }else {
            ""
          }
        right + indent + right2 + tmp + rootStr + left
      }
    
    }

    调用:
    val modelPath = "..."
        val model = DecisionTreeModel.load(sc, modelPath)
            println(model.toDebugString)
        val str = PrintUtils.printDecisionTree(model)
        println(str)

    打印得到:
    DecisionTreeModel classifier of depth 7 with 45 nodes
             /----- 1.0 (prob = 1.0)
     /----- Feature:27 > 0.1397
     |       |               /----- 1.0 (prob = 1.0)
     |       |       /----- Feature:20 > 18.79
     |       |       |       |               /----- 1.0 (prob = 1.0)
     |       |       |       |       /----- Feature:0 > 17.3
     |       |       |       |       |       \----- 0.0 (prob = 1.0)
     |       |       |       \----- Feature:17 > 0.009753
     |       |       |               \----- 1.0 (prob = 1.0)
     |       \----- Feature:1 > 14.96
     |               \----- 0.0 (prob = 1.0)
    Feature:22 > 114.2
     |                               /----- 1.0 (prob = 1.0)
     |                       /----- Feature:12 > 1.539
     |                       |       \----- 0.0 (prob = 1.0)
     |               /----- Feature:7 > 0.05266
     |               |       |               /----- 1.0 (prob = 1.0)
     |               |       |       /----- Feature:4 > 0.09073
     |               |       |       |       |       /----- 0.0 (prob = 1.0)
     |               |       |       |       \----- Feature:10 > 0.2406
     |               |       |       |               \----- 1.0 (prob = 1.0)
     |               |       \----- Feature:20 > 15.5
     |               |               \----- 0.0 (prob = 1.0)
     |       /----- Feature:21 > 25.72
     |       |       |       /----- 1.0 (prob = 1.0)
     |       |       \----- Feature:24 > 0.1786
     |       |               |               /----- 0.0 (prob = 1.0)
     |       |               |       /----- Feature:0 > 14.02
     |       |               |       |       \----- 1.0 (prob = 1.0)
     |       |               \----- Feature:23 > 809.7
     |       |                       \----- 0.0 (prob = 1.0)
     \----- Feature:27 > 0.1108
             |                       /----- 1.0 (prob = 1.0)
             |               /----- Feature:14 > 0.004571
             |               |       \----- 0.0 (prob = 1.0)
             |       /----- Feature:21 > 22.13
             |       |       \----- 0.0 (prob = 1.0)
             \----- Feature:13 > 45.19
                     |               /----- 0.0 (prob = 1.0)
                     |       /----- Feature:1 > 22.61
                     |       |       |       /----- 1.0 (prob = 1.0)
                     |       |       \----- Feature:0 > 11.49
                     |       |               \----- 0.0 (prob = 1.0)
                     \----- Feature:21 > 32.84
                             \----- 0.0 (prob = 1.0)

    这样看起来更加清楚点!

    问题2:

    如果DecisionTreeModel里面的节点是离散型的呢?
    1. DecisionTreeModel里面定义的Node是可以是离散型的;
    2. 在进行建模时,train方法里面的使用的是RDD[LabeledPoint]:

    所以每个节点只能是连续的,不存在离散的;
    3. 猜测,应该是后面的Spark版本会有对应的支持;


    分享,成长,快乐

    脚踏实地,专注

    转载请注明blog地址:http://blog.csdn.NET/fansy1990




沪ICP备19023445号-2号
友情链接