361 lines
12 KiB
Python
361 lines
12 KiB
Python
import json
|
|
import numpy as np
|
|
|
|
from pig_lite.environment.base import Environment
|
|
|
|
DELTAS = [
|
|
(-1, 0),
|
|
(+1, 0),
|
|
(0, -1),
|
|
(0, +1)
|
|
]
|
|
NAMES = [
|
|
'left',
|
|
'right',
|
|
'up',
|
|
'down'
|
|
]
|
|
|
|
def sample(rng, elements):
|
|
""" Samples an element of `elements` randomly. """
|
|
csp = np.cumsum([elm[0] for elm in elements])
|
|
idx = np.argmax(csp > rng.uniform(0, 1))
|
|
return elements[idx]
|
|
|
|
|
|
class Gridworld(Environment):
|
|
def __init__(self, seed, dones, rewards, starts):
|
|
self.seed = seed
|
|
self.rng = np.random.RandomState(seed)
|
|
self.dones = dones
|
|
self.rewards = rewards
|
|
self.starts = starts
|
|
|
|
self.__compute_P()
|
|
|
|
def reset(self):
|
|
""" Resets the environment of this gridworld to a randomly sampled start state. """
|
|
_, self.state = sample(self.rng, self.starts)
|
|
return self.state
|
|
|
|
def step(self, action):
|
|
""" Performs the action on the gridworld, where next state of environment is sampled based on self.P. """
|
|
_, self.state, reward, done = sample(self.rng, self.P[self.state][action])
|
|
return self.state, reward, done
|
|
|
|
def get_n_actions(self):
|
|
""" Returns the number of actions available in this gridworld. """
|
|
return 4
|
|
|
|
def get_n_states(self):
|
|
""" Returns the number of states available in this gridworld. """
|
|
return np.prod(self.dones.shape)
|
|
|
|
def get_gamma(self):
|
|
""" Returns discount factor gamma for this gridworld. """
|
|
return 0.99
|
|
|
|
def __compute_P(self):
|
|
""" Computes and stores the transitions for this gridworld. """
|
|
w, h = self.dones.shape
|
|
|
|
def inbounds(i, j):
|
|
""" Checks whether coordinates i and j are within the grid. """
|
|
return i >= 0 and j >= 0 and i < w and j < h
|
|
|
|
self.P = dict()
|
|
for i in range(0, w):
|
|
for j in range(0, h):
|
|
state = j * w + i
|
|
self.P[state] = dict()
|
|
|
|
if self.dones[i, j]:
|
|
for action in range(self.get_n_actions()):
|
|
# make it absorbing
|
|
self.P[state][action] = [(1, state, 0, True)]
|
|
else:
|
|
for action, (dx, dy) in enumerate(DELTAS):
|
|
ortho_dir_probs = [
|
|
(0.8, dx, dy),
|
|
(0.1, dy, dx),
|
|
(0.1, -dy, -dx)
|
|
]
|
|
transitions = []
|
|
for p, di, dj in ortho_dir_probs:
|
|
ni = i + di
|
|
nj = j + dj
|
|
if inbounds(ni, nj):
|
|
# we move
|
|
sprime = nj * w + ni
|
|
done = self.dones[ni, nj]
|
|
reward = self.rewards[ni, nj]
|
|
transitions.append((p, sprime, reward, done))
|
|
else:
|
|
# stay in the same state, b/c we bounced
|
|
sprime = state
|
|
done = self.dones[i, j]
|
|
reward = self.rewards[i, j]
|
|
transitions.append((p, sprime, reward, done))
|
|
|
|
self.P[state][action] = transitions
|
|
|
|
def to_json(self):
|
|
""" Converts and stores this gridworld to a JSON file. """
|
|
return json.dumps(dict(
|
|
type=self.__class__.__name__,
|
|
seed=self.seed,
|
|
dones=self.dones.tolist(),
|
|
rewards=self.rewards.tolist(),
|
|
starts=self.starts.tolist()
|
|
))
|
|
|
|
@staticmethod
|
|
def from_json(jsonstring):
|
|
""" Loads given JSON file, and creates gridworld with information. """
|
|
data = json.loads(jsonstring)
|
|
return Gridworld(
|
|
data['seed'],
|
|
np.array(data['dones']),
|
|
np.array(data['rewards']),
|
|
np.array(data['starts'], dtype=np.int64),
|
|
)
|
|
|
|
@staticmethod
|
|
def from_dict(data):
|
|
""" Creates gridworld with information in given data-dictionary. """
|
|
return Gridworld(
|
|
data['seed'],
|
|
np.array(data['dones']),
|
|
np.array(data['rewards']),
|
|
np.array(data['starts'], dtype=np.int64),
|
|
)
|
|
|
|
@staticmethod
|
|
def get_random_instance(rng, size):
|
|
""" Given random generator and problem size, generates Gridworld instance. """
|
|
dones, rewards, starts = Gridworld.__generate(rng, size)
|
|
return Gridworld(rng.randint(0, 2 ** 31), dones, rewards, starts)
|
|
|
|
@staticmethod
|
|
def __generate(rng, size):
|
|
""" Helper function that retrieves dones, rewards, starts for Gridworld instance generation. """
|
|
dones = np.full((size, size), False, dtype=bool)
|
|
rewards = np.zeros((size, size), dtype=np.int8) - 1
|
|
|
|
coordinates = []
|
|
for i in range(1, size - 1):
|
|
for j in range(1, size - 1):
|
|
coordinates.append((i, j))
|
|
indices = np.arange(len(coordinates))
|
|
|
|
chosen = rng.choice(indices, max(1, len(indices) // 10), replace=False)
|
|
|
|
for c in chosen:
|
|
x, y = coordinates[c]
|
|
dones[x, y] = True
|
|
rewards[x, y] = -100
|
|
|
|
starts = np.array([[1, 0]])
|
|
dones[-1, -1] = True
|
|
rewards[-1, -1] = 100
|
|
|
|
return dones, rewards, starts
|
|
|
|
@staticmethod
|
|
def get_minimum_problem_size():
|
|
return 3
|
|
|
|
def visualize(self, outcome, coords=None, grid=None):
|
|
""" Visualisation function for gridworld; plots environment, policy, Q. """
|
|
policy = None
|
|
Q = None
|
|
V = None
|
|
if outcome is not None:
|
|
if outcome.policy is not None:
|
|
policy = outcome.policy
|
|
|
|
if outcome.V is not None:
|
|
V = outcome.V
|
|
|
|
if outcome.Q is not None:
|
|
Q = outcome.Q
|
|
|
|
self._plot_environment_and_policy(policy, V, Q, show_coordinates=coords, show_grid=grid)
|
|
|
|
def _plot_environment_and_policy(self, policy=None,V=None, Q=None, show_coordinates=False,
|
|
show_grid=False, plot_filename=None, debug_info=False):
|
|
""" Function that plots environment and policy. """
|
|
import matplotlib.pyplot as plt
|
|
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
|
|
dones_ax = axes[0, 0]
|
|
rewards_ax = axes[0, 1]
|
|
V_ax = axes[1, 0]
|
|
Q_ax = axes[1, 1]
|
|
|
|
dones_ax.set_title('Terminal States and Policy')
|
|
dones_ax.imshow(self.dones.T, cmap='gray_r', vmin=0, vmax=4)
|
|
|
|
rewards_ax.set_title('Immediate Rewards')
|
|
rewards_ax.imshow(self.rewards.T, cmap='RdBu_r', vmin=-25, vmax=25)
|
|
|
|
if len(policy) > 0:
|
|
self._plot_policy(dones_ax, policy)
|
|
|
|
w, h = self.dones.shape
|
|
V_array = V.reshape(self.dones.shape).T
|
|
V_ax.set_title('State Value Function $V(s)$')
|
|
r = max(1e-13, np.max(np.abs(V_array)))
|
|
V_ax.imshow(V_array.T, cmap='RdBu_r', vmin=-r, vmax=r)
|
|
|
|
if debug_info:
|
|
for s in range(len(V)):
|
|
sy, sx = divmod(s, w)
|
|
V_ax.text(sx, sy, f'{sx},{sy}:{s}',
|
|
color='w', fontdict=dict(size=6),
|
|
horizontalalignment='center', verticalalignment='center')
|
|
|
|
Q_ax.set_title('State Action Value Function $Q(s, a)$')
|
|
poly_patches_q_values = self._draw_Q(Q_ax, Q, debug_info)
|
|
|
|
def format_coord(x, y):
|
|
for poly_patch, q_value in poly_patches_q_values:
|
|
if poly_patch.contains_point(Q_ax.transData.transform((x, y))):
|
|
return f'x:{x:4.2f} y:{y:4.2f} {q_value}'
|
|
return f'x:{x:4.2f} y:{y:4.2f}'
|
|
|
|
Q_ax.format_coord = format_coord
|
|
|
|
for ax in [dones_ax, rewards_ax, V_ax, Q_ax]:
|
|
ax.tick_params(
|
|
top=show_coordinates,
|
|
left=show_coordinates,
|
|
labelleft=show_coordinates,
|
|
labeltop=show_coordinates,
|
|
right=False,
|
|
bottom=False,
|
|
labelbottom=False
|
|
)
|
|
|
|
# Major ticks
|
|
s = self.dones.shape[0]
|
|
ax.set_xticks(np.arange(0, s, 1))
|
|
ax.set_yticks(np.arange(0, s, 1))
|
|
|
|
# Minor ticks
|
|
ax.set_xticks(np.arange(-.5, s, 1), minor=True)
|
|
ax.set_yticks(np.arange(-.5, s, 1), minor=True)
|
|
|
|
if show_grid:
|
|
for color, ax in zip(['m', 'w', 'w'], [dones_ax, rewards_ax, V_ax]):
|
|
# Gridlines based on minor ticks
|
|
ax.grid(which='minor', color=color, linestyle='-', linewidth=1)
|
|
|
|
plt.tight_layout()
|
|
if plot_filename is not None:
|
|
plt.savefig(plot_filename)
|
|
plt.close(fig)
|
|
else:
|
|
plt.show()
|
|
|
|
def _plot_policy(self, ax, policy):
|
|
""" Function that plots policy. """
|
|
w, h = self.dones.shape
|
|
xs = np.arange(w)
|
|
ys = np.arange(h)
|
|
xx, yy = np.meshgrid(xs, ys)
|
|
|
|
# we need a quiver for each of the four action
|
|
quivers = list()
|
|
for a in range(self.get_n_actions()):
|
|
quivers.append(list())
|
|
|
|
# we parse the textual description of the lake
|
|
for s in range(self.get_n_states()):
|
|
y, x = divmod(s, w)
|
|
if self.dones[x, y]:
|
|
for a in range(self.get_n_actions()):
|
|
quivers[a].append((0., 0.))
|
|
else:
|
|
for a in range(self.get_n_actions()):
|
|
wdx, wdy = DELTAS[a]
|
|
corrected = np.array([wdx, -wdy])
|
|
quivers[a].append(corrected * policy[s, a])
|
|
|
|
# plot each quiver
|
|
for quiver in quivers:
|
|
q = np.array(quiver)
|
|
ax.quiver(xx, yy, q[:, 0], q[:, 1], units='xy', scale=1.5)
|
|
|
|
def _draw_Q(self, ax, Q, debug_info):
|
|
""" Function that draws Q. """
|
|
pattern = np.zeros(self.dones.shape)
|
|
ax.imshow(pattern, cmap='gray_r')
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.cm import ScalarMappable
|
|
from matplotlib.colors import Normalize
|
|
from matplotlib.patches import Rectangle, Polygon
|
|
w, h = self.dones.shape
|
|
|
|
r = max(1e-13, np.max(np.abs(Q)))
|
|
norm = Normalize(vmin=-r, vmax=r)
|
|
cmap = plt.get_cmap('RdBu_r')
|
|
sm = ScalarMappable(norm, cmap)
|
|
|
|
hover_polygons = []
|
|
for state in range(len(Q)):
|
|
qs = Q[state]
|
|
# print('qs', qs)
|
|
y, x = divmod(state, w)
|
|
if self.dones[x, y]:
|
|
continue
|
|
y += 0.5
|
|
x += 0.5
|
|
|
|
dx = 1
|
|
dy = 1
|
|
|
|
ulx = (x - 1) * dx
|
|
uly = (y - 1) * dy
|
|
|
|
rect = Rectangle(
|
|
xy=(ulx, uly),
|
|
width=dx,
|
|
height=dy,
|
|
edgecolor='k',
|
|
facecolor='none'
|
|
)
|
|
ax.add_artist(rect)
|
|
|
|
mx = (x - 1) * dx + dx / 2.
|
|
my = (y - 1) * dy + dy / 2.
|
|
|
|
ul = ulx, uly
|
|
ur = ulx + dx, uly
|
|
ll = ulx, uly + dy
|
|
lr = ulx + dx, uly + dy
|
|
m = mx, my
|
|
|
|
up = [ul, m, ur]
|
|
left = [ul, m, ll]
|
|
right = [ur, m, lr]
|
|
down = [ll, m, lr]
|
|
action_polys = [left, right, up, down]
|
|
for a, poly in enumerate(action_polys):
|
|
poly_patch = Polygon(
|
|
poly,
|
|
edgecolor='k',
|
|
linewidth=0.1,
|
|
facecolor=sm.to_rgba(qs[a])
|
|
)
|
|
if debug_info:
|
|
mmx = np.mean([x for x, y in poly])
|
|
mmy = np.mean([y for x, y in poly])
|
|
sss = '\n'.join(map(str, self.P[state][a]))
|
|
ax.text(mmx, mmy, f'{NAMES[a][0]}:{sss}',
|
|
fontdict=dict(size=5), horizontalalignment='center',
|
|
verticalalignment='center')
|
|
|
|
hover_polygons.append((poly_patch, f'{NAMES[a]}:{qs[a]:4.2f}'))
|
|
ax.add_artist(poly_patch)
|
|
return hover_polygons
|