initial commit

This commit is contained in:
2025-10-07 18:22:35 +02:00
commit c19876ed78
34 changed files with 3866 additions and 0 deletions

View File

@@ -0,0 +1,61 @@
from pig_lite.decision_tree.dt_node import DecisionTreeNodeBase
import scipy.stats as stats
def entropy(y: list):
"""
Compute the entropy of a binary label distribution.
This function calculates the entropy of a binary classification label list `y` as a wrapper
around `scipy.stats.entropy`. It assumes the labels are binary (0 or 1) and computes the
proportion of positive labels (1s) to calculate the entropy.
Parameters
----------
y : list
A list of binary labels (0 or 1).
Returns
-------
float
The entropy of the label distribution. If the list is empty, returns 0.0.
Notes
-----
- Entropy is calculated using the formula:
H = -p*log2(p) - (1-p)*log2(1-p)
where `p` is the proportion of positive labels (1s).
- If `y` is empty, entropy is defined as 0.0.
Examples
--------
>>> entropy([0, 0, 1, 1])
1.0
>>> entropy([1, 1, 1, 1])
0.0
>>> entropy([])
0.0
"""
if len(y) == 0: return 0.0
positive = sum(y) / len(y)
return stats.entropy([positive, 1 - positive], base=2)
# these two dummy classes are only used so we can import them and load trees from a pickle file before they are implemented by the students
class DecisionTree():
def __init__(self) -> None:
pass
def get_height(self, node):
if node is None:
return 0
return max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
def print(self):
if self.root is not None:
height = self.get_height(self.root)
self.root.print_tree(height)
class DecisionTreeNode(DecisionTreeNodeBase):
def __init__(self) -> None:
pass

View File

@@ -0,0 +1,70 @@
from pig_lite.datastructures.queue import Queue
class DecisionTreeNodeBase():
def __init__(self):
self.label = None
self.split_point = None
self.split_feature = None
self.left_child = None
self.right_child = None
def print_node(self, height, level=1):
node_width = 10
n_spaces = 2 ** (height - level - 1) * node_width - node_width // 2
if n_spaces > 0:
text = " " * n_spaces
else:
text = ""
if self.label is None and self.split_feature is None:
return f"{text} {text}"
if self.label is not None:
text = f"{text}( {self.label} ){text}"
elif self.split_feature is not None:
text_snippet = f"(x{self.split_feature}:{self.split_point:.2f})"
if len(text_snippet) != node_width:
text_snippet = f" {text_snippet}"
text = f"{text}{text_snippet}{text}"
return text
def __str__(self):
if self.label is not None: return f"({self.label})"
str_value = f"{self.split_feature}:{self.split_point:.2f}|{self.left_child}{self.right_child}"
return str_value
def print_tree(self, height):
visited = set()
frontier = Queue()
lines = ['']
previous_level = 1
frontier.put((self, 1))
while frontier.has_elements():
current, level = frontier.get()
if level > previous_level:
lines.append('')
previous_level = level
lines[-1] += current.print_node(height, level)
if current not in visited:
visited.add(current)
if current.left_child is not None:
frontier.put((current.left_child, level + 1))
else:
if level < height: frontier.put((DecisionTreeNodeBase(), level + 1))
if current.right_child is not None:
frontier.put((current.right_child, level + 1))
else:
if level < height: frontier.put((DecisionTreeNodeBase(), level + 1))
for line in lines:
print(line)
return None
def split():
raise NotImplementedError()

View File

@@ -0,0 +1,84 @@
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import warnings
class TrainingSet():
def __init__(self, X, y):
self.X = X
self.y = y
def to_json(self):
return json.dumps(dict(
type=self.__class__.__name__,
X=self.X.tolist(),
y=self.y.tolist()
))
@staticmethod
def from_json(jsonstring):
data = json.loads(jsonstring)
return TrainingSet.from_dict(data)
@staticmethod
def from_dict(data):
return TrainingSet(
np.array(data['X']).squeeze(),
np.array(data['y'])
)
def plot_node_boundaries(self, node, limit_left, limit_right, limit_top, limit_bottom, max_depth, level=1):
split_point = node.split_point
limit_left_updated = limit_left
limit_right_updated = limit_right
limit_top_updated = limit_top
limit_bottom_updated = limit_bottom
if node.split_feature == 0:
if limit_bottom == limit_top:
warnings.warn('limit_bottom equals limit_top; extending by 0.1')
plt.plot([split_point, split_point], [limit_bottom - 0.1, limit_top + 0.1], color="purple", alpha=1 / level)
else:
plt.plot([split_point, split_point], [limit_bottom, limit_top], color="purple", alpha=1 / level)
limit_left_updated = split_point
limit_right_updated = split_point
else:
if limit_left == limit_right:
warnings.warn('limit_left equals limit_right; extending by 0.1')
plt.plot([limit_left - 0.1, limit_right + 0.1], [split_point, split_point], color="purple", alpha=1 / level)
else:
plt.plot([limit_left, limit_right], [split_point, split_point], color="purple", alpha=1 / level)
limit_top_updated = split_point
limit_bottom_updated = split_point
if level == max_depth:
return
if node.left_child is not None: self.plot_node_boundaries(node.left_child, limit_left, limit_right_updated,
limit_top_updated, limit_bottom, max_depth, level + 1)
if node.right_child is not None: self.plot_node_boundaries(node.right_child, limit_left_updated, limit_right, limit_top,
limit_bottom_updated, max_depth, level + 1)
def visualize(self, tree=None, max_height=None):
symbols = [["x", "o"][index] for index in self.y]
for y in set(self.y):
X = self.X[self.y == y, :]
plt.scatter(X[:, 0], X[:, 1],
color=["red", "blue"][y],
marker=symbols[y],
label="class: {}".format(y))
if tree is not None:
tree_height = tree.get_height(tree.root)
if max_height is None or max_height > tree_height:
max_height = tree_height
self.plot_node_boundaries(tree.root,
limit_left=min(self.X[:, 0]),
limit_right=max(self.X[:, 0]),
limit_top=max(self.X[:, 1]),
limit_bottom=min(self.X[:, 1]),
max_depth=max_height) # TODO: make parameterizable