IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    [转]决策树C4.5分类算法的C++实现

    zhoubl668发表于 2014-11-05 19:09:24
    love 0

    一、前言

    当年实习公司布置了一个任务让写一个决策树,以前并未接触数据挖掘的东西,但作为一个数据挖掘最基本的知识点,还是应该有所理解的。

      程序的源码可以点击这里进行下载,下面简要介绍一下决策树以及相关算法概念。

      决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。决策树仅有单一输出,若欲有复数输出,可以建立独立的决策树以处理不同输出。 数据挖掘中决策树是一种经常要用到的技术,可以用于分析数据,同样也可以用来作预测(就像上面的银行官员用他来预测贷款风险)。从数据产生决策树的机器学习技术叫做决策树学习, 通俗说就是决策树。(来自维基百科)

      1986年Quinlan提出了著名的ID3算法。在ID3算法的基础上,1993年Quinlan又提出了C4.5算法。为了适应处理大规模数据集的需要,后来又提出了若干改进的算法,其中SLIQ (super-vised learning in quest)和SPRINT (scalable parallelizableinduction of decision trees)是比较有代表性的两个算法,此处暂且略过。

      本文实现了C4.5的算法,在ID3的基础上计算信息增益,从而更加准确的反应信息量。其实通俗的说就是构建一棵加权的最短路径Haffman树,让权值最大的节点为父节点。

    二、基本概念

      下面简要介绍一下ID3算法:

      ID3算法的核心是:在决策树各级结点上选择属性时,用信息增益(information gain)作为属性的选择标准,以使得在每一个非叶结点进行测试时,能获得关于被测试记录最大的类别信息。

      其具体方法是:检测所有的属性,选择信息增益最大的属性产生决策树结点,由该属性的不同取值建立分支,再对各分支的子集递归调用该方法建立决策树结点的分支,直到所有子集仅包含同一类别的数据为止。最后得到一棵决策树,它可以用来对新的样本进行分类。

      某属性的信息增益按下列方法计算:


    信息熵是香农提出的,用于描述信息不纯度(不稳定性),其计算公式是Info(D)。

      其中:Pi为子集合中不同性(而二元分类即正样例和负样例)的样例的比例;j是属性A中的索引,D是集合样本,Dj是D中属性A上值等于j的样本集合。

    这样信息收益可以定义为样本按照某属性划分时造成熵减少的期望,可以区分训练样本中正负样本的能力。信息增益定义为结点与其子结点的信息熵之差,公式为Gain(A)。

      ID3算法的优点是:算法的理论清晰,方法简单,学习能力较强。其缺点是:只对比较小的数据集有效,且对噪声比较敏感,当训练数据集加大时,决策树可能会随之改变。

      C4.5算法继承了ID3算法的优点,并在以下几方面对ID3算法进行了改进:

      1) 用信息增益率来选择属性,克服了用信息增益选择属性时偏向选择取值多的属性的不足,公式为GainRatio(A);

      2) 在树构造过程中进行剪枝;

      3) 能够完成对连续属性的离散化处理;

      4) 能够对不完整数据进行处理。

      C4.5算法与其它分类算法如统计方法、神经网络等比较起来有如下优点:产生的分类规则易于理解,准确率较高。其缺点是:在构造树的过程中,需要对数据集进行多次的顺序扫描和排序,因而导致算法的低效。此外,C4.5只适合于能够驻留于内存的数据集,当训练集大得无法在内存容纳时程序无法运行。


    三、数据集

    实现的C4.5数据集合如下:


    它记录了再不同的天气状况下,是否出去觅食的数据。

    四、程序代码

      程序引入状态树作为统计和计算属性的数据结构,它记录了每次计算后,各个属性的统计数据,其定义如下:

    [cpp] view plaincopyprint?在CODE上查看代码片派生到我的代码片
    1. struct attrItem
    2. {
    3. std::vector<int> itemNum; //itemNum[0] = itemLine.size()
    4. //itemNum[1] = decision num
    5. set<int> itemLine;
    6. };
    7. struct attributes
    8. {
    9. string attriName;
    10. vector<double> statResult;
    11. map attriItem;
    12. };
    13. vector statTree;

    决策树节点数据结构如下:

    [cpp] view plaincopyprint?在CODE上查看代码片派生到我的代码片
    1. struct TreeNode
    2. {
    3. std::string m_sAttribute;
    4. int m_iDeciNum;
    5. int m_iUnDecinum;
    6. std::vector m_vChildren;
    7. };

    程序源码如下所示(程序中有详细注解):

    [cpp] view plaincopyprint?在CODE上查看代码片派生到我的代码片
    1. #include "DecisionTree.h"
    2. int main(int argc, char* argv[]){
    3. string filename = "source.txt";
    4. DecisionTree dt ;
    5. int attr_node = 0;
    6. TreeNode* treeHead = nullptr;
    7. set<int> readLineNum;
    8. vector<int> readClumNum;
    9. int deep = 0;
    10. if (dt.pretreatment(filename, readLineNum, readClumNum) == 0)
    11. {
    12. dt.CreatTree(treeHead, dt.getStatTree(), dt.getInfos(), readLineNum, readClumNum, deep);
    13. }
    14. return 0;
    15. }
    16. /*
    17. * @function CreatTree 预处理函数,负责读入数据,并生成信息矩阵和属性标记
    18. * @param: filename 文件名
    19. * @param: readLineNum 可使用行set
    20. * @param: readClumNum 可用属性set
    21. * @return int 返回函数执行状态
    22. */
    23. int DecisionTree::pretreatment(string filename, set<int>& readLineNum, vector<int>& readClumNum)
    24. {
    25. ifstream read(filename.c_str());
    26. string itemline = "";
    27. getline(read, itemline);
    28. istringstream iss(itemline);
    29. string attr = "";
    30. while(iss >> attr)
    31. {
    32. attributes* s_attr = new attributes();
    33. s_attr->attriName = attr;
    34. //初始化属性名
    35. statTree.push_back(s_attr);
    36. //初始化属性映射
    37. attr_clum[attr] = attriNum;
    38. attriNum++;
    39. //初始化可用属性列
    40. readClumNum.push_back(0);
    41. s_attr = nullptr;
    42. }
    43. int i = 0;
    44. //添加具体数据
    45. while(true)
    46. {
    47. getline(read, itemline);
    48. if(itemline == "" || itemline.length() <= 1)
    49. {
    50. break;
    51. }
    52. vector infoline;
    53. istringstream stream(itemline);
    54. string item = "";
    55. while(stream >> item)
    56. {
    57. infoline.push_back(item);
    58. }
    59. infos.push_back(infoline);
    60. readLineNum.insert(i);
    61. i++;
    62. }
    63. read.close();
    64. return 0;
    65. }
    66. int DecisionTree::statister(vector>& infos, vector& statTree,
    67. set<int>& readLine, vector<int>& readClumNum)
    68. {
    69. //yes的总行数
    70. int deciNum = 0;
    71. //统计每一行
    72. set<int>::iterator iter_end = readLine.end();
    73. for (set<int>::iterator line_iter = readLine.begin(); line_iter != iter_end; ++line_iter)
    74. {
    75. bool decisLine = false;
    76. if (infos[*line_iter][attriNum - 1] == "yes")
    77. {
    78. decisLine = true;
    79. deciNum++;
    80. }
    81. //如果该列未被锁定并且为属性列,进行统计
    82. for (int i = 0; i < attriNum - 1; i++)
    83. {
    84. if (readClumNum[i] == 0)
    85. {
    86. std::string tempitem = infos[*line_iter][i];
    87. auto map_iter = statTree[i]->attriItem.find(tempitem);
    88. //没有找到
    89. if (map_iter == (statTree[i]->attriItem).end())
    90. {
    91. //新建
    92. attrItem* attritem = new attrItem();
    93. attritem->itemNum.push_back(1);
    94. decisLine ? attritem->itemNum.push_back(1) : attritem->itemNum.push_back(0);
    95. attritem->itemLine.insert(*line_iter);
    96. //建立属性名->item映射
    97. (statTree[i]->attriItem)[tempitem] = attritem;
    98. attritem = nullptr;
    99. }
    100. else
    101. {
    102. (map_iter->second)->itemNum[0]++;
    103. (map_iter->second)->itemLine.insert(*line_iter);
    104. if(decisLine)
    105. {
    106. (map_iter->second)->itemNum[1]++;
    107. }
    108. }
    109. }
    110. }
    111. }
    112. return deciNum;
    113. }
    114. /*
    115. * @function CreatTree 递归DFS创建并输出决策树
    116. * @param: treeHead 为生成的决定树
    117. * @param: statTree 为状态树,此树动态更新,但是由于是DFS对数据更新,所以不必每次新建状态树
    118. * @param: infos 数据信息
    119. * @param: readLine 当前在infos中所要进行统计的行数,由函数外给出
    120. * @param: deep 决定树的深度,用于打印
    121. * @return void
    122. */
    123. void DecisionTree::CreatTree(TreeNode* treeHead, vector& statTree, vector>& infos,
    124. set<int>& readLine, vector<int>& readClumNum, int deep)
    125. {
    126. //有可统计的行
    127. if (readLine.size() != 0)
    128. {
    129. string treeLine = "";
    130. for (int i = 0; i < deep; i++)
    131. {
    132. treeLine += "--";
    133. }
    134. //清空其他属性子树,进行递归
    135. resetStatTree(statTree, readClumNum);
    136. //统计当前readLine中的数据:包括统计哪几个属性、哪些行,
    137. //并生成statTree(由于公用一个statTree,所有用引用代替),并返回目的信息数
    138. int deciNum = statister(getInfos(), statTree, readLine, readClumNum);
    139. int lineNum = readLine.size();
    140. int attr_node = compuDecisiNote(statTree, deciNum, lineNum, readClumNum);//本条复制为局部变量
    141. //该列被锁定
    142. readClumNum[attr_node] = 1;
    143. //建立树根
    144. TreeNode* treeNote = new TreeNode();
    145. treeNote->m_sAttribute = statTree[attr_node]->attriName;
    146. treeNote->m_iDeciNum = deciNum;
    147. treeNote->m_iUnDecinum = lineNum - deciNum;
    148. if (treeHead == nullptr)
    149. {
    150. treeHead = treeNote; //树根
    151. }
    152. else
    153. {
    154. treeHead->m_vChildren.push_back(treeNote); //子节点
    155. }
    156. cout << "节点-"<< treeLine << ">" << statTree[attr_node]->attriName << endl;
    157. //从孩子分支进行递归
    158. for(map::iterator map_iterator = statTree[attr_node]->attriItem.begin();
    159. map_iterator != statTree[attr_node]->attriItem.end(); ++map_iterator)
    160. {
    161. //打印分支
    162. int sum = map_iterator->second->itemNum[0];
    163. int deci_Num = map_iterator->second->itemNum[1];
    164. cout << "分支--"<< treeLine << ">" << map_iterator->first << endl;
    165. //递归计算、创建
    166. if (deci_Num != 0 && sum != deci_Num )
    167. {
    168. //计算有效行数
    169. set<int> newReadLineNum = map_iterator->second->itemLine;
    170. //DFS
    171. CreatTree(treeNote, statTree, infos, newReadLineNum, readClumNum, deep + 1);
    172. }
    173. else
    174. {
    175. //建立叶子节点
    176. TreeNode* treeEnd = new TreeNode();
    177. treeEnd->m_sAttribute = statTree[attr_node]->attriName;
    178. treeEnd->m_iDeciNum = deci_Num;
    179. treeEnd->m_iUnDecinum = sum - deci_Num;
    180. treeNote->m_vChildren.push_back(treeEnd);
    181. //打印叶子
    182. if (deci_Num == 0)
    183. {
    184. cout << "叶子---"<< treeLine << ">no" << endl;
    185. }
    186. else
    187. {
    188. cout << "叶子---"<< treeLine << ">yes" << endl;
    189. }
    190. }
    191. }
    192. //还原属性列可用性
    193. readClumNum[attr_node] = 0;
    194. }
    195. }
    196. /*
    197. * @function compuDecisiNote 计算C4.5
    198. * @param: statTree 为状态树,此树动态更新,但是由于是DFS对数据更新,所以不必每次新建状态树
    199. * @param: deciNum Yes的数据量
    200. * @param: lineNum 计算set的行数
    201. * @param: readClumNum 用于计算的set
    202. * @return int 信息量最大的属性号
    203. */
    204. int DecisionTree::compuDecisiNote(vector& statTree, int deciNum, int lineNum, vector<int>& readClumNum)
    205. {
    206. double max_temp = 0;
    207. int max_attribute = 0;
    208. //总的yes行的信息量
    209. double infoD = info_D(deciNum, lineNum);
    210. for (int i = 0; i < attriNum - 1; i++)
    211. {
    212. if (readClumNum[i] == 0)
    213. {
    214. double splitInfo = 0.0;
    215. //info
    216. double info_temp = Info_attr(statTree[i]->attriItem, splitInfo, lineNum);
    217. statTree[i]->statResult.push_back(info_temp);
    218. //gain
    219. double gain_temp = infoD - info_temp;
    220. statTree[i]->statResult.push_back(gain_temp);
    221. //split_info
    222. statTree[i]->statResult.push_back(splitInfo);
    223. //gain_info
    224. double temp = gain_temp / splitInfo;
    225. statTree[i]->statResult.push_back(temp);
    226. //得到最大值*/
    227. if (temp > max_temp)
    228. {
    229. max_temp = temp;
    230. max_attribute = i;
    231. }
    232. }
    233. }
    234. return max_attribute;
    235. }
    236. /*
    237. * @function Info_attr info_D 总信息量
    238. * @param: deciNum 有效信息数
    239. * @param: sum 总信息量
    240. * @return double 总信息量比例
    241. */
    242. double DecisionTree::info_D(int deciNum, int sum)
    243. {
    244. double pi = (double)deciNum / (double)sum;
    245. double result = 0.0;
    246. if (pi == 1.0 || pi == 0.0)
    247. {
    248. return result;
    249. }
    250. result = pi * (log(pi) / log((double)2)) + (1 - pi)*(log(1 - pi)/log((double)2));
    251. return -result;
    252. }
    253. /*
    254. * @function Info_attr 总信息量
    255. * @param: deciNum 有效信息数
    256. * @param: sum 总信息量
    257. * @return double
    258. */
    259. double DecisionTree::Info_attr(map& attriItem, double& splitInfo, int lineNum)
    260. {
    261. double result = 0.0;
    262. for (map::iterator item = attriItem.begin();
    263. item != attriItem.end();
    264. ++item
    265. )
    266. {
    267. double pi = (double)(item->second->itemNum[0]) / (double)lineNum;
    268. splitInfo += pi * (log(pi) / log((double)2));
    269. double sub_attr = info_D(item->second->itemNum[1], item->second->itemNum[0]);
    270. result += pi * sub_attr;
    271. }
    272. splitInfo = -splitInfo;
    273. return result;
    274. }
    275. /*
    276. * @function resetStatTree 清理状态树
    277. * @param: statTree 状态树
    278. * @param: readClumNum 需要清理的属性set
    279. * @return void
    280. */
    281. void DecisionTree::resetStatTree(vector& statTree, vector<int>& readClumNum)
    282. {
    283. for (int i = 0; i < readClumNum.size() - 1; i++)
    284. {
    285. if (readClumNum[i] == 0)
    286. {
    287. map::iterator it_end = statTree[i]->attriItem.end();
    288. for (map::iterator it = statTree[i]->attriItem.begin();
    289. it != it_end; it++)
    290. {
    291. delete it->second;
    292. }
    293. statTree[i]->attriItem.clear();
    294. statTree[i]->statResult.clear();
    295. }
    296. }
    297. }

    五、结果分析

    程序输出结果为:



    以图形表示为:


    六、小结:

      1、在设计程序时,对程序逻辑有时会发生混乱,·后者在纸上仔细画了些草图才解决这些问题,画一个好图可以有效的帮助你理解程序的流程以及逻辑脉络,是需求分析时最为关键的基本功。

      2、在编写程序之初,一直在纠结用什么样的数据结构,后来经过几次在编程实现推敲,才确定最佳的数据结构,可见数据结构在程序中的重要性。

      3、决策树的编写,其实就是理论与实践的相结合,虽然理论上比较简单,但是实践中却会遇到这样那样的问题,而这些问题就是考验一个程序员对最基本的数据结构、算法的理解和熟练程度,所以,勤学勤练基本功依然是关键。

      4、程序的效率还有待提高,欢迎各路高手指正。




    http://blog.csdn.net/fy2462/article/details/31762429



沪ICP备19023445号-2号
友情链接