数据挖掘笔记七:决策树

记得很久以前中学课堂上偷看杂志,就有一大堆“预测”,通常是第一个问题,回答几个选项,选项A跳到第二个问题,选项B跳到第三个问题…一直走下去,最后的答案就可以预测到一个结果,如果把这个题目的设计路线画出来,那么就是决策树。

决策树

当然,那些杂志99%都是一两个小编胡编乱造随手一画,并没有训练数据支持,也没有决策树算法,所以看过那些预测大多是欣然一笑,然后就忘了~。

决策树,Decision Tree,算法如其名,给定训练集数据之后,通过算法对数据进行分析,就可以生成出根上面题目差不多的树了。决策树所需要用到的主要概念主要是信息熵,当然也可用其它信息增益算法来计算。

计算公式倒是简单,但有一点要记住,有些属性是没什么用的,千万不能拿去算信息增益,比如生日,身份证号等。每个人的身份证号都不一样,使用身份证会让信息增益的值会直接得到1,然而并没有什么卵用,相关导致树变的太大,过学习了,会直接导致模型高分低能。

惩罚量

一般而言,决策树会生成出千奇百怪的树,对于决策树来说,最优的一般是很简单的,不一定是成绩最高的,简单就是层数少,分支少,这样结人看起来也爽。

方法很简单,就是用信息增益的值除以分支的个数。

如上所说的身份证号,生日等,虽然使用这些属性后

另外,除了通过香农熵计算信息增益外,还可以使用基尼不纯度来计算

算法

构建决策树的算法主要有:ID3,C45,C50,CART,这里先说一下ID3算法,主要是ID3是决策树的鼻祖,后面的一些算法主要解决了一些ID3的一些问题,有些支持数值型,有些使用基尼指数等等。

但原理上差不了太多,在SKLearn包里面,默认就是“最优”的CART算法,之后以用引号,还是具体场景具体分析了。ID3算法原理相对简单理解。

数据准备

还是使用上次的数据,为了计算方便,我就随便找了2000条数据,1500做为训练集,500条做为测试集,分别存为fit.txt和test.txt

只是看看原理,对于离散型数据就懒得去平滑了,比如负债,每周工作的时长,年龄等,就直接删掉了。

import numpy as np
training_arr = np.loadtxt('fit.txt', str, '#', ',')
testing_arr = np.loadtxt('test.txt', str, '#', ',')
# 删除年龄,每周时长等数据:说明:肯定会影响结果,这里只是懒得处理
training_arr = np.delete(training_arr,[0,8,9,10],axis=1)
testing_arr = np.delete(testing_arr,[0,8,9,10],axis=1)

算法代码

说明,由于本人很懒,时间不多,写了一段代码后发现开始有个递归没考虑到,不过思路还算是对的,就把代码帖出来了,后续有闲工夫再整这一段,其实了解了原理,可以参考时就把这些参数控制好,就自自己实现了ID3了。


from math import log

class DeciconTree:
    # 初始化,将数据集,测试目标集传入
    def __init__(self,dataset,labels):
        self.dataset = dataset
        self.labels = labels
        self.dataSize = self.dataset.shape[0]
        self.featureCount = self.dataset.shape[1]
        self.baseEntropy = self.getEntropy(self.labels)
    # 示出数组的不确定性
    def getEntropy(self,arr):
        labelCount = {}
        datasize = len(arr)
        for label in arr:
            if label not in labelCount.keys():
                labelCount[label]=0
            labelCount[label] += 1
            entropy = 0.0
        fieldSize = len(labelCount.keys())
        # 只有1个数据,纯了直接返回0
        if fieldSize == 1:
            return entropy
        for key in labelCount.keys():
            prob = float(labelCount[key]/datasize)
            entropy -= prob*log(prob,fieldSize)
        return entropy
        
    # 按列的值提出数组
    def splitDataSet(self,fieldIndex,value):
        result = [];
        for index in range(self.dataSize):
            if self.dataset[index][fieldIndex]==value:
                result.append(self.labels[index])
        return result
        
    # 计算所在列数字段的信息增益
    def getInformationGain(self,fieldIndex):
        featureList = [data[fieldIndex] for data in self.dataset]
        uniqueFeature = set(featureList)
        entropy = 0.0
        for featureValue in uniqueFeature:
            subLabel = self.splitDataSet(fieldIndex,featureValue)
            prob = float(len(subLabel)/self.dataSize)
            entropy += prob*self.getEntropy(subLabel)
        return self.baseEntropy -  entropy
    
    # 找到当前信息增益最大的属性
    def chooseBest(self):
        bestFieldIndex = 0
        bestInfomationGain = 0
        for index in range(self.featureCount):
            information = self.getInformationGain(index)
            if information > bestInfomationGain:
                bestFieldIndex = index
                bestInfomationGain = information
        return bestFieldIndex
        
tree = DeciconTree(training_arr,training_label)
print(tree.chooseBest()) #第四个字段最好,信息增益度为0.17 

Sklearn

在sklearn中,决策树的API是tree。使用tree.DecisionTreeClassifier可以构造一个树

主要参数如下,由于使用了cart算法,在sklearn默认使用基尼指数做为信息计算。使用参数时,可以根据实际情况进行灵活的构造出树,增加准确度,提升性能。

还有sklearn还带了工具可以把最终的结果导出为图片/PDF保存起来,这样就可以很清楚看到树是什么样子的了,充分体现了决策树的另外一个优点。

参数 类型 说明(加粗就是默认)
criterion string [“gini“,”entropy”] “gini”:基尼指数*,”entropy”:信息熵
splitter string [“best“,”random”]
max_features int, float, string or None 最大属性,选择也很多,一般用int就行了吧 ,默认为None,即 最大属性数 等于 属性总数
max_depth int 最大深度
min_samples_split int 2 最小多少条数据才会去切割
min_samples_leaf int 1 要生成一个叶子最少需要多少数据
min_weight_fraction_leaf int 0 最小的加权分数
max_leaf_nodes int 最大节点总数
presort bool False 最大深度
# 说明,以下是不能运行的伪代码,(数据需要预处理)
from sklearn import tree
clf = tree.DecisionTreeClassifier()
clf.fit(training_arr,training_label)

小结

决策树的优点明显,结果可以看的到,给普通用户也可以说的通。但缺点也有,刚才说的如果不控制还是比较容易发生过学习的问题,想算出最好的结果还需要过程,结果也需要剪枝处理。