Files
exercise-00/pig_lite/decision_tree/dt_base.py
2025-10-07 18:22:35 +02:00

62 lines
1.7 KiB
Python

from pig_lite.decision_tree.dt_node import DecisionTreeNodeBase
import scipy.stats as stats
def entropy(y: list):
"""
Compute the entropy of a binary label distribution.
This function calculates the entropy of a binary classification label list `y` as a wrapper
around `scipy.stats.entropy`. It assumes the labels are binary (0 or 1) and computes the
proportion of positive labels (1s) to calculate the entropy.
Parameters
----------
y : list
A list of binary labels (0 or 1).
Returns
-------
float
The entropy of the label distribution. If the list is empty, returns 0.0.
Notes
-----
- Entropy is calculated using the formula:
H = -p*log2(p) - (1-p)*log2(1-p)
where `p` is the proportion of positive labels (1s).
- If `y` is empty, entropy is defined as 0.0.
Examples
--------
>>> entropy([0, 0, 1, 1])
1.0
>>> entropy([1, 1, 1, 1])
0.0
>>> entropy([])
0.0
"""
if len(y) == 0: return 0.0
positive = sum(y) / len(y)
return stats.entropy([positive, 1 - positive], base=2)
# these two dummy classes are only used so we can import them and load trees from a pickle file before they are implemented by the students
class DecisionTree():
def __init__(self) -> None:
pass
def get_height(self, node):
if node is None:
return 0
return max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
def print(self):
if self.root is not None:
height = self.get_height(self.root)
self.root.print_tree(height)
class DecisionTreeNode(DecisionTreeNodeBase):
def __init__(self) -> None:
pass