-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDecisionTree.cpp
executable file
·159 lines (124 loc) · 5.02 KB
/
DecisionTree.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <utility>
#include <utility>
#include <iostream>
//
// Created by karl on 05.05.19.
//
#include "DecisionTree.h"
#include "FrequencyTable.h"
void DecisionTreeNode::addNode(const std::string& attributeVal, std::shared_ptr<DecisionTreeNode> next) {
nextNodes[attributeVal] = std::move(next);
}
const std::map<std::string, std::shared_ptr<DecisionTreeNode>> &DecisionTreeNode::getNextNodes() const {
return nextNodes;
}
const std::string &DecisionTreeNode::getAttribute() const {
return attribute;
}
void DecisionTree::build(const std::vector<std::vector<std::string>>& data) {
root = buildRec(data);
}
void DecisionTree::print() {
printRec(root, 0);
}
void DecisionTree::printRec(const std::shared_ptr<DecisionTreeNode>& currentNode, int depth) {
auto nodes = currentNode->getNextNodes();
if (nodes.empty()) { // This is a leaf node
std::cout << "Decision: " << currentNode->getAttribute() << std::endl;
return;
}
std::cout << currentNode->getAttribute() << "?" << std::endl;
for (const auto& nextNode : nodes) {
for (int i = 0; i < depth; i++) {
std::cout << "\t";
}
std::cout << "-> " << nextNode.first << ": ";
printRec(nextNode.second, depth + 1);
}
}
std::shared_ptr<DecisionTreeNode> DecisionTree::buildRec(const std::vector<std::vector<std::string>> &data) {
// Pure decision?
std::string firstOutcome = data[1].back();
bool pure = true;
for (int i = 1; i < data.size(); i++) {
if (data[i].back() != firstOutcome) {
pure = false;
break;
}
}
if (pure) {
// Make leaf node here (store the outcome in the attribute of the node)
return std::make_shared<DecisionTreeNode>(DecisionTreeNode(firstOutcome));
}
// No more attributes?
if (data[0].size() == 1) { // Only the 'Play' column left
// Check which outcome is more common
std::map<std::string, int> outcomes;
for (int i = 1; i < data.size(); i++) {
outcomes[data[i][0]] += 1;
}
std::string highestKey;
int highestOccurance = 0;
for (const auto& outcome : outcomes) {
if (outcome.second > highestOccurance) {
highestOccurance = outcome.second;
highestKey = outcome.first;
}
}
return std::make_shared<DecisionTreeNode>(DecisionTreeNode(highestKey));
}
int colNumber = data.front().size();
// Check where the gain would be the highest
double highestGain = -1000; // TODO: Initialize properly
int colWithHighestGain = 0;
std::unique_ptr<FrequencyTable> frequencyTableWithHighestGain;
for (int col = 0; col < colNumber - 1; col++) {
FrequencyTable ft = FrequencyTable(data, col);
double gain = ft.getGain();
if (gain > highestGain) {
highestGain = gain;
colWithHighestGain = col;
frequencyTableWithHighestGain = std::make_unique<FrequencyTable>(ft);
}
}
// Build this node accordingly
std::shared_ptr<DecisionTreeNode> currentNode =
std::make_shared<DecisionTreeNode>(DecisionTreeNode(data.front()[colWithHighestGain]));
for (const auto &attribute : frequencyTableWithHighestGain->getAttributes()) {
// Delete the column we chose, and only take all entries with the current attribute
// for the new table
// TODO: Could be made more efficient by removing the column we chose before this loop
std::vector<std::vector<std::string>> new_data;
std::vector<std::string> firstLine = data[0];
firstLine.erase(firstLine.begin() + colWithHighestGain);
new_data.push_back(firstLine);
for (int i = 1; i < data.size(); i++) {
// If the attributes match, add this line to the new data
if (data[i][colWithHighestGain] == attribute) {
std::vector<std::string> line = data[i];
line.erase(line.begin() + colWithHighestGain);
new_data.push_back(line);
}
}
std::shared_ptr<DecisionTreeNode> nodeToInsert = buildRec(new_data);
// If the tree continues, stick it to this node (at this attribute)
if (nodeToInsert != nullptr) {
currentNode->addNode(attribute, nodeToInsert);
}
}
return currentNode;
}
std::string DecisionTree::classify(std::map<std::string, std::string> attributes) {
// Walk down the tree until we arrive at a leaf node
std::shared_ptr currentNode = root;
std::map<std::string, std::shared_ptr<DecisionTreeNode>> nextNodes = currentNode->getNextNodes();
do {
if (nextNodes[attributes[currentNode->getAttribute()]] == nullptr) {
break; // FIXME why can this be the case?
}
currentNode = nextNodes[attributes[currentNode->getAttribute()]];
nextNodes = currentNode->getNextNodes();
} while(!nextNodes.empty());
// Return the attribute, where leaf nodes store the classification
return currentNode->getAttribute();
}