initial commit
This commit is contained in:
61
pig_lite/decision_tree/dt_base.py
Normal file
61
pig_lite/decision_tree/dt_base.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from pig_lite.decision_tree.dt_node import DecisionTreeNodeBase
|
||||
import scipy.stats as stats
|
||||
|
||||
def entropy(y: list):
|
||||
"""
|
||||
Compute the entropy of a binary label distribution.
|
||||
|
||||
This function calculates the entropy of a binary classification label list `y` as a wrapper
|
||||
around `scipy.stats.entropy`. It assumes the labels are binary (0 or 1) and computes the
|
||||
proportion of positive labels (1s) to calculate the entropy.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y : list
|
||||
A list of binary labels (0 or 1).
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The entropy of the label distribution. If the list is empty, returns 0.0.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Entropy is calculated using the formula:
|
||||
H = -p*log2(p) - (1-p)*log2(1-p)
|
||||
where `p` is the proportion of positive labels (1s).
|
||||
- If `y` is empty, entropy is defined as 0.0.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> entropy([0, 0, 1, 1])
|
||||
1.0
|
||||
|
||||
>>> entropy([1, 1, 1, 1])
|
||||
0.0
|
||||
|
||||
>>> entropy([])
|
||||
0.0
|
||||
"""
|
||||
if len(y) == 0: return 0.0
|
||||
positive = sum(y) / len(y)
|
||||
return stats.entropy([positive, 1 - positive], base=2)
|
||||
|
||||
# these two dummy classes are only used so we can import them and load trees from a pickle file before they are implemented by the students
|
||||
class DecisionTree():
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_height(self, node):
|
||||
if node is None:
|
||||
return 0
|
||||
return max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
|
||||
|
||||
def print(self):
|
||||
if self.root is not None:
|
||||
height = self.get_height(self.root)
|
||||
self.root.print_tree(height)
|
||||
|
||||
class DecisionTreeNode(DecisionTreeNodeBase):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
70
pig_lite/decision_tree/dt_node.py
Normal file
70
pig_lite/decision_tree/dt_node.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from pig_lite.datastructures.queue import Queue
|
||||
|
||||
class DecisionTreeNodeBase():
|
||||
def __init__(self):
|
||||
self.label = None
|
||||
self.split_point = None
|
||||
self.split_feature = None
|
||||
self.left_child = None
|
||||
self.right_child = None
|
||||
|
||||
def print_node(self, height, level=1):
|
||||
node_width = 10
|
||||
n_spaces = 2 ** (height - level - 1) * node_width - node_width // 2
|
||||
if n_spaces > 0:
|
||||
text = " " * n_spaces
|
||||
else:
|
||||
text = ""
|
||||
|
||||
if self.label is None and self.split_feature is None:
|
||||
return f"{text} {text}"
|
||||
|
||||
if self.label is not None:
|
||||
text = f"{text}( {self.label} ){text}"
|
||||
elif self.split_feature is not None:
|
||||
text_snippet = f"(x{self.split_feature}:{self.split_point:.2f})"
|
||||
if len(text_snippet) != node_width:
|
||||
text_snippet = f" {text_snippet}"
|
||||
text = f"{text}{text_snippet}{text}"
|
||||
return text
|
||||
|
||||
def __str__(self):
|
||||
if self.label is not None: return f"({self.label})"
|
||||
|
||||
str_value = f"{self.split_feature}:{self.split_point:.2f}|{self.left_child}{self.right_child}"
|
||||
return str_value
|
||||
|
||||
def print_tree(self, height):
|
||||
visited = set()
|
||||
frontier = Queue()
|
||||
|
||||
lines = ['']
|
||||
|
||||
previous_level = 1
|
||||
frontier.put((self, 1))
|
||||
|
||||
while frontier.has_elements():
|
||||
current, level = frontier.get()
|
||||
if level > previous_level:
|
||||
lines.append('')
|
||||
previous_level = level
|
||||
lines[-1] += current.print_node(height, level)
|
||||
if current not in visited:
|
||||
visited.add(current)
|
||||
if current.left_child is not None:
|
||||
frontier.put((current.left_child, level + 1))
|
||||
else:
|
||||
if level < height: frontier.put((DecisionTreeNodeBase(), level + 1))
|
||||
if current.right_child is not None:
|
||||
frontier.put((current.right_child, level + 1))
|
||||
else:
|
||||
if level < height: frontier.put((DecisionTreeNodeBase(), level + 1))
|
||||
|
||||
for line in lines:
|
||||
print(line)
|
||||
return None
|
||||
|
||||
def split():
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
84
pig_lite/decision_tree/training_set.py
Normal file
84
pig_lite/decision_tree/training_set.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import warnings
|
||||
|
||||
class TrainingSet():
|
||||
def __init__(self, X, y):
|
||||
self.X = X
|
||||
self.y = y
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(dict(
|
||||
type=self.__class__.__name__,
|
||||
X=self.X.tolist(),
|
||||
y=self.y.tolist()
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def from_json(jsonstring):
|
||||
data = json.loads(jsonstring)
|
||||
return TrainingSet.from_dict(data)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data):
|
||||
return TrainingSet(
|
||||
np.array(data['X']).squeeze(),
|
||||
np.array(data['y'])
|
||||
)
|
||||
|
||||
def plot_node_boundaries(self, node, limit_left, limit_right, limit_top, limit_bottom, max_depth, level=1):
|
||||
|
||||
split_point = node.split_point
|
||||
limit_left_updated = limit_left
|
||||
limit_right_updated = limit_right
|
||||
limit_top_updated = limit_top
|
||||
limit_bottom_updated = limit_bottom
|
||||
|
||||
if node.split_feature == 0:
|
||||
if limit_bottom == limit_top:
|
||||
warnings.warn('limit_bottom equals limit_top; extending by 0.1')
|
||||
plt.plot([split_point, split_point], [limit_bottom - 0.1, limit_top + 0.1], color="purple", alpha=1 / level)
|
||||
else:
|
||||
plt.plot([split_point, split_point], [limit_bottom, limit_top], color="purple", alpha=1 / level)
|
||||
limit_left_updated = split_point
|
||||
limit_right_updated = split_point
|
||||
|
||||
else:
|
||||
if limit_left == limit_right:
|
||||
warnings.warn('limit_left equals limit_right; extending by 0.1')
|
||||
plt.plot([limit_left - 0.1, limit_right + 0.1], [split_point, split_point], color="purple", alpha=1 / level)
|
||||
else:
|
||||
plt.plot([limit_left, limit_right], [split_point, split_point], color="purple", alpha=1 / level)
|
||||
limit_top_updated = split_point
|
||||
limit_bottom_updated = split_point
|
||||
|
||||
if level == max_depth:
|
||||
return
|
||||
if node.left_child is not None: self.plot_node_boundaries(node.left_child, limit_left, limit_right_updated,
|
||||
limit_top_updated, limit_bottom, max_depth, level + 1)
|
||||
if node.right_child is not None: self.plot_node_boundaries(node.right_child, limit_left_updated, limit_right, limit_top,
|
||||
limit_bottom_updated, max_depth, level + 1)
|
||||
|
||||
def visualize(self, tree=None, max_height=None):
|
||||
symbols = [["x", "o"][index] for index in self.y]
|
||||
for y in set(self.y):
|
||||
X = self.X[self.y == y, :]
|
||||
plt.scatter(X[:, 0], X[:, 1],
|
||||
color=["red", "blue"][y],
|
||||
marker=symbols[y],
|
||||
label="class: {}".format(y))
|
||||
|
||||
if tree is not None:
|
||||
tree_height = tree.get_height(tree.root)
|
||||
if max_height is None or max_height > tree_height:
|
||||
max_height = tree_height
|
||||
self.plot_node_boundaries(tree.root,
|
||||
limit_left=min(self.X[:, 0]),
|
||||
limit_right=max(self.X[:, 0]),
|
||||
limit_top=max(self.X[:, 1]),
|
||||
limit_bottom=min(self.X[:, 1]),
|
||||
max_depth=max_height) # TODO: make parameterizable
|
||||
|
||||
Reference in New Issue
Block a user