initial commit
This commit is contained in:
84
pig_lite/decision_tree/training_set.py
Normal file
84
pig_lite/decision_tree/training_set.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import warnings
|
||||
|
||||
class TrainingSet():
|
||||
def __init__(self, X, y):
|
||||
self.X = X
|
||||
self.y = y
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(dict(
|
||||
type=self.__class__.__name__,
|
||||
X=self.X.tolist(),
|
||||
y=self.y.tolist()
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def from_json(jsonstring):
|
||||
data = json.loads(jsonstring)
|
||||
return TrainingSet.from_dict(data)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data):
|
||||
return TrainingSet(
|
||||
np.array(data['X']).squeeze(),
|
||||
np.array(data['y'])
|
||||
)
|
||||
|
||||
def plot_node_boundaries(self, node, limit_left, limit_right, limit_top, limit_bottom, max_depth, level=1):
|
||||
|
||||
split_point = node.split_point
|
||||
limit_left_updated = limit_left
|
||||
limit_right_updated = limit_right
|
||||
limit_top_updated = limit_top
|
||||
limit_bottom_updated = limit_bottom
|
||||
|
||||
if node.split_feature == 0:
|
||||
if limit_bottom == limit_top:
|
||||
warnings.warn('limit_bottom equals limit_top; extending by 0.1')
|
||||
plt.plot([split_point, split_point], [limit_bottom - 0.1, limit_top + 0.1], color="purple", alpha=1 / level)
|
||||
else:
|
||||
plt.plot([split_point, split_point], [limit_bottom, limit_top], color="purple", alpha=1 / level)
|
||||
limit_left_updated = split_point
|
||||
limit_right_updated = split_point
|
||||
|
||||
else:
|
||||
if limit_left == limit_right:
|
||||
warnings.warn('limit_left equals limit_right; extending by 0.1')
|
||||
plt.plot([limit_left - 0.1, limit_right + 0.1], [split_point, split_point], color="purple", alpha=1 / level)
|
||||
else:
|
||||
plt.plot([limit_left, limit_right], [split_point, split_point], color="purple", alpha=1 / level)
|
||||
limit_top_updated = split_point
|
||||
limit_bottom_updated = split_point
|
||||
|
||||
if level == max_depth:
|
||||
return
|
||||
if node.left_child is not None: self.plot_node_boundaries(node.left_child, limit_left, limit_right_updated,
|
||||
limit_top_updated, limit_bottom, max_depth, level + 1)
|
||||
if node.right_child is not None: self.plot_node_boundaries(node.right_child, limit_left_updated, limit_right, limit_top,
|
||||
limit_bottom_updated, max_depth, level + 1)
|
||||
|
||||
def visualize(self, tree=None, max_height=None):
|
||||
symbols = [["x", "o"][index] for index in self.y]
|
||||
for y in set(self.y):
|
||||
X = self.X[self.y == y, :]
|
||||
plt.scatter(X[:, 0], X[:, 1],
|
||||
color=["red", "blue"][y],
|
||||
marker=symbols[y],
|
||||
label="class: {}".format(y))
|
||||
|
||||
if tree is not None:
|
||||
tree_height = tree.get_height(tree.root)
|
||||
if max_height is None or max_height > tree_height:
|
||||
max_height = tree_height
|
||||
self.plot_node_boundaries(tree.root,
|
||||
limit_left=min(self.X[:, 0]),
|
||||
limit_right=max(self.X[:, 0]),
|
||||
limit_top=max(self.X[:, 1]),
|
||||
limit_bottom=min(self.X[:, 1]),
|
||||
max_depth=max_height) # TODO: make parameterizable
|
||||
|
||||
Reference in New Issue
Block a user