first commit
This commit is contained in:
360
pig_lite/environment/gridworld.py
Normal file
360
pig_lite/environment/gridworld.py
Normal file
@@ -0,0 +1,360 @@
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from pig_lite.environment.base import Environment
|
||||
|
||||
DELTAS = [
|
||||
(-1, 0),
|
||||
(+1, 0),
|
||||
(0, -1),
|
||||
(0, +1)
|
||||
]
|
||||
NAMES = [
|
||||
'left',
|
||||
'right',
|
||||
'up',
|
||||
'down'
|
||||
]
|
||||
|
||||
def sample(rng, elements):
|
||||
""" Samples an element of `elements` randomly. """
|
||||
csp = np.cumsum([elm[0] for elm in elements])
|
||||
idx = np.argmax(csp > rng.uniform(0, 1))
|
||||
return elements[idx]
|
||||
|
||||
|
||||
class Gridworld(Environment):
|
||||
def __init__(self, seed, dones, rewards, starts):
|
||||
self.seed = seed
|
||||
self.rng = np.random.RandomState(seed)
|
||||
self.dones = dones
|
||||
self.rewards = rewards
|
||||
self.starts = starts
|
||||
|
||||
self.__compute_P()
|
||||
|
||||
def reset(self):
|
||||
""" Resets the environment of this gridworld to a randomly sampled start state. """
|
||||
_, self.state = sample(self.rng, self.starts)
|
||||
return self.state
|
||||
|
||||
def step(self, action):
|
||||
""" Performs the action on the gridworld, where next state of environment is sampled based on self.P. """
|
||||
_, self.state, reward, done = sample(self.rng, self.P[self.state][action])
|
||||
return self.state, reward, done
|
||||
|
||||
def get_n_actions(self):
|
||||
""" Returns the number of actions available in this gridworld. """
|
||||
return 4
|
||||
|
||||
def get_n_states(self):
|
||||
""" Returns the number of states available in this gridworld. """
|
||||
return np.prod(self.dones.shape)
|
||||
|
||||
def get_gamma(self):
|
||||
""" Returns discount factor gamma for this gridworld. """
|
||||
return 0.99
|
||||
|
||||
def __compute_P(self):
|
||||
""" Computes and stores the transitions for this gridworld. """
|
||||
w, h = self.dones.shape
|
||||
|
||||
def inbounds(i, j):
|
||||
""" Checks whether coordinates i and j are within the grid. """
|
||||
return i >= 0 and j >= 0 and i < w and j < h
|
||||
|
||||
self.P = dict()
|
||||
for i in range(0, w):
|
||||
for j in range(0, h):
|
||||
state = j * w + i
|
||||
self.P[state] = dict()
|
||||
|
||||
if self.dones[i, j]:
|
||||
for action in range(self.get_n_actions()):
|
||||
# make it absorbing
|
||||
self.P[state][action] = [(1, state, 0, True)]
|
||||
else:
|
||||
for action, (dx, dy) in enumerate(DELTAS):
|
||||
ortho_dir_probs = [
|
||||
(0.8, dx, dy),
|
||||
(0.1, dy, dx),
|
||||
(0.1, -dy, -dx)
|
||||
]
|
||||
transitions = []
|
||||
for p, di, dj in ortho_dir_probs:
|
||||
ni = i + di
|
||||
nj = j + dj
|
||||
if inbounds(ni, nj):
|
||||
# we move
|
||||
sprime = nj * w + ni
|
||||
done = self.dones[ni, nj]
|
||||
reward = self.rewards[ni, nj]
|
||||
transitions.append((p, sprime, reward, done))
|
||||
else:
|
||||
# stay in the same state, b/c we bounced
|
||||
sprime = state
|
||||
done = self.dones[i, j]
|
||||
reward = self.rewards[i, j]
|
||||
transitions.append((p, sprime, reward, done))
|
||||
|
||||
self.P[state][action] = transitions
|
||||
|
||||
def to_json(self):
|
||||
""" Converts and stores this gridworld to a JSON file. """
|
||||
return json.dumps(dict(
|
||||
type=self.__class__.__name__,
|
||||
seed=self.seed,
|
||||
dones=self.dones.tolist(),
|
||||
rewards=self.rewards.tolist(),
|
||||
starts=self.starts.tolist()
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def from_json(jsonstring):
|
||||
""" Loads given JSON file, and creates gridworld with information. """
|
||||
data = json.loads(jsonstring)
|
||||
return Gridworld(
|
||||
data['seed'],
|
||||
np.array(data['dones']),
|
||||
np.array(data['rewards']),
|
||||
np.array(data['starts'], dtype=np.int64),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data):
|
||||
""" Creates gridworld with information in given data-dictionary. """
|
||||
return Gridworld(
|
||||
data['seed'],
|
||||
np.array(data['dones']),
|
||||
np.array(data['rewards']),
|
||||
np.array(data['starts'], dtype=np.int64),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_random_instance(rng, size):
|
||||
""" Given random generator and problem size, generates Gridworld instance. """
|
||||
dones, rewards, starts = Gridworld.__generate(rng, size)
|
||||
return Gridworld(rng.randint(0, 2 ** 31), dones, rewards, starts)
|
||||
|
||||
@staticmethod
|
||||
def __generate(rng, size):
|
||||
""" Helper function that retrieves dones, rewards, starts for Gridworld instance generation. """
|
||||
dones = np.full((size, size), False, dtype=bool)
|
||||
rewards = np.zeros((size, size), dtype=np.int8) - 1
|
||||
|
||||
coordinates = []
|
||||
for i in range(1, size - 1):
|
||||
for j in range(1, size - 1):
|
||||
coordinates.append((i, j))
|
||||
indices = np.arange(len(coordinates))
|
||||
|
||||
chosen = rng.choice(indices, max(1, len(indices) // 10), replace=False)
|
||||
|
||||
for c in chosen:
|
||||
x, y = coordinates[c]
|
||||
dones[x, y] = True
|
||||
rewards[x, y] = -100
|
||||
|
||||
starts = np.array([[1, 0]])
|
||||
dones[-1, -1] = True
|
||||
rewards[-1, -1] = 100
|
||||
|
||||
return dones, rewards, starts
|
||||
|
||||
@staticmethod
|
||||
def get_minimum_problem_size():
|
||||
return 3
|
||||
|
||||
def visualize(self, outcome, coords=None, grid=None):
|
||||
""" Visualisation function for gridworld; plots environment, policy, Q. """
|
||||
policy = None
|
||||
Q = None
|
||||
V = None
|
||||
if outcome is not None:
|
||||
if outcome.policy is not None:
|
||||
policy = outcome.policy
|
||||
|
||||
if outcome.V is not None:
|
||||
V = outcome.V
|
||||
|
||||
if outcome.Q is not None:
|
||||
Q = outcome.Q
|
||||
|
||||
self._plot_environment_and_policy(policy, V, Q, show_coordinates=coords, show_grid=grid)
|
||||
|
||||
def _plot_environment_and_policy(self, policy=None,V=None, Q=None, show_coordinates=False,
|
||||
show_grid=False, plot_filename=None, debug_info=False):
|
||||
""" Function that plots environment and policy. """
|
||||
import matplotlib.pyplot as plt
|
||||
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
|
||||
dones_ax = axes[0, 0]
|
||||
rewards_ax = axes[0, 1]
|
||||
V_ax = axes[1, 0]
|
||||
Q_ax = axes[1, 1]
|
||||
|
||||
dones_ax.set_title('Terminal States and Policy')
|
||||
dones_ax.imshow(self.dones.T, cmap='gray_r', vmin=0, vmax=4)
|
||||
|
||||
rewards_ax.set_title('Immediate Rewards')
|
||||
rewards_ax.imshow(self.rewards.T, cmap='RdBu_r', vmin=-25, vmax=25)
|
||||
|
||||
if len(policy) > 0:
|
||||
self._plot_policy(dones_ax, policy)
|
||||
|
||||
w, h = self.dones.shape
|
||||
V_array = V.reshape(self.dones.shape).T
|
||||
V_ax.set_title('State Value Function $V(s)$')
|
||||
r = max(1e-13, np.max(np.abs(V_array)))
|
||||
V_ax.imshow(V_array.T, cmap='RdBu_r', vmin=-r, vmax=r)
|
||||
|
||||
if debug_info:
|
||||
for s in range(len(V)):
|
||||
sy, sx = divmod(s, w)
|
||||
V_ax.text(sx, sy, f'{sx},{sy}:{s}',
|
||||
color='w', fontdict=dict(size=6),
|
||||
horizontalalignment='center', verticalalignment='center')
|
||||
|
||||
Q_ax.set_title('State Action Value Function $Q(s, a)$')
|
||||
poly_patches_q_values = self._draw_Q(Q_ax, Q, debug_info)
|
||||
|
||||
def format_coord(x, y):
|
||||
for poly_patch, q_value in poly_patches_q_values:
|
||||
if poly_patch.contains_point(Q_ax.transData.transform((x, y))):
|
||||
return f'x:{x:4.2f} y:{y:4.2f} {q_value}'
|
||||
return f'x:{x:4.2f} y:{y:4.2f}'
|
||||
|
||||
Q_ax.format_coord = format_coord
|
||||
|
||||
for ax in [dones_ax, rewards_ax, V_ax, Q_ax]:
|
||||
ax.tick_params(
|
||||
top=show_coordinates,
|
||||
left=show_coordinates,
|
||||
labelleft=show_coordinates,
|
||||
labeltop=show_coordinates,
|
||||
right=False,
|
||||
bottom=False,
|
||||
labelbottom=False
|
||||
)
|
||||
|
||||
# Major ticks
|
||||
s = self.dones.shape[0]
|
||||
ax.set_xticks(np.arange(0, s, 1))
|
||||
ax.set_yticks(np.arange(0, s, 1))
|
||||
|
||||
# Minor ticks
|
||||
ax.set_xticks(np.arange(-.5, s, 1), minor=True)
|
||||
ax.set_yticks(np.arange(-.5, s, 1), minor=True)
|
||||
|
||||
if show_grid:
|
||||
for color, ax in zip(['m', 'w', 'w'], [dones_ax, rewards_ax, V_ax]):
|
||||
# Gridlines based on minor ticks
|
||||
ax.grid(which='minor', color=color, linestyle='-', linewidth=1)
|
||||
|
||||
plt.tight_layout()
|
||||
if plot_filename is not None:
|
||||
plt.savefig(plot_filename)
|
||||
plt.close(fig)
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
def _plot_policy(self, ax, policy):
|
||||
""" Function that plots policy. """
|
||||
w, h = self.dones.shape
|
||||
xs = np.arange(w)
|
||||
ys = np.arange(h)
|
||||
xx, yy = np.meshgrid(xs, ys)
|
||||
|
||||
# we need a quiver for each of the four action
|
||||
quivers = list()
|
||||
for a in range(self.get_n_actions()):
|
||||
quivers.append(list())
|
||||
|
||||
# we parse the textual description of the lake
|
||||
for s in range(self.get_n_states()):
|
||||
y, x = divmod(s, w)
|
||||
if self.dones[x, y]:
|
||||
for a in range(self.get_n_actions()):
|
||||
quivers[a].append((0., 0.))
|
||||
else:
|
||||
for a in range(self.get_n_actions()):
|
||||
wdx, wdy = DELTAS[a]
|
||||
corrected = np.array([wdx, -wdy])
|
||||
quivers[a].append(corrected * policy[s, a])
|
||||
|
||||
# plot each quiver
|
||||
for quiver in quivers:
|
||||
q = np.array(quiver)
|
||||
ax.quiver(xx, yy, q[:, 0], q[:, 1], units='xy', scale=1.5)
|
||||
|
||||
def _draw_Q(self, ax, Q, debug_info):
|
||||
""" Function that draws Q. """
|
||||
pattern = np.zeros(self.dones.shape)
|
||||
ax.imshow(pattern, cmap='gray_r')
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.cm import ScalarMappable
|
||||
from matplotlib.colors import Normalize
|
||||
from matplotlib.patches import Rectangle, Polygon
|
||||
w, h = self.dones.shape
|
||||
|
||||
r = max(1e-13, np.max(np.abs(Q)))
|
||||
norm = Normalize(vmin=-r, vmax=r)
|
||||
cmap = plt.get_cmap('RdBu_r')
|
||||
sm = ScalarMappable(norm, cmap)
|
||||
|
||||
hover_polygons = []
|
||||
for state in range(len(Q)):
|
||||
qs = Q[state]
|
||||
# print('qs', qs)
|
||||
y, x = divmod(state, w)
|
||||
if self.dones[x, y]:
|
||||
continue
|
||||
y += 0.5
|
||||
x += 0.5
|
||||
|
||||
dx = 1
|
||||
dy = 1
|
||||
|
||||
ulx = (x - 1) * dx
|
||||
uly = (y - 1) * dy
|
||||
|
||||
rect = Rectangle(
|
||||
xy=(ulx, uly),
|
||||
width=dx,
|
||||
height=dy,
|
||||
edgecolor='k',
|
||||
facecolor='none'
|
||||
)
|
||||
ax.add_artist(rect)
|
||||
|
||||
mx = (x - 1) * dx + dx / 2.
|
||||
my = (y - 1) * dy + dy / 2.
|
||||
|
||||
ul = ulx, uly
|
||||
ur = ulx + dx, uly
|
||||
ll = ulx, uly + dy
|
||||
lr = ulx + dx, uly + dy
|
||||
m = mx, my
|
||||
|
||||
up = [ul, m, ur]
|
||||
left = [ul, m, ll]
|
||||
right = [ur, m, lr]
|
||||
down = [ll, m, lr]
|
||||
action_polys = [left, right, up, down]
|
||||
for a, poly in enumerate(action_polys):
|
||||
poly_patch = Polygon(
|
||||
poly,
|
||||
edgecolor='k',
|
||||
linewidth=0.1,
|
||||
facecolor=sm.to_rgba(qs[a])
|
||||
)
|
||||
if debug_info:
|
||||
mmx = np.mean([x for x, y in poly])
|
||||
mmy = np.mean([y for x, y in poly])
|
||||
sss = '\n'.join(map(str, self.P[state][a]))
|
||||
ax.text(mmx, mmy, f'{NAMES[a][0]}:{sss}',
|
||||
fontdict=dict(size=5), horizontalalignment='center',
|
||||
verticalalignment='center')
|
||||
|
||||
hover_polygons.append((poly_patch, f'{NAMES[a]}:{qs[a]:4.2f}'))
|
||||
ax.add_artist(poly_patch)
|
||||
return hover_polygons
|
||||
Reference in New Issue
Block a user