initial commit
This commit is contained in:
576
.ipynb_checkpoints/introducing_pig-checkpoint.ipynb
Normal file
576
.ipynb_checkpoints/introducing_pig-checkpoint.ipynb
Normal file
File diff suppressed because one or more lines are too long
17
boards/tiny0.json
Normal file
17
boards/tiny0.json
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
{"type": "Simple2DProblem",
|
||||||
|
"board": [
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
[0, 0, 0, 0, 0],
|
||||||
|
[0, 1, 0, 1, 0],
|
||||||
|
[0, 1, 0, 0, 0]
|
||||||
|
],
|
||||||
|
"costs": [
|
||||||
|
[4, 4, 3, 1, 0],
|
||||||
|
[3, 4, 2, 2, 2],
|
||||||
|
[4, 3, 3, 4, 0],
|
||||||
|
[4, 4, 3, 3, 0],
|
||||||
|
[4, 3, 1, 1, 2]
|
||||||
|
],
|
||||||
|
"start_state": [0, 0],
|
||||||
|
"end_state": [4, 4]}
|
||||||
576
introducing_pig.ipynb
Normal file
576
introducing_pig.ipynb
Normal file
File diff suppressed because one or more lines are too long
1
pig_lite/.gitignore
vendored
Normal file
1
pig_lite/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__pycache__
|
||||||
8
pig_lite/.idea/.gitignore
generated
vendored
Normal file
8
pig_lite/.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
24
pig_lite/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
24
pig_lite/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredPackages">
|
||||||
|
<value>
|
||||||
|
<list size="11">
|
||||||
|
<item index="0" class="java.lang.String" itemvalue="jupyter" />
|
||||||
|
<item index="1" class="java.lang.String" itemvalue="umap-learn" />
|
||||||
|
<item index="2" class="java.lang.String" itemvalue="matplotlib" />
|
||||||
|
<item index="3" class="java.lang.String" itemvalue="numpy" />
|
||||||
|
<item index="4" class="java.lang.String" itemvalue="tqdm" />
|
||||||
|
<item index="5" class="java.lang.String" itemvalue="seaborn" />
|
||||||
|
<item index="6" class="java.lang.String" itemvalue="captum" />
|
||||||
|
<item index="7" class="java.lang.String" itemvalue="upsilonconf" />
|
||||||
|
<item index="8" class="java.lang.String" itemvalue="pytorch" />
|
||||||
|
<item index="9" class="java.lang.String" itemvalue="torchvision" />
|
||||||
|
<item index="10" class="java.lang.String" itemvalue="scipy" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
||||||
6
pig_lite/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
pig_lite/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
7
pig_lite/.idea/misc.xml
generated
Normal file
7
pig_lite/.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.10" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
8
pig_lite/.idea/modules.xml
generated
Normal file
8
pig_lite/.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/pig_lite.iml" filepath="$PROJECT_DIR$/.idea/pig_lite.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
12
pig_lite/.idea/pig_lite.iml
generated
Normal file
12
pig_lite/.idea/pig_lite.iml
generated
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="format" value="PLAIN" />
|
||||||
|
<option name="myDocStringFormat" value="Plain" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
6
pig_lite/.idea/vcs.xml
generated
Normal file
6
pig_lite/.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
3
pig_lite/README.md
Normal file
3
pig_lite/README.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# pig_lite
|
||||||
|
|
||||||
|
This is PIG (=Problem Instance Generator) Lite, a simplified and cleaned up version of the framework previously used for the AI assignments.
|
||||||
0
pig_lite/bayesian_net/__init__.py
Normal file
0
pig_lite/bayesian_net/__init__.py
Normal file
154
pig_lite/bayesian_net/bayesian_net.py
Normal file
154
pig_lite/bayesian_net/bayesian_net.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
import matplotlib
|
||||||
|
import numpy as np
|
||||||
|
import networkx as nx
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
#matplotlib.use('TkAgg')
|
||||||
|
|
||||||
|
|
||||||
|
class BayesianNode:
|
||||||
|
""" Building stone for BayesianNet class. Represents conditional probability distribution
|
||||||
|
for a boolean random variable, P(X | parents). """
|
||||||
|
def __init__(self, X: str, parents: str, cpt: dict = None):
|
||||||
|
"""
|
||||||
|
X: String describing variable name
|
||||||
|
|
||||||
|
parents: String containing parent variable names, separated with a whitespace
|
||||||
|
|
||||||
|
cpt: dict that contains the distribution P(X=true | parent1=v1, parent2=v2...).
|
||||||
|
Dict should be structured as follows: {(v1, v2, ...): p, ...}, and each key must have
|
||||||
|
as many values as there are parents. Values (v1, v2, ...) must be True/False.
|
||||||
|
"""
|
||||||
|
if not isinstance(X, str) or not isinstance(parents, str):
|
||||||
|
raise ValueError("Use valid arguments - X and parents have to be strings (but at least one is not)!")
|
||||||
|
self.rand_var = X
|
||||||
|
self.parents = parents.split()
|
||||||
|
self.children = []
|
||||||
|
|
||||||
|
# in case of 0 or 1 parent, fix tuples first
|
||||||
|
if cpt and isinstance(cpt, (float, int)):
|
||||||
|
cpt = {(): cpt}
|
||||||
|
elif cpt and isinstance(cpt, dict):
|
||||||
|
if isinstance(list(cpt.keys())[0], bool):
|
||||||
|
# only one parent
|
||||||
|
cpt = {(k, ): v for k, v in cpt.items()}
|
||||||
|
elif cpt:
|
||||||
|
raise ValueError("Define cpt with a valid data type (dict, or int).")
|
||||||
|
# check format of cpt dict
|
||||||
|
if cpt:
|
||||||
|
for val, p in cpt.items():
|
||||||
|
assert isinstance(val, tuple) and len(val) == len(self.parents)
|
||||||
|
assert all(isinstance(v, bool) for v in val)
|
||||||
|
assert 0 <= p <= 1
|
||||||
|
|
||||||
|
self.cpt = cpt
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
""" String representation of Bayesian Node. """
|
||||||
|
return repr((self.rand_var, ' '.join(["parent(s):"] + self.parents)))
|
||||||
|
|
||||||
|
def cond_probability(self, value: bool, event: dict):
|
||||||
|
"""
|
||||||
|
Returns conditional probability P(X=value | event) for an atomic event,
|
||||||
|
i.e. where each parent needs to be assigned a value.
|
||||||
|
value: bool (value of this random variable)
|
||||||
|
event: dict, assigning a value to each parent variable
|
||||||
|
"""
|
||||||
|
assert isinstance(value, bool)
|
||||||
|
if self.cpt:
|
||||||
|
prob_true = self.cpt[self.get_event_values(event)]
|
||||||
|
return prob_true if value else 1 - prob_true
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_event_values(self, event: dict):
|
||||||
|
""" Given an event (dict), returns tuple of values for all parents. """
|
||||||
|
return tuple(event[p] for p in self.parents)
|
||||||
|
|
||||||
|
|
||||||
|
class BayesianNet:
|
||||||
|
""" Bayesian Network class for boolean random variables. Consists of BayesianNode-s. """
|
||||||
|
def __init__(self, node_specs: list):
|
||||||
|
"""
|
||||||
|
Creates BayesianNet with given node_specs. Nodes should be in causal order (parents before children).
|
||||||
|
node_specs should be list of parameters for BayesianNode class.
|
||||||
|
"""
|
||||||
|
self.nodes = []
|
||||||
|
self.rand_vars = []
|
||||||
|
for spec in node_specs:
|
||||||
|
self.add_node(spec)
|
||||||
|
|
||||||
|
def add_node(self, node_spec):
|
||||||
|
""" Creates a BayesianNode and adds it to the net, if the variable does *not*, and the parents do exist. """
|
||||||
|
node = BayesianNode(*node_spec)
|
||||||
|
if node.rand_var in self.rand_vars:
|
||||||
|
raise ValueError("Variable {} already exists in network, cannot be defined twice!".format(node.rand_var))
|
||||||
|
if not all((parent in self.rand_vars) for parent in node.parents):
|
||||||
|
raise ValueError("Parents do not all exist yet! Make sure to first add all parent nodes.")
|
||||||
|
self.nodes.append(node)
|
||||||
|
self.rand_vars.append(node.rand_var)
|
||||||
|
for parent in node.parents:
|
||||||
|
self.get_node_for_name(parent).children.append(node)
|
||||||
|
|
||||||
|
def get_node_for_name(self, node_name):
|
||||||
|
""" Given the name of a random variable, returns the according BayesianNode of this network. """
|
||||||
|
for n in self.nodes:
|
||||||
|
if n.rand_var == node_name:
|
||||||
|
return n
|
||||||
|
|
||||||
|
raise ValueError("The variable {} does not exist in this network!".format(node_name))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
""" String representation of this Bayesian Network. """
|
||||||
|
return "BayesianNet:\n{0!r}".format(self.nodes)
|
||||||
|
|
||||||
|
def _get_depth(self, rand_var):
|
||||||
|
""" Given random variable, returns "depth" of node in graph for plotting. """
|
||||||
|
node = self.get_node_for_name(rand_var)
|
||||||
|
if len(node.parents) == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return max([self._get_depth(p) for p in node.parents]) + 1
|
||||||
|
|
||||||
|
def draw(self, title, save_path=None):
|
||||||
|
""" Draws the BN with networkx. Requires title for plot. """
|
||||||
|
plt.figure(figsize=(14, 8))
|
||||||
|
nx_bn = nx.DiGraph()
|
||||||
|
nx_bn.add_nodes_from(self.rand_vars)
|
||||||
|
pos = {rand_var: (10, 10) for rand_var in self.rand_vars}
|
||||||
|
for rand_var in self.rand_vars:
|
||||||
|
node = self.get_node_for_name(rand_var)
|
||||||
|
for c in node.children:
|
||||||
|
nx_bn.add_edge(rand_var, c.rand_var)
|
||||||
|
pos.update({c.rand_var: (pos[c.rand_var][0], pos[c.rand_var][1] - 3)})
|
||||||
|
|
||||||
|
depths = {rand_var: self._get_depth(rand_var) for rand_var in self.rand_vars}
|
||||||
|
_, counts = np.unique(list(depths.values()), return_counts=True)
|
||||||
|
xs = [list(np.linspace(6, 14, c)) if c > 1 else [10] for c in counts]
|
||||||
|
pos = {rand_var: (xs[depths[rand_var]].pop(), 10 - depths[rand_var] * 3) for rand_var in self.rand_vars}
|
||||||
|
|
||||||
|
nx.set_node_attributes(nx_bn, pos, 'pos')
|
||||||
|
nx.draw_networkx(nx_bn, arrows=True, pos=nx.get_node_attributes(nx_bn, "pos"),
|
||||||
|
node_shape="o", node_color="white", node_size=7000, edgecolors="gray")
|
||||||
|
plt.title(title)
|
||||||
|
plt.box(False)
|
||||||
|
plt.margins(0.3)
|
||||||
|
plt.tight_layout()
|
||||||
|
if save_path:
|
||||||
|
plt.savefig(save_path, dpi=400)
|
||||||
|
else:
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
T = True
|
||||||
|
F = False
|
||||||
|
bn = BayesianNet([
|
||||||
|
('Burglary', '', 0.001),
|
||||||
|
('Earthquake', '', {(): 0.002}),
|
||||||
|
('Alarm', 'Burglary Earthquake',
|
||||||
|
{(T, T): 0.95, (T, F): 0.94, (F, T): 0.29, (F, F): 0.001}),
|
||||||
|
('JohnCalls', 'Alarm', {T: 0.90, F: 0.05}),
|
||||||
|
('MaryCalls', 'Alarm', {T: 0.70, F: 0.01})
|
||||||
|
])
|
||||||
|
bn.draw("")
|
||||||
0
pig_lite/datastructures/__init__.py
Normal file
0
pig_lite/datastructures/__init__.py
Normal file
61
pig_lite/datastructures/priority_queue.py
Normal file
61
pig_lite/datastructures/priority_queue.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import heapq
|
||||||
|
from functools import total_ordering
|
||||||
|
|
||||||
|
|
||||||
|
# this annotation saves us some implementation work
|
||||||
|
@total_ordering
|
||||||
|
class Item(object):
|
||||||
|
def __init__(self, insertion, priority, value):
|
||||||
|
self.insertion = insertion
|
||||||
|
self.priority = priority
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
# if the decision "self < other" can be done
|
||||||
|
# based on the priority, do that
|
||||||
|
if self.priority < other.priority:
|
||||||
|
return True
|
||||||
|
elif self.priority == other.priority:
|
||||||
|
# in case the priorities are equal, we
|
||||||
|
# fall back on the insertion order,
|
||||||
|
# which establishes a total ordering
|
||||||
|
return self.insertion < other.insertion
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.priority == other.priority and self.insertion == other.insertion
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '({}, {}, {})'.format(self.priority, self.insertion, self.value)
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityQueue(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.insertion = 0
|
||||||
|
self.heap = []
|
||||||
|
|
||||||
|
def has_elements(self):
|
||||||
|
return len(self.heap) > 0
|
||||||
|
|
||||||
|
def put(self, priority, value):
|
||||||
|
heapq.heappush(self.heap, Item(self.insertion, priority, value))
|
||||||
|
self.insertion += 1
|
||||||
|
|
||||||
|
def get(self, include_priority=False):
|
||||||
|
item = heapq.heappop(self.heap)
|
||||||
|
if include_priority:
|
||||||
|
return item.priority, item.value
|
||||||
|
else:
|
||||||
|
return item.value
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter([item.value for item in self.heap])
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ('PriorityQueue [' + ','.join((str(item.value) for item in self.heap)) + ']')
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.heap)
|
||||||
27
pig_lite/datastructures/queue.py
Normal file
27
pig_lite/datastructures/queue.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class Queue(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.d = deque()
|
||||||
|
|
||||||
|
def put(self, v):
|
||||||
|
self.d.append(v)
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
return self.d.popleft()
|
||||||
|
|
||||||
|
def has_elements(self):
|
||||||
|
return len(self.d) > 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.d)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ('Queue [' + ','.join((str(item) for item in self.d)) + ']')
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.d)
|
||||||
21
pig_lite/datastructures/stack.py
Normal file
21
pig_lite/datastructures/stack.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class Stack(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.d = deque()
|
||||||
|
|
||||||
|
def put(self, v):
|
||||||
|
self.d.append(v)
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
return self.d.pop()
|
||||||
|
|
||||||
|
def has_elements(self):
|
||||||
|
return len(self.d) > 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.d)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return ('Stack [' + ','.join((str(item) for item in self.d)) + ']')
|
||||||
61
pig_lite/decision_tree/dt_base.py
Normal file
61
pig_lite/decision_tree/dt_base.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from pig_lite.decision_tree.dt_node import DecisionTreeNodeBase
|
||||||
|
import scipy.stats as stats
|
||||||
|
|
||||||
|
def entropy(y: list):
|
||||||
|
"""
|
||||||
|
Compute the entropy of a binary label distribution.
|
||||||
|
|
||||||
|
This function calculates the entropy of a binary classification label list `y` as a wrapper
|
||||||
|
around `scipy.stats.entropy`. It assumes the labels are binary (0 or 1) and computes the
|
||||||
|
proportion of positive labels (1s) to calculate the entropy.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
y : list
|
||||||
|
A list of binary labels (0 or 1).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
The entropy of the label distribution. If the list is empty, returns 0.0.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
- Entropy is calculated using the formula:
|
||||||
|
H = -p*log2(p) - (1-p)*log2(1-p)
|
||||||
|
where `p` is the proportion of positive labels (1s).
|
||||||
|
- If `y` is empty, entropy is defined as 0.0.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> entropy([0, 0, 1, 1])
|
||||||
|
1.0
|
||||||
|
|
||||||
|
>>> entropy([1, 1, 1, 1])
|
||||||
|
0.0
|
||||||
|
|
||||||
|
>>> entropy([])
|
||||||
|
0.0
|
||||||
|
"""
|
||||||
|
if len(y) == 0: return 0.0
|
||||||
|
positive = sum(y) / len(y)
|
||||||
|
return stats.entropy([positive, 1 - positive], base=2)
|
||||||
|
|
||||||
|
# these two dummy classes are only used so we can import them and load trees from a pickle file before they are implemented by the students
|
||||||
|
class DecisionTree():
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_height(self, node):
|
||||||
|
if node is None:
|
||||||
|
return 0
|
||||||
|
return max(self.get_height(node.left_child), self.get_height(node.right_child)) + 1
|
||||||
|
|
||||||
|
def print(self):
|
||||||
|
if self.root is not None:
|
||||||
|
height = self.get_height(self.root)
|
||||||
|
self.root.print_tree(height)
|
||||||
|
|
||||||
|
class DecisionTreeNode(DecisionTreeNodeBase):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
70
pig_lite/decision_tree/dt_node.py
Normal file
70
pig_lite/decision_tree/dt_node.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from pig_lite.datastructures.queue import Queue
|
||||||
|
|
||||||
|
class DecisionTreeNodeBase():
|
||||||
|
def __init__(self):
|
||||||
|
self.label = None
|
||||||
|
self.split_point = None
|
||||||
|
self.split_feature = None
|
||||||
|
self.left_child = None
|
||||||
|
self.right_child = None
|
||||||
|
|
||||||
|
def print_node(self, height, level=1):
|
||||||
|
node_width = 10
|
||||||
|
n_spaces = 2 ** (height - level - 1) * node_width - node_width // 2
|
||||||
|
if n_spaces > 0:
|
||||||
|
text = " " * n_spaces
|
||||||
|
else:
|
||||||
|
text = ""
|
||||||
|
|
||||||
|
if self.label is None and self.split_feature is None:
|
||||||
|
return f"{text} {text}"
|
||||||
|
|
||||||
|
if self.label is not None:
|
||||||
|
text = f"{text}( {self.label} ){text}"
|
||||||
|
elif self.split_feature is not None:
|
||||||
|
text_snippet = f"(x{self.split_feature}:{self.split_point:.2f})"
|
||||||
|
if len(text_snippet) != node_width:
|
||||||
|
text_snippet = f" {text_snippet}"
|
||||||
|
text = f"{text}{text_snippet}{text}"
|
||||||
|
return text
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if self.label is not None: return f"({self.label})"
|
||||||
|
|
||||||
|
str_value = f"{self.split_feature}:{self.split_point:.2f}|{self.left_child}{self.right_child}"
|
||||||
|
return str_value
|
||||||
|
|
||||||
|
def print_tree(self, height):
|
||||||
|
visited = set()
|
||||||
|
frontier = Queue()
|
||||||
|
|
||||||
|
lines = ['']
|
||||||
|
|
||||||
|
previous_level = 1
|
||||||
|
frontier.put((self, 1))
|
||||||
|
|
||||||
|
while frontier.has_elements():
|
||||||
|
current, level = frontier.get()
|
||||||
|
if level > previous_level:
|
||||||
|
lines.append('')
|
||||||
|
previous_level = level
|
||||||
|
lines[-1] += current.print_node(height, level)
|
||||||
|
if current not in visited:
|
||||||
|
visited.add(current)
|
||||||
|
if current.left_child is not None:
|
||||||
|
frontier.put((current.left_child, level + 1))
|
||||||
|
else:
|
||||||
|
if level < height: frontier.put((DecisionTreeNodeBase(), level + 1))
|
||||||
|
if current.right_child is not None:
|
||||||
|
frontier.put((current.right_child, level + 1))
|
||||||
|
else:
|
||||||
|
if level < height: frontier.put((DecisionTreeNodeBase(), level + 1))
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
print(line)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def split():
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
84
pig_lite/decision_tree/training_set.py
Normal file
84
pig_lite/decision_tree/training_set.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
class TrainingSet():
|
||||||
|
def __init__(self, X, y):
|
||||||
|
self.X = X
|
||||||
|
self.y = y
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
return json.dumps(dict(
|
||||||
|
type=self.__class__.__name__,
|
||||||
|
X=self.X.tolist(),
|
||||||
|
y=self.y.tolist()
|
||||||
|
))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(jsonstring):
|
||||||
|
data = json.loads(jsonstring)
|
||||||
|
return TrainingSet.from_dict(data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(data):
|
||||||
|
return TrainingSet(
|
||||||
|
np.array(data['X']).squeeze(),
|
||||||
|
np.array(data['y'])
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_node_boundaries(self, node, limit_left, limit_right, limit_top, limit_bottom, max_depth, level=1):
|
||||||
|
|
||||||
|
split_point = node.split_point
|
||||||
|
limit_left_updated = limit_left
|
||||||
|
limit_right_updated = limit_right
|
||||||
|
limit_top_updated = limit_top
|
||||||
|
limit_bottom_updated = limit_bottom
|
||||||
|
|
||||||
|
if node.split_feature == 0:
|
||||||
|
if limit_bottom == limit_top:
|
||||||
|
warnings.warn('limit_bottom equals limit_top; extending by 0.1')
|
||||||
|
plt.plot([split_point, split_point], [limit_bottom - 0.1, limit_top + 0.1], color="purple", alpha=1 / level)
|
||||||
|
else:
|
||||||
|
plt.plot([split_point, split_point], [limit_bottom, limit_top], color="purple", alpha=1 / level)
|
||||||
|
limit_left_updated = split_point
|
||||||
|
limit_right_updated = split_point
|
||||||
|
|
||||||
|
else:
|
||||||
|
if limit_left == limit_right:
|
||||||
|
warnings.warn('limit_left equals limit_right; extending by 0.1')
|
||||||
|
plt.plot([limit_left - 0.1, limit_right + 0.1], [split_point, split_point], color="purple", alpha=1 / level)
|
||||||
|
else:
|
||||||
|
plt.plot([limit_left, limit_right], [split_point, split_point], color="purple", alpha=1 / level)
|
||||||
|
limit_top_updated = split_point
|
||||||
|
limit_bottom_updated = split_point
|
||||||
|
|
||||||
|
if level == max_depth:
|
||||||
|
return
|
||||||
|
if node.left_child is not None: self.plot_node_boundaries(node.left_child, limit_left, limit_right_updated,
|
||||||
|
limit_top_updated, limit_bottom, max_depth, level + 1)
|
||||||
|
if node.right_child is not None: self.plot_node_boundaries(node.right_child, limit_left_updated, limit_right, limit_top,
|
||||||
|
limit_bottom_updated, max_depth, level + 1)
|
||||||
|
|
||||||
|
def visualize(self, tree=None, max_height=None):
|
||||||
|
symbols = [["x", "o"][index] for index in self.y]
|
||||||
|
for y in set(self.y):
|
||||||
|
X = self.X[self.y == y, :]
|
||||||
|
plt.scatter(X[:, 0], X[:, 1],
|
||||||
|
color=["red", "blue"][y],
|
||||||
|
marker=symbols[y],
|
||||||
|
label="class: {}".format(y))
|
||||||
|
|
||||||
|
if tree is not None:
|
||||||
|
tree_height = tree.get_height(tree.root)
|
||||||
|
if max_height is None or max_height > tree_height:
|
||||||
|
max_height = tree_height
|
||||||
|
self.plot_node_boundaries(tree.root,
|
||||||
|
limit_left=min(self.X[:, 0]),
|
||||||
|
limit_right=max(self.X[:, 0]),
|
||||||
|
limit_top=max(self.X[:, 1]),
|
||||||
|
limit_bottom=min(self.X[:, 1]),
|
||||||
|
max_depth=max_height) # TODO: make parameterizable
|
||||||
|
|
||||||
0
pig_lite/environment/__init__.py
Normal file
0
pig_lite/environment/__init__.py
Normal file
60
pig_lite/environment/base.py
Normal file
60
pig_lite/environment/base.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class Environment:
|
||||||
|
def step(self, action):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_n_actions(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_n_states(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_flat_policy(self, policy):
|
||||||
|
flat_policy = []
|
||||||
|
for state in range(self.get_n_states()):
|
||||||
|
for action in range(self.get_n_actions()):
|
||||||
|
flat_policy.append((state, action, policy[state, action]))
|
||||||
|
return flat_policy
|
||||||
|
|
||||||
|
def get_policy_hash(self, outcome):
|
||||||
|
flat_policy = self.get_flat_policy(outcome.policy)
|
||||||
|
flat_policy_as_str = ','.join(map(str, flat_policy))
|
||||||
|
flat_policy_hash = hashlib.sha256(flat_policy_as_str.encode('UTF-8')).hexdigest()
|
||||||
|
return flat_policy_hash
|
||||||
|
|
||||||
|
|
||||||
|
class Outcome:
|
||||||
|
def __init__(self, n_episodes, policy, V, Q):
|
||||||
|
self.n_episodes = n_episodes
|
||||||
|
self.policy = policy
|
||||||
|
self.V = V
|
||||||
|
self.Q = Q
|
||||||
|
|
||||||
|
def get_n_episodes(self):
|
||||||
|
return self.n_episodes
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
return json.dumps(dict(
|
||||||
|
type=self.__class__.__name__,
|
||||||
|
n_episodes=self.n_episodes,
|
||||||
|
policy=self.policy.tolist(),
|
||||||
|
V=self.V.tolist(),
|
||||||
|
Q=self.Q.tolist(),
|
||||||
|
))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(jsonstring):
|
||||||
|
data = json.loads(jsonstring)
|
||||||
|
return Outcome(
|
||||||
|
data['n_episodes'],
|
||||||
|
np.array(data['policy']),
|
||||||
|
np.array(data['V']),
|
||||||
|
np.array(data['Q'])
|
||||||
|
)
|
||||||
360
pig_lite/environment/gridworld.py
Normal file
360
pig_lite/environment/gridworld.py
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pig_lite.environment.base import Environment
|
||||||
|
|
||||||
|
DELTAS = [
|
||||||
|
(-1, 0),
|
||||||
|
(+1, 0),
|
||||||
|
(0, -1),
|
||||||
|
(0, +1)
|
||||||
|
]
|
||||||
|
NAMES = [
|
||||||
|
'left',
|
||||||
|
'right',
|
||||||
|
'up',
|
||||||
|
'down'
|
||||||
|
]
|
||||||
|
|
||||||
|
def sample(rng, elements):
|
||||||
|
""" Samples an element of `elements` randomly. """
|
||||||
|
csp = np.cumsum([elm[0] for elm in elements])
|
||||||
|
idx = np.argmax(csp > rng.uniform(0, 1))
|
||||||
|
return elements[idx]
|
||||||
|
|
||||||
|
|
||||||
|
class Gridworld(Environment):
|
||||||
|
def __init__(self, seed, dones, rewards, starts):
|
||||||
|
self.seed = seed
|
||||||
|
self.rng = np.random.RandomState(seed)
|
||||||
|
self.dones = dones
|
||||||
|
self.rewards = rewards
|
||||||
|
self.starts = starts
|
||||||
|
|
||||||
|
self.__compute_P()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
""" Resets the environment of this gridworld to a randomly sampled start state. """
|
||||||
|
_, self.state = sample(self.rng, self.starts)
|
||||||
|
return self.state
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
""" Performs the action on the gridworld, where next state of environment is sampled based on self.P. """
|
||||||
|
_, self.state, reward, done = sample(self.rng, self.P[self.state][action])
|
||||||
|
return self.state, reward, done
|
||||||
|
|
||||||
|
def get_n_actions(self):
|
||||||
|
""" Returns the number of actions available in this gridworld. """
|
||||||
|
return 4
|
||||||
|
|
||||||
|
def get_n_states(self):
|
||||||
|
""" Returns the number of states available in this gridworld. """
|
||||||
|
return np.prod(self.dones.shape)
|
||||||
|
|
||||||
|
def get_gamma(self):
|
||||||
|
""" Returns discount factor gamma for this gridworld. """
|
||||||
|
return 0.99
|
||||||
|
|
||||||
|
def __compute_P(self):
|
||||||
|
""" Computes and stores the transitions for this gridworld. """
|
||||||
|
w, h = self.dones.shape
|
||||||
|
|
||||||
|
def inbounds(i, j):
|
||||||
|
""" Checks whether coordinates i and j are within the grid. """
|
||||||
|
return i >= 0 and j >= 0 and i < w and j < h
|
||||||
|
|
||||||
|
self.P = dict()
|
||||||
|
for i in range(0, w):
|
||||||
|
for j in range(0, h):
|
||||||
|
state = j * w + i
|
||||||
|
self.P[state] = dict()
|
||||||
|
|
||||||
|
if self.dones[i, j]:
|
||||||
|
for action in range(self.get_n_actions()):
|
||||||
|
# make it absorbing
|
||||||
|
self.P[state][action] = [(1, state, 0, True)]
|
||||||
|
else:
|
||||||
|
for action, (dx, dy) in enumerate(DELTAS):
|
||||||
|
ortho_dir_probs = [
|
||||||
|
(0.8, dx, dy),
|
||||||
|
(0.1, dy, dx),
|
||||||
|
(0.1, -dy, -dx)
|
||||||
|
]
|
||||||
|
transitions = []
|
||||||
|
for p, di, dj in ortho_dir_probs:
|
||||||
|
ni = i + di
|
||||||
|
nj = j + dj
|
||||||
|
if inbounds(ni, nj):
|
||||||
|
# we move
|
||||||
|
sprime = nj * w + ni
|
||||||
|
done = self.dones[ni, nj]
|
||||||
|
reward = self.rewards[ni, nj]
|
||||||
|
transitions.append((p, sprime, reward, done))
|
||||||
|
else:
|
||||||
|
# stay in the same state, b/c we bounced
|
||||||
|
sprime = state
|
||||||
|
done = self.dones[i, j]
|
||||||
|
reward = self.rewards[i, j]
|
||||||
|
transitions.append((p, sprime, reward, done))
|
||||||
|
|
||||||
|
self.P[state][action] = transitions
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
""" Converts and stores this gridworld to a JSON file. """
|
||||||
|
return json.dumps(dict(
|
||||||
|
type=self.__class__.__name__,
|
||||||
|
seed=self.seed,
|
||||||
|
dones=self.dones.tolist(),
|
||||||
|
rewards=self.rewards.tolist(),
|
||||||
|
starts=self.starts.tolist()
|
||||||
|
))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(jsonstring):
|
||||||
|
""" Loads given JSON file, and creates gridworld with information. """
|
||||||
|
data = json.loads(jsonstring)
|
||||||
|
return Gridworld(
|
||||||
|
data['seed'],
|
||||||
|
np.array(data['dones']),
|
||||||
|
np.array(data['rewards']),
|
||||||
|
np.array(data['starts'], dtype=np.int64),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(data):
|
||||||
|
""" Creates gridworld with information in given data-dictionary. """
|
||||||
|
return Gridworld(
|
||||||
|
data['seed'],
|
||||||
|
np.array(data['dones']),
|
||||||
|
np.array(data['rewards']),
|
||||||
|
np.array(data['starts'], dtype=np.int64),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_random_instance(rng, size):
|
||||||
|
""" Given random generator and problem size, generates Gridworld instance. """
|
||||||
|
dones, rewards, starts = Gridworld.__generate(rng, size)
|
||||||
|
return Gridworld(rng.randint(0, 2 ** 31), dones, rewards, starts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __generate(rng, size):
|
||||||
|
""" Helper function that retrieves dones, rewards, starts for Gridworld instance generation. """
|
||||||
|
dones = np.full((size, size), False, dtype=bool)
|
||||||
|
rewards = np.zeros((size, size), dtype=np.int8) - 1
|
||||||
|
|
||||||
|
coordinates = []
|
||||||
|
for i in range(1, size - 1):
|
||||||
|
for j in range(1, size - 1):
|
||||||
|
coordinates.append((i, j))
|
||||||
|
indices = np.arange(len(coordinates))
|
||||||
|
|
||||||
|
chosen = rng.choice(indices, max(1, len(indices) // 10), replace=False)
|
||||||
|
|
||||||
|
for c in chosen:
|
||||||
|
x, y = coordinates[c]
|
||||||
|
dones[x, y] = True
|
||||||
|
rewards[x, y] = -100
|
||||||
|
|
||||||
|
starts = np.array([[1, 0]])
|
||||||
|
dones[-1, -1] = True
|
||||||
|
rewards[-1, -1] = 100
|
||||||
|
|
||||||
|
return dones, rewards, starts
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_minimum_problem_size():
|
||||||
|
return 3
|
||||||
|
|
||||||
|
def visualize(self, outcome, coords=None, grid=None):
|
||||||
|
""" Visualisation function for gridworld; plots environment, policy, Q. """
|
||||||
|
policy = None
|
||||||
|
Q = None
|
||||||
|
V = None
|
||||||
|
if outcome is not None:
|
||||||
|
if outcome.policy is not None:
|
||||||
|
policy = outcome.policy
|
||||||
|
|
||||||
|
if outcome.V is not None:
|
||||||
|
V = outcome.V
|
||||||
|
|
||||||
|
if outcome.Q is not None:
|
||||||
|
Q = outcome.Q
|
||||||
|
|
||||||
|
self._plot_environment_and_policy(policy, V, Q, show_coordinates=coords, show_grid=grid)
|
||||||
|
|
||||||
|
def _plot_environment_and_policy(self, policy=None,V=None, Q=None, show_coordinates=False,
|
||||||
|
show_grid=False, plot_filename=None, debug_info=False):
|
||||||
|
""" Function that plots environment and policy. """
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
|
||||||
|
dones_ax = axes[0, 0]
|
||||||
|
rewards_ax = axes[0, 1]
|
||||||
|
V_ax = axes[1, 0]
|
||||||
|
Q_ax = axes[1, 1]
|
||||||
|
|
||||||
|
dones_ax.set_title('Terminal States and Policy')
|
||||||
|
dones_ax.imshow(self.dones.T, cmap='gray_r', vmin=0, vmax=4)
|
||||||
|
|
||||||
|
rewards_ax.set_title('Immediate Rewards')
|
||||||
|
rewards_ax.imshow(self.rewards.T, cmap='RdBu_r', vmin=-25, vmax=25)
|
||||||
|
|
||||||
|
if len(policy) > 0:
|
||||||
|
self._plot_policy(dones_ax, policy)
|
||||||
|
|
||||||
|
w, h = self.dones.shape
|
||||||
|
V_array = V.reshape(self.dones.shape).T
|
||||||
|
V_ax.set_title('State Value Function $V(s)$')
|
||||||
|
r = max(1e-13, np.max(np.abs(V_array)))
|
||||||
|
V_ax.imshow(V_array.T, cmap='RdBu_r', vmin=-r, vmax=r)
|
||||||
|
|
||||||
|
if debug_info:
|
||||||
|
for s in range(len(V)):
|
||||||
|
sy, sx = divmod(s, w)
|
||||||
|
V_ax.text(sx, sy, f'{sx},{sy}:{s}',
|
||||||
|
color='w', fontdict=dict(size=6),
|
||||||
|
horizontalalignment='center', verticalalignment='center')
|
||||||
|
|
||||||
|
Q_ax.set_title('State Action Value Function $Q(s, a)$')
|
||||||
|
poly_patches_q_values = self._draw_Q(Q_ax, Q, debug_info)
|
||||||
|
|
||||||
|
def format_coord(x, y):
|
||||||
|
for poly_patch, q_value in poly_patches_q_values:
|
||||||
|
if poly_patch.contains_point(Q_ax.transData.transform((x, y))):
|
||||||
|
return f'x:{x:4.2f} y:{y:4.2f} {q_value}'
|
||||||
|
return f'x:{x:4.2f} y:{y:4.2f}'
|
||||||
|
|
||||||
|
Q_ax.format_coord = format_coord
|
||||||
|
|
||||||
|
for ax in [dones_ax, rewards_ax, V_ax, Q_ax]:
|
||||||
|
ax.tick_params(
|
||||||
|
top=show_coordinates,
|
||||||
|
left=show_coordinates,
|
||||||
|
labelleft=show_coordinates,
|
||||||
|
labeltop=show_coordinates,
|
||||||
|
right=False,
|
||||||
|
bottom=False,
|
||||||
|
labelbottom=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Major ticks
|
||||||
|
s = self.dones.shape[0]
|
||||||
|
ax.set_xticks(np.arange(0, s, 1))
|
||||||
|
ax.set_yticks(np.arange(0, s, 1))
|
||||||
|
|
||||||
|
# Minor ticks
|
||||||
|
ax.set_xticks(np.arange(-.5, s, 1), minor=True)
|
||||||
|
ax.set_yticks(np.arange(-.5, s, 1), minor=True)
|
||||||
|
|
||||||
|
if show_grid:
|
||||||
|
for color, ax in zip(['m', 'w', 'w'], [dones_ax, rewards_ax, V_ax]):
|
||||||
|
# Gridlines based on minor ticks
|
||||||
|
ax.grid(which='minor', color=color, linestyle='-', linewidth=1)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
if plot_filename is not None:
|
||||||
|
plt.savefig(plot_filename)
|
||||||
|
plt.close(fig)
|
||||||
|
else:
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
def _plot_policy(self, ax, policy):
|
||||||
|
""" Function that plots policy. """
|
||||||
|
w, h = self.dones.shape
|
||||||
|
xs = np.arange(w)
|
||||||
|
ys = np.arange(h)
|
||||||
|
xx, yy = np.meshgrid(xs, ys)
|
||||||
|
|
||||||
|
# we need a quiver for each of the four action
|
||||||
|
quivers = list()
|
||||||
|
for a in range(self.get_n_actions()):
|
||||||
|
quivers.append(list())
|
||||||
|
|
||||||
|
# we parse the textual description of the lake
|
||||||
|
for s in range(self.get_n_states()):
|
||||||
|
y, x = divmod(s, w)
|
||||||
|
if self.dones[x, y]:
|
||||||
|
for a in range(self.get_n_actions()):
|
||||||
|
quivers[a].append((0., 0.))
|
||||||
|
else:
|
||||||
|
for a in range(self.get_n_actions()):
|
||||||
|
wdx, wdy = DELTAS[a]
|
||||||
|
corrected = np.array([wdx, -wdy])
|
||||||
|
quivers[a].append(corrected * policy[s, a])
|
||||||
|
|
||||||
|
# plot each quiver
|
||||||
|
for quiver in quivers:
|
||||||
|
q = np.array(quiver)
|
||||||
|
ax.quiver(xx, yy, q[:, 0], q[:, 1], units='xy', scale=1.5)
|
||||||
|
|
||||||
|
def _draw_Q(self, ax, Q, debug_info):
|
||||||
|
""" Function that draws Q. """
|
||||||
|
pattern = np.zeros(self.dones.shape)
|
||||||
|
ax.imshow(pattern, cmap='gray_r')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.cm import ScalarMappable
|
||||||
|
from matplotlib.colors import Normalize
|
||||||
|
from matplotlib.patches import Rectangle, Polygon
|
||||||
|
w, h = self.dones.shape
|
||||||
|
|
||||||
|
r = max(1e-13, np.max(np.abs(Q)))
|
||||||
|
norm = Normalize(vmin=-r, vmax=r)
|
||||||
|
cmap = plt.get_cmap('RdBu_r')
|
||||||
|
sm = ScalarMappable(norm, cmap)
|
||||||
|
|
||||||
|
hover_polygons = []
|
||||||
|
for state in range(len(Q)):
|
||||||
|
qs = Q[state]
|
||||||
|
# print('qs', qs)
|
||||||
|
y, x = divmod(state, w)
|
||||||
|
if self.dones[x, y]:
|
||||||
|
continue
|
||||||
|
y += 0.5
|
||||||
|
x += 0.5
|
||||||
|
|
||||||
|
dx = 1
|
||||||
|
dy = 1
|
||||||
|
|
||||||
|
ulx = (x - 1) * dx
|
||||||
|
uly = (y - 1) * dy
|
||||||
|
|
||||||
|
rect = Rectangle(
|
||||||
|
xy=(ulx, uly),
|
||||||
|
width=dx,
|
||||||
|
height=dy,
|
||||||
|
edgecolor='k',
|
||||||
|
facecolor='none'
|
||||||
|
)
|
||||||
|
ax.add_artist(rect)
|
||||||
|
|
||||||
|
mx = (x - 1) * dx + dx / 2.
|
||||||
|
my = (y - 1) * dy + dy / 2.
|
||||||
|
|
||||||
|
ul = ulx, uly
|
||||||
|
ur = ulx + dx, uly
|
||||||
|
ll = ulx, uly + dy
|
||||||
|
lr = ulx + dx, uly + dy
|
||||||
|
m = mx, my
|
||||||
|
|
||||||
|
up = [ul, m, ur]
|
||||||
|
left = [ul, m, ll]
|
||||||
|
right = [ur, m, lr]
|
||||||
|
down = [ll, m, lr]
|
||||||
|
action_polys = [left, right, up, down]
|
||||||
|
for a, poly in enumerate(action_polys):
|
||||||
|
poly_patch = Polygon(
|
||||||
|
poly,
|
||||||
|
edgecolor='k',
|
||||||
|
linewidth=0.1,
|
||||||
|
facecolor=sm.to_rgba(qs[a])
|
||||||
|
)
|
||||||
|
if debug_info:
|
||||||
|
mmx = np.mean([x for x, y in poly])
|
||||||
|
mmy = np.mean([y for x, y in poly])
|
||||||
|
sss = '\n'.join(map(str, self.P[state][a]))
|
||||||
|
ax.text(mmx, mmy, f'{NAMES[a][0]}:{sss}',
|
||||||
|
fontdict=dict(size=5), horizontalalignment='center',
|
||||||
|
verticalalignment='center')
|
||||||
|
|
||||||
|
hover_polygons.append((poly_patch, f'{NAMES[a]}:{qs[a]:4.2f}'))
|
||||||
|
ax.add_artist(poly_patch)
|
||||||
|
return hover_polygons
|
||||||
0
pig_lite/game/__init__.py
Normal file
0
pig_lite/game/__init__.py
Normal file
87
pig_lite/game/base.py
Normal file
87
pig_lite/game/base.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import hashlib
|
||||||
|
|
||||||
|
class Node(object):
|
||||||
|
def __init__(self, parent, state, action, player, depth):
|
||||||
|
self.parent = parent
|
||||||
|
self.state = state
|
||||||
|
self.action = action
|
||||||
|
self.player = player
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
def key(self):
|
||||||
|
# if state is composed of other stuff (dict, set, ...)
|
||||||
|
# make it a tuple containing hashable datatypes
|
||||||
|
# (this is supposed to be overridden by subclasses)
|
||||||
|
return tuple(self.state) + (self.player, )
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.key())
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if type(self) == type(other):
|
||||||
|
return self.key() == other.key()
|
||||||
|
raise ValueError('cannot simply compare two different node types')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'Node(id:{}, parent:{}, state:{}, action:{}, player:{}, depth:{})'.format(
|
||||||
|
id(self),
|
||||||
|
id(self.parent),
|
||||||
|
self.state,
|
||||||
|
self.action,
|
||||||
|
self.player,
|
||||||
|
self.depth
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_move_sequence(self):
|
||||||
|
current = self
|
||||||
|
reverse_sequence = []
|
||||||
|
while current.parent is not None:
|
||||||
|
reverse_sequence.append((current.player, current.action))
|
||||||
|
current = current.parent
|
||||||
|
return list(reversed(reverse_sequence))
|
||||||
|
|
||||||
|
def get_move_sequence_hash(self):
|
||||||
|
move_sequence = self.get_move_sequence()
|
||||||
|
move_sequence_as_str = ';'.join(map(str, move_sequence))
|
||||||
|
move_sequence_hash = hashlib.sha256(move_sequence_as_str.encode('UTF-8')).hexdigest()
|
||||||
|
return move_sequence_hash
|
||||||
|
|
||||||
|
class Game(object):
|
||||||
|
def get_number_of_expanded_nodes(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_start_node(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def winner(self, node):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def successors(self, node):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_max_player(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_move_sequence(self, end: Node):
|
||||||
|
if end is None:
|
||||||
|
return list()
|
||||||
|
return end.get_move_sequence()
|
||||||
|
|
||||||
|
def get_move_sequence_hash(self, end: Node):
|
||||||
|
if end is None:
|
||||||
|
return ''
|
||||||
|
return end.get_move_sequence_hash()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(jsonstring):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_minimum_problem_size():
|
||||||
|
raise NotImplementedError()
|
||||||
371
pig_lite/game/tictactoe.py
Normal file
371
pig_lite/game/tictactoe.py
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from pig_lite.game.base import Node, Game
|
||||||
|
|
||||||
|
|
||||||
|
class TTTNode(Node):
|
||||||
|
def key(self):
|
||||||
|
return tuple(self.state.flatten().tolist() + [self.player])
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '"TTTNode(\nid:{}\nparent:{}\nboard:\n{}\nplayer:\n{}\naction:\n{}\ndepth:{})"'.format(
|
||||||
|
id(self),
|
||||||
|
id(self.parent),
|
||||||
|
# this needs to be printed transposed, so it fits together with
|
||||||
|
# how matplotlib's 'imshow' renders images
|
||||||
|
self.state.T,
|
||||||
|
self.player,
|
||||||
|
self.action,
|
||||||
|
self.depth
|
||||||
|
)
|
||||||
|
|
||||||
|
def pretty_print(self):
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.colors import ListedColormap
|
||||||
|
cm = ListedColormap(['tab:blue', 'lightgray', 'tab:orange'])
|
||||||
|
print('State of the board:')
|
||||||
|
plt.figure(figsize=(2, 2))
|
||||||
|
plt.imshow(self.state.T, cmap=cm)
|
||||||
|
plt.axis('off')
|
||||||
|
plt.show()
|
||||||
|
print('Performed moves: {}'.format(self.depth))
|
||||||
|
|
||||||
|
|
||||||
|
class TicTacToe(Game):
|
||||||
|
def __init__(self, rng=None, depth=None):
|
||||||
|
self.n_expands = 0
|
||||||
|
self.play_randomly(rng, depth)
|
||||||
|
|
||||||
|
def play_randomly(self, rng, depth):
|
||||||
|
""" Initialises self.start_node to be either empty board, or board at given depth after random playing. """
|
||||||
|
empty_board = np.zeros((3, 3), dtype=int)
|
||||||
|
start_from_empty = TTTNode(None, empty_board, None, 1, 0)
|
||||||
|
if rng is None or depth is None or depth == 0:
|
||||||
|
self.start_node = start_from_empty
|
||||||
|
else:
|
||||||
|
# proceed playing randomly until either 'depth' is reached,
|
||||||
|
# or the node is a terminal node
|
||||||
|
nodes = []
|
||||||
|
successors = [start_from_empty]
|
||||||
|
while True:
|
||||||
|
index = rng.randint(0, len(successors))
|
||||||
|
current = successors[index]
|
||||||
|
|
||||||
|
if current.depth == depth:
|
||||||
|
break
|
||||||
|
|
||||||
|
nodes.append(current)
|
||||||
|
terminal, winner = self.outcome(current)
|
||||||
|
if terminal:
|
||||||
|
break
|
||||||
|
successors = self.successors(current)
|
||||||
|
|
||||||
|
for node in successors:
|
||||||
|
nodes.append(node)
|
||||||
|
|
||||||
|
self.start_node = TTTNode(None, current.state, None, current.player, 0)
|
||||||
|
|
||||||
|
def get_start_node(self):
|
||||||
|
""" Returns start node of this Game. """
|
||||||
|
return self.start_node
|
||||||
|
|
||||||
|
def outcome(self, node):
|
||||||
|
""" Returns tuple stating whether game is finished or not, and winner (or None otherwise). """
|
||||||
|
board = node.state
|
||||||
|
for player in [-1, 1]:
|
||||||
|
# checks rows and columns
|
||||||
|
for i in range(3):
|
||||||
|
if (board[i, :] == player).all() or (board[:, i] == player).all():
|
||||||
|
return True, player
|
||||||
|
|
||||||
|
# checks diagonals
|
||||||
|
if (np.diag(board) == player).all() or (np.diag(np.rot90(board)) == player).all():
|
||||||
|
return True, player
|
||||||
|
|
||||||
|
# if board is full, and none of the conditions above are true,
|
||||||
|
# nobody has won --- it's a draw
|
||||||
|
if (board != 0).all():
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
# else, continue
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def get_max_player(self):
|
||||||
|
""" Returns identifier of MAX player used in this game. """
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def successor(self, node, action):
|
||||||
|
""" Performs given action at given game node, and returns successor TTT node. """
|
||||||
|
board = node.state
|
||||||
|
player = node.player
|
||||||
|
|
||||||
|
next_board = board.copy()
|
||||||
|
next_board[action] = player
|
||||||
|
|
||||||
|
if player == 1:
|
||||||
|
next_player = -1
|
||||||
|
else:
|
||||||
|
next_player = 1
|
||||||
|
|
||||||
|
return TTTNode(
|
||||||
|
node,
|
||||||
|
next_board,
|
||||||
|
action,
|
||||||
|
next_player,
|
||||||
|
node.depth + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_number_of_expanded_nodes(self):
|
||||||
|
return self.n_expands
|
||||||
|
|
||||||
|
def successors(self, node):
|
||||||
|
""" Given a game node, returns all possible successor nodes based on all actions that can be performed. """
|
||||||
|
self.n_expands += 1
|
||||||
|
terminal, winner = self.outcome(node)
|
||||||
|
|
||||||
|
if terminal:
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
successor_nodes = []
|
||||||
|
# iterate through all possible coordinates (==actions)
|
||||||
|
for action in zip(*np.nonzero(node.state == 0)):
|
||||||
|
successor_nodes.append(self.successor(node, action))
|
||||||
|
return successor_nodes
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
""" Converts and stores this TTT game to a JSON file. """
|
||||||
|
return json.dumps(dict(
|
||||||
|
type=self.__class__.__name__,
|
||||||
|
start_state=self.start_node.state.tolist(),
|
||||||
|
start_player=self.start_node.player
|
||||||
|
))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(jsonstring):
|
||||||
|
""" Loads given JSON file, and creates game with information. """
|
||||||
|
data = json.loads(jsonstring)
|
||||||
|
|
||||||
|
ttt = TicTacToe()
|
||||||
|
ttt.start_node = TTTNode(
|
||||||
|
None,
|
||||||
|
np.array(data['start_state'], dtype=int),
|
||||||
|
None,
|
||||||
|
data['start_player'],
|
||||||
|
0
|
||||||
|
)
|
||||||
|
return ttt
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(data):
|
||||||
|
""" Creates game with information in given data-dictionary. """
|
||||||
|
ttt = TicTacToe()
|
||||||
|
ttt.start_node = TTTNode(
|
||||||
|
None,
|
||||||
|
np.array(data['start_state'], dtype=int),
|
||||||
|
None,
|
||||||
|
data['start_player'],
|
||||||
|
0
|
||||||
|
)
|
||||||
|
return ttt
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_minimum_problem_size():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def visualize(self, move_sequence, show_possible=False, tree_name=''):
|
||||||
|
game = deepcopy(self)
|
||||||
|
nodes = []
|
||||||
|
current = game.get_start_node()
|
||||||
|
nodes.append(current)
|
||||||
|
for player, move in move_sequence:
|
||||||
|
if show_possible:
|
||||||
|
successors = game.successors(current)
|
||||||
|
nodes.extend(successors)
|
||||||
|
current = None
|
||||||
|
for succ in successors:
|
||||||
|
if succ.action == move:
|
||||||
|
current = succ
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
current = game.successor(current, move)
|
||||||
|
nodes.append(current)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.networkx_plot_game_tree(tree_name, nodes)
|
||||||
|
except ImportError:
|
||||||
|
print('#' * 30)
|
||||||
|
print('#' * 30)
|
||||||
|
print('starting position')
|
||||||
|
print(self.get_start_node())
|
||||||
|
print('#' * 30)
|
||||||
|
print('#' * 30)
|
||||||
|
print('-' * 30)
|
||||||
|
print('sequence of nodes')
|
||||||
|
for node in nodes:
|
||||||
|
print('-' * 30)
|
||||||
|
print(node)
|
||||||
|
terminal, winner = game.outcome(node)
|
||||||
|
print('terminal {}, winner {}'.format(terminal, winner))
|
||||||
|
|
||||||
|
def networkx_plot_game_tree(self, title, nodes, highlight=None):
|
||||||
|
# TODO: this needs some serious refactoring
|
||||||
|
# use visitors for styling, for example, instead of cumbersome dicts
|
||||||
|
import networkx as nx
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from networkx.drawing.nx_pydot import graphviz_layout
|
||||||
|
from matplotlib.offsetbox import OffsetImage, AnnotationBbox, HPacker, VPacker, TextArea
|
||||||
|
|
||||||
|
fig, tree_ax = plt.subplots()
|
||||||
|
tree_ax.set_title(title)
|
||||||
|
G = nx.DiGraph(ordering='out')
|
||||||
|
nodes_extra = dict()
|
||||||
|
edges_extra = dict()
|
||||||
|
|
||||||
|
def sort_key(node):
|
||||||
|
if node.action is None:
|
||||||
|
return (-1, -1)
|
||||||
|
return node.action
|
||||||
|
|
||||||
|
for node in sorted(nodes, key=sort_key):
|
||||||
|
G.add_node(id(node), search_node=node)
|
||||||
|
terminal, winner = self.outcome(node)
|
||||||
|
nodes_extra[id(node)] = dict(
|
||||||
|
board=node.state,
|
||||||
|
player=node.player,
|
||||||
|
depth=node.depth,
|
||||||
|
terminal=terminal,
|
||||||
|
winner=winner
|
||||||
|
)
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
if node.parent is not None:
|
||||||
|
edge = id(node.parent), id(node)
|
||||||
|
G.add_edge(*edge, parent_node=node.parent)
|
||||||
|
edges_extra[edge] = dict(
|
||||||
|
label='{}'.format(node.action),
|
||||||
|
parent_player=node.parent.player
|
||||||
|
)
|
||||||
|
|
||||||
|
node_size = 1000
|
||||||
|
positions = graphviz_layout(G, prog='dot')
|
||||||
|
|
||||||
|
from matplotlib.colors import Normalize, LinearSegmentedColormap
|
||||||
|
|
||||||
|
blue_orange = LinearSegmentedColormap.from_list(
|
||||||
|
'blue_orange',
|
||||||
|
['tab:blue', 'lightgray', 'tab:orange']
|
||||||
|
)
|
||||||
|
|
||||||
|
inf = float('Inf')
|
||||||
|
x_range = [inf, -inf]
|
||||||
|
y_range = [inf, -inf]
|
||||||
|
for id_node, pos in positions.items():
|
||||||
|
x, y = pos
|
||||||
|
x_range = [min(x, x_range[0]), max(x, x_range[1])]
|
||||||
|
y_range = [min(y, y_range[0]), max(y, y_range[1])]
|
||||||
|
|
||||||
|
player = nodes_extra[id_node]['player']
|
||||||
|
text_player = 'p:{}'.format(player)
|
||||||
|
text_depth = 'd:{}'.format(nodes_extra[id_node]['depth'])
|
||||||
|
color_player = 'tab:blue' if player == -1 else 'tab:orange'
|
||||||
|
|
||||||
|
frameon = False
|
||||||
|
bboxprops = None
|
||||||
|
if nodes_extra[id_node]['terminal']:
|
||||||
|
winner = nodes_extra[id_node]['winner']
|
||||||
|
frameon = True
|
||||||
|
if winner is None:
|
||||||
|
edgecolor = 'tab:purple'
|
||||||
|
else:
|
||||||
|
edgecolor = 'tab:blue' if winner == -1 else 'tab:orange'
|
||||||
|
bboxprops = dict(
|
||||||
|
facecolor='none',
|
||||||
|
edgecolor=edgecolor
|
||||||
|
)
|
||||||
|
color_player = 'k'
|
||||||
|
text_player = 'w:{}'.format(winner)
|
||||||
|
if winner is None:
|
||||||
|
text_player = ''
|
||||||
|
|
||||||
|
# needs to be transposed b/c image coordinates etc ...
|
||||||
|
board = nodes_extra[id_node]['board'].T
|
||||||
|
textbox_player = TextArea(text_player, textprops=dict(size=6, color=color_player))
|
||||||
|
textbox_depth = TextArea(text_depth, textprops=dict(size=6))
|
||||||
|
|
||||||
|
textbox_children = [textbox_player, textbox_depth]
|
||||||
|
|
||||||
|
if highlight is not None:
|
||||||
|
if id_node in highlight:
|
||||||
|
if nodes_extra[id_node]['terminal']:
|
||||||
|
frameon = True
|
||||||
|
if nodes_extra[id_node]['winner'] is None:
|
||||||
|
edgecolor = 'tab:purple'
|
||||||
|
else:
|
||||||
|
edgecolor = 'tab:blue' if winner == -1 else 'tab:orange'
|
||||||
|
|
||||||
|
bboxprops = dict(
|
||||||
|
facecolor='none',
|
||||||
|
edgecolor=edgecolor
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(highlight[id_node]) > 0:
|
||||||
|
for key, value in highlight[id_node].items():
|
||||||
|
textbox_children.append(
|
||||||
|
TextArea('{}:{}'.format(key, value), textprops=dict(size=6))
|
||||||
|
)
|
||||||
|
|
||||||
|
imagebox = OffsetImage(board, zoom=5, cmap=blue_orange, norm=Normalize(vmin=-1, vmax=1))
|
||||||
|
packed = HPacker(
|
||||||
|
align='center',
|
||||||
|
children=[
|
||||||
|
imagebox,
|
||||||
|
VPacker(
|
||||||
|
align='center',
|
||||||
|
children=textbox_children,
|
||||||
|
sep=0.1, pad=0.1
|
||||||
|
)
|
||||||
|
],
|
||||||
|
sep=0.1, pad=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
ab = AnnotationBbox(packed, pos, xycoords='data', frameon=frameon, bboxprops=bboxprops)
|
||||||
|
tree_ax.add_artist(ab)
|
||||||
|
|
||||||
|
def min_dist(a, b):
|
||||||
|
if a == b:
|
||||||
|
return [a - 1, b + 1]
|
||||||
|
else:
|
||||||
|
return [a - 0.9 * abs(a), b + 0.1 * abs(b)]
|
||||||
|
|
||||||
|
x_range = min_dist(*x_range)
|
||||||
|
y_range = min_dist(*y_range)
|
||||||
|
tree_ax.set_xlim(x_range)
|
||||||
|
tree_ax.set_ylim(y_range)
|
||||||
|
|
||||||
|
orange_edges = []
|
||||||
|
blue_edges = []
|
||||||
|
|
||||||
|
for edge, extra in edges_extra.items():
|
||||||
|
if extra['parent_player'] == -1:
|
||||||
|
blue_edges.append(edge)
|
||||||
|
else:
|
||||||
|
orange_edges.append(edge)
|
||||||
|
|
||||||
|
for color, edgelist in [('tab:orange', orange_edges), ('tab:blue', blue_edges)]:
|
||||||
|
nx.draw_networkx_edges(
|
||||||
|
G, positions,
|
||||||
|
edgelist=edgelist,
|
||||||
|
edge_color=color,
|
||||||
|
arrowstyle='-|>',
|
||||||
|
arrowsize=10,
|
||||||
|
node_size=node_size,
|
||||||
|
ax=tree_ax
|
||||||
|
)
|
||||||
|
edge_labels = {edge_id: edge['label'] for edge_id, edge in edges_extra.items()}
|
||||||
|
nx.draw_networkx_edge_labels(G, positions, edge_labels, ax=tree_ax, font_size=6)
|
||||||
|
|
||||||
|
tree_ax.axis('off')
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
0
pig_lite/instance_generation/__init__.py
Normal file
0
pig_lite/instance_generation/__init__.py
Normal file
5
pig_lite/instance_generation/enc.py
Normal file
5
pig_lite/instance_generation/enc.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# this is the common encoding for different level tiles
|
||||||
|
WALL = 1
|
||||||
|
SPACE = 0
|
||||||
|
EXPOSED = -1
|
||||||
|
UNDETERMINED = -2
|
||||||
96
pig_lite/instance_generation/problem_factory.py
Normal file
96
pig_lite/instance_generation/problem_factory.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from pig_lite.problem.simple_2d import Simple2DProblem, MazeLevel, TerrainLevel, RoomLevel
|
||||||
|
from pig_lite.environment.gridworld import Gridworld
|
||||||
|
from pig_lite.game.tictactoe import TicTacToe
|
||||||
|
from pig_lite.decision_tree.training_set import TrainingSet
|
||||||
|
|
||||||
|
# this is the common encoding for different level tiles
|
||||||
|
encoding = {
|
||||||
|
'WALL': 1,
|
||||||
|
'SPACE': 0,
|
||||||
|
'EXPOSED': -1,
|
||||||
|
'UNDETERMINED': -2
|
||||||
|
}
|
||||||
|
|
||||||
|
class ProblemFactory():
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_problem(problem_type, problem_size, rng):
|
||||||
|
if problem_type == 'maze':
|
||||||
|
level = MazeLevel(rng, size=problem_size)
|
||||||
|
return Simple2DProblem(level.get_field(),
|
||||||
|
level.get_costs(),
|
||||||
|
level.get_start(),
|
||||||
|
level.get_end())
|
||||||
|
elif problem_type == 'terrain':
|
||||||
|
level = TerrainLevel(rng, size=problem_size)
|
||||||
|
return Simple2DProblem(level.get_field(),
|
||||||
|
level.get_costs(),
|
||||||
|
level.get_start(),
|
||||||
|
level.get_end())
|
||||||
|
elif problem_type == 'rooms':
|
||||||
|
level = RoomLevel(rng, size=problem_size)
|
||||||
|
return Simple2DProblem(level.get_field(),
|
||||||
|
level.get_costs(),
|
||||||
|
level.get_start(),
|
||||||
|
level.get_end())
|
||||||
|
elif problem_type == 'tictactoe':
|
||||||
|
return TicTacToe(rng, depth=problem_size)
|
||||||
|
elif problem_type == 'gridworld':
|
||||||
|
return Gridworld.get_random_instance(rng, size=problem_size)
|
||||||
|
elif problem_type =='trainset':
|
||||||
|
raise NotImplementedError(f'problem_type {problem_type} is not implemented yet')
|
||||||
|
else:
|
||||||
|
raise ValueError(f'unknown problem_type {problem_type}')
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_problem_from_json(json_path):
|
||||||
|
with open(json_path, 'r') as file:
|
||||||
|
data = json.load(file)
|
||||||
|
problem_type = data['type']
|
||||||
|
|
||||||
|
if problem_type == 'Simple2DProblem':
|
||||||
|
problem = Simple2DProblem.from_dict(data)
|
||||||
|
return problem
|
||||||
|
elif problem_type == 'TicTacToe':
|
||||||
|
problem = TicTacToe.from_dict(data)
|
||||||
|
return problem
|
||||||
|
elif problem_type == 'Gridworld':
|
||||||
|
problem = Gridworld.from_dict(data)
|
||||||
|
return problem
|
||||||
|
elif problem_type == 'TrainingSet':
|
||||||
|
problem = TrainingSet.from_dict(data)
|
||||||
|
return problem
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown problem type: {problem_type}")
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_problem_from_dict(data, problem_type='Simple2DProblem'):
|
||||||
|
import numpy as np
|
||||||
|
if problem_type == 'Simple2DProblem':
|
||||||
|
if not ('board' in data.keys() and 'costs' in data.keys()
|
||||||
|
and 'start_state' in data.keys() and 'end_state' in data.keys()):
|
||||||
|
raise ValueError('data dict must contain: "board", "costs", "start_state" and "end_state"')
|
||||||
|
if np.array(data['board']).shape != np.array(data['costs']).shape:
|
||||||
|
raise ValueError('data["board"] and data["costs"] must have same shape')
|
||||||
|
problem = Simple2DProblem.from_dict(data)
|
||||||
|
return problem
|
||||||
|
if problem_type == 'TicTacToe':
|
||||||
|
if not ('start_state' in data.keys() and 'start_player' in data.keys()):
|
||||||
|
raise ValueError('data dict must contain: "start_state", "start_player"')
|
||||||
|
problem = TicTacToe.from_dict(data)
|
||||||
|
return problem
|
||||||
|
if problem_type == 'Gridworld':
|
||||||
|
if not ('seed' in data.keys() and 'dones' in data.keys()
|
||||||
|
and 'rewards' in data.keys() and 'starts' in data.keys()):
|
||||||
|
raise ValueError('data dict must contain: "seed", "dones", "rewards", "starts"')
|
||||||
|
problem = Gridworld.from_dict(data)
|
||||||
|
return problem
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'problem_type {problem_type} is not implemented yet')
|
||||||
|
|
||||||
529
pig_lite/problem/.ipynb_checkpoints/simple_2d-checkpoint.py
Normal file
529
pig_lite/problem/.ipynb_checkpoints/simple_2d-checkpoint.py
Normal file
@@ -0,0 +1,529 @@
|
|||||||
|
from pig_lite.problem.base import Problem, Node
|
||||||
|
from pig_lite.instance_generation import enc
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from collections import OrderedDict
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||||
|
from matplotlib.colors import TABLEAU_COLORS, XKCD_COLORS
|
||||||
|
|
||||||
|
class BaseLevel():
|
||||||
|
def __init__(self, rng, size) -> None:
|
||||||
|
self.rng = rng
|
||||||
|
self.size = size
|
||||||
|
self.field = None
|
||||||
|
self.costs = None
|
||||||
|
self.start = None
|
||||||
|
self.end = None
|
||||||
|
|
||||||
|
self.initialize_level()
|
||||||
|
|
||||||
|
def initialize_level(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_field(self):
|
||||||
|
return self.field
|
||||||
|
|
||||||
|
def get_costs(self):
|
||||||
|
return self.costs
|
||||||
|
|
||||||
|
def get_start(self):
|
||||||
|
return self.start
|
||||||
|
|
||||||
|
def get_end(self):
|
||||||
|
return self.end
|
||||||
|
|
||||||
|
|
||||||
|
class MazeLevel(BaseLevel):
|
||||||
|
# this method generates a random maze according to prim's randomized
|
||||||
|
# algorithm
|
||||||
|
# http://en.wikipedia.org/wiki/Maze_generation_algorithm#Randomized_Prim.27s_algorithm
|
||||||
|
|
||||||
|
def __init__(self, rng, size):
|
||||||
|
super().__init__(rng, size)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_level(self):
|
||||||
|
|
||||||
|
self.field = np.full((self.size, self.size), enc.WALL, dtype=np.int8)
|
||||||
|
self.costs = self.rng.randint(1, 5, self.field.shape, dtype=np.int8)
|
||||||
|
|
||||||
|
self.start = (0, 0)
|
||||||
|
|
||||||
|
self.deltas = [
|
||||||
|
(0, 1),
|
||||||
|
(0, -1),
|
||||||
|
(1, 0),
|
||||||
|
(-1, 0)
|
||||||
|
]
|
||||||
|
self.random_walk()
|
||||||
|
end = np.where(self.field == enc.SPACE)
|
||||||
|
self.end = (int(end[0][-1]), int(end[1][-1]))
|
||||||
|
|
||||||
|
self.replace_walls_with_high_cost_tiles()
|
||||||
|
|
||||||
|
def replace_walls_with_high_cost_tiles(self):
|
||||||
|
# select only coordinates of walls
|
||||||
|
walls = np.where(self.field == enc.WALL)
|
||||||
|
|
||||||
|
n_walls = len(walls[0])
|
||||||
|
|
||||||
|
# replace about a tenth of the walls...
|
||||||
|
to_replace = self.rng.randint(0, n_walls, n_walls // 9)
|
||||||
|
|
||||||
|
# ... with space, but very *costly* space (it's trap!)
|
||||||
|
for ri in to_replace:
|
||||||
|
x, y = walls[0][ri], walls[1][ri]
|
||||||
|
self.field[x, y] = enc.SPACE
|
||||||
|
self.costs[x, y] = 9
|
||||||
|
|
||||||
|
def random_walk(self):
|
||||||
|
frontier = list()
|
||||||
|
|
||||||
|
sx, sy = self.start
|
||||||
|
self.field[sx, sy] = enc.SPACE
|
||||||
|
frontier.extend(self.get_walls(self.start))
|
||||||
|
|
||||||
|
while len(frontier) > 0:
|
||||||
|
current, opposing = frontier[self.rng.randint(len(frontier))]
|
||||||
|
|
||||||
|
cx, cy = current
|
||||||
|
ox, oy = opposing
|
||||||
|
if self.field[ox, oy] == enc.WALL:
|
||||||
|
self.field[cx, cy] = enc.SPACE
|
||||||
|
self.field[ox, oy] = enc.SPACE
|
||||||
|
frontier.extend(self.get_walls(opposing))
|
||||||
|
else:
|
||||||
|
frontier.remove((current, opposing))
|
||||||
|
|
||||||
|
def in_bounds(self, position):
|
||||||
|
x, y = position
|
||||||
|
return x >= 0 and y >= 0 and x < self.size and y < self.size
|
||||||
|
|
||||||
|
def get_walls(self, position):
|
||||||
|
walls = []
|
||||||
|
px, py = position
|
||||||
|
for dx, dy in self.deltas:
|
||||||
|
cx = px + dx
|
||||||
|
cy = py + dy
|
||||||
|
current = (cx, cy)
|
||||||
|
|
||||||
|
ox = px + 2 * dx
|
||||||
|
oy = py + 2 * dy
|
||||||
|
opposing = (ox, oy)
|
||||||
|
|
||||||
|
if (self.in_bounds(current) and self.field[cx, cy] == enc.WALL and self.in_bounds(opposing)):
|
||||||
|
walls.append((current, opposing))
|
||||||
|
return walls
|
||||||
|
|
||||||
|
|
||||||
|
# this is code taken from
|
||||||
|
# https://github.com/dandrino/terrain-erosion-3-ways/blob/master/util.py
|
||||||
|
# Copyright (c) 2018 Daniel Andrino
|
||||||
|
# (project is MIT licensed)
|
||||||
|
def fbm(shape, p, lower=-np.inf, upper=np.inf):
|
||||||
|
freqs = tuple(np.fft.fftfreq(n, d=1.0 / n) for n in shape)
|
||||||
|
freq_radial = np.hypot(*np.meshgrid(*freqs))
|
||||||
|
envelope = (np.power(freq_radial, p, where=freq_radial != 0) *
|
||||||
|
(freq_radial > lower) * (freq_radial < upper))
|
||||||
|
envelope[0][0] = 0.0
|
||||||
|
phase_noise = np.exp(2j * np.pi * np.random.rand(*shape))
|
||||||
|
return np.real(np.fft.ifft2(np.fft.fft2(phase_noise) * envelope))
|
||||||
|
|
||||||
|
|
||||||
|
class TerrainLevel(BaseLevel):
|
||||||
|
def __init__(self, rng, size):
|
||||||
|
super().__init__(rng, size)
|
||||||
|
|
||||||
|
def initialize_level(self):
|
||||||
|
|
||||||
|
self.field = np.full((self.size, self.size), enc.SPACE, dtype=np.int8)
|
||||||
|
|
||||||
|
self.costs = fbm(self.field.shape, -2)
|
||||||
|
self.costs -= self.costs.min()
|
||||||
|
self.costs /= self.costs.max()
|
||||||
|
self.costs *= 9
|
||||||
|
self.costs += 1
|
||||||
|
self.costs = self.costs.astype(int)
|
||||||
|
|
||||||
|
self.start = (0, 0)
|
||||||
|
self.end = (self.size - 1, self.size - 1)
|
||||||
|
|
||||||
|
x = 0
|
||||||
|
y = self.size - 1
|
||||||
|
for i in range(0, self.size):
|
||||||
|
self.field[x, y] = enc.WALL
|
||||||
|
x += 1
|
||||||
|
y -= 1
|
||||||
|
|
||||||
|
self.replace_one_or_more_walls()
|
||||||
|
|
||||||
|
def replace_one_or_more_walls(self):
|
||||||
|
# select only coordinates of walls
|
||||||
|
walls = np.where(self.field == enc.WALL)
|
||||||
|
n_walls = len(walls[0])
|
||||||
|
n_replace = self.rng.randint(1, max(2, n_walls // 5))
|
||||||
|
to_replace = self.rng.randint(0, n_walls, n_replace)
|
||||||
|
|
||||||
|
for ri in to_replace:
|
||||||
|
x, y = walls[0][ri], walls[1][ri]
|
||||||
|
self.field[x, y] = enc.SPACE
|
||||||
|
|
||||||
|
|
||||||
|
class RoomLevel(BaseLevel):
|
||||||
|
def __init__(self, rng, size):
|
||||||
|
super().__init__(rng, size)
|
||||||
|
|
||||||
|
def initialize_level(self):
|
||||||
|
self.field = np.full((self.size, self.size), enc.SPACE, dtype=np.int8)
|
||||||
|
self.costs = np.ones_like(self.field, dtype=np.float32)
|
||||||
|
|
||||||
|
k = 1
|
||||||
|
self.subdivide(self.field.view(), self.costs.view(), k, 0, 0)
|
||||||
|
|
||||||
|
# such a *crutch*!
|
||||||
|
# this 'repairs' dead ends. horrible stuff.
|
||||||
|
for x in range(1, self.size - 1):
|
||||||
|
for y in range(1, self.size - 1):
|
||||||
|
s = 0
|
||||||
|
s += self.field[x - 1, y]
|
||||||
|
s += self.field[x + 1, y]
|
||||||
|
s += self.field[x, y - 1]
|
||||||
|
s += self.field[x, y + 1]
|
||||||
|
if self.field[x, y] == enc.SPACE and s >= 3:
|
||||||
|
self.field[x - 1, y] = enc.SPACE
|
||||||
|
self.field[x + 1, y] = enc.SPACE
|
||||||
|
self.field[x, y - 1] = enc.SPACE
|
||||||
|
self.field[x, y + 1] = enc.SPACE
|
||||||
|
|
||||||
|
spaces = np.where(self.field == enc.SPACE)
|
||||||
|
n_spaces = len(spaces[0])
|
||||||
|
|
||||||
|
n_danger = self.rng.randint(3, 7)
|
||||||
|
dangers = self.rng.choice(range(n_spaces), n_danger, replace=False)
|
||||||
|
for di in dangers:
|
||||||
|
rx, ry = np.unravel_index(di, (self.size, self.size))
|
||||||
|
const = max(1., self.rng.randint(self.size // 5, self.size // 2))
|
||||||
|
for x in range(self.size):
|
||||||
|
for y in range(self.size):
|
||||||
|
distance = np.sqrt((rx - x) ** 2 + (ry - y) ** 2)
|
||||||
|
self.costs[x, y] = self.costs[x, y] + (1. / (const + distance))
|
||||||
|
|
||||||
|
self.costs = self.costs - self.costs.min()
|
||||||
|
self.costs = self.costs / self.costs.max()
|
||||||
|
self.costs = self.costs * 9
|
||||||
|
self.costs = self.costs + 1
|
||||||
|
self.costs = self.costs.astype(int)
|
||||||
|
|
||||||
|
start_choice = 0
|
||||||
|
end_choice = -1
|
||||||
|
|
||||||
|
self.start = (int(spaces[0][start_choice]), int(spaces[1][start_choice]))
|
||||||
|
self.end = (int(spaces[0][end_choice]), int(spaces[1][end_choice]))
|
||||||
|
|
||||||
|
if self.start == self.end:
|
||||||
|
raise RuntimeError('should never happen')
|
||||||
|
|
||||||
|
def subdivide(self, current, costs, k, d, previous_door):
|
||||||
|
w, h = current.shape
|
||||||
|
random_stop = self.rng.randint(0, 10) == 0 and d > 2
|
||||||
|
if w <= 2 * k + 1 or h <= 2 * k + 1 or random_stop:
|
||||||
|
return
|
||||||
|
|
||||||
|
split = previous_door
|
||||||
|
while split == previous_door:
|
||||||
|
split = self.rng.randint(k, w - k)
|
||||||
|
current[split, :] = enc.WALL
|
||||||
|
door = self.rng.randint(k, h - k)
|
||||||
|
current[split, door] = enc.SPACE
|
||||||
|
|
||||||
|
self.subdivide(
|
||||||
|
current[:split, :].T,
|
||||||
|
costs[:split, :].T,
|
||||||
|
k,
|
||||||
|
d + 1,
|
||||||
|
door
|
||||||
|
)
|
||||||
|
self.subdivide(
|
||||||
|
current[split + 1:, :].T,
|
||||||
|
costs[split + 1:, :].T,
|
||||||
|
k,
|
||||||
|
d + 1,
|
||||||
|
door
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Simple2DProblem(Problem):
|
||||||
|
"""
|
||||||
|
the states are the positions on the board that the agent can walk on
|
||||||
|
"""
|
||||||
|
|
||||||
|
ACTIONS_DELTA = OrderedDict([
|
||||||
|
('R', (+1, 0)),
|
||||||
|
('U', (0, -1)),
|
||||||
|
('D', (0, +1)),
|
||||||
|
('L', (-1, 0)),
|
||||||
|
])
|
||||||
|
|
||||||
|
def __init__(self, board, costs, start, end):
|
||||||
|
self.board = board
|
||||||
|
self.costs = costs
|
||||||
|
self.start_state = start
|
||||||
|
self.end_state = end
|
||||||
|
self.n_expands = 0
|
||||||
|
|
||||||
|
def get_start_node(self):
|
||||||
|
return Node(None, self.start_state, None, 0, 0)
|
||||||
|
|
||||||
|
def get_end_node(self):
|
||||||
|
return Node(None, self.end_state, None, 0, 0)
|
||||||
|
|
||||||
|
def is_end(self, node):
|
||||||
|
return node.state == self.end_state
|
||||||
|
|
||||||
|
def action_cost(self, state, action):
|
||||||
|
# for the MazeProblem, the cost of any action
|
||||||
|
# is stored at the coordinates of the successor state,
|
||||||
|
# and represents the cost of 'stepping onto' this
|
||||||
|
# position on the board
|
||||||
|
sx, sy = self.__delta_state(state, action)
|
||||||
|
return self.costs[sx, sy]
|
||||||
|
|
||||||
|
def successor(self, node, action):
|
||||||
|
# determine the next state
|
||||||
|
successor_state = self.__delta_state(node.state, action)
|
||||||
|
if successor_state is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# determine what it would cost to take this action in this state
|
||||||
|
cost = self.action_cost(node.state, action)
|
||||||
|
|
||||||
|
# add the next state to the list of successor nodes
|
||||||
|
return Node(
|
||||||
|
node,
|
||||||
|
successor_state,
|
||||||
|
action,
|
||||||
|
node.cost + cost,
|
||||||
|
node.depth + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_number_of_expanded_nodes(self):
|
||||||
|
return self.n_expands
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.n_expands = 0
|
||||||
|
|
||||||
|
def successors(self, node):
|
||||||
|
self.n_expands += 1
|
||||||
|
successor_nodes = []
|
||||||
|
for action in self.ACTIONS_DELTA.keys():
|
||||||
|
succ = self.successor(node, action)
|
||||||
|
if succ is not None and succ != node:
|
||||||
|
successor_nodes.append(succ)
|
||||||
|
return successor_nodes
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
return json.dumps(dict(
|
||||||
|
type=self.__class__.__name__,
|
||||||
|
board=self.board.tolist(),
|
||||||
|
costs=self.costs.tolist(),
|
||||||
|
start_state=self.start_state,
|
||||||
|
end_state=self.end_state
|
||||||
|
))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def draw_nodes(fig, ax, name, node_collection, color, marker):
|
||||||
|
states = np.array([node.state for node in node_collection])
|
||||||
|
if len(states) > 0:
|
||||||
|
ax.scatter(states[:, 0], states[:, 1], color=color, label=name, marker=marker)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def plot_nodes(fig, ax, nodes):
|
||||||
|
if len(nodes) > 0:
|
||||||
|
if len(nodes[0]) == 3:
|
||||||
|
for (name, marker, node_collection), color in zip(nodes, TABLEAU_COLORS):
|
||||||
|
if len(node_collection) > 0:
|
||||||
|
Simple2DProblem.draw_nodes(fig, ax, name, node_collection, color, marker)
|
||||||
|
else:
|
||||||
|
for name, marker, node_collection, color in nodes:
|
||||||
|
if len(node_collection) > 0:
|
||||||
|
Simple2DProblem.draw_nodes(fig, ax, name, node_collection, color, marker)
|
||||||
|
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(0.5, -0.03),
|
||||||
|
loc='upper center',
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_sequences(self, fig, ax, sequences):
|
||||||
|
start_node = self.get_start_node()
|
||||||
|
for (name, action_sequence), color in zip(sequences, XKCD_COLORS):
|
||||||
|
self.draw_path(fig, ax, name, start_node, action_sequence, color)
|
||||||
|
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(0.5, -0.03),
|
||||||
|
loc='upper center',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_path(self, fig, ax, name, start_node, action_sequence, color):
|
||||||
|
current = start_node
|
||||||
|
xs = [current.state[0]]
|
||||||
|
ys = [current.state[1]]
|
||||||
|
us = [0]
|
||||||
|
vs = [0]
|
||||||
|
|
||||||
|
length = len(action_sequence)
|
||||||
|
cost = 0
|
||||||
|
costs = [0] * length
|
||||||
|
for i, action in enumerate(action_sequence):
|
||||||
|
costs[i] = current.cost
|
||||||
|
xs.append(current.state[0])
|
||||||
|
ys.append(current.state[1])
|
||||||
|
current = self.successor(current, action)
|
||||||
|
dx, dy = self.ACTIONS_DELTA[action]
|
||||||
|
us.append(dx)
|
||||||
|
vs.append(-dy)
|
||||||
|
cost = current.cost
|
||||||
|
|
||||||
|
quiv = ax.quiver(
|
||||||
|
xs, ys, us, vs,
|
||||||
|
color=color,
|
||||||
|
label='{} l:{} c:{}'.format(name, length, cost),
|
||||||
|
scale_units='xy',
|
||||||
|
units='xy',
|
||||||
|
scale=1,
|
||||||
|
headwidth=1,
|
||||||
|
headlength=1,
|
||||||
|
linewidth=1,
|
||||||
|
picker=5
|
||||||
|
)
|
||||||
|
return quiv
|
||||||
|
|
||||||
|
def plot_field_and_costs_aux(self, fig, show_coordinates, show_grid,
|
||||||
|
field_ax=None, costs_ax=None):
|
||||||
|
|
||||||
|
if field_ax is None:
|
||||||
|
ax = field_ax = plt.subplot(121)
|
||||||
|
else:
|
||||||
|
ax = field_ax
|
||||||
|
|
||||||
|
ax.set_title('The field')
|
||||||
|
im = ax.imshow(self.board.T, cmap='gray_r')
|
||||||
|
|
||||||
|
divider = make_axes_locatable(ax)
|
||||||
|
cax = divider.append_axes('right', size='5%', pad=0)
|
||||||
|
cbar = fig.colorbar(im, cax=cax, orientation='vertical')
|
||||||
|
cbar.set_ticks([0, 1])
|
||||||
|
cbar.set_ticklabels([0, 1])
|
||||||
|
|
||||||
|
if costs_ax is None:
|
||||||
|
ax = costs_ax = plt.subplot(122, sharex=ax, sharey=ax)
|
||||||
|
else:
|
||||||
|
ax = costs_ax
|
||||||
|
|
||||||
|
ax.set_title('The costs (for stepping on a tile)')
|
||||||
|
im = ax.imshow(self.costs.T, cmap='viridis')
|
||||||
|
divider = make_axes_locatable(ax)
|
||||||
|
cax = divider.append_axes('right', size='5%', pad=0)
|
||||||
|
cbar = fig.colorbar(im, cax=cax, orientation='vertical')
|
||||||
|
ticks = np.arange(self.costs.min(), self.costs.max() + 1)
|
||||||
|
cbar.set_ticks(ticks)
|
||||||
|
cbar.set_ticklabels(ticks)
|
||||||
|
|
||||||
|
for ax in [field_ax, costs_ax]:
|
||||||
|
ax.tick_params(
|
||||||
|
top=show_coordinates,
|
||||||
|
left=show_coordinates,
|
||||||
|
labelleft=show_coordinates,
|
||||||
|
labeltop=show_coordinates,
|
||||||
|
right=False,
|
||||||
|
bottom=False,
|
||||||
|
labelbottom=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Major ticks
|
||||||
|
s = self.board.shape[0]
|
||||||
|
ax.set_xticks(np.arange(0, s, 1))
|
||||||
|
ax.set_yticks(np.arange(0, s, 1))
|
||||||
|
|
||||||
|
# Minor ticks
|
||||||
|
ax.set_xticks(np.arange(-.5, s, 1), minor=True)
|
||||||
|
ax.set_yticks(np.arange(-.5, s, 1), minor=True)
|
||||||
|
|
||||||
|
if show_grid:
|
||||||
|
for color, ax in zip(['m', 'w'], [field_ax, costs_ax]):
|
||||||
|
# Gridlines based on minor ticks
|
||||||
|
ax.grid(which='minor', color=color, linestyle='-', linewidth=1)
|
||||||
|
|
||||||
|
return field_ax, costs_ax
|
||||||
|
|
||||||
|
def visualize(self, sequences=None, show_coordinates=False, show_grid=False, plot_filename=None):
|
||||||
|
|
||||||
|
nodes = [
|
||||||
|
('start', 'o', [self.get_start_node()]),
|
||||||
|
('end', 'o', [self.get_end_node()])
|
||||||
|
]
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(10, 7))
|
||||||
|
field_ax, costs_ax = self.plot_field_and_costs_aux(fig, show_coordinates, show_grid)
|
||||||
|
if sequences is not None and len(sequences) > 0:
|
||||||
|
self.plot_sequences(fig, field_ax, sequences)
|
||||||
|
self.plot_sequences(fig, costs_ax, sequences)
|
||||||
|
|
||||||
|
if nodes is not None and len(nodes) > 0:
|
||||||
|
Simple2DProblem.plot_nodes(fig, field_ax, nodes)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
if plot_filename is not None:
|
||||||
|
plt.savefig(plot_filename)
|
||||||
|
plt.close(fig)
|
||||||
|
else:
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(jsonstring):
|
||||||
|
data = json.loads(jsonstring)
|
||||||
|
return Simple2DProblem(
|
||||||
|
np.array(data['board']),
|
||||||
|
np.array(data['costs']),
|
||||||
|
tuple(data['start_state']),
|
||||||
|
tuple(data['end_state'])
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(data):
|
||||||
|
return Simple2DProblem(
|
||||||
|
np.array(data['board']),
|
||||||
|
np.array(data['costs']),
|
||||||
|
tuple(data['start_state']),
|
||||||
|
tuple(data['end_state'])
|
||||||
|
)
|
||||||
|
|
||||||
|
def __delta_state(self, state, action):
|
||||||
|
# the old state's coordinates
|
||||||
|
x, y = state
|
||||||
|
|
||||||
|
# the deltas for each coordinates
|
||||||
|
dx, dy = self.ACTIONS_DELTA[action]
|
||||||
|
|
||||||
|
# compute the coordinates of the next state
|
||||||
|
sx = x + dx
|
||||||
|
sy = y + dy
|
||||||
|
|
||||||
|
if self.__on_board(sx, sy) and self.__walkable(sx, sy):
|
||||||
|
# (sx, sy) is a *valid* state if it is on the board
|
||||||
|
# and there is no wall where we want to go
|
||||||
|
return sx, sy
|
||||||
|
else:
|
||||||
|
# EIEIEIEIEI. up until assignment 1, this returned None :/
|
||||||
|
# this had no consequences on the correctness of the algorithms,
|
||||||
|
# but the explanations, and the self-edges were wrong
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
def __on_board(self, x, y):
|
||||||
|
size = len(self.board) # all boards are quadratic
|
||||||
|
return x >= 0 and x < size and y >= 0 and y < size
|
||||||
|
|
||||||
|
def __walkable(self, x, y):
|
||||||
|
return self.board[x, y] != enc.WALL
|
||||||
92
pig_lite/problem/base.py
Normal file
92
pig_lite/problem/base.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import hashlib
|
||||||
|
|
||||||
|
class Node(object):
|
||||||
|
def __init__(self, parent, state, action, cost, depth):
|
||||||
|
self.parent = parent
|
||||||
|
self.state = state
|
||||||
|
self.action = action
|
||||||
|
self.cost = cost
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
def key(self):
|
||||||
|
# if state is composed of other stuff (dict, set, ...)
|
||||||
|
# make it a tuple containing hashable datatypes
|
||||||
|
# (this is supposed to be overridden by subclasses)
|
||||||
|
return self.state
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.key())
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if type(self) == type(other):
|
||||||
|
return self.key() == other.key()
|
||||||
|
raise ValueError('cannot simply compare two different node types')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'Node(id:{}, parent:{}, state:{}, action:{}, cost:{}, depth:{})'.format(
|
||||||
|
id(self),
|
||||||
|
id(self.parent),
|
||||||
|
self.state,
|
||||||
|
self.action,
|
||||||
|
self.cost,
|
||||||
|
self.depth
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_action_sequence(self):
|
||||||
|
current = self
|
||||||
|
reverse_sequence = []
|
||||||
|
while current.parent is not None:
|
||||||
|
reverse_sequence.append(current.action)
|
||||||
|
current = current.parent
|
||||||
|
return list(reversed(reverse_sequence))
|
||||||
|
|
||||||
|
def get_action_sequence_hash(self):
|
||||||
|
action_sequence = self.get_action_sequence()
|
||||||
|
action_sequence_as_str = ','.join(map(str, action_sequence))
|
||||||
|
action_sequence_hash = hashlib.sha256(action_sequence_as_str.encode('UTF-8')).hexdigest() # should solution node return hashcode?
|
||||||
|
return action_sequence_hash
|
||||||
|
|
||||||
|
def pretty_print(self):
|
||||||
|
print(f"state {self.state} was reached following the sequence {self.get_action_sequence()} (cost: {self.cost}, depth: {self.depth})")
|
||||||
|
|
||||||
|
|
||||||
|
class Problem(object):
|
||||||
|
def get_number_of_expanded_nodes(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_start_node(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_end_node(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def is_end(self, node):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def action_cost(self, state, action):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def successors(self, node):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def visualize(self, **kwargs):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_action_sequence(self, end: Node):
|
||||||
|
if end is None:
|
||||||
|
return list()
|
||||||
|
return end.get_action_sequence()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(jsonstring):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_minimum_problem_size():
|
||||||
|
raise NotImplementedError()
|
||||||
529
pig_lite/problem/simple_2d.py
Normal file
529
pig_lite/problem/simple_2d.py
Normal file
@@ -0,0 +1,529 @@
|
|||||||
|
from pig_lite.problem.base import Problem, Node
|
||||||
|
from pig_lite.instance_generation import enc
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from collections import OrderedDict
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||||
|
from matplotlib.colors import TABLEAU_COLORS, XKCD_COLORS
|
||||||
|
|
||||||
|
class BaseLevel():
|
||||||
|
def __init__(self, rng, size) -> None:
|
||||||
|
self.rng = rng
|
||||||
|
self.size = size
|
||||||
|
self.field = None
|
||||||
|
self.costs = None
|
||||||
|
self.start = None
|
||||||
|
self.end = None
|
||||||
|
|
||||||
|
self.initialize_level()
|
||||||
|
|
||||||
|
def initialize_level(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_field(self):
|
||||||
|
return self.field
|
||||||
|
|
||||||
|
def get_costs(self):
|
||||||
|
return self.costs
|
||||||
|
|
||||||
|
def get_start(self):
|
||||||
|
return self.start
|
||||||
|
|
||||||
|
def get_end(self):
|
||||||
|
return self.end
|
||||||
|
|
||||||
|
|
||||||
|
class MazeLevel(BaseLevel):
|
||||||
|
# this method generates a random maze according to prim's randomized
|
||||||
|
# algorithm
|
||||||
|
# http://en.wikipedia.org/wiki/Maze_generation_algorithm#Randomized_Prim.27s_algorithm
|
||||||
|
|
||||||
|
def __init__(self, rng, size):
|
||||||
|
super().__init__(rng, size)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_level(self):
|
||||||
|
|
||||||
|
self.field = np.full((self.size, self.size), enc.WALL, dtype=np.int8)
|
||||||
|
self.costs = self.rng.randint(1, 5, self.field.shape, dtype=np.int8)
|
||||||
|
|
||||||
|
self.start = (0, 0)
|
||||||
|
|
||||||
|
self.deltas = [
|
||||||
|
(0, 1),
|
||||||
|
(0, -1),
|
||||||
|
(1, 0),
|
||||||
|
(-1, 0)
|
||||||
|
]
|
||||||
|
self.random_walk()
|
||||||
|
end = np.where(self.field == enc.SPACE)
|
||||||
|
self.end = (int(end[0][-1]), int(end[1][-1]))
|
||||||
|
|
||||||
|
self.replace_walls_with_high_cost_tiles()
|
||||||
|
|
||||||
|
def replace_walls_with_high_cost_tiles(self):
|
||||||
|
# select only coordinates of walls
|
||||||
|
walls = np.where(self.field == enc.WALL)
|
||||||
|
|
||||||
|
n_walls = len(walls[0])
|
||||||
|
|
||||||
|
# replace about a tenth of the walls...
|
||||||
|
to_replace = self.rng.randint(0, n_walls, n_walls // 9)
|
||||||
|
|
||||||
|
# ... with space, but very *costly* space (it's trap!)
|
||||||
|
for ri in to_replace:
|
||||||
|
x, y = walls[0][ri], walls[1][ri]
|
||||||
|
self.field[x, y] = enc.SPACE
|
||||||
|
self.costs[x, y] = 9
|
||||||
|
|
||||||
|
def random_walk(self):
|
||||||
|
frontier = list()
|
||||||
|
|
||||||
|
sx, sy = self.start
|
||||||
|
self.field[sx, sy] = enc.SPACE
|
||||||
|
frontier.extend(self.get_walls(self.start))
|
||||||
|
|
||||||
|
while len(frontier) > 0:
|
||||||
|
current, opposing = frontier[self.rng.randint(len(frontier))]
|
||||||
|
|
||||||
|
cx, cy = current
|
||||||
|
ox, oy = opposing
|
||||||
|
if self.field[ox, oy] == enc.WALL:
|
||||||
|
self.field[cx, cy] = enc.SPACE
|
||||||
|
self.field[ox, oy] = enc.SPACE
|
||||||
|
frontier.extend(self.get_walls(opposing))
|
||||||
|
else:
|
||||||
|
frontier.remove((current, opposing))
|
||||||
|
|
||||||
|
def in_bounds(self, position):
|
||||||
|
x, y = position
|
||||||
|
return x >= 0 and y >= 0 and x < self.size and y < self.size
|
||||||
|
|
||||||
|
def get_walls(self, position):
|
||||||
|
walls = []
|
||||||
|
px, py = position
|
||||||
|
for dx, dy in self.deltas:
|
||||||
|
cx = px + dx
|
||||||
|
cy = py + dy
|
||||||
|
current = (cx, cy)
|
||||||
|
|
||||||
|
ox = px + 2 * dx
|
||||||
|
oy = py + 2 * dy
|
||||||
|
opposing = (ox, oy)
|
||||||
|
|
||||||
|
if (self.in_bounds(current) and self.field[cx, cy] == enc.WALL and self.in_bounds(opposing)):
|
||||||
|
walls.append((current, opposing))
|
||||||
|
return walls
|
||||||
|
|
||||||
|
|
||||||
|
# this is code taken from
|
||||||
|
# https://github.com/dandrino/terrain-erosion-3-ways/blob/master/util.py
|
||||||
|
# Copyright (c) 2018 Daniel Andrino
|
||||||
|
# (project is MIT licensed)
|
||||||
|
def fbm(shape, p, lower=-np.inf, upper=np.inf):
|
||||||
|
freqs = tuple(np.fft.fftfreq(n, d=1.0 / n) for n in shape)
|
||||||
|
freq_radial = np.hypot(*np.meshgrid(*freqs))
|
||||||
|
envelope = (np.power(freq_radial, p, where=freq_radial != 0) *
|
||||||
|
(freq_radial > lower) * (freq_radial < upper))
|
||||||
|
envelope[0][0] = 0.0
|
||||||
|
phase_noise = np.exp(2j * np.pi * np.random.rand(*shape))
|
||||||
|
return np.real(np.fft.ifft2(np.fft.fft2(phase_noise) * envelope))
|
||||||
|
|
||||||
|
|
||||||
|
class TerrainLevel(BaseLevel):
|
||||||
|
def __init__(self, rng, size):
|
||||||
|
super().__init__(rng, size)
|
||||||
|
|
||||||
|
def initialize_level(self):
|
||||||
|
|
||||||
|
self.field = np.full((self.size, self.size), enc.SPACE, dtype=np.int8)
|
||||||
|
|
||||||
|
self.costs = fbm(self.field.shape, -2)
|
||||||
|
self.costs -= self.costs.min()
|
||||||
|
self.costs /= self.costs.max()
|
||||||
|
self.costs *= 9
|
||||||
|
self.costs += 1
|
||||||
|
self.costs = self.costs.astype(int)
|
||||||
|
|
||||||
|
self.start = (0, 0)
|
||||||
|
self.end = (self.size - 1, self.size - 1)
|
||||||
|
|
||||||
|
x = 0
|
||||||
|
y = self.size - 1
|
||||||
|
for i in range(0, self.size):
|
||||||
|
self.field[x, y] = enc.WALL
|
||||||
|
x += 1
|
||||||
|
y -= 1
|
||||||
|
|
||||||
|
self.replace_one_or_more_walls()
|
||||||
|
|
||||||
|
def replace_one_or_more_walls(self):
|
||||||
|
# select only coordinates of walls
|
||||||
|
walls = np.where(self.field == enc.WALL)
|
||||||
|
n_walls = len(walls[0])
|
||||||
|
n_replace = self.rng.randint(1, max(2, n_walls // 5))
|
||||||
|
to_replace = self.rng.randint(0, n_walls, n_replace)
|
||||||
|
|
||||||
|
for ri in to_replace:
|
||||||
|
x, y = walls[0][ri], walls[1][ri]
|
||||||
|
self.field[x, y] = enc.SPACE
|
||||||
|
|
||||||
|
|
||||||
|
class RoomLevel(BaseLevel):
|
||||||
|
def __init__(self, rng, size):
|
||||||
|
super().__init__(rng, size)
|
||||||
|
|
||||||
|
def initialize_level(self):
|
||||||
|
self.field = np.full((self.size, self.size), enc.SPACE, dtype=np.int8)
|
||||||
|
self.costs = np.ones_like(self.field, dtype=np.float32)
|
||||||
|
|
||||||
|
k = 1
|
||||||
|
self.subdivide(self.field.view(), self.costs.view(), k, 0, 0)
|
||||||
|
|
||||||
|
# such a *crutch*!
|
||||||
|
# this 'repairs' dead ends. horrible stuff.
|
||||||
|
for x in range(1, self.size - 1):
|
||||||
|
for y in range(1, self.size - 1):
|
||||||
|
s = 0
|
||||||
|
s += self.field[x - 1, y]
|
||||||
|
s += self.field[x + 1, y]
|
||||||
|
s += self.field[x, y - 1]
|
||||||
|
s += self.field[x, y + 1]
|
||||||
|
if self.field[x, y] == enc.SPACE and s >= 3:
|
||||||
|
self.field[x - 1, y] = enc.SPACE
|
||||||
|
self.field[x + 1, y] = enc.SPACE
|
||||||
|
self.field[x, y - 1] = enc.SPACE
|
||||||
|
self.field[x, y + 1] = enc.SPACE
|
||||||
|
|
||||||
|
spaces = np.where(self.field == enc.SPACE)
|
||||||
|
n_spaces = len(spaces[0])
|
||||||
|
|
||||||
|
n_danger = self.rng.randint(3, 7)
|
||||||
|
dangers = self.rng.choice(range(n_spaces), n_danger, replace=False)
|
||||||
|
for di in dangers:
|
||||||
|
rx, ry = np.unravel_index(di, (self.size, self.size))
|
||||||
|
const = max(1., self.rng.randint(self.size // 5, self.size // 2))
|
||||||
|
for x in range(self.size):
|
||||||
|
for y in range(self.size):
|
||||||
|
distance = np.sqrt((rx - x) ** 2 + (ry - y) ** 2)
|
||||||
|
self.costs[x, y] = self.costs[x, y] + (1. / (const + distance))
|
||||||
|
|
||||||
|
self.costs = self.costs - self.costs.min()
|
||||||
|
self.costs = self.costs / self.costs.max()
|
||||||
|
self.costs = self.costs * 9
|
||||||
|
self.costs = self.costs + 1
|
||||||
|
self.costs = self.costs.astype(int)
|
||||||
|
|
||||||
|
start_choice = 0
|
||||||
|
end_choice = -1
|
||||||
|
|
||||||
|
self.start = (int(spaces[0][start_choice]), int(spaces[1][start_choice]))
|
||||||
|
self.end = (int(spaces[0][end_choice]), int(spaces[1][end_choice]))
|
||||||
|
|
||||||
|
if self.start == self.end:
|
||||||
|
raise RuntimeError('should never happen')
|
||||||
|
|
||||||
|
def subdivide(self, current, costs, k, d, previous_door):
|
||||||
|
w, h = current.shape
|
||||||
|
random_stop = self.rng.randint(0, 10) == 0 and d > 2
|
||||||
|
if w <= 2 * k + 1 or h <= 2 * k + 1 or random_stop:
|
||||||
|
return
|
||||||
|
|
||||||
|
split = previous_door
|
||||||
|
while split == previous_door:
|
||||||
|
split = self.rng.randint(k, w - k)
|
||||||
|
current[split, :] = enc.WALL
|
||||||
|
door = self.rng.randint(k, h - k)
|
||||||
|
current[split, door] = enc.SPACE
|
||||||
|
|
||||||
|
self.subdivide(
|
||||||
|
current[:split, :].T,
|
||||||
|
costs[:split, :].T,
|
||||||
|
k,
|
||||||
|
d + 1,
|
||||||
|
door
|
||||||
|
)
|
||||||
|
self.subdivide(
|
||||||
|
current[split + 1:, :].T,
|
||||||
|
costs[split + 1:, :].T,
|
||||||
|
k,
|
||||||
|
d + 1,
|
||||||
|
door
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Simple2DProblem(Problem):
|
||||||
|
"""
|
||||||
|
the states are the positions on the board that the agent can walk on
|
||||||
|
"""
|
||||||
|
|
||||||
|
ACTIONS_DELTA = OrderedDict([
|
||||||
|
('R', (+1, 0)),
|
||||||
|
('U', (0, -1)),
|
||||||
|
('D', (0, +1)),
|
||||||
|
('L', (-1, 0)),
|
||||||
|
])
|
||||||
|
|
||||||
|
def __init__(self, board, costs, start, end):
|
||||||
|
self.board = board
|
||||||
|
self.costs = costs
|
||||||
|
self.start_state = start
|
||||||
|
self.end_state = end
|
||||||
|
self.n_expands = 0
|
||||||
|
|
||||||
|
def get_start_node(self):
|
||||||
|
return Node(None, self.start_state, None, 0, 0)
|
||||||
|
|
||||||
|
def get_end_node(self):
|
||||||
|
return Node(None, self.end_state, None, 0, 0)
|
||||||
|
|
||||||
|
def is_end(self, node):
|
||||||
|
return node.state == self.end_state
|
||||||
|
|
||||||
|
def action_cost(self, state, action):
|
||||||
|
# for the MazeProblem, the cost of any action
|
||||||
|
# is stored at the coordinates of the successor state,
|
||||||
|
# and represents the cost of 'stepping onto' this
|
||||||
|
# position on the board
|
||||||
|
sx, sy = self.__delta_state(state, action)
|
||||||
|
return self.costs[sx, sy]
|
||||||
|
|
||||||
|
def successor(self, node, action):
|
||||||
|
# determine the next state
|
||||||
|
successor_state = self.__delta_state(node.state, action)
|
||||||
|
if successor_state is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# determine what it would cost to take this action in this state
|
||||||
|
cost = self.action_cost(node.state, action)
|
||||||
|
|
||||||
|
# add the next state to the list of successor nodes
|
||||||
|
return Node(
|
||||||
|
node,
|
||||||
|
successor_state,
|
||||||
|
action,
|
||||||
|
node.cost + cost,
|
||||||
|
node.depth + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_number_of_expanded_nodes(self):
|
||||||
|
return self.n_expands
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.n_expands = 0
|
||||||
|
|
||||||
|
def successors(self, node):
|
||||||
|
self.n_expands += 1
|
||||||
|
successor_nodes = []
|
||||||
|
for action in self.ACTIONS_DELTA.keys():
|
||||||
|
succ = self.successor(node, action)
|
||||||
|
if succ is not None and succ != node:
|
||||||
|
successor_nodes.append(succ)
|
||||||
|
return successor_nodes
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
return json.dumps(dict(
|
||||||
|
type=self.__class__.__name__,
|
||||||
|
board=self.board.tolist(),
|
||||||
|
costs=self.costs.tolist(),
|
||||||
|
start_state=self.start_state,
|
||||||
|
end_state=self.end_state
|
||||||
|
))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def draw_nodes(fig, ax, name, node_collection, color, marker):
|
||||||
|
states = np.array([node.state for node in node_collection])
|
||||||
|
if len(states) > 0:
|
||||||
|
ax.scatter(states[:, 0], states[:, 1], color=color, label=name, marker=marker)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def plot_nodes(fig, ax, nodes):
|
||||||
|
if len(nodes) > 0:
|
||||||
|
if len(nodes[0]) == 3:
|
||||||
|
for (name, marker, node_collection), color in zip(nodes, TABLEAU_COLORS):
|
||||||
|
if len(node_collection) > 0:
|
||||||
|
Simple2DProblem.draw_nodes(fig, ax, name, node_collection, color, marker)
|
||||||
|
else:
|
||||||
|
for name, marker, node_collection, color in nodes:
|
||||||
|
if len(node_collection) > 0:
|
||||||
|
Simple2DProblem.draw_nodes(fig, ax, name, node_collection, color, marker)
|
||||||
|
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(0.5, -0.03),
|
||||||
|
loc='upper center',
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_sequences(self, fig, ax, sequences):
|
||||||
|
start_node = self.get_start_node()
|
||||||
|
for (name, action_sequence), color in zip(sequences, XKCD_COLORS):
|
||||||
|
self.draw_path(fig, ax, name, start_node, action_sequence, color)
|
||||||
|
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(0.5, -0.03),
|
||||||
|
loc='upper center',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def draw_path(self, fig, ax, name, start_node, action_sequence, color):
|
||||||
|
current = start_node
|
||||||
|
xs = [current.state[0]]
|
||||||
|
ys = [current.state[1]]
|
||||||
|
us = [0]
|
||||||
|
vs = [0]
|
||||||
|
|
||||||
|
length = len(action_sequence)
|
||||||
|
cost = 0
|
||||||
|
costs = [0] * length
|
||||||
|
for i, action in enumerate(action_sequence):
|
||||||
|
costs[i] = current.cost
|
||||||
|
xs.append(current.state[0])
|
||||||
|
ys.append(current.state[1])
|
||||||
|
current = self.successor(current, action)
|
||||||
|
dx, dy = self.ACTIONS_DELTA[action]
|
||||||
|
us.append(dx)
|
||||||
|
vs.append(-dy)
|
||||||
|
cost = current.cost
|
||||||
|
|
||||||
|
quiv = ax.quiver(
|
||||||
|
xs, ys, us, vs,
|
||||||
|
color=color,
|
||||||
|
label='{} l:{} c:{}'.format(name, length, cost),
|
||||||
|
scale_units='xy',
|
||||||
|
units='xy',
|
||||||
|
scale=1,
|
||||||
|
headwidth=1,
|
||||||
|
headlength=1,
|
||||||
|
linewidth=1,
|
||||||
|
picker=5
|
||||||
|
)
|
||||||
|
return quiv
|
||||||
|
|
||||||
|
def plot_field_and_costs_aux(self, fig, show_coordinates, show_grid,
|
||||||
|
field_ax=None, costs_ax=None):
|
||||||
|
|
||||||
|
if field_ax is None:
|
||||||
|
ax = field_ax = plt.subplot(121)
|
||||||
|
else:
|
||||||
|
ax = field_ax
|
||||||
|
|
||||||
|
ax.set_title('The field')
|
||||||
|
im = ax.imshow(self.board.T, cmap='gray_r')
|
||||||
|
|
||||||
|
divider = make_axes_locatable(ax)
|
||||||
|
cax = divider.append_axes('right', size='5%', pad=0)
|
||||||
|
cbar = fig.colorbar(im, cax=cax, orientation='vertical')
|
||||||
|
cbar.set_ticks([0, 1])
|
||||||
|
cbar.set_ticklabels([0, 1])
|
||||||
|
|
||||||
|
if costs_ax is None:
|
||||||
|
ax = costs_ax = plt.subplot(122, sharex=ax, sharey=ax)
|
||||||
|
else:
|
||||||
|
ax = costs_ax
|
||||||
|
|
||||||
|
ax.set_title('The costs (for stepping on a tile)')
|
||||||
|
im = ax.imshow(self.costs.T, cmap='viridis')
|
||||||
|
divider = make_axes_locatable(ax)
|
||||||
|
cax = divider.append_axes('right', size='5%', pad=0)
|
||||||
|
cbar = fig.colorbar(im, cax=cax, orientation='vertical')
|
||||||
|
ticks = np.arange(self.costs.min(), self.costs.max() + 1)
|
||||||
|
cbar.set_ticks(ticks)
|
||||||
|
cbar.set_ticklabels(ticks)
|
||||||
|
|
||||||
|
for ax in [field_ax, costs_ax]:
|
||||||
|
ax.tick_params(
|
||||||
|
top=show_coordinates,
|
||||||
|
left=show_coordinates,
|
||||||
|
labelleft=show_coordinates,
|
||||||
|
labeltop=show_coordinates,
|
||||||
|
right=False,
|
||||||
|
bottom=False,
|
||||||
|
labelbottom=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Major ticks
|
||||||
|
s = self.board.shape[0]
|
||||||
|
ax.set_xticks(np.arange(0, s, 1))
|
||||||
|
ax.set_yticks(np.arange(0, s, 1))
|
||||||
|
|
||||||
|
# Minor ticks
|
||||||
|
ax.set_xticks(np.arange(-.5, s, 1), minor=True)
|
||||||
|
ax.set_yticks(np.arange(-.5, s, 1), minor=True)
|
||||||
|
|
||||||
|
if show_grid:
|
||||||
|
for color, ax in zip(['m', 'w'], [field_ax, costs_ax]):
|
||||||
|
# Gridlines based on minor ticks
|
||||||
|
ax.grid(which='minor', color=color, linestyle='-', linewidth=1)
|
||||||
|
|
||||||
|
return field_ax, costs_ax
|
||||||
|
|
||||||
|
def visualize(self, sequences=None, show_coordinates=False, show_grid=False, plot_filename=None):
|
||||||
|
|
||||||
|
nodes = [
|
||||||
|
('start', 'o', [self.get_start_node()]),
|
||||||
|
('end', 'o', [self.get_end_node()])
|
||||||
|
]
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(10, 7))
|
||||||
|
field_ax, costs_ax = self.plot_field_and_costs_aux(fig, show_coordinates, show_grid)
|
||||||
|
if sequences is not None and len(sequences) > 0:
|
||||||
|
self.plot_sequences(fig, field_ax, sequences)
|
||||||
|
self.plot_sequences(fig, costs_ax, sequences)
|
||||||
|
|
||||||
|
if nodes is not None and len(nodes) > 0:
|
||||||
|
Simple2DProblem.plot_nodes(fig, field_ax, nodes)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
if plot_filename is not None:
|
||||||
|
plt.savefig(plot_filename)
|
||||||
|
plt.close(fig)
|
||||||
|
else:
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(jsonstring):
|
||||||
|
data = json.loads(jsonstring)
|
||||||
|
return Simple2DProblem(
|
||||||
|
np.array(data['board']),
|
||||||
|
np.array(data['costs']),
|
||||||
|
tuple(data['start_state']),
|
||||||
|
tuple(data['end_state'])
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(data):
|
||||||
|
return Simple2DProblem(
|
||||||
|
np.array(data['board']),
|
||||||
|
np.array(data['costs']),
|
||||||
|
tuple(data['start_state']),
|
||||||
|
tuple(data['end_state'])
|
||||||
|
)
|
||||||
|
|
||||||
|
def __delta_state(self, state, action):
|
||||||
|
# the old state's coordinates
|
||||||
|
x, y = state
|
||||||
|
|
||||||
|
# the deltas for each coordinates
|
||||||
|
dx, dy = self.ACTIONS_DELTA[action]
|
||||||
|
|
||||||
|
# compute the coordinates of the next state
|
||||||
|
sx = x + dx
|
||||||
|
sy = y + dy
|
||||||
|
|
||||||
|
if self.__on_board(sx, sy) and self.__walkable(sx, sy):
|
||||||
|
# (sx, sy) is a *valid* state if it is on the board
|
||||||
|
# and there is no wall where we want to go
|
||||||
|
return sx, sy
|
||||||
|
else:
|
||||||
|
# EIEIEIEIEI. up until assignment 1, this returned None :/
|
||||||
|
# this had no consequences on the correctness of the algorithms,
|
||||||
|
# but the explanations, and the self-edges were wrong
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
def __on_board(self, x, y):
|
||||||
|
size = len(self.board) # all boards are quadratic
|
||||||
|
return x >= 0 and x < size and y >= 0 and y < size
|
||||||
|
|
||||||
|
def __walkable(self, x, y):
|
||||||
|
return self.board[x, y] != enc.WALL
|
||||||
Reference in New Issue
Block a user