85 lines
3.4 KiB
Python
85 lines
3.4 KiB
Python
import json
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
import matplotlib.pyplot as plt
|
|
import warnings
|
|
|
|
class TrainingSet():
|
|
def __init__(self, X, y):
|
|
self.X = X
|
|
self.y = y
|
|
|
|
def to_json(self):
|
|
return json.dumps(dict(
|
|
type=self.__class__.__name__,
|
|
X=self.X.tolist(),
|
|
y=self.y.tolist()
|
|
))
|
|
|
|
@staticmethod
|
|
def from_json(jsonstring):
|
|
data = json.loads(jsonstring)
|
|
return TrainingSet.from_dict(data)
|
|
|
|
@staticmethod
|
|
def from_dict(data):
|
|
return TrainingSet(
|
|
np.array(data['X']).squeeze(),
|
|
np.array(data['y'])
|
|
)
|
|
|
|
def plot_node_boundaries(self, node, limit_left, limit_right, limit_top, limit_bottom, max_depth, level=1):
|
|
|
|
split_point = node.split_point
|
|
limit_left_updated = limit_left
|
|
limit_right_updated = limit_right
|
|
limit_top_updated = limit_top
|
|
limit_bottom_updated = limit_bottom
|
|
|
|
if node.split_feature == 0:
|
|
if limit_bottom == limit_top:
|
|
warnings.warn('limit_bottom equals limit_top; extending by 0.1')
|
|
plt.plot([split_point, split_point], [limit_bottom - 0.1, limit_top + 0.1], color="purple", alpha=1 / level)
|
|
else:
|
|
plt.plot([split_point, split_point], [limit_bottom, limit_top], color="purple", alpha=1 / level)
|
|
limit_left_updated = split_point
|
|
limit_right_updated = split_point
|
|
|
|
else:
|
|
if limit_left == limit_right:
|
|
warnings.warn('limit_left equals limit_right; extending by 0.1')
|
|
plt.plot([limit_left - 0.1, limit_right + 0.1], [split_point, split_point], color="purple", alpha=1 / level)
|
|
else:
|
|
plt.plot([limit_left, limit_right], [split_point, split_point], color="purple", alpha=1 / level)
|
|
limit_top_updated = split_point
|
|
limit_bottom_updated = split_point
|
|
|
|
if level == max_depth:
|
|
return
|
|
if node.left_child is not None: self.plot_node_boundaries(node.left_child, limit_left, limit_right_updated,
|
|
limit_top_updated, limit_bottom, max_depth, level + 1)
|
|
if node.right_child is not None: self.plot_node_boundaries(node.right_child, limit_left_updated, limit_right, limit_top,
|
|
limit_bottom_updated, max_depth, level + 1)
|
|
|
|
def visualize(self, tree=None, max_height=None):
|
|
symbols = [["x", "o"][index] for index in self.y]
|
|
for y in set(self.y):
|
|
X = self.X[self.y == y, :]
|
|
plt.scatter(X[:, 0], X[:, 1],
|
|
color=["red", "blue"][y],
|
|
marker=symbols[y],
|
|
label="class: {}".format(y))
|
|
|
|
if tree is not None:
|
|
tree_height = tree.get_height(tree.root)
|
|
if max_height is None or max_height > tree_height:
|
|
max_height = tree_height
|
|
self.plot_node_boundaries(tree.root,
|
|
limit_left=min(self.X[:, 0]),
|
|
limit_right=max(self.X[:, 0]),
|
|
limit_top=max(self.X[:, 1]),
|
|
limit_bottom=min(self.X[:, 1]),
|
|
max_depth=max_height) # TODO: make parameterizable
|
|
|