Files
exercise-00/pig_lite/environment/gridworld.py
2025-10-07 18:22:35 +02:00

361 lines
12 KiB
Python

import json
import numpy as np
from pig_lite.environment.base import Environment
DELTAS = [
(-1, 0),
(+1, 0),
(0, -1),
(0, +1)
]
NAMES = [
'left',
'right',
'up',
'down'
]
def sample(rng, elements):
""" Samples an element of `elements` randomly. """
csp = np.cumsum([elm[0] for elm in elements])
idx = np.argmax(csp > rng.uniform(0, 1))
return elements[idx]
class Gridworld(Environment):
def __init__(self, seed, dones, rewards, starts):
self.seed = seed
self.rng = np.random.RandomState(seed)
self.dones = dones
self.rewards = rewards
self.starts = starts
self.__compute_P()
def reset(self):
""" Resets the environment of this gridworld to a randomly sampled start state. """
_, self.state = sample(self.rng, self.starts)
return self.state
def step(self, action):
""" Performs the action on the gridworld, where next state of environment is sampled based on self.P. """
_, self.state, reward, done = sample(self.rng, self.P[self.state][action])
return self.state, reward, done
def get_n_actions(self):
""" Returns the number of actions available in this gridworld. """
return 4
def get_n_states(self):
""" Returns the number of states available in this gridworld. """
return np.prod(self.dones.shape)
def get_gamma(self):
""" Returns discount factor gamma for this gridworld. """
return 0.99
def __compute_P(self):
""" Computes and stores the transitions for this gridworld. """
w, h = self.dones.shape
def inbounds(i, j):
""" Checks whether coordinates i and j are within the grid. """
return i >= 0 and j >= 0 and i < w and j < h
self.P = dict()
for i in range(0, w):
for j in range(0, h):
state = j * w + i
self.P[state] = dict()
if self.dones[i, j]:
for action in range(self.get_n_actions()):
# make it absorbing
self.P[state][action] = [(1, state, 0, True)]
else:
for action, (dx, dy) in enumerate(DELTAS):
ortho_dir_probs = [
(0.8, dx, dy),
(0.1, dy, dx),
(0.1, -dy, -dx)
]
transitions = []
for p, di, dj in ortho_dir_probs:
ni = i + di
nj = j + dj
if inbounds(ni, nj):
# we move
sprime = nj * w + ni
done = self.dones[ni, nj]
reward = self.rewards[ni, nj]
transitions.append((p, sprime, reward, done))
else:
# stay in the same state, b/c we bounced
sprime = state
done = self.dones[i, j]
reward = self.rewards[i, j]
transitions.append((p, sprime, reward, done))
self.P[state][action] = transitions
def to_json(self):
""" Converts and stores this gridworld to a JSON file. """
return json.dumps(dict(
type=self.__class__.__name__,
seed=self.seed,
dones=self.dones.tolist(),
rewards=self.rewards.tolist(),
starts=self.starts.tolist()
))
@staticmethod
def from_json(jsonstring):
""" Loads given JSON file, and creates gridworld with information. """
data = json.loads(jsonstring)
return Gridworld(
data['seed'],
np.array(data['dones']),
np.array(data['rewards']),
np.array(data['starts'], dtype=np.int64),
)
@staticmethod
def from_dict(data):
""" Creates gridworld with information in given data-dictionary. """
return Gridworld(
data['seed'],
np.array(data['dones']),
np.array(data['rewards']),
np.array(data['starts'], dtype=np.int64),
)
@staticmethod
def get_random_instance(rng, size):
""" Given random generator and problem size, generates Gridworld instance. """
dones, rewards, starts = Gridworld.__generate(rng, size)
return Gridworld(rng.randint(0, 2 ** 31), dones, rewards, starts)
@staticmethod
def __generate(rng, size):
""" Helper function that retrieves dones, rewards, starts for Gridworld instance generation. """
dones = np.full((size, size), False, dtype=bool)
rewards = np.zeros((size, size), dtype=np.int8) - 1
coordinates = []
for i in range(1, size - 1):
for j in range(1, size - 1):
coordinates.append((i, j))
indices = np.arange(len(coordinates))
chosen = rng.choice(indices, max(1, len(indices) // 10), replace=False)
for c in chosen:
x, y = coordinates[c]
dones[x, y] = True
rewards[x, y] = -100
starts = np.array([[1, 0]])
dones[-1, -1] = True
rewards[-1, -1] = 100
return dones, rewards, starts
@staticmethod
def get_minimum_problem_size():
return 3
def visualize(self, outcome, coords=None, grid=None):
""" Visualisation function for gridworld; plots environment, policy, Q. """
policy = None
Q = None
V = None
if outcome is not None:
if outcome.policy is not None:
policy = outcome.policy
if outcome.V is not None:
V = outcome.V
if outcome.Q is not None:
Q = outcome.Q
self._plot_environment_and_policy(policy, V, Q, show_coordinates=coords, show_grid=grid)
def _plot_environment_and_policy(self, policy=None,V=None, Q=None, show_coordinates=False,
show_grid=False, plot_filename=None, debug_info=False):
""" Function that plots environment and policy. """
import matplotlib.pyplot as plt
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
dones_ax = axes[0, 0]
rewards_ax = axes[0, 1]
V_ax = axes[1, 0]
Q_ax = axes[1, 1]
dones_ax.set_title('Terminal States and Policy')
dones_ax.imshow(self.dones.T, cmap='gray_r', vmin=0, vmax=4)
rewards_ax.set_title('Immediate Rewards')
rewards_ax.imshow(self.rewards.T, cmap='RdBu_r', vmin=-25, vmax=25)
if len(policy) > 0:
self._plot_policy(dones_ax, policy)
w, h = self.dones.shape
V_array = V.reshape(self.dones.shape).T
V_ax.set_title('State Value Function $V(s)$')
r = max(1e-13, np.max(np.abs(V_array)))
V_ax.imshow(V_array.T, cmap='RdBu_r', vmin=-r, vmax=r)
if debug_info:
for s in range(len(V)):
sy, sx = divmod(s, w)
V_ax.text(sx, sy, f'{sx},{sy}:{s}',
color='w', fontdict=dict(size=6),
horizontalalignment='center', verticalalignment='center')
Q_ax.set_title('State Action Value Function $Q(s, a)$')
poly_patches_q_values = self._draw_Q(Q_ax, Q, debug_info)
def format_coord(x, y):
for poly_patch, q_value in poly_patches_q_values:
if poly_patch.contains_point(Q_ax.transData.transform((x, y))):
return f'x:{x:4.2f} y:{y:4.2f} {q_value}'
return f'x:{x:4.2f} y:{y:4.2f}'
Q_ax.format_coord = format_coord
for ax in [dones_ax, rewards_ax, V_ax, Q_ax]:
ax.tick_params(
top=show_coordinates,
left=show_coordinates,
labelleft=show_coordinates,
labeltop=show_coordinates,
right=False,
bottom=False,
labelbottom=False
)
# Major ticks
s = self.dones.shape[0]
ax.set_xticks(np.arange(0, s, 1))
ax.set_yticks(np.arange(0, s, 1))
# Minor ticks
ax.set_xticks(np.arange(-.5, s, 1), minor=True)
ax.set_yticks(np.arange(-.5, s, 1), minor=True)
if show_grid:
for color, ax in zip(['m', 'w', 'w'], [dones_ax, rewards_ax, V_ax]):
# Gridlines based on minor ticks
ax.grid(which='minor', color=color, linestyle='-', linewidth=1)
plt.tight_layout()
if plot_filename is not None:
plt.savefig(plot_filename)
plt.close(fig)
else:
plt.show()
def _plot_policy(self, ax, policy):
""" Function that plots policy. """
w, h = self.dones.shape
xs = np.arange(w)
ys = np.arange(h)
xx, yy = np.meshgrid(xs, ys)
# we need a quiver for each of the four action
quivers = list()
for a in range(self.get_n_actions()):
quivers.append(list())
# we parse the textual description of the lake
for s in range(self.get_n_states()):
y, x = divmod(s, w)
if self.dones[x, y]:
for a in range(self.get_n_actions()):
quivers[a].append((0., 0.))
else:
for a in range(self.get_n_actions()):
wdx, wdy = DELTAS[a]
corrected = np.array([wdx, -wdy])
quivers[a].append(corrected * policy[s, a])
# plot each quiver
for quiver in quivers:
q = np.array(quiver)
ax.quiver(xx, yy, q[:, 0], q[:, 1], units='xy', scale=1.5)
def _draw_Q(self, ax, Q, debug_info):
""" Function that draws Q. """
pattern = np.zeros(self.dones.shape)
ax.imshow(pattern, cmap='gray_r')
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
from matplotlib.patches import Rectangle, Polygon
w, h = self.dones.shape
r = max(1e-13, np.max(np.abs(Q)))
norm = Normalize(vmin=-r, vmax=r)
cmap = plt.get_cmap('RdBu_r')
sm = ScalarMappable(norm, cmap)
hover_polygons = []
for state in range(len(Q)):
qs = Q[state]
# print('qs', qs)
y, x = divmod(state, w)
if self.dones[x, y]:
continue
y += 0.5
x += 0.5
dx = 1
dy = 1
ulx = (x - 1) * dx
uly = (y - 1) * dy
rect = Rectangle(
xy=(ulx, uly),
width=dx,
height=dy,
edgecolor='k',
facecolor='none'
)
ax.add_artist(rect)
mx = (x - 1) * dx + dx / 2.
my = (y - 1) * dy + dy / 2.
ul = ulx, uly
ur = ulx + dx, uly
ll = ulx, uly + dy
lr = ulx + dx, uly + dy
m = mx, my
up = [ul, m, ur]
left = [ul, m, ll]
right = [ur, m, lr]
down = [ll, m, lr]
action_polys = [left, right, up, down]
for a, poly in enumerate(action_polys):
poly_patch = Polygon(
poly,
edgecolor='k',
linewidth=0.1,
facecolor=sm.to_rgba(qs[a])
)
if debug_info:
mmx = np.mean([x for x, y in poly])
mmy = np.mean([y for x, y in poly])
sss = '\n'.join(map(str, self.P[state][a]))
ax.text(mmx, mmy, f'{NAMES[a][0]}:{sss}',
fontdict=dict(size=5), horizontalalignment='center',
verticalalignment='center')
hover_polygons.append((poly_patch, f'{NAMES[a]}:{qs[a]:4.2f}'))
ax.add_artist(poly_patch)
return hover_polygons