Files
exercise-00/pig_lite/decision_tree/training_set.py
2025-10-07 18:22:35 +02:00

85 lines
3.4 KiB
Python

import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import warnings
class TrainingSet():
def __init__(self, X, y):
self.X = X
self.y = y
def to_json(self):
return json.dumps(dict(
type=self.__class__.__name__,
X=self.X.tolist(),
y=self.y.tolist()
))
@staticmethod
def from_json(jsonstring):
data = json.loads(jsonstring)
return TrainingSet.from_dict(data)
@staticmethod
def from_dict(data):
return TrainingSet(
np.array(data['X']).squeeze(),
np.array(data['y'])
)
def plot_node_boundaries(self, node, limit_left, limit_right, limit_top, limit_bottom, max_depth, level=1):
split_point = node.split_point
limit_left_updated = limit_left
limit_right_updated = limit_right
limit_top_updated = limit_top
limit_bottom_updated = limit_bottom
if node.split_feature == 0:
if limit_bottom == limit_top:
warnings.warn('limit_bottom equals limit_top; extending by 0.1')
plt.plot([split_point, split_point], [limit_bottom - 0.1, limit_top + 0.1], color="purple", alpha=1 / level)
else:
plt.plot([split_point, split_point], [limit_bottom, limit_top], color="purple", alpha=1 / level)
limit_left_updated = split_point
limit_right_updated = split_point
else:
if limit_left == limit_right:
warnings.warn('limit_left equals limit_right; extending by 0.1')
plt.plot([limit_left - 0.1, limit_right + 0.1], [split_point, split_point], color="purple", alpha=1 / level)
else:
plt.plot([limit_left, limit_right], [split_point, split_point], color="purple", alpha=1 / level)
limit_top_updated = split_point
limit_bottom_updated = split_point
if level == max_depth:
return
if node.left_child is not None: self.plot_node_boundaries(node.left_child, limit_left, limit_right_updated,
limit_top_updated, limit_bottom, max_depth, level + 1)
if node.right_child is not None: self.plot_node_boundaries(node.right_child, limit_left_updated, limit_right, limit_top,
limit_bottom_updated, max_depth, level + 1)
def visualize(self, tree=None, max_height=None):
symbols = [["x", "o"][index] for index in self.y]
for y in set(self.y):
X = self.X[self.y == y, :]
plt.scatter(X[:, 0], X[:, 1],
color=["red", "blue"][y],
marker=symbols[y],
label="class: {}".format(y))
if tree is not None:
tree_height = tree.get_height(tree.root)
if max_height is None or max_height > tree_height:
max_height = tree_height
self.plot_node_boundaries(tree.root,
limit_left=min(self.X[:, 0]),
limit_right=max(self.X[:, 0]),
limit_top=max(self.X[:, 1]),
limit_bottom=min(self.X[:, 1]),
max_depth=max_height) # TODO: make parameterizable