first commit
This commit is contained in:
15
boards/cyclic_map.json
Normal file
15
boards/cyclic_map.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"type": "Simple2DProblem",
|
||||
"board": [
|
||||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0]
|
||||
],
|
||||
"costs": [
|
||||
[1, 1, 1],
|
||||
[1, 1, 1],
|
||||
[1, 1, 1]
|
||||
],
|
||||
"start_state": [0, 0],
|
||||
"end_state": [2, 2]
|
||||
}
|
||||
1
boards/large.json
Normal file
1
boards/large.json
Normal file
File diff suppressed because one or more lines are too long
19
boards/narrow_path.json
Normal file
19
boards/narrow_path.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"type": "Simple2DProblem",
|
||||
"board": [
|
||||
[0, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1]
|
||||
],
|
||||
"costs": [
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1]
|
||||
],
|
||||
"start_state": [0, 0],
|
||||
"end_state": [2, 4]
|
||||
}
|
||||
15
boards/no_possible_path.json
Normal file
15
boards/no_possible_path.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"type": "Simple2DProblem",
|
||||
"board": [
|
||||
[0, 0, 1],
|
||||
[1, 1, 1],
|
||||
[0, 0, 0]
|
||||
],
|
||||
"costs": [
|
||||
[1, 1, 1],
|
||||
[1, 1, 1],
|
||||
[1, 1, 1]
|
||||
],
|
||||
"start_state": [0, 0],
|
||||
"end_state": [2, 2]
|
||||
}
|
||||
13
boards/start_is_goal.json
Normal file
13
boards/start_is_goal.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"type": "Simple2DProblem",
|
||||
"board": [
|
||||
[0, 0],
|
||||
[0, 0]
|
||||
],
|
||||
"costs": [
|
||||
[1, 1],
|
||||
[1, 1]
|
||||
],
|
||||
"start_state": [0, 0],
|
||||
"end_state": [0, 0]
|
||||
}
|
||||
17
boards/tiny0.json
Normal file
17
boards/tiny0.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{"type": "Simple2DProblem",
|
||||
"board": [
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 1, 1, 1, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 1, 0],
|
||||
[0, 1, 0, 0, 0]
|
||||
],
|
||||
"costs": [
|
||||
[4, 4, 3, 1, 0],
|
||||
[3, 4, 2, 2, 2],
|
||||
[4, 3, 3, 4, 0],
|
||||
[4, 4, 3, 3, 0],
|
||||
[4, 3, 1, 1, 2]
|
||||
],
|
||||
"start_state": [0, 0],
|
||||
"end_state": [4, 4]}
|
||||
17
boards/tiny1.json
Normal file
17
boards/tiny1.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{"type": "Simple2DProblem",
|
||||
"board": [
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 1, 1, 1, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 1, 1],
|
||||
[0, 0, 0, 0, 0]
|
||||
],
|
||||
"costs": [
|
||||
[4, 8, 8, 8, 8],
|
||||
[1, 4, 2, 2, 1],
|
||||
[1, 8, 8, 8, 1],
|
||||
[1, 4, 3, 3, 1],
|
||||
[1, 1, 1, 1, 1]
|
||||
],
|
||||
"start_state": [0, 0],
|
||||
"end_state": [2, 4]}
|
||||
17
boards/tiny2.json
Normal file
17
boards/tiny2.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{"type": "Simple2DProblem",
|
||||
"board": [
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0],
|
||||
[1, 0, 0, 1, 0],
|
||||
[1, 0, 1, 1, 0],
|
||||
[1, 0, 0, 0, 0]
|
||||
],
|
||||
"costs": [
|
||||
[4, 1, 2, 2, 3],
|
||||
[1, 4, 2, 2, 3],
|
||||
[1, 4, 2, 2, 3],
|
||||
[1, 4, 3, 3, 1],
|
||||
[1, 1, 1, 1, 1]
|
||||
],
|
||||
"start_state": [0, 0],
|
||||
"end_state": [2, 2]}
|
||||
17
boards/tiny3.json
Normal file
17
boards/tiny3.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{"type": "Simple2DProblem",
|
||||
"board": [
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 1, 1, 1, 0],
|
||||
[0, 0, 0, 1, 0],
|
||||
[0, 0, 1, 1, 0],
|
||||
[0, 0, 0, 0, 0]
|
||||
],
|
||||
"costs": [
|
||||
[4, 1, 2, 2, 3],
|
||||
[1, 4, 2, 2, 3],
|
||||
[1, 4, 2, 2, 3],
|
||||
[1, 4, 3, 3, 1],
|
||||
[1, 1, 1, 1, 1]
|
||||
],
|
||||
"start_state": [0, 0],
|
||||
"end_state": [2, 2]}
|
||||
17
boards/tiny4.json
Normal file
17
boards/tiny4.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{"type": "Simple2DProblem",
|
||||
"board": [
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 1, 0],
|
||||
[0, 1, 0, 1, 0],
|
||||
[0, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0]
|
||||
],
|
||||
"costs": [
|
||||
[3, 2, 4, 4, 1],
|
||||
[2, 3, 4, 2, 2],
|
||||
[3, 3, 3, 1, 4],
|
||||
[2, 4, 2, 2, 1],
|
||||
[4, 3, 4, 1, 2]
|
||||
],
|
||||
"start_state": [0, 0],
|
||||
"end_state": [4, 4]}
|
||||
1
pig_lite/.gitignore
vendored
Normal file
1
pig_lite/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
__pycache__
|
||||
8
pig_lite/.idea/.gitignore
generated
vendored
Normal file
8
pig_lite/.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
24
pig_lite/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
24
pig_lite/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,24 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="11">
|
||||
<item index="0" class="java.lang.String" itemvalue="jupyter" />
|
||||
<item index="1" class="java.lang.String" itemvalue="umap-learn" />
|
||||
<item index="2" class="java.lang.String" itemvalue="matplotlib" />
|
||||
<item index="3" class="java.lang.String" itemvalue="numpy" />
|
||||
<item index="4" class="java.lang.String" itemvalue="tqdm" />
|
||||
<item index="5" class="java.lang.String" itemvalue="seaborn" />
|
||||
<item index="6" class="java.lang.String" itemvalue="captum" />
|
||||
<item index="7" class="java.lang.String" itemvalue="upsilonconf" />
|
||||
<item index="8" class="java.lang.String" itemvalue="pytorch" />
|
||||
<item index="9" class="java.lang.String" itemvalue="torchvision" />
|
||||
<item index="10" class="java.lang.String" itemvalue="scipy" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
6
pig_lite/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
pig_lite/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
pig_lite/.idea/misc.xml
generated
Normal file
7
pig_lite/.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.10" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
pig_lite/.idea/modules.xml
generated
Normal file
8
pig_lite/.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/pig_lite.iml" filepath="$PROJECT_DIR$/.idea/pig_lite.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
12
pig_lite/.idea/pig_lite.iml
generated
Normal file
12
pig_lite/.idea/pig_lite.iml
generated
Normal file
@@ -0,0 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
6
pig_lite/.idea/vcs.xml
generated
Normal file
6
pig_lite/.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
3
pig_lite/README.md
Normal file
3
pig_lite/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# pig_lite
|
||||
|
||||
This is PIG (=Problem Instance Generator) Lite, a simplified and cleaned up version of the framework previously used for the AI assignments.
|
||||
0
pig_lite/bayesian_net/__init__.py
Normal file
0
pig_lite/bayesian_net/__init__.py
Normal file
154
pig_lite/bayesian_net/bayesian_net.py
Normal file
154
pig_lite/bayesian_net/bayesian_net.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
#matplotlib.use('TkAgg')
|
||||
|
||||
|
||||
class BayesianNode:
|
||||
""" Building stone for BayesianNet class. Represents conditional probability distribution
|
||||
for a boolean random variable, P(X | parents). """
|
||||
def __init__(self, X: str, parents: str, cpt: dict = None):
|
||||
"""
|
||||
X: String describing variable name
|
||||
|
||||
parents: String containing parent variable names, separated with a whitespace
|
||||
|
||||
cpt: dict that contains the distribution P(X=true | parent1=v1, parent2=v2...).
|
||||
Dict should be structured as follows: {(v1, v2, ...): p, ...}, and each key must have
|
||||
as many values as there are parents. Values (v1, v2, ...) must be True/False.
|
||||
"""
|
||||
if not isinstance(X, str) or not isinstance(parents, str):
|
||||
raise ValueError("Use valid arguments - X and parents have to be strings (but at least one is not)!")
|
||||
self.rand_var = X
|
||||
self.parents = parents.split()
|
||||
self.children = []
|
||||
|
||||
# in case of 0 or 1 parent, fix tuples first
|
||||
if cpt and isinstance(cpt, (float, int)):
|
||||
cpt = {(): cpt}
|
||||
elif cpt and isinstance(cpt, dict):
|
||||
if isinstance(list(cpt.keys())[0], bool):
|
||||
# only one parent
|
||||
cpt = {(k, ): v for k, v in cpt.items()}
|
||||
elif cpt:
|
||||
raise ValueError("Define cpt with a valid data type (dict, or int).")
|
||||
# check format of cpt dict
|
||||
if cpt:
|
||||
for val, p in cpt.items():
|
||||
assert isinstance(val, tuple) and len(val) == len(self.parents)
|
||||
assert all(isinstance(v, bool) for v in val)
|
||||
assert 0 <= p <= 1
|
||||
|
||||
self.cpt = cpt
|
||||
|
||||
def __repr__(self):
|
||||
""" String representation of Bayesian Node. """
|
||||
return repr((self.rand_var, ' '.join(["parent(s):"] + self.parents)))
|
||||
|
||||
def cond_probability(self, value: bool, event: dict):
|
||||
"""
|
||||
Returns conditional probability P(X=value | event) for an atomic event,
|
||||
i.e. where each parent needs to be assigned a value.
|
||||
value: bool (value of this random variable)
|
||||
event: dict, assigning a value to each parent variable
|
||||
"""
|
||||
assert isinstance(value, bool)
|
||||
if self.cpt:
|
||||
prob_true = self.cpt[self.get_event_values(event)]
|
||||
return prob_true if value else 1 - prob_true
|
||||
|
||||
return None
|
||||
|
||||
def get_event_values(self, event: dict):
|
||||
""" Given an event (dict), returns tuple of values for all parents. """
|
||||
return tuple(event[p] for p in self.parents)
|
||||
|
||||
|
||||
class BayesianNet:
|
||||
""" Bayesian Network class for boolean random variables. Consists of BayesianNode-s. """
|
||||
def __init__(self, node_specs: list):
|
||||
"""
|
||||
Creates BayesianNet with given node_specs. Nodes should be in causal order (parents before children).
|
||||
node_specs should be list of parameters for BayesianNode class.
|
||||
"""
|
||||
self.nodes = []
|
||||
self.rand_vars = []
|
||||
for spec in node_specs:
|
||||
self.add_node(spec)
|
||||
|
||||
def add_node(self, node_spec):
|
||||
""" Creates a BayesianNode and adds it to the net, if the variable does *not*, and the parents do exist. """
|
||||
node = BayesianNode(*node_spec)
|
||||
if node.rand_var in self.rand_vars:
|
||||
raise ValueError("Variable {} already exists in network, cannot be defined twice!".format(node.rand_var))
|
||||
if not all((parent in self.rand_vars) for parent in node.parents):
|
||||
raise ValueError("Parents do not all exist yet! Make sure to first add all parent nodes.")
|
||||
self.nodes.append(node)
|
||||
self.rand_vars.append(node.rand_var)
|
||||
for parent in node.parents:
|
||||
self.get_node_for_name(parent).children.append(node)
|
||||
|
||||
def get_node_for_name(self, node_name):
|
||||
""" Given the name of a random variable, returns the according BayesianNode of this network. """
|
||||
for n in self.nodes:
|
||||
if n.rand_var == node_name:
|
||||
return n
|
||||
|
||||
raise ValueError("The variable {} does not exist in this network!".format(node_name))
|
||||
|
||||
def __repr__(self):
|
||||
""" String representation of this Bayesian Network. """
|
||||
return "BayesianNet:\n{0!r}".format(self.nodes)
|
||||
|
||||
def _get_depth(self, rand_var):
|
||||
""" Given random variable, returns "depth" of node in graph for plotting. """
|
||||
node = self.get_node_for_name(rand_var)
|
||||
if len(node.parents) == 0:
|
||||
return 0
|
||||
|
||||
return max([self._get_depth(p) for p in node.parents]) + 1
|
||||
|
||||
def draw(self, title, save_path=None):
|
||||
""" Draws the BN with networkx. Requires title for plot. """
|
||||
plt.figure(figsize=(14, 8))
|
||||
nx_bn = nx.DiGraph()
|
||||
nx_bn.add_nodes_from(self.rand_vars)
|
||||
pos = {rand_var: (10, 10) for rand_var in self.rand_vars}
|
||||
for rand_var in self.rand_vars:
|
||||
node = self.get_node_for_name(rand_var)
|
||||
for c in node.children:
|
||||
nx_bn.add_edge(rand_var, c.rand_var)
|
||||
pos.update({c.rand_var: (pos[c.rand_var][0], pos[c.rand_var][1] - 3)})
|
||||
|
||||
depths = {rand_var: self._get_depth(rand_var) for rand_var in self.rand_vars}
|
||||
_, counts = np.unique(list(depths.values()), return_counts=True)
|
||||
xs = [list(np.linspace(6, 14, c)) if c > 1 else [10] for c in counts]
|
||||
pos = {rand_var: (xs[depths[rand_var]].pop(), 10 - depths[rand_var] * 3) for rand_var in self.rand_vars}
|
||||
|
||||
nx.set_node_attributes(nx_bn, pos, 'pos')
|
||||
nx.draw_networkx(nx_bn, arrows=True, pos=nx.get_node_attributes(nx_bn, "pos"),
|
||||
node_shape="o", node_color="white", node_size=7000, edgecolors="gray")
|
||||
plt.title(title)
|
||||
plt.box(False)
|
||||
plt.margins(0.3)
|
||||
plt.tight_layout()
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=400)
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
T = True
|
||||
F = False
|
||||
bn = BayesianNet([
|
||||
('Burglary', '', 0.001),
|
||||
('Earthquake', '', {(): 0.002}),
|
||||
('Alarm', 'Burglary Earthquake',
|
||||
{(T, T): 0.95, (T, F): 0.94, (F, T): 0.29, (F, F): 0.001}),
|
||||
('JohnCalls', 'Alarm', {T: 0.90, F: 0.05}),
|
||||
('MaryCalls', 'Alarm', {T: 0.70, F: 0.01})
|
||||
])
|
||||
bn.draw("")
|
||||
0
pig_lite/datastructures/__init__.py
Normal file
0
pig_lite/datastructures/__init__.py
Normal file
61
pig_lite/datastructures/priority_queue.py
Normal file
61
pig_lite/datastructures/priority_queue.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import heapq
|
||||
from functools import total_ordering
|
||||
|
||||
|
||||
# this annotation saves us some implementation work
|
||||
@total_ordering
|
||||
class Item(object):
|
||||
def __init__(self, insertion, priority, value):
|
||||
self.insertion = insertion
|
||||
self.priority = priority
|
||||
self.value = value
|
||||
|
||||
def __lt__(self, other):
|
||||
# if the decision "self < other" can be done
|
||||
# based on the priority, do that
|
||||
if self.priority < other.priority:
|
||||
return True
|
||||
elif self.priority == other.priority:
|
||||
# in case the priorities are equal, we
|
||||
# fall back on the insertion order,
|
||||
# which establishes a total ordering
|
||||
return self.insertion < other.insertion
|
||||
return False
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.priority == other.priority and self.insertion == other.insertion
|
||||
|
||||
def __repr__(self):
|
||||
return '({}, {}, {})'.format(self.priority, self.insertion, self.value)
|
||||
|
||||
|
||||
class PriorityQueue(object):
|
||||
def __init__(self):
|
||||
self.insertion = 0
|
||||
self.heap = []
|
||||
|
||||
def has_elements(self):
|
||||
return len(self.heap) > 0
|
||||
|
||||
def put(self, priority, value):
|
||||
heapq.heappush(self.heap, Item(self.insertion, priority, value))
|
||||
self.insertion += 1
|
||||
|
||||
def get(self, include_priority=False):
|
||||
item = heapq.heappop(self.heap)
|
||||
if include_priority:
|
||||
return item.priority, item.value
|
||||
else:
|
||||
return item.value
|
||||
|
||||
def __iter__(self):
|
||||
return iter([item.value for item in self.heap])
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
return ('PriorityQueue [' + ','.join((str(item.value) for item in self.heap)) + ']')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.heap)
|
||||
27
pig_lite/datastructures/queue.py
Normal file
27
pig_lite/datastructures/queue.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from collections import deque
|
||||
|
||||
|
||||
class Queue(object):
|
||||
def __init__(self):
|
||||
self.d = deque()
|
||||
|
||||
def put(self, v):
|
||||
self.d.append(v)
|
||||
|
||||
def get(self):
|
||||
return self.d.popleft()
|
||||
|
||||
def has_elements(self):
|
||||
return len(self.d) > 0
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.d)
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
return ('Queue [' + ','.join((str(item) for item in self.d)) + ']')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.d)
|
||||
21
pig_lite/datastructures/stack.py
Normal file
21
pig_lite/datastructures/stack.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from collections import deque
|
||||
|
||||
|
||||
class Stack(object):
|
||||
def __init__(self):
|
||||
self.d = deque()
|
||||
|
||||
def put(self, v):
|
||||
self.d.append(v)
|
||||
|
||||
def get(self):
|
||||
return self.d.pop()
|
||||
|
||||
def has_elements(self):
|
||||
return len(self.d) > 0
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.d)
|
||||
|
||||
def __repr__(self):
|
||||
return ('Stack [' + ','.join((str(item) for item in self.d)) + ']')
|
||||
61
pig_lite/decision_tree/dt_base.py
Normal file
61
pig_lite/decision_tree/dt_base.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from pig_lite.decision_tree.dt_node import DecisionTreeNodeBase
|
||||
import scipy.stats as stats
|
||||
|
||||
def entropy(y: list):
|
||||
"""
|
||||
Compute the entropy of a binary label distribution.
|
||||
|
||||
This function calculates the entropy of a binary classification label list `y` as a wrapper
|
||||
around `scipy.stats.entropy`. It assumes the labels are binary (0 or 1) and computes the
|
||||
proportion of positive labels (1s) to calculate the entropy.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y : list
|
||||
A list of binary labels (0 or 1).
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
The entropy of the label distribution. If the list is empty, returns 0.0.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Entropy is calculated using the formula:
|
||||
H = -p*log2(p) - (1-p)*log2(1-p)
|
||||
where `p` is the proportion of positive labels (1s).
|
||||
- If `y` is empty, entropy is defined as 0.0.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> entropy([0, 0, 1, 1])
|
||||
1.0
|
||||
|
||||
>>> entropy([1, 1, 1, 1])
|
||||
0.0
|
||||
|
||||
>>> entropy([])
|
||||
0.0
|
||||
"""
|
||||
if len(y) == 0: return 0.0
|
||||
positive = sum(y) / len(y)
|
||||
return stats.entropy([positive, 1 - positive], base=2)
|
||||
|
||||
# these two dummy classes are only used so we can import them and load trees from a pickle file before they are implemented by the students
|
||||
class DecisionTree():
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_height(self, node):
|
||||
if node is None:
|
||||
return 0
|
||||
return max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
|
||||
|
||||
def print(self):
|
||||
if self.root is not None:
|
||||
height = self.get_height(self.root)
|
||||
self.root.print_tree(height)
|
||||
|
||||
class DecisionTreeNode(DecisionTreeNodeBase):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
70
pig_lite/decision_tree/dt_node.py
Normal file
70
pig_lite/decision_tree/dt_node.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from pig_lite.datastructures.queue import Queue
|
||||
|
||||
class DecisionTreeNodeBase():
|
||||
def __init__(self):
|
||||
self.label = None
|
||||
self.split_point = None
|
||||
self.split_feature = None
|
||||
self.left_child = None
|
||||
self.right_child = None
|
||||
|
||||
def print_node(self, height, level=1):
|
||||
node_width = 10
|
||||
n_spaces = 2 ** (height - level - 1) * node_width - node_width // 2
|
||||
if n_spaces > 0:
|
||||
text = " " * n_spaces
|
||||
else:
|
||||
text = ""
|
||||
|
||||
if self.label is None and self.split_feature is None:
|
||||
return f"{text} {text}"
|
||||
|
||||
if self.label is not None:
|
||||
text = f"{text}( {self.label} ){text}"
|
||||
elif self.split_feature is not None:
|
||||
text_snippet = f"(x{self.split_feature}:{self.split_point:.2f})"
|
||||
if len(text_snippet) != node_width:
|
||||
text_snippet = f" {text_snippet}"
|
||||
text = f"{text}{text_snippet}{text}"
|
||||
return text
|
||||
|
||||
def __str__(self):
|
||||
if self.label is not None: return f"({self.label})"
|
||||
|
||||
str_value = f"{self.split_feature}:{self.split_point:.2f}|{self.left_child}{self.right_child}"
|
||||
return str_value
|
||||
|
||||
def print_tree(self, height):
|
||||
visited = set()
|
||||
frontier = Queue()
|
||||
|
||||
lines = ['']
|
||||
|
||||
previous_level = 1
|
||||
frontier.put((self, 1))
|
||||
|
||||
while frontier.has_elements():
|
||||
current, level = frontier.get()
|
||||
if level > previous_level:
|
||||
lines.append('')
|
||||
previous_level = level
|
||||
lines[-1] += current.print_node(height, level)
|
||||
if current not in visited:
|
||||
visited.add(current)
|
||||
if current.left_child is not None:
|
||||
frontier.put((current.left_child, level + 1))
|
||||
else:
|
||||
if level < height: frontier.put((DecisionTreeNodeBase(), level + 1))
|
||||
if current.right_child is not None:
|
||||
frontier.put((current.right_child, level + 1))
|
||||
else:
|
||||
if level < height: frontier.put((DecisionTreeNodeBase(), level + 1))
|
||||
|
||||
for line in lines:
|
||||
print(line)
|
||||
return None
|
||||
|
||||
def split():
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
84
pig_lite/decision_tree/training_set.py
Normal file
84
pig_lite/decision_tree/training_set.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import json
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import warnings
|
||||
|
||||
class TrainingSet():
|
||||
def __init__(self, X, y):
|
||||
self.X = X
|
||||
self.y = y
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(dict(
|
||||
type=self.__class__.__name__,
|
||||
X=self.X.tolist(),
|
||||
y=self.y.tolist()
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def from_json(jsonstring):
|
||||
data = json.loads(jsonstring)
|
||||
return TrainingSet.from_dict(data)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data):
|
||||
return TrainingSet(
|
||||
np.array(data['X']).squeeze(),
|
||||
np.array(data['y'])
|
||||
)
|
||||
|
||||
def plot_node_boundaries(self, node, limit_left, limit_right, limit_top, limit_bottom, max_depth, level=1):
|
||||
|
||||
split_point = node.split_point
|
||||
limit_left_updated = limit_left
|
||||
limit_right_updated = limit_right
|
||||
limit_top_updated = limit_top
|
||||
limit_bottom_updated = limit_bottom
|
||||
|
||||
if node.split_feature == 0:
|
||||
if limit_bottom == limit_top:
|
||||
warnings.warn('limit_bottom equals limit_top; extending by 0.1')
|
||||
plt.plot([split_point, split_point], [limit_bottom - 0.1, limit_top + 0.1], color="purple", alpha=1 / level)
|
||||
else:
|
||||
plt.plot([split_point, split_point], [limit_bottom, limit_top], color="purple", alpha=1 / level)
|
||||
limit_left_updated = split_point
|
||||
limit_right_updated = split_point
|
||||
|
||||
else:
|
||||
if limit_left == limit_right:
|
||||
warnings.warn('limit_left equals limit_right; extending by 0.1')
|
||||
plt.plot([limit_left - 0.1, limit_right + 0.1], [split_point, split_point], color="purple", alpha=1 / level)
|
||||
else:
|
||||
plt.plot([limit_left, limit_right], [split_point, split_point], color="purple", alpha=1 / level)
|
||||
limit_top_updated = split_point
|
||||
limit_bottom_updated = split_point
|
||||
|
||||
if level == max_depth:
|
||||
return
|
||||
if node.left_child is not None: self.plot_node_boundaries(node.left_child, limit_left, limit_right_updated,
|
||||
limit_top_updated, limit_bottom, max_depth, level + 1)
|
||||
if node.right_child is not None: self.plot_node_boundaries(node.right_child, limit_left_updated, limit_right, limit_top,
|
||||
limit_bottom_updated, max_depth, level + 1)
|
||||
|
||||
def visualize(self, tree=None, max_height=None):
|
||||
symbols = [["x", "o"][index] for index in self.y]
|
||||
for y in set(self.y):
|
||||
X = self.X[self.y == y, :]
|
||||
plt.scatter(X[:, 0], X[:, 1],
|
||||
color=["red", "blue"][y],
|
||||
marker=symbols[y],
|
||||
label="class: {}".format(y))
|
||||
|
||||
if tree is not None:
|
||||
tree_height = tree.get_height(tree.root)
|
||||
if max_height is None or max_height > tree_height:
|
||||
max_height = tree_height
|
||||
self.plot_node_boundaries(tree.root,
|
||||
limit_left=min(self.X[:, 0]),
|
||||
limit_right=max(self.X[:, 0]),
|
||||
limit_top=max(self.X[:, 1]),
|
||||
limit_bottom=min(self.X[:, 1]),
|
||||
max_depth=max_height) # TODO: make parameterizable
|
||||
|
||||
0
pig_lite/environment/__init__.py
Normal file
0
pig_lite/environment/__init__.py
Normal file
60
pig_lite/environment/base.py
Normal file
60
pig_lite/environment/base.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import json
|
||||
import hashlib
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Environment:
|
||||
def step(self, action):
|
||||
raise NotImplementedError()
|
||||
|
||||
def reset(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_n_actions(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_n_states(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_flat_policy(self, policy):
|
||||
flat_policy = []
|
||||
for state in range(self.get_n_states()):
|
||||
for action in range(self.get_n_actions()):
|
||||
flat_policy.append((state, action, policy[state, action]))
|
||||
return flat_policy
|
||||
|
||||
def get_policy_hash(self, outcome):
|
||||
flat_policy = self.get_flat_policy(outcome.policy)
|
||||
flat_policy_as_str = ','.join(map(str, flat_policy))
|
||||
flat_policy_hash = hashlib.sha256(flat_policy_as_str.encode('UTF-8')).hexdigest()
|
||||
return flat_policy_hash
|
||||
|
||||
|
||||
class Outcome:
|
||||
def __init__(self, n_episodes, policy, V, Q):
|
||||
self.n_episodes = n_episodes
|
||||
self.policy = policy
|
||||
self.V = V
|
||||
self.Q = Q
|
||||
|
||||
def get_n_episodes(self):
|
||||
return self.n_episodes
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(dict(
|
||||
type=self.__class__.__name__,
|
||||
n_episodes=self.n_episodes,
|
||||
policy=self.policy.tolist(),
|
||||
V=self.V.tolist(),
|
||||
Q=self.Q.tolist(),
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def from_json(jsonstring):
|
||||
data = json.loads(jsonstring)
|
||||
return Outcome(
|
||||
data['n_episodes'],
|
||||
np.array(data['policy']),
|
||||
np.array(data['V']),
|
||||
np.array(data['Q'])
|
||||
)
|
||||
360
pig_lite/environment/gridworld.py
Normal file
360
pig_lite/environment/gridworld.py
Normal file
@@ -0,0 +1,360 @@
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from pig_lite.environment.base import Environment
|
||||
|
||||
DELTAS = [
|
||||
(-1, 0),
|
||||
(+1, 0),
|
||||
(0, -1),
|
||||
(0, +1)
|
||||
]
|
||||
NAMES = [
|
||||
'left',
|
||||
'right',
|
||||
'up',
|
||||
'down'
|
||||
]
|
||||
|
||||
def sample(rng, elements):
|
||||
""" Samples an element of `elements` randomly. """
|
||||
csp = np.cumsum([elm[0] for elm in elements])
|
||||
idx = np.argmax(csp > rng.uniform(0, 1))
|
||||
return elements[idx]
|
||||
|
||||
|
||||
class Gridworld(Environment):
|
||||
def __init__(self, seed, dones, rewards, starts):
|
||||
self.seed = seed
|
||||
self.rng = np.random.RandomState(seed)
|
||||
self.dones = dones
|
||||
self.rewards = rewards
|
||||
self.starts = starts
|
||||
|
||||
self.__compute_P()
|
||||
|
||||
def reset(self):
|
||||
""" Resets the environment of this gridworld to a randomly sampled start state. """
|
||||
_, self.state = sample(self.rng, self.starts)
|
||||
return self.state
|
||||
|
||||
def step(self, action):
|
||||
""" Performs the action on the gridworld, where next state of environment is sampled based on self.P. """
|
||||
_, self.state, reward, done = sample(self.rng, self.P[self.state][action])
|
||||
return self.state, reward, done
|
||||
|
||||
def get_n_actions(self):
|
||||
""" Returns the number of actions available in this gridworld. """
|
||||
return 4
|
||||
|
||||
def get_n_states(self):
|
||||
""" Returns the number of states available in this gridworld. """
|
||||
return np.prod(self.dones.shape)
|
||||
|
||||
def get_gamma(self):
|
||||
""" Returns discount factor gamma for this gridworld. """
|
||||
return 0.99
|
||||
|
||||
def __compute_P(self):
|
||||
""" Computes and stores the transitions for this gridworld. """
|
||||
w, h = self.dones.shape
|
||||
|
||||
def inbounds(i, j):
|
||||
""" Checks whether coordinates i and j are within the grid. """
|
||||
return i >= 0 and j >= 0 and i < w and j < h
|
||||
|
||||
self.P = dict()
|
||||
for i in range(0, w):
|
||||
for j in range(0, h):
|
||||
state = j * w + i
|
||||
self.P[state] = dict()
|
||||
|
||||
if self.dones[i, j]:
|
||||
for action in range(self.get_n_actions()):
|
||||
# make it absorbing
|
||||
self.P[state][action] = [(1, state, 0, True)]
|
||||
else:
|
||||
for action, (dx, dy) in enumerate(DELTAS):
|
||||
ortho_dir_probs = [
|
||||
(0.8, dx, dy),
|
||||
(0.1, dy, dx),
|
||||
(0.1, -dy, -dx)
|
||||
]
|
||||
transitions = []
|
||||
for p, di, dj in ortho_dir_probs:
|
||||
ni = i + di
|
||||
nj = j + dj
|
||||
if inbounds(ni, nj):
|
||||
# we move
|
||||
sprime = nj * w + ni
|
||||
done = self.dones[ni, nj]
|
||||
reward = self.rewards[ni, nj]
|
||||
transitions.append((p, sprime, reward, done))
|
||||
else:
|
||||
# stay in the same state, b/c we bounced
|
||||
sprime = state
|
||||
done = self.dones[i, j]
|
||||
reward = self.rewards[i, j]
|
||||
transitions.append((p, sprime, reward, done))
|
||||
|
||||
self.P[state][action] = transitions
|
||||
|
||||
def to_json(self):
|
||||
""" Converts and stores this gridworld to a JSON file. """
|
||||
return json.dumps(dict(
|
||||
type=self.__class__.__name__,
|
||||
seed=self.seed,
|
||||
dones=self.dones.tolist(),
|
||||
rewards=self.rewards.tolist(),
|
||||
starts=self.starts.tolist()
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def from_json(jsonstring):
|
||||
""" Loads given JSON file, and creates gridworld with information. """
|
||||
data = json.loads(jsonstring)
|
||||
return Gridworld(
|
||||
data['seed'],
|
||||
np.array(data['dones']),
|
||||
np.array(data['rewards']),
|
||||
np.array(data['starts'], dtype=np.int64),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data):
|
||||
""" Creates gridworld with information in given data-dictionary. """
|
||||
return Gridworld(
|
||||
data['seed'],
|
||||
np.array(data['dones']),
|
||||
np.array(data['rewards']),
|
||||
np.array(data['starts'], dtype=np.int64),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_random_instance(rng, size):
|
||||
""" Given random generator and problem size, generates Gridworld instance. """
|
||||
dones, rewards, starts = Gridworld.__generate(rng, size)
|
||||
return Gridworld(rng.randint(0, 2 ** 31), dones, rewards, starts)
|
||||
|
||||
@staticmethod
|
||||
def __generate(rng, size):
|
||||
""" Helper function that retrieves dones, rewards, starts for Gridworld instance generation. """
|
||||
dones = np.full((size, size), False, dtype=bool)
|
||||
rewards = np.zeros((size, size), dtype=np.int8) - 1
|
||||
|
||||
coordinates = []
|
||||
for i in range(1, size - 1):
|
||||
for j in range(1, size - 1):
|
||||
coordinates.append((i, j))
|
||||
indices = np.arange(len(coordinates))
|
||||
|
||||
chosen = rng.choice(indices, max(1, len(indices) // 10), replace=False)
|
||||
|
||||
for c in chosen:
|
||||
x, y = coordinates[c]
|
||||
dones[x, y] = True
|
||||
rewards[x, y] = -100
|
||||
|
||||
starts = np.array([[1, 0]])
|
||||
dones[-1, -1] = True
|
||||
rewards[-1, -1] = 100
|
||||
|
||||
return dones, rewards, starts
|
||||
|
||||
@staticmethod
|
||||
def get_minimum_problem_size():
|
||||
return 3
|
||||
|
||||
def visualize(self, outcome, coords=None, grid=None):
|
||||
""" Visualisation function for gridworld; plots environment, policy, Q. """
|
||||
policy = None
|
||||
Q = None
|
||||
V = None
|
||||
if outcome is not None:
|
||||
if outcome.policy is not None:
|
||||
policy = outcome.policy
|
||||
|
||||
if outcome.V is not None:
|
||||
V = outcome.V
|
||||
|
||||
if outcome.Q is not None:
|
||||
Q = outcome.Q
|
||||
|
||||
self._plot_environment_and_policy(policy, V, Q, show_coordinates=coords, show_grid=grid)
|
||||
|
||||
def _plot_environment_and_policy(self, policy=None,V=None, Q=None, show_coordinates=False,
|
||||
show_grid=False, plot_filename=None, debug_info=False):
|
||||
""" Function that plots environment and policy. """
|
||||
import matplotlib.pyplot as plt
|
||||
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
|
||||
dones_ax = axes[0, 0]
|
||||
rewards_ax = axes[0, 1]
|
||||
V_ax = axes[1, 0]
|
||||
Q_ax = axes[1, 1]
|
||||
|
||||
dones_ax.set_title('Terminal States and Policy')
|
||||
dones_ax.imshow(self.dones.T, cmap='gray_r', vmin=0, vmax=4)
|
||||
|
||||
rewards_ax.set_title('Immediate Rewards')
|
||||
rewards_ax.imshow(self.rewards.T, cmap='RdBu_r', vmin=-25, vmax=25)
|
||||
|
||||
if len(policy) > 0:
|
||||
self._plot_policy(dones_ax, policy)
|
||||
|
||||
w, h = self.dones.shape
|
||||
V_array = V.reshape(self.dones.shape).T
|
||||
V_ax.set_title('State Value Function $V(s)$')
|
||||
r = max(1e-13, np.max(np.abs(V_array)))
|
||||
V_ax.imshow(V_array.T, cmap='RdBu_r', vmin=-r, vmax=r)
|
||||
|
||||
if debug_info:
|
||||
for s in range(len(V)):
|
||||
sy, sx = divmod(s, w)
|
||||
V_ax.text(sx, sy, f'{sx},{sy}:{s}',
|
||||
color='w', fontdict=dict(size=6),
|
||||
horizontalalignment='center', verticalalignment='center')
|
||||
|
||||
Q_ax.set_title('State Action Value Function $Q(s, a)$')
|
||||
poly_patches_q_values = self._draw_Q(Q_ax, Q, debug_info)
|
||||
|
||||
def format_coord(x, y):
|
||||
for poly_patch, q_value in poly_patches_q_values:
|
||||
if poly_patch.contains_point(Q_ax.transData.transform((x, y))):
|
||||
return f'x:{x:4.2f} y:{y:4.2f} {q_value}'
|
||||
return f'x:{x:4.2f} y:{y:4.2f}'
|
||||
|
||||
Q_ax.format_coord = format_coord
|
||||
|
||||
for ax in [dones_ax, rewards_ax, V_ax, Q_ax]:
|
||||
ax.tick_params(
|
||||
top=show_coordinates,
|
||||
left=show_coordinates,
|
||||
labelleft=show_coordinates,
|
||||
labeltop=show_coordinates,
|
||||
right=False,
|
||||
bottom=False,
|
||||
labelbottom=False
|
||||
)
|
||||
|
||||
# Major ticks
|
||||
s = self.dones.shape[0]
|
||||
ax.set_xticks(np.arange(0, s, 1))
|
||||
ax.set_yticks(np.arange(0, s, 1))
|
||||
|
||||
# Minor ticks
|
||||
ax.set_xticks(np.arange(-.5, s, 1), minor=True)
|
||||
ax.set_yticks(np.arange(-.5, s, 1), minor=True)
|
||||
|
||||
if show_grid:
|
||||
for color, ax in zip(['m', 'w', 'w'], [dones_ax, rewards_ax, V_ax]):
|
||||
# Gridlines based on minor ticks
|
||||
ax.grid(which='minor', color=color, linestyle='-', linewidth=1)
|
||||
|
||||
plt.tight_layout()
|
||||
if plot_filename is not None:
|
||||
plt.savefig(plot_filename)
|
||||
plt.close(fig)
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def _plot_policy(self, ax, policy):
|
||||
""" Function that plots policy. """
|
||||
w, h = self.dones.shape
|
||||
xs = np.arange(w)
|
||||
ys = np.arange(h)
|
||||
xx, yy = np.meshgrid(xs, ys)
|
||||
|
||||
# we need a quiver for each of the four action
|
||||
quivers = list()
|
||||
for a in range(self.get_n_actions()):
|
||||
quivers.append(list())
|
||||
|
||||
# we parse the textual description of the lake
|
||||
for s in range(self.get_n_states()):
|
||||
y, x = divmod(s, w)
|
||||
if self.dones[x, y]:
|
||||
for a in range(self.get_n_actions()):
|
||||
quivers[a].append((0., 0.))
|
||||
else:
|
||||
for a in range(self.get_n_actions()):
|
||||
wdx, wdy = DELTAS[a]
|
||||
corrected = np.array([wdx, -wdy])
|
||||
quivers[a].append(corrected * policy[s, a])
|
||||
|
||||
# plot each quiver
|
||||
for quiver in quivers:
|
||||
q = np.array(quiver)
|
||||
ax.quiver(xx, yy, q[:, 0], q[:, 1], units='xy', scale=1.5)
|
||||
|
||||
def _draw_Q(self, ax, Q, debug_info):
|
||||
""" Function that draws Q. """
|
||||
pattern = np.zeros(self.dones.shape)
|
||||
ax.imshow(pattern, cmap='gray_r')
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.cm import ScalarMappable
|
||||
from matplotlib.colors import Normalize
|
||||
from matplotlib.patches import Rectangle, Polygon
|
||||
w, h = self.dones.shape
|
||||
|
||||
r = max(1e-13, np.max(np.abs(Q)))
|
||||
norm = Normalize(vmin=-r, vmax=r)
|
||||
cmap = plt.get_cmap('RdBu_r')
|
||||
sm = ScalarMappable(norm, cmap)
|
||||
|
||||
hover_polygons = []
|
||||
for state in range(len(Q)):
|
||||
qs = Q[state]
|
||||
# print('qs', qs)
|
||||
y, x = divmod(state, w)
|
||||
if self.dones[x, y]:
|
||||
continue
|
||||
y += 0.5
|
||||
x += 0.5
|
||||
|
||||
dx = 1
|
||||
dy = 1
|
||||
|
||||
ulx = (x - 1) * dx
|
||||
uly = (y - 1) * dy
|
||||
|
||||
rect = Rectangle(
|
||||
xy=(ulx, uly),
|
||||
width=dx,
|
||||
height=dy,
|
||||
edgecolor='k',
|
||||
facecolor='none'
|
||||
)
|
||||
ax.add_artist(rect)
|
||||
|
||||
mx = (x - 1) * dx + dx / 2.
|
||||
my = (y - 1) * dy + dy / 2.
|
||||
|
||||
ul = ulx, uly
|
||||
ur = ulx + dx, uly
|
||||
ll = ulx, uly + dy
|
||||
lr = ulx + dx, uly + dy
|
||||
m = mx, my
|
||||
|
||||
up = [ul, m, ur]
|
||||
left = [ul, m, ll]
|
||||
right = [ur, m, lr]
|
||||
down = [ll, m, lr]
|
||||
action_polys = [left, right, up, down]
|
||||
for a, poly in enumerate(action_polys):
|
||||
poly_patch = Polygon(
|
||||
poly,
|
||||
edgecolor='k',
|
||||
linewidth=0.1,
|
||||
facecolor=sm.to_rgba(qs[a])
|
||||
)
|
||||
if debug_info:
|
||||
mmx = np.mean([x for x, y in poly])
|
||||
mmy = np.mean([y for x, y in poly])
|
||||
sss = '\n'.join(map(str, self.P[state][a]))
|
||||
ax.text(mmx, mmy, f'{NAMES[a][0]}:{sss}',
|
||||
fontdict=dict(size=5), horizontalalignment='center',
|
||||
verticalalignment='center')
|
||||
|
||||
hover_polygons.append((poly_patch, f'{NAMES[a]}:{qs[a]:4.2f}'))
|
||||
ax.add_artist(poly_patch)
|
||||
return hover_polygons
|
||||
0
pig_lite/game/__init__.py
Normal file
0
pig_lite/game/__init__.py
Normal file
87
pig_lite/game/base.py
Normal file
87
pig_lite/game/base.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import hashlib
|
||||
|
||||
class Node(object):
|
||||
def __init__(self, parent, state, action, player, depth):
|
||||
self.parent = parent
|
||||
self.state = state
|
||||
self.action = action
|
||||
self.player = player
|
||||
self.depth = depth
|
||||
|
||||
def key(self):
|
||||
# if state is composed of other stuff (dict, set, ...)
|
||||
# make it a tuple containing hashable datatypes
|
||||
# (this is supposed to be overridden by subclasses)
|
||||
return tuple(self.state) + (self.player, )
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.key())
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(self) == type(other):
|
||||
return self.key() == other.key()
|
||||
raise ValueError('cannot simply compare two different node types')
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
return 'Node(id:{}, parent:{}, state:{}, action:{}, player:{}, depth:{})'.format(
|
||||
id(self),
|
||||
id(self.parent),
|
||||
self.state,
|
||||
self.action,
|
||||
self.player,
|
||||
self.depth
|
||||
)
|
||||
|
||||
def get_move_sequence(self):
|
||||
current = self
|
||||
reverse_sequence = []
|
||||
while current.parent is not None:
|
||||
reverse_sequence.append((current.player, current.action))
|
||||
current = current.parent
|
||||
return list(reversed(reverse_sequence))
|
||||
|
||||
def get_move_sequence_hash(self):
|
||||
move_sequence = self.get_move_sequence()
|
||||
move_sequence_as_str = ';'.join(map(str, move_sequence))
|
||||
move_sequence_hash = hashlib.sha256(move_sequence_as_str.encode('UTF-8')).hexdigest()
|
||||
return move_sequence_hash
|
||||
|
||||
class Game(object):
|
||||
def get_number_of_expanded_nodes(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_start_node(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def winner(self, node):
|
||||
raise NotImplementedError()
|
||||
|
||||
def successors(self, node):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_max_player(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_json(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_move_sequence(self, end: Node):
|
||||
if end is None:
|
||||
return list()
|
||||
return end.get_move_sequence()
|
||||
|
||||
def get_move_sequence_hash(self, end: Node):
|
||||
if end is None:
|
||||
return ''
|
||||
return end.get_move_sequence_hash()
|
||||
|
||||
@staticmethod
|
||||
def from_json(jsonstring):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def get_minimum_problem_size():
|
||||
raise NotImplementedError()
|
||||
371
pig_lite/game/tictactoe.py
Normal file
371
pig_lite/game/tictactoe.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from copy import deepcopy
|
||||
from pig_lite.game.base import Node, Game
|
||||
|
||||
|
||||
class TTTNode(Node):
|
||||
def key(self):
|
||||
return tuple(self.state.flatten().tolist() + [self.player])
|
||||
|
||||
def __repr__(self):
|
||||
return '"TTTNode(\nid:{}\nparent:{}\nboard:\n{}\nplayer:\n{}\naction:\n{}\ndepth:{})"'.format(
|
||||
id(self),
|
||||
id(self.parent),
|
||||
# this needs to be printed transposed, so it fits together with
|
||||
# how matplotlib's 'imshow' renders images
|
||||
self.state.T,
|
||||
self.player,
|
||||
self.action,
|
||||
self.depth
|
||||
)
|
||||
|
||||
def pretty_print(self):
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.colors import ListedColormap
|
||||
cm = ListedColormap(['tab:blue', 'lightgray', 'tab:orange'])
|
||||
print('State of the board:')
|
||||
plt.figure(figsize=(2, 2))
|
||||
plt.imshow(self.state.T, cmap=cm)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
print('Performed moves: {}'.format(self.depth))
|
||||
|
||||
|
||||
class TicTacToe(Game):
|
||||
def __init__(self, rng=None, depth=None):
|
||||
self.n_expands = 0
|
||||
self.play_randomly(rng, depth)
|
||||
|
||||
def play_randomly(self, rng, depth):
|
||||
""" Initialises self.start_node to be either empty board, or board at given depth after random playing. """
|
||||
empty_board = np.zeros((3, 3), dtype=int)
|
||||
start_from_empty = TTTNode(None, empty_board, None, 1, 0)
|
||||
if rng is None or depth is None or depth == 0:
|
||||
self.start_node = start_from_empty
|
||||
else:
|
||||
# proceed playing randomly until either 'depth' is reached,
|
||||
# or the node is a terminal node
|
||||
nodes = []
|
||||
successors = [start_from_empty]
|
||||
while True:
|
||||
index = rng.randint(0, len(successors))
|
||||
current = successors[index]
|
||||
|
||||
if current.depth == depth:
|
||||
break
|
||||
|
||||
nodes.append(current)
|
||||
terminal, winner = self.outcome(current)
|
||||
if terminal:
|
||||
break
|
||||
successors = self.successors(current)
|
||||
|
||||
for node in successors:
|
||||
nodes.append(node)
|
||||
|
||||
self.start_node = TTTNode(None, current.state, None, current.player, 0)
|
||||
|
||||
def get_start_node(self):
|
||||
""" Returns start node of this Game. """
|
||||
return self.start_node
|
||||
|
||||
def outcome(self, node):
|
||||
""" Returns tuple stating whether game is finished or not, and winner (or None otherwise). """
|
||||
board = node.state
|
||||
for player in [-1, 1]:
|
||||
# checks rows and columns
|
||||
for i in range(3):
|
||||
if (board[i, :] == player).all() or (board[:, i] == player).all():
|
||||
return True, player
|
||||
|
||||
# checks diagonals
|
||||
if (np.diag(board) == player).all() or (np.diag(np.rot90(board)) == player).all():
|
||||
return True, player
|
||||
|
||||
# if board is full, and none of the conditions above are true,
|
||||
# nobody has won --- it's a draw
|
||||
if (board != 0).all():
|
||||
return True, None
|
||||
|
||||
# else, continue
|
||||
return False, None
|
||||
|
||||
def get_max_player(self):
|
||||
""" Returns identifier of MAX player used in this game. """
|
||||
return 1
|
||||
|
||||
def successor(self, node, action):
|
||||
""" Performs given action at given game node, and returns successor TTT node. """
|
||||
board = node.state
|
||||
player = node.player
|
||||
|
||||
next_board = board.copy()
|
||||
next_board[action] = player
|
||||
|
||||
if player == 1:
|
||||
next_player = -1
|
||||
else:
|
||||
next_player = 1
|
||||
|
||||
return TTTNode(
|
||||
node,
|
||||
next_board,
|
||||
action,
|
||||
next_player,
|
||||
node.depth + 1
|
||||
)
|
||||
|
||||
def get_number_of_expanded_nodes(self):
|
||||
return self.n_expands
|
||||
|
||||
def successors(self, node):
|
||||
""" Given a game node, returns all possible successor nodes based on all actions that can be performed. """
|
||||
self.n_expands += 1
|
||||
terminal, winner = self.outcome(node)
|
||||
|
||||
if terminal:
|
||||
return []
|
||||
else:
|
||||
successor_nodes = []
|
||||
# iterate through all possible coordinates (==actions)
|
||||
for action in zip(*np.nonzero(node.state == 0)):
|
||||
successor_nodes.append(self.successor(node, action))
|
||||
return successor_nodes
|
||||
|
||||
def to_json(self):
|
||||
""" Converts and stores this TTT game to a JSON file. """
|
||||
return json.dumps(dict(
|
||||
type=self.__class__.__name__,
|
||||
start_state=self.start_node.state.tolist(),
|
||||
start_player=self.start_node.player
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def from_json(jsonstring):
|
||||
""" Loads given JSON file, and creates game with information. """
|
||||
data = json.loads(jsonstring)
|
||||
|
||||
ttt = TicTacToe()
|
||||
ttt.start_node = TTTNode(
|
||||
None,
|
||||
np.array(data['start_state'], dtype=int),
|
||||
None,
|
||||
data['start_player'],
|
||||
0
|
||||
)
|
||||
return ttt
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data):
|
||||
""" Creates game with information in given data-dictionary. """
|
||||
ttt = TicTacToe()
|
||||
ttt.start_node = TTTNode(
|
||||
None,
|
||||
np.array(data['start_state'], dtype=int),
|
||||
None,
|
||||
data['start_player'],
|
||||
0
|
||||
)
|
||||
return ttt
|
||||
|
||||
@staticmethod
|
||||
def get_minimum_problem_size():
|
||||
return 0
|
||||
|
||||
def visualize(self, move_sequence, show_possible=False, tree_name=''):
|
||||
game = deepcopy(self)
|
||||
nodes = []
|
||||
current = game.get_start_node()
|
||||
nodes.append(current)
|
||||
for player, move in move_sequence:
|
||||
if show_possible:
|
||||
successors = game.successors(current)
|
||||
nodes.extend(successors)
|
||||
current = None
|
||||
for succ in successors:
|
||||
if succ.action == move:
|
||||
current = succ
|
||||
break
|
||||
else:
|
||||
current = game.successor(current, move)
|
||||
nodes.append(current)
|
||||
|
||||
try:
|
||||
self.networkx_plot_game_tree(tree_name, nodes)
|
||||
except ImportError:
|
||||
print('#' * 30)
|
||||
print('#' * 30)
|
||||
print('starting position')
|
||||
print(self.get_start_node())
|
||||
print('#' * 30)
|
||||
print('#' * 30)
|
||||
print('-' * 30)
|
||||
print('sequence of nodes')
|
||||
for node in nodes:
|
||||
print('-' * 30)
|
||||
print(node)
|
||||
terminal, winner = game.outcome(node)
|
||||
print('terminal {}, winner {}'.format(terminal, winner))
|
||||
|
||||
def networkx_plot_game_tree(self, title, nodes, highlight=None):
|
||||
# TODO: this needs some serious refactoring
|
||||
# use visitors for styling, for example, instead of cumbersome dicts
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
from networkx.drawing.nx_pydot import graphviz_layout
|
||||
from matplotlib.offsetbox import OffsetImage, AnnotationBbox, HPacker, VPacker, TextArea
|
||||
|
||||
fig, tree_ax = plt.subplots()
|
||||
tree_ax.set_title(title)
|
||||
G = nx.DiGraph(ordering='out')
|
||||
nodes_extra = dict()
|
||||
edges_extra = dict()
|
||||
|
||||
def sort_key(node):
|
||||
if node.action is None:
|
||||
return (-1, -1)
|
||||
return node.action
|
||||
|
||||
for node in sorted(nodes, key=sort_key):
|
||||
G.add_node(id(node), search_node=node)
|
||||
terminal, winner = self.outcome(node)
|
||||
nodes_extra[id(node)] = dict(
|
||||
board=node.state,
|
||||
player=node.player,
|
||||
depth=node.depth,
|
||||
terminal=terminal,
|
||||
winner=winner
|
||||
)
|
||||
|
||||
for node in nodes:
|
||||
if node.parent is not None:
|
||||
edge = id(node.parent), id(node)
|
||||
G.add_edge(*edge, parent_node=node.parent)
|
||||
edges_extra[edge] = dict(
|
||||
label='{}'.format(node.action),
|
||||
parent_player=node.parent.player
|
||||
)
|
||||
|
||||
node_size = 1000
|
||||
positions = graphviz_layout(G, prog='dot')
|
||||
|
||||
from matplotlib.colors import Normalize, LinearSegmentedColormap
|
||||
|
||||
blue_orange = LinearSegmentedColormap.from_list(
|
||||
'blue_orange',
|
||||
['tab:blue', 'lightgray', 'tab:orange']
|
||||
)
|
||||
|
||||
inf = float('Inf')
|
||||
x_range = [inf, -inf]
|
||||
y_range = [inf, -inf]
|
||||
for id_node, pos in positions.items():
|
||||
x, y = pos
|
||||
x_range = [min(x, x_range[0]), max(x, x_range[1])]
|
||||
y_range = [min(y, y_range[0]), max(y, y_range[1])]
|
||||
|
||||
player = nodes_extra[id_node]['player']
|
||||
text_player = 'p:{}'.format(player)
|
||||
text_depth = 'd:{}'.format(nodes_extra[id_node]['depth'])
|
||||
color_player = 'tab:blue' if player == -1 else 'tab:orange'
|
||||
|
||||
frameon = False
|
||||
bboxprops = None
|
||||
if nodes_extra[id_node]['terminal']:
|
||||
winner = nodes_extra[id_node]['winner']
|
||||
frameon = True
|
||||
if winner is None:
|
||||
edgecolor = 'tab:purple'
|
||||
else:
|
||||
edgecolor = 'tab:blue' if winner == -1 else 'tab:orange'
|
||||
bboxprops = dict(
|
||||
facecolor='none',
|
||||
edgecolor=edgecolor
|
||||
)
|
||||
color_player = 'k'
|
||||
text_player = 'w:{}'.format(winner)
|
||||
if winner is None:
|
||||
text_player = ''
|
||||
|
||||
# needs to be transposed b/c image coordinates etc ...
|
||||
board = nodes_extra[id_node]['board'].T
|
||||
textbox_player = TextArea(text_player, textprops=dict(size=6, color=color_player))
|
||||
textbox_depth = TextArea(text_depth, textprops=dict(size=6))
|
||||
|
||||
textbox_children = [textbox_player, textbox_depth]
|
||||
|
||||
if highlight is not None:
|
||||
if id_node in highlight:
|
||||
if nodes_extra[id_node]['terminal']:
|
||||
frameon = True
|
||||
if nodes_extra[id_node]['winner'] is None:
|
||||
edgecolor = 'tab:purple'
|
||||
else:
|
||||
edgecolor = 'tab:blue' if winner == -1 else 'tab:orange'
|
||||
|
||||
bboxprops = dict(
|
||||
facecolor='none',
|
||||
edgecolor=edgecolor
|
||||
)
|
||||
|
||||
if len(highlight[id_node]) > 0:
|
||||
for key, value in highlight[id_node].items():
|
||||
textbox_children.append(
|
||||
TextArea('{}:{}'.format(key, value), textprops=dict(size=6))
|
||||
)
|
||||
|
||||
imagebox = OffsetImage(board, zoom=5, cmap=blue_orange, norm=Normalize(vmin=-1, vmax=1))
|
||||
packed = HPacker(
|
||||
align='center',
|
||||
children=[
|
||||
imagebox,
|
||||
VPacker(
|
||||
align='center',
|
||||
children=textbox_children,
|
||||
sep=0.1, pad=0.1
|
||||
)
|
||||
],
|
||||
sep=0.1, pad=0.1
|
||||
)
|
||||
|
||||
ab = AnnotationBbox(packed, pos, xycoords='data', frameon=frameon, bboxprops=bboxprops)
|
||||
tree_ax.add_artist(ab)
|
||||
|
||||
def min_dist(a, b):
|
||||
if a == b:
|
||||
return [a - 1, b + 1]
|
||||
else:
|
||||
return [a - 0.9 * abs(a), b + 0.1 * abs(b)]
|
||||
|
||||
x_range = min_dist(*x_range)
|
||||
y_range = min_dist(*y_range)
|
||||
tree_ax.set_xlim(x_range)
|
||||
tree_ax.set_ylim(y_range)
|
||||
|
||||
orange_edges = []
|
||||
blue_edges = []
|
||||
|
||||
for edge, extra in edges_extra.items():
|
||||
if extra['parent_player'] == -1:
|
||||
blue_edges.append(edge)
|
||||
else:
|
||||
orange_edges.append(edge)
|
||||
|
||||
for color, edgelist in [('tab:orange', orange_edges), ('tab:blue', blue_edges)]:
|
||||
nx.draw_networkx_edges(
|
||||
G, positions,
|
||||
edgelist=edgelist,
|
||||
edge_color=color,
|
||||
arrowstyle='-|>',
|
||||
arrowsize=10,
|
||||
node_size=node_size,
|
||||
ax=tree_ax
|
||||
)
|
||||
edge_labels = {edge_id: edge['label'] for edge_id, edge in edges_extra.items()}
|
||||
nx.draw_networkx_edge_labels(G, positions, edge_labels, ax=tree_ax, font_size=6)
|
||||
|
||||
tree_ax.axis('off')
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
0
pig_lite/instance_generation/__init__.py
Normal file
0
pig_lite/instance_generation/__init__.py
Normal file
5
pig_lite/instance_generation/enc.py
Normal file
5
pig_lite/instance_generation/enc.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# this is the common encoding for different level tiles
|
||||
WALL = 1
|
||||
SPACE = 0
|
||||
EXPOSED = -1
|
||||
UNDETERMINED = -2
|
||||
96
pig_lite/instance_generation/problem_factory.py
Normal file
96
pig_lite/instance_generation/problem_factory.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import json
|
||||
|
||||
from pig_lite.problem.simple_2d import Simple2DProblem, MazeLevel, TerrainLevel, RoomLevel
|
||||
from pig_lite.environment.gridworld import Gridworld
|
||||
from pig_lite.game.tictactoe import TicTacToe
|
||||
from pig_lite.decision_tree.training_set import TrainingSet
|
||||
|
||||
# this is the common encoding for different level tiles
|
||||
encoding = {
|
||||
'WALL': 1,
|
||||
'SPACE': 0,
|
||||
'EXPOSED': -1,
|
||||
'UNDETERMINED': -2
|
||||
}
|
||||
|
||||
class ProblemFactory():
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def generate_problem(problem_type, problem_size, rng):
|
||||
if problem_type == 'maze':
|
||||
level = MazeLevel(rng, size=problem_size)
|
||||
return Simple2DProblem(level.get_field(),
|
||||
level.get_costs(),
|
||||
level.get_start(),
|
||||
level.get_end())
|
||||
elif problem_type == 'terrain':
|
||||
level = TerrainLevel(rng, size=problem_size)
|
||||
return Simple2DProblem(level.get_field(),
|
||||
level.get_costs(),
|
||||
level.get_start(),
|
||||
level.get_end())
|
||||
elif problem_type == 'rooms':
|
||||
level = RoomLevel(rng, size=problem_size)
|
||||
return Simple2DProblem(level.get_field(),
|
||||
level.get_costs(),
|
||||
level.get_start(),
|
||||
level.get_end())
|
||||
elif problem_type == 'tictactoe':
|
||||
return TicTacToe(rng, depth=problem_size)
|
||||
elif problem_type == 'gridworld':
|
||||
return Gridworld.get_random_instance(rng, size=problem_size)
|
||||
elif problem_type =='trainset':
|
||||
raise NotImplementedError(f'problem_type {problem_type} is not implemented yet')
|
||||
else:
|
||||
raise ValueError(f'unknown problem_type {problem_type}')
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_problem_from_json(json_path):
|
||||
with open(json_path, 'r') as file:
|
||||
data = json.load(file)
|
||||
problem_type = data['type']
|
||||
|
||||
if problem_type == 'Simple2DProblem':
|
||||
problem = Simple2DProblem.from_dict(data)
|
||||
return problem
|
||||
elif problem_type == 'TicTacToe':
|
||||
problem = TicTacToe.from_dict(data)
|
||||
return problem
|
||||
elif problem_type == 'Gridworld':
|
||||
problem = Gridworld.from_dict(data)
|
||||
return problem
|
||||
elif problem_type == 'TrainingSet':
|
||||
problem = TrainingSet.from_dict(data)
|
||||
return problem
|
||||
else:
|
||||
raise ValueError(f"Unknown problem type: {problem_type}")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_problem_from_dict(data, problem_type='Simple2DProblem'):
|
||||
import numpy as np
|
||||
if problem_type == 'Simple2DProblem':
|
||||
if not ('board' in data.keys() and 'costs' in data.keys()
|
||||
and 'start_state' in data.keys() and 'end_state' in data.keys()):
|
||||
raise ValueError('data dict must contain: "board", "costs", "start_state" and "end_state"')
|
||||
if np.array(data['board']).shape != np.array(data['costs']).shape:
|
||||
raise ValueError('data["board"] and data["costs"] must have same shape')
|
||||
problem = Simple2DProblem.from_dict(data)
|
||||
return problem
|
||||
if problem_type == 'TicTacToe':
|
||||
if not ('start_state' in data.keys() and 'start_player' in data.keys()):
|
||||
raise ValueError('data dict must contain: "start_state", "start_player"')
|
||||
problem = TicTacToe.from_dict(data)
|
||||
return problem
|
||||
if problem_type == 'Gridworld':
|
||||
if not ('seed' in data.keys() and 'dones' in data.keys()
|
||||
and 'rewards' in data.keys() and 'starts' in data.keys()):
|
||||
raise ValueError('data dict must contain: "seed", "dones", "rewards", "starts"')
|
||||
problem = Gridworld.from_dict(data)
|
||||
return problem
|
||||
else:
|
||||
raise NotImplementedError(f'problem_type {problem_type} is not implemented yet')
|
||||
|
||||
92
pig_lite/problem/base.py
Normal file
92
pig_lite/problem/base.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import hashlib
|
||||
|
||||
class Node(object):
|
||||
def __init__(self, parent, state, action, cost, depth):
|
||||
self.parent = parent
|
||||
self.state = state
|
||||
self.action = action
|
||||
self.cost = cost
|
||||
self.depth = depth
|
||||
|
||||
def key(self):
|
||||
# if state is composed of other stuff (dict, set, ...)
|
||||
# make it a tuple containing hashable datatypes
|
||||
# (this is supposed to be overridden by subclasses)
|
||||
return self.state
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.key())
|
||||
|
||||
def __eq__(self, other):
|
||||
if type(self) == type(other):
|
||||
return self.key() == other.key()
|
||||
raise ValueError('cannot simply compare two different node types')
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
return 'Node(id:{}, parent:{}, state:{}, action:{}, cost:{}, depth:{})'.format(
|
||||
id(self),
|
||||
id(self.parent),
|
||||
self.state,
|
||||
self.action,
|
||||
self.cost,
|
||||
self.depth
|
||||
)
|
||||
|
||||
def get_action_sequence(self):
|
||||
current = self
|
||||
reverse_sequence = []
|
||||
while current.parent is not None:
|
||||
reverse_sequence.append(current.action)
|
||||
current = current.parent
|
||||
return list(reversed(reverse_sequence))
|
||||
|
||||
def get_action_sequence_hash(self):
|
||||
action_sequence = self.get_action_sequence()
|
||||
action_sequence_as_str = ','.join(map(str, action_sequence))
|
||||
action_sequence_hash = hashlib.sha256(action_sequence_as_str.encode('UTF-8')).hexdigest() # should solution node return hashcode?
|
||||
return action_sequence_hash
|
||||
|
||||
def pretty_print(self):
|
||||
print(f"state {self.state} was reached following the sequence {self.get_action_sequence()} (cost: {self.cost}, depth: {self.depth})")
|
||||
|
||||
|
||||
class Problem(object):
|
||||
def get_number_of_expanded_nodes(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_start_node(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_end_node(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_end(self, node):
|
||||
raise NotImplementedError()
|
||||
|
||||
def action_cost(self, state, action):
|
||||
raise NotImplementedError()
|
||||
|
||||
def successors(self, node):
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_json(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def visualize(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_action_sequence(self, end: Node):
|
||||
if end is None:
|
||||
return list()
|
||||
return end.get_action_sequence()
|
||||
|
||||
@staticmethod
|
||||
def from_json(jsonstring):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def get_minimum_problem_size():
|
||||
raise NotImplementedError()
|
||||
529
pig_lite/problem/simple_2d.py
Normal file
529
pig_lite/problem/simple_2d.py
Normal file
@@ -0,0 +1,529 @@
|
||||
from pig_lite.problem.base import Problem, Node
|
||||
from pig_lite.instance_generation import enc
|
||||
import json
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
from matplotlib.colors import TABLEAU_COLORS, XKCD_COLORS
|
||||
|
||||
class BaseLevel():
|
||||
def __init__(self, rng, size) -> None:
|
||||
self.rng = rng
|
||||
self.size = size
|
||||
self.field = None
|
||||
self.costs = None
|
||||
self.start = None
|
||||
self.end = None
|
||||
|
||||
self.initialize_level()
|
||||
|
||||
def initialize_level(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_field(self):
|
||||
return self.field
|
||||
|
||||
def get_costs(self):
|
||||
return self.costs
|
||||
|
||||
def get_start(self):
|
||||
return self.start
|
||||
|
||||
def get_end(self):
|
||||
return self.end
|
||||
|
||||
|
||||
class MazeLevel(BaseLevel):
|
||||
# this method generates a random maze according to prim's randomized
|
||||
# algorithm
|
||||
# http://en.wikipedia.org/wiki/Maze_generation_algorithm#Randomized_Prim.27s_algorithm
|
||||
|
||||
def __init__(self, rng, size):
|
||||
super().__init__(rng, size)
|
||||
|
||||
|
||||
def initialize_level(self):
|
||||
|
||||
self.field = np.full((self.size, self.size), enc.WALL, dtype=np.int8)
|
||||
self.costs = self.rng.randint(1, 5, self.field.shape, dtype=np.int8)
|
||||
|
||||
self.start = (0, 0)
|
||||
|
||||
self.deltas = [
|
||||
(0, 1),
|
||||
(0, -1),
|
||||
(1, 0),
|
||||
(-1, 0)
|
||||
]
|
||||
self.random_walk()
|
||||
end = np.where(self.field == enc.SPACE)
|
||||
self.end = (int(end[0][-1]), int(end[1][-1]))
|
||||
|
||||
self.replace_walls_with_high_cost_tiles()
|
||||
|
||||
def replace_walls_with_high_cost_tiles(self):
|
||||
# select only coordinates of walls
|
||||
walls = np.where(self.field == enc.WALL)
|
||||
|
||||
n_walls = len(walls[0])
|
||||
|
||||
# replace about a tenth of the walls...
|
||||
to_replace = self.rng.randint(0, n_walls, n_walls // 9)
|
||||
|
||||
# ... with space, but very *costly* space (it's trap!)
|
||||
for ri in to_replace:
|
||||
x, y = walls[0][ri], walls[1][ri]
|
||||
self.field[x, y] = enc.SPACE
|
||||
self.costs[x, y] = 9
|
||||
|
||||
def random_walk(self):
|
||||
frontier = list()
|
||||
|
||||
sx, sy = self.start
|
||||
self.field[sx, sy] = enc.SPACE
|
||||
frontier.extend(self.get_walls(self.start))
|
||||
|
||||
while len(frontier) > 0:
|
||||
current, opposing = frontier[self.rng.randint(len(frontier))]
|
||||
|
||||
cx, cy = current
|
||||
ox, oy = opposing
|
||||
if self.field[ox, oy] == enc.WALL:
|
||||
self.field[cx, cy] = enc.SPACE
|
||||
self.field[ox, oy] = enc.SPACE
|
||||
frontier.extend(self.get_walls(opposing))
|
||||
else:
|
||||
frontier.remove((current, opposing))
|
||||
|
||||
def in_bounds(self, position):
|
||||
x, y = position
|
||||
return x >= 0 and y >= 0 and x < self.size and y < self.size
|
||||
|
||||
def get_walls(self, position):
|
||||
walls = []
|
||||
px, py = position
|
||||
for dx, dy in self.deltas:
|
||||
cx = px + dx
|
||||
cy = py + dy
|
||||
current = (cx, cy)
|
||||
|
||||
ox = px + 2 * dx
|
||||
oy = py + 2 * dy
|
||||
opposing = (ox, oy)
|
||||
|
||||
if (self.in_bounds(current) and self.field[cx, cy] == enc.WALL and self.in_bounds(opposing)):
|
||||
walls.append((current, opposing))
|
||||
return walls
|
||||
|
||||
|
||||
# this is code taken from
|
||||
# https://github.com/dandrino/terrain-erosion-3-ways/blob/master/util.py
|
||||
# Copyright (c) 2018 Daniel Andrino
|
||||
# (project is MIT licensed)
|
||||
def fbm(shape, p, lower=-np.inf, upper=np.inf):
|
||||
freqs = tuple(np.fft.fftfreq(n, d=1.0 / n) for n in shape)
|
||||
freq_radial = np.hypot(*np.meshgrid(*freqs))
|
||||
envelope = (np.power(freq_radial, p, where=freq_radial != 0) *
|
||||
(freq_radial > lower) * (freq_radial < upper))
|
||||
envelope[0][0] = 0.0
|
||||
phase_noise = np.exp(2j * np.pi * np.random.rand(*shape))
|
||||
return np.real(np.fft.ifft2(np.fft.fft2(phase_noise) * envelope))
|
||||
|
||||
|
||||
class TerrainLevel(BaseLevel):
|
||||
def __init__(self, rng, size):
|
||||
super().__init__(rng, size)
|
||||
|
||||
def initialize_level(self):
|
||||
|
||||
self.field = np.full((self.size, self.size), enc.SPACE, dtype=np.int8)
|
||||
|
||||
self.costs = fbm(self.field.shape, -2)
|
||||
self.costs -= self.costs.min()
|
||||
self.costs /= self.costs.max()
|
||||
self.costs *= 9
|
||||
self.costs += 1
|
||||
self.costs = self.costs.astype(int)
|
||||
|
||||
self.start = (0, 0)
|
||||
self.end = (self.size - 1, self.size - 1)
|
||||
|
||||
x = 0
|
||||
y = self.size - 1
|
||||
for i in range(0, self.size):
|
||||
self.field[x, y] = enc.WALL
|
||||
x += 1
|
||||
y -= 1
|
||||
|
||||
self.replace_one_or_more_walls()
|
||||
|
||||
def replace_one_or_more_walls(self):
|
||||
# select only coordinates of walls
|
||||
walls = np.where(self.field == enc.WALL)
|
||||
n_walls = len(walls[0])
|
||||
n_replace = self.rng.randint(1, max(2, n_walls // 5))
|
||||
to_replace = self.rng.randint(0, n_walls, n_replace)
|
||||
|
||||
for ri in to_replace:
|
||||
x, y = walls[0][ri], walls[1][ri]
|
||||
self.field[x, y] = enc.SPACE
|
||||
|
||||
|
||||
class RoomLevel(BaseLevel):
|
||||
def __init__(self, rng, size):
|
||||
super().__init__(rng, size)
|
||||
|
||||
def initialize_level(self):
|
||||
self.field = np.full((self.size, self.size), enc.SPACE, dtype=np.int8)
|
||||
self.costs = np.ones_like(self.field, dtype=np.float32)
|
||||
|
||||
k = 1
|
||||
self.subdivide(self.field.view(), self.costs.view(), k, 0, 0)
|
||||
|
||||
# such a *crutch*!
|
||||
# this 'repairs' dead ends. horrible stuff.
|
||||
for x in range(1, self.size - 1):
|
||||
for y in range(1, self.size - 1):
|
||||
s = 0
|
||||
s += self.field[x - 1, y]
|
||||
s += self.field[x + 1, y]
|
||||
s += self.field[x, y - 1]
|
||||
s += self.field[x, y + 1]
|
||||
if self.field[x, y] == enc.SPACE and s >= 3:
|
||||
self.field[x - 1, y] = enc.SPACE
|
||||
self.field[x + 1, y] = enc.SPACE
|
||||
self.field[x, y - 1] = enc.SPACE
|
||||
self.field[x, y + 1] = enc.SPACE
|
||||
|
||||
spaces = np.where(self.field == enc.SPACE)
|
||||
n_spaces = len(spaces[0])
|
||||
|
||||
n_danger = self.rng.randint(3, 7)
|
||||
dangers = self.rng.choice(range(n_spaces), n_danger, replace=False)
|
||||
for di in dangers:
|
||||
rx, ry = np.unravel_index(di, (self.size, self.size))
|
||||
const = max(1., self.rng.randint(self.size // 5, self.size // 2))
|
||||
for x in range(self.size):
|
||||
for y in range(self.size):
|
||||
distance = np.sqrt((rx - x) ** 2 + (ry - y) ** 2)
|
||||
self.costs[x, y] = self.costs[x, y] + (1. / (const + distance))
|
||||
|
||||
self.costs = self.costs - self.costs.min()
|
||||
self.costs = self.costs / self.costs.max()
|
||||
self.costs = self.costs * 9
|
||||
self.costs = self.costs + 1
|
||||
self.costs = self.costs.astype(int)
|
||||
|
||||
start_choice = 0
|
||||
end_choice = -1
|
||||
|
||||
self.start = (int(spaces[0][start_choice]), int(spaces[1][start_choice]))
|
||||
self.end = (int(spaces[0][end_choice]), int(spaces[1][end_choice]))
|
||||
|
||||
if self.start == self.end:
|
||||
raise RuntimeError('should never happen')
|
||||
|
||||
def subdivide(self, current, costs, k, d, previous_door):
|
||||
w, h = current.shape
|
||||
random_stop = self.rng.randint(0, 10) == 0 and d > 2
|
||||
if w <= 2 * k + 1 or h <= 2 * k + 1 or random_stop:
|
||||
return
|
||||
|
||||
split = previous_door
|
||||
while split == previous_door:
|
||||
split = self.rng.randint(k, w - k)
|
||||
current[split, :] = enc.WALL
|
||||
door = self.rng.randint(k, h - k)
|
||||
current[split, door] = enc.SPACE
|
||||
|
||||
self.subdivide(
|
||||
current[:split, :].T,
|
||||
costs[:split, :].T,
|
||||
k,
|
||||
d + 1,
|
||||
door
|
||||
)
|
||||
self.subdivide(
|
||||
current[split + 1:, :].T,
|
||||
costs[split + 1:, :].T,
|
||||
k,
|
||||
d + 1,
|
||||
door
|
||||
)
|
||||
|
||||
|
||||
class Simple2DProblem(Problem):
|
||||
"""
|
||||
the states are the positions on the board that the agent can walk on
|
||||
"""
|
||||
|
||||
ACTIONS_DELTA = OrderedDict([
|
||||
('R', (+1, 0)),
|
||||
('U', (0, -1)),
|
||||
('D', (0, +1)),
|
||||
('L', (-1, 0)),
|
||||
])
|
||||
|
||||
def __init__(self, board, costs, start, end):
|
||||
self.board = board
|
||||
self.costs = costs
|
||||
self.start_state = start
|
||||
self.end_state = end
|
||||
self.n_expands = 0
|
||||
|
||||
def get_start_node(self):
|
||||
return Node(None, self.start_state, None, 0, 0)
|
||||
|
||||
def get_end_node(self):
|
||||
return Node(None, self.end_state, None, 0, 0)
|
||||
|
||||
def is_end(self, node):
|
||||
return node.state == self.end_state
|
||||
|
||||
def action_cost(self, state, action):
|
||||
# for the MazeProblem, the cost of any action
|
||||
# is stored at the coordinates of the successor state,
|
||||
# and represents the cost of 'stepping onto' this
|
||||
# position on the board
|
||||
sx, sy = self.__delta_state(state, action)
|
||||
return self.costs[sx, sy]
|
||||
|
||||
def successor(self, node, action):
|
||||
# determine the next state
|
||||
successor_state = self.__delta_state(node.state, action)
|
||||
if successor_state is None:
|
||||
return None
|
||||
|
||||
# determine what it would cost to take this action in this state
|
||||
cost = self.action_cost(node.state, action)
|
||||
|
||||
# add the next state to the list of successor nodes
|
||||
return Node(
|
||||
node,
|
||||
successor_state,
|
||||
action,
|
||||
node.cost + cost,
|
||||
node.depth + 1
|
||||
)
|
||||
|
||||
def get_number_of_expanded_nodes(self):
|
||||
return self.n_expands
|
||||
|
||||
def reset(self):
|
||||
self.n_expands = 0
|
||||
|
||||
def successors(self, node):
|
||||
self.n_expands += 1
|
||||
successor_nodes = []
|
||||
for action in self.ACTIONS_DELTA.keys():
|
||||
succ = self.successor(node, action)
|
||||
if succ is not None and succ != node:
|
||||
successor_nodes.append(succ)
|
||||
return successor_nodes
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(dict(
|
||||
type=self.__class__.__name__,
|
||||
board=self.board.tolist(),
|
||||
costs=self.costs.tolist(),
|
||||
start_state=self.start_state,
|
||||
end_state=self.end_state
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def draw_nodes(fig, ax, name, node_collection, color, marker):
|
||||
states = np.array([node.state for node in node_collection])
|
||||
if len(states) > 0:
|
||||
ax.scatter(states[:, 0], states[:, 1], color=color, label=name, marker=marker)
|
||||
|
||||
@staticmethod
|
||||
def plot_nodes(fig, ax, nodes):
|
||||
if len(nodes) > 0:
|
||||
if len(nodes[0]) == 3:
|
||||
for (name, marker, node_collection), color in zip(nodes, TABLEAU_COLORS):
|
||||
if len(node_collection) > 0:
|
||||
Simple2DProblem.draw_nodes(fig, ax, name, node_collection, color, marker)
|
||||
else:
|
||||
for name, marker, node_collection, color in nodes:
|
||||
if len(node_collection) > 0:
|
||||
Simple2DProblem.draw_nodes(fig, ax, name, node_collection, color, marker)
|
||||
|
||||
ax.legend(
|
||||
bbox_to_anchor=(0.5, -0.03),
|
||||
loc='upper center',
|
||||
)
|
||||
|
||||
def plot_sequences(self, fig, ax, sequences):
|
||||
start_node = self.get_start_node()
|
||||
for (name, action_sequence), color in zip(sequences, XKCD_COLORS):
|
||||
self.draw_path(fig, ax, name, start_node, action_sequence, color)
|
||||
|
||||
ax.legend(
|
||||
bbox_to_anchor=(0.5, -0.03),
|
||||
loc='upper center',
|
||||
)
|
||||
|
||||
|
||||
def draw_path(self, fig, ax, name, start_node, action_sequence, color):
|
||||
current = start_node
|
||||
xs = [current.state[0]]
|
||||
ys = [current.state[1]]
|
||||
us = [0]
|
||||
vs = [0]
|
||||
|
||||
length = len(action_sequence)
|
||||
cost = 0
|
||||
costs = [0] * length
|
||||
for i, action in enumerate(action_sequence):
|
||||
costs[i] = current.cost
|
||||
xs.append(current.state[0])
|
||||
ys.append(current.state[1])
|
||||
current = self.successor(current, action)
|
||||
dx, dy = self.ACTIONS_DELTA[action]
|
||||
us.append(dx)
|
||||
vs.append(-dy)
|
||||
cost = current.cost
|
||||
|
||||
quiv = ax.quiver(
|
||||
xs, ys, us, vs,
|
||||
color=color,
|
||||
label='{} l:{} c:{}'.format(name, length, cost),
|
||||
scale_units='xy',
|
||||
units='xy',
|
||||
scale=1,
|
||||
headwidth=1,
|
||||
headlength=1,
|
||||
linewidth=1,
|
||||
picker=5
|
||||
)
|
||||
return quiv
|
||||
|
||||
def plot_field_and_costs_aux(self, fig, show_coordinates, show_grid,
|
||||
field_ax=None, costs_ax=None):
|
||||
|
||||
if field_ax is None:
|
||||
ax = field_ax = plt.subplot(121)
|
||||
else:
|
||||
ax = field_ax
|
||||
|
||||
ax.set_title('The field')
|
||||
im = ax.imshow(self.board.T, cmap='gray_r')
|
||||
|
||||
divider = make_axes_locatable(ax)
|
||||
cax = divider.append_axes('right', size='5%', pad=0)
|
||||
cbar = fig.colorbar(im, cax=cax, orientation='vertical')
|
||||
cbar.set_ticks([0, 1])
|
||||
cbar.set_ticklabels([0, 1])
|
||||
|
||||
if costs_ax is None:
|
||||
ax = costs_ax = plt.subplot(122, sharex=ax, sharey=ax)
|
||||
else:
|
||||
ax = costs_ax
|
||||
|
||||
ax.set_title('The costs (for stepping on a tile)')
|
||||
im = ax.imshow(self.costs.T, cmap='viridis')
|
||||
divider = make_axes_locatable(ax)
|
||||
cax = divider.append_axes('right', size='5%', pad=0)
|
||||
cbar = fig.colorbar(im, cax=cax, orientation='vertical')
|
||||
ticks = np.arange(self.costs.min(), self.costs.max() + 1)
|
||||
cbar.set_ticks(ticks)
|
||||
cbar.set_ticklabels(ticks)
|
||||
|
||||
for ax in [field_ax, costs_ax]:
|
||||
ax.tick_params(
|
||||
top=show_coordinates,
|
||||
left=show_coordinates,
|
||||
labelleft=show_coordinates,
|
||||
labeltop=show_coordinates,
|
||||
right=False,
|
||||
bottom=False,
|
||||
labelbottom=False
|
||||
)
|
||||
|
||||
# Major ticks
|
||||
s = self.board.shape[0]
|
||||
ax.set_xticks(np.arange(0, s, 1))
|
||||
ax.set_yticks(np.arange(0, s, 1))
|
||||
|
||||
# Minor ticks
|
||||
ax.set_xticks(np.arange(-.5, s, 1), minor=True)
|
||||
ax.set_yticks(np.arange(-.5, s, 1), minor=True)
|
||||
|
||||
if show_grid:
|
||||
for color, ax in zip(['m', 'w'], [field_ax, costs_ax]):
|
||||
# Gridlines based on minor ticks
|
||||
ax.grid(which='minor', color=color, linestyle='-', linewidth=1)
|
||||
|
||||
return field_ax, costs_ax
|
||||
|
||||
def visualize(self, sequences=None, show_coordinates=False, show_grid=False, plot_filename=None):
|
||||
|
||||
nodes = [
|
||||
('start', 'o', [self.get_start_node()]),
|
||||
('end', 'o', [self.get_end_node()])
|
||||
]
|
||||
|
||||
fig = plt.figure(figsize=(10, 7))
|
||||
field_ax, costs_ax = self.plot_field_and_costs_aux(fig, show_coordinates, show_grid)
|
||||
if sequences is not None and len(sequences) > 0:
|
||||
self.plot_sequences(fig, field_ax, sequences)
|
||||
self.plot_sequences(fig, costs_ax, sequences)
|
||||
|
||||
if nodes is not None and len(nodes) > 0:
|
||||
Simple2DProblem.plot_nodes(fig, field_ax, nodes)
|
||||
|
||||
plt.tight_layout()
|
||||
if plot_filename is not None:
|
||||
plt.savefig(plot_filename)
|
||||
plt.close(fig)
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_json(jsonstring):
|
||||
data = json.loads(jsonstring)
|
||||
return Simple2DProblem(
|
||||
np.array(data['board']),
|
||||
np.array(data['costs']),
|
||||
tuple(data['start_state']),
|
||||
tuple(data['end_state'])
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data):
|
||||
return Simple2DProblem(
|
||||
np.array(data['board']),
|
||||
np.array(data['costs']),
|
||||
tuple(data['start_state']),
|
||||
tuple(data['end_state'])
|
||||
)
|
||||
|
||||
def __delta_state(self, state, action):
|
||||
# the old state's coordinates
|
||||
x, y = state
|
||||
|
||||
# the deltas for each coordinates
|
||||
dx, dy = self.ACTIONS_DELTA[action]
|
||||
|
||||
# compute the coordinates of the next state
|
||||
sx = x + dx
|
||||
sy = y + dy
|
||||
|
||||
if self.__on_board(sx, sy) and self.__walkable(sx, sy):
|
||||
# (sx, sy) is a *valid* state if it is on the board
|
||||
# and there is no wall where we want to go
|
||||
return sx, sy
|
||||
else:
|
||||
# EIEIEIEIEI. up until assignment 1, this returned None :/
|
||||
# this had no consequences on the correctness of the algorithms,
|
||||
# but the explanations, and the self-edges were wrong
|
||||
return x, y
|
||||
|
||||
def __on_board(self, x, y):
|
||||
size = len(self.board) # all boards are quadratic
|
||||
return x >= 0 and x < size and y >= 0 and y < size
|
||||
|
||||
def __walkable(self, x, y):
|
||||
return self.board[x, y] != enc.WALL
|
||||
15
shell.nix
Normal file
15
shell.nix
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
pkgs ? import <nixpkgs> { },
|
||||
}:
|
||||
|
||||
pkgs.mkShell {
|
||||
buildInputs = with pkgs; [
|
||||
python3
|
||||
python3Packages.notebook
|
||||
python3Packages.numpy
|
||||
python3Packages.matplotlib
|
||||
graphviz
|
||||
python3Packages.networkx
|
||||
python3Packages.pydot
|
||||
];
|
||||
}
|
||||
1797
uninformed_search.ipynb
Normal file
1797
uninformed_search.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user