1021 lines
48 KiB
Plaintext
1021 lines
48 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "markdown",
|
|
"checksum": "6a7ea65e2f19d811f1a48145be4a29dd",
|
|
"grade": false,
|
|
"grade_id": "cell-84617f606b66d110",
|
|
"locked": true,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"source": [
|
|
"# Artificial Intelligence UE\n",
|
|
"## Exercises 3 - Game Playing\n",
|
|
"\n",
|
|
"In this series of exercises you are looking at game playing - more precisely, at the Minimax algorithm, Alpha-Beta pruning and Q-Learning. \n",
|
|
"\n",
|
|
"The algorithms have been explained in the lecture (VO) and we gave you some additional information in the exercise (UE). Please refer to the lecture slides (VO) for the pseudo algorithms and the exercise slides (UE) for additional hints.\n",
|
|
"\n",
|
|
"<div class=\"alert alert-warning\">\n",
|
|
"\n",
|
|
"<p><strong>Practical hints:</strong></p>\n",
|
|
"<ul>\n",
|
|
"<li>Replace the placeholders <code># YOUR CODE HERE</code>, <code>raise NotImplementedError()</code> with your code.</li>\n",
|
|
"<li>Do not rename any of the already existing variables (this might lead to tests failing / not working).</li>\n",
|
|
"<li>if you want a number smaller than all others, you may use <code>float('-Inf')</code></li>\n",
|
|
"<li>if you want a number larger than all others, you may use <code>float('Inf')</code></li>\n",
|
|
"</ul>\n",
|
|
"</div>\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "8f6aa7b610ced6df3c25d748b625c832",
|
|
"grade": false,
|
|
"grade_id": "cell-9f190755dfdee1fc",
|
|
"locked": true,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# import stuff\n",
|
|
"from pig_lite.game.base import Game, Node\n",
|
|
"from pig_lite.environment.base import Environment, Outcome\n",
|
|
"from pig_lite.instance_generation.problem_factory import ProblemFactory\n",
|
|
"\n",
|
|
"import math\n",
|
|
"import random\n",
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "markdown",
|
|
"checksum": "bde8a4d40df318658aa6813eced457d2",
|
|
"grade": false,
|
|
"grade_id": "cell-2f0104814be2be96",
|
|
"locked": true,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"source": [
|
|
"## Small Intro into the World of TicTacToe"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# you can generate a new game board of TicTacToe as follows; note that the problem_size here describes the depth of the board in a game tree\n",
|
|
"rng = np.random.RandomState(seed=123)\n",
|
|
"game = ProblemFactory().generate_problem('tictactoe', problem_size=3, rng=rng)\n",
|
|
"\n",
|
|
"# or, you can load an existing one from a .json file like so:\n",
|
|
"game = ProblemFactory().create_problem_from_json(json_path='boards/game.json')\n",
|
|
"\n",
|
|
"# if we use Minimax / Alphabeta pruning to derive a move sequence, we can visualise it as follows:\n",
|
|
"move_sequence = [(-1, (0, 2)), (1, (1, 1)), (-1, (1, 2)), (1, (0, 0)), (-1, (2, 0))] # arbitrary move sequence for demonstration purposes\n",
|
|
"game.visualize(move_sequence, show_possible=False, tree_name='Arbitrary Tree 1')\n",
|
|
"# if we set show_possible to True, the function shows all possible moves from a state in the path\n",
|
|
"game.visualize(move_sequence, show_possible=True, tree_name='Arbitrary Tree 2')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "markdown",
|
|
"checksum": "185560b1b20aa70e61fad1611159680a",
|
|
"grade": false,
|
|
"grade_id": "cell-8ca2076d79b78bdb",
|
|
"locked": true,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"source": [
|
|
"## Minimax\n",
|
|
"\n",
|
|
"Now, let us implement the Minimax algorithm!\n",
|
|
"\n",
|
|
"**NOTE**: If multiple paths lead to the same outcome for these algorithms, choose the first expanded / leftmost path."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "3b4382686aacd28991f7d0ee2a8d0ad2",
|
|
"grade": false,
|
|
"grade_id": "cell-c5c2a2df427bd111",
|
|
"locked": false,
|
|
"schema_version": 3,
|
|
"solution": true,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Minimax():\n",
|
|
" def play(self, game: Game):\n",
|
|
" \"\"\" Starts game playing, and returns found terminal node according to minimax. \"\"\"\n",
|
|
" start = game.get_start_node()\n",
|
|
" # 'game.get_max_player()' asks the game how it identifies the MAX player internally\n",
|
|
" value, terminal_node = self.minimax(game, start, game.get_max_player())\n",
|
|
" return terminal_node\n",
|
|
"\n",
|
|
" def minimax(self, game, node, max_player):\n",
|
|
" \"\"\" Performs minimax algorithm (recursively). \"\"\"\n",
|
|
" # here we check if the current node 'node' is a terminal node\n",
|
|
" terminal, winner = game.outcome(node)\n",
|
|
"\n",
|
|
" # if it is a terminal node, determine who won, and return\n",
|
|
" # a) the utility value (-1, 0, 1)\n",
|
|
" # and b) the terminal node itself, to be able to determine the path of moves/plies that led to this terminal node\n",
|
|
" if terminal:\n",
|
|
" if winner is None:\n",
|
|
" return 0, node\n",
|
|
" elif winner == max_player:\n",
|
|
" return 1, node\n",
|
|
" else:\n",
|
|
" return -1, node\n",
|
|
"\n",
|
|
" # TODO: implement the minimax algorithm recursively here\n",
|
|
" if node.player == max_player:\n",
|
|
" # you have to remember the best value *and* the best node for the MAX player (TODO: initialise appropriately)\n",
|
|
" best_value, best_node = None, None\n",
|
|
" # YOUR CODE HERE\n",
|
|
" raise NotImplementedError()\n",
|
|
" return best_value, best_node\n",
|
|
" else:\n",
|
|
" # you have to remember the best value *and* the best node for the MIN player (TODO: initialise appropriately)\n",
|
|
" best_value, best_node = None, None\n",
|
|
" # YOUR CODE HERE\n",
|
|
" raise NotImplementedError()\n",
|
|
" return best_value, best_node\n",
|
|
"\n",
|
|
"game = ProblemFactory().create_problem_from_json(json_path='boards/game.json')\n",
|
|
"outcome = Minimax().play(game)\n",
|
|
"minimax_nodes = game.get_number_of_expanded_nodes()\n",
|
|
"\n",
|
|
"if outcome is not None:\n",
|
|
" terminated, winner = game.outcome(outcome)\n",
|
|
" print('Game terminated: {}, winner is: {} (1: Max, -1: Min); nr of expanded nodes: {}'.format(terminated, winner, minimax_nodes))\n",
|
|
" outcome.pretty_print()\n",
|
|
" game.visualize(game.get_move_sequence(outcome), False, 'Minimax Tree') "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Minimax Checks"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "2e346b2a42e7c2d3c69aac9c2a351ade",
|
|
"grade": true,
|
|
"grade_id": "cell-69d1bdefa9930127",
|
|
"locked": true,
|
|
"points": 0.5,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# check found path here \n",
|
|
"assert(outcome is not None), 'Minimax returned None, something is wrong with the implementation'\n",
|
|
"# this check tests whether you really chose the left-most path\n",
|
|
"game = ProblemFactory().create_problem_from_json(json_path='boards/game.json')\n",
|
|
"outcome = Minimax().play(game)\n",
|
|
"terminated, winner = game.outcome(outcome)\n",
|
|
"assert(terminated == True), 'Minimax did not return a terminal node, so likely wrong node was returned'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) != 'b68a7dd18dfc7694da1b25e0290903b4d2fedba66c8d090c0d11923f5e2c8d22'), 'Minimax did not find correct move sequence, likely due to not taking the first expanded optimal path for MAX player'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) != 'fc8ecbd22f45983e14e43ef4cdd120252913898e521f2712cd9d07b8913cade0'), 'Minimax did not find correct move sequence, likely due to not taking the first expanded optimal path for MIN player'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) != 'c263685af9af51da439db5aa6714bb39845ad19f2d9f7c4279e9b9c086559e46'), 'Minimax did not find correct move sequence, likely due to not taking the first expanded optimal path for MAX+MIN player'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) == 'bfb38fb43f84847e4b001d09dcb22cfe573f41efac370e118f3b6630fd0f259a'), 'Minimax did not find correct move sequence to terminal state for the provided problem instance'\n",
|
|
"print('Minimax found correct terminal node and move sequence leading to it for provided problem instance')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "19a815a4bd0fe5cade1bb6080253418a",
|
|
"grade": true,
|
|
"grade_id": "cell-c3fb982c8f3e0663",
|
|
"locked": true,
|
|
"points": 0.5,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"game = ProblemFactory().create_problem_from_json(json_path='boards/tictactoe1.json')\n",
|
|
"outcome = Minimax().play(game)\n",
|
|
"assert(outcome is not None), 'Minimax returned None, this can have various reasons; e.g., you might have not or incorrectly updated the best value / node, or the update check was wrong (val > best_value for MAX, val < best_value for MIN)'\n",
|
|
"assert(np.all(outcome.state.T == np.array([[1, -1, -1], [-1, 1, 1], [1, -1, -1]]))), 'Minimax did not find correct terminal node for private instance 1, maybe wrong node was returned or minimax has an error'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) == '4e49e7612322365aceaccd3a80dca6c4836c0d7196593c785d9e3da7960fe2ea'), 'Minimax did not find correct move sequence to terminal state of private problem instance 1'\n",
|
|
"print('Minimax found correct terminal node and move sequence leading to it for private problem instance 1')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "47365356cc5162c4ab15e9ba3c0e473d",
|
|
"grade": true,
|
|
"grade_id": "cell-d955df46dc4a948b",
|
|
"locked": true,
|
|
"points": 0.5,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"game = ProblemFactory().create_problem_from_json(json_path='boards/tictactoe2.json')\n",
|
|
"outcome = Minimax().play(game)\n",
|
|
"assert(outcome is not None), 'Minimax returned None, this can have various reasons; e.g., you might have not or incorrectly updated the best value / node, or the update check was wrong (val > best_value for MAX, val < best_value for MIN)'\n",
|
|
"assert not (game.get_move_sequence_hash(outcome) == 'bba1199f0c4f21f58aac340189660bc99a8b50ded6a227f7f0b2039572b627c4' or\n",
|
|
" np.all(outcome.state.T == np.array([[-1, 0, 1], [0, 0, 1], [0, -1, 1]]))), 'Minimax did not find correct move sequence or terminal node, likely due to not taking the first expanded optimal path for MAX player'\n",
|
|
"assert not (game.get_move_sequence_hash(outcome) == '45940174f9875ce563c45439c8d86b4d1eed0ebaa6947a4587f93bb28bf44476' or\n",
|
|
" np.all(outcome.state.T == np.array([[1, 0, 1], [1, 0, -1], [1, -1, -1]]))), 'Minimax did not find correct move sequence or terminal node, likely due to not taking the first expanded optimal path for MIN player'\n",
|
|
"assert not (game.get_move_sequence_hash(outcome) == '1961082c8650fc498c973b099e4b0c12edba87a04eaf278d5a4e6389364c694e' or\n",
|
|
" np.all(outcome.state.T == np.array([[0, -1, 1], [0, 1, -1], [1, -1, 1]]))), 'Minimax did not find correct move sequence or terminal node, likely due to not taking the first expanded optimal path for MAX+MIN player'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) == '79e76d8da84c876fcd145e25d88601e2997d9557004c0b969b4402ddb9185f78' and\n",
|
|
" np.all(outcome.state.T == np.array([[1, -1, 1], [-1, 1, 0], [1, -1, 0]]))), 'Minimax did not find correct move sequence to terminal state or terminal node for the provided problem instance'\n",
|
|
"print('Minimax found correct terminal node and move sequence leading to it for private problem instance 2')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "c0a98cc7ae08e34d29b193cf19453896",
|
|
"grade": true,
|
|
"grade_id": "cell-7d888625c8c74fc2",
|
|
"locked": true,
|
|
"points": 0.5,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"game = ProblemFactory().create_problem_from_json(json_path='boards/tictactoe3.json')\n",
|
|
"outcome = Minimax().play(game)\n",
|
|
"assert(outcome is not None), 'Minimax returned None, this can have various reasons; e.g., you might have not or incorrectly updated the best value / node, or the update check was wrong (val > best_value for MAX, val < best_value for MIN)'\n",
|
|
"assert not (game.get_move_sequence_hash(outcome) == '1b585860ffb92cd318887950c257f7f12d9e0555ae849ffd2b2fad8caf66e8e6' or\n",
|
|
" np.all(outcome.state.T == np.array([[-1, 0, 1],[-1, 1, 1],[1, -1, 0]]))), 'Minimax did not find correct move sequence or terminal node, likely due to not taking the first expanded optimal path for MAX player'\n",
|
|
"assert not (game.get_move_sequence_hash(outcome) == 'b4b1d966698346de24c319905e52877e0afefe1a2b9e60f9b4f314fbdc25d9d5' or\n",
|
|
" np.all(outcome.state.T == np.array([[1, 0, 0], [1, 1, -1], [1, -1, -1]]))), 'Minimax did not find correct move sequence or terminal node, likely due to not taking the first expanded optimal path for MIN player'\n",
|
|
"assert not (game.get_move_sequence_hash(outcome) == 'fdb559d7f1b35a9f0a01e2b857c2c592cba1c691292aef491db2a1b89164ed1b' or\n",
|
|
" np.all(outcome.state.T == np.array([[0, 0, -1], [1, 1, 1], [1, -1, -1]]))), 'Minimax did not find correct move sequence or terminal node, likely due to not taking the first expanded optimal path for MAX+MIN player'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) == '20576ffcccc1fa1d6081497f283229761012385ffd5c34f86547801cba3ffead' and\n",
|
|
" np.all(outcome.state.T == np.array([[-1, -1, 1], [1, 1, 0], [1, -1, 0]]))), 'Minimax did not find correct move sequence to terminal state or terminal node for the provided problem instance'\n",
|
|
"print('Minimax found correct terminal node and move sequence leading to it for private problem instance 3')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "b38fd276ca771307aa4bb62c5cb95b74",
|
|
"grade": true,
|
|
"grade_id": "cell-be5f3456e9709954",
|
|
"locked": true,
|
|
"points": 0.5,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"game = ProblemFactory().create_problem_from_json(json_path='boards/tictactoe4.json')\n",
|
|
"outcome = Minimax().play(game)\n",
|
|
"assert(outcome is not None), 'Minimax returned None, this can have various reasons; e.g., you might have not or incorrectly updated the best value / node, or the update check was wrong (val > best_value for MAX, val < best_value for MIN)'\n",
|
|
"assert(np.all(outcome.state.T == np.array([[1, -1, 1], [1, -1, 1], [-1, 1, -1]]))), 'Minimax did not find correct terminal node for private instance 4, maybe wrong node was returned or minimax has an error'\n",
|
|
"assert (game.get_move_sequence_hash(outcome) != '1dfb09398dc0726119c331fe7134cf22085cb8b7e9fb32ff8358a5675c2e40a0'), 'Minimax did not find correct move sequence or terminal node, likely due to not taking the first expanded optimal path for MAX player'\n",
|
|
"assert (game.get_move_sequence_hash(outcome) != '26c69f8c65d0f1cdac09aafcc00d4ba0ce03fd9336d1977928100481f7570ee7'), 'Minimax did not find correct move sequence or terminal node, likely due to not taking the first expanded optimal path for MIN or MAX+MIN player'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) == '49e861565ea87b3b2e555c75dc1835b8c030ac0bf55da27c3f4efd2f3585e6ca'), 'Minimax did not find correct move sequence to terminal state or terminal node for the provided problem instance'\n",
|
|
"print('Minimax found correct terminal node and move sequence leading to it for private problem instance 4')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "markdown",
|
|
"checksum": "78a753186db6517a0564a3a4195ecc73",
|
|
"grade": false,
|
|
"grade_id": "cell-592b6c2a2d4b5c05",
|
|
"locked": true,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
},
|
|
"slideshow": {
|
|
"slide_type": ""
|
|
},
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"## Alpha-Beta Pruning\n",
|
|
"\n",
|
|
"Here, let us implement Alpha-Beta pruning. \n",
|
|
"\n",
|
|
"**NOTE**: If multiple paths lead to the same outcome for these algorithms, choose the first expanded / leftmost path."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": true,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "1968a996abb1b95f3e682fbbdbb1aaf9",
|
|
"grade": false,
|
|
"grade_id": "cell-d2df9b0e3d90cf00",
|
|
"locked": false,
|
|
"schema_version": 3,
|
|
"solution": true,
|
|
"task": false
|
|
},
|
|
"slideshow": {
|
|
"slide_type": ""
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class AlphaBeta(object):\n",
|
|
" def play(self, game: Game):\n",
|
|
" \"\"\" Starts game playing, and returns found terminal node according to alpha-beta pruning. \"\"\"\n",
|
|
" start = game.get_start_node()\n",
|
|
" alpha = float('-Inf')\n",
|
|
" beta = float('Inf')\n",
|
|
" value, terminal_node = self.alphabeta(game, start, alpha, beta, game.get_max_player())\n",
|
|
" return terminal_node\n",
|
|
"\n",
|
|
" def alphabeta(self, game, node, alpha, beta, max_player):\n",
|
|
" \"\"\" Performs alpha-beta pruning algorithm (recursively). \"\"\"\n",
|
|
" # here we check if the current node 'node' is a terminal node\n",
|
|
" terminal, winner = game.outcome(node)\n",
|
|
" # if it is a terminal node, determine who won, and return\n",
|
|
" if terminal:\n",
|
|
" if winner is None:\n",
|
|
" return 0, node\n",
|
|
" elif winner == max_player:\n",
|
|
" return 1, node\n",
|
|
" else:\n",
|
|
" return -1, node\n",
|
|
"\n",
|
|
" # TODO: implement the alpha-beta pruning algorithm recursively here\n",
|
|
" # the structure should be almost the same as for minimax\n",
|
|
" # YOUR CODE HERE\n",
|
|
" raise NotImplementedError()\n",
|
|
"\n",
|
|
"game = ProblemFactory().create_problem_from_json(json_path='boards/game.json')\n",
|
|
"outcome = AlphaBeta().play(game)\n",
|
|
"alphabeta_nodes = game.get_number_of_expanded_nodes()\n",
|
|
"\n",
|
|
"if outcome is not None:\n",
|
|
" terminated, winner = game.outcome(outcome)\n",
|
|
" print('Game terminated: {}, winner is: {} (1: Max, -1: Min); nr of expanded nodes: {}'.format(terminated, winner, alphabeta_nodes))\n",
|
|
" outcome.pretty_print()\n",
|
|
" game.visualize(game.get_move_sequence(outcome), False, 'Alpha-Beta Tree') "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "markdown",
|
|
"checksum": "458736acf00f1d309bbdb6dd2d60ae36",
|
|
"grade": false,
|
|
"grade_id": "cell-c8aa7d319fc1f6cb",
|
|
"locked": true,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"source": [
|
|
"### Alpha-Beta Pruning Checks"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "dda53e8b045f8f7f22eb209071c6681f",
|
|
"grade": true,
|
|
"grade_id": "cell-2db5203948b220c7",
|
|
"locked": true,
|
|
"points": 0.5,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# check found path here \n",
|
|
"assert(outcome is not None), 'Alpha-beta pruning returned None, something is wrong with the implementation'\n",
|
|
"\n",
|
|
"game = ProblemFactory().create_problem_from_json(json_path='boards/game.json')\n",
|
|
"outcome = AlphaBeta().play(game)\n",
|
|
"terminated, winner = game.outcome(outcome)\n",
|
|
"assert(terminated == True), 'Alpha-beta pruning did not return a terminal node, so likely wrong node was returned'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) != 'f224894cb163f1fd5530927610c3c2667f6c5141c2c026f6fd9e7f6ace1e4bb8'), 'Alpha-beta pruning did not find correct move sequence, likely due to not taking the first expanded optimal path for MAX(+MIN) player'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) != 'fc8ecbd22f45983e14e43ef4cdd120252913898e521f2712cd9d07b8913cade0'), 'Alpha-beta pruning did not find correct move sequence, likely due to not taking the first expanded optimal path for MIN player'\n",
|
|
"assert(np.all(outcome.state.T == np.array([[1, 1, 1], [-1, 1, 0], [-1, -1, 0]]))), 'Alpha-beta did not find correct terminal node, maybe wrong best node was stored'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) == 'bfb38fb43f84847e4b001d09dcb22cfe573f41efac370e118f3b6630fd0f259a'), 'Alpha-beta did not find correct move sequence to terminal state'\n",
|
|
"print('Alpha-beta pruning found correct terminal node and move sequence leading to it for provided problem instance')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "8be67fa34c9c3decc3ad39324ebb01a2",
|
|
"grade": true,
|
|
"grade_id": "cell-5e7587f893fa567c",
|
|
"locked": true,
|
|
"points": 0.5,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
},
|
|
"slideshow": {
|
|
"slide_type": ""
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"game = ProblemFactory().create_problem_from_json(json_path='boards/tictactoe1.json')\n",
|
|
"outcome = AlphaBeta().play(game)\n",
|
|
"assert(outcome is not None), 'Alpha-beta pruning returned None, this can have various reasons; e.g., you might have not or incorrectly updated the best value / node, or the update check was wrong (val > best_value for MAX, val < best_value for MIN)'\n",
|
|
"assert(np.all(outcome.state.T == np.array([[1, -1, -1], [-1, 1, 1], [1, -1, -1]]))), 'Alpha-beta did not find correct terminal node for private instance 1, maybe wrong node was returned or minimax has an error'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) == '4e49e7612322365aceaccd3a80dca6c4836c0d7196593c785d9e3da7960fe2ea'), 'Alpha-beta did not find correct move sequence to terminal state for private problem instance 1'\n",
|
|
"print('Alpha-beta pruning found correct terminal node and move sequence leading to it for private problem instance 1')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "38c9cd2ff8767c68b1ca36ff59203471",
|
|
"grade": true,
|
|
"grade_id": "cell-127304dc5089e06c",
|
|
"locked": true,
|
|
"points": 0.5,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"game = ProblemFactory().create_problem_from_json(json_path='boards/tictactoe_draw_5x5.json')\n",
|
|
"outcome = AlphaBeta().play(game)\n",
|
|
"assert(outcome is not None), 'Alpha-beta pruning returned None, this can have various reasons; e.g., you might have not or incorrectly updated the best value / node, or the update check was wrong (val > best_value for MAX, val < best_value for MIN)'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) != '7b8a98f36acbd985721d447bd59a9dacc987e31d5d8d2ccafaabe6891cc3de39'), 'Alpha-beta pruning did not find correct move sequence, likely due to not taking the first expanded optimal path for MAX(+MIN) player' \n",
|
|
"assert(game.get_move_sequence_hash(outcome) != 'f728a654955355bc35f3f4ca86a0b409a0febf600899828db999bbe4118a2f65'), 'Alpha-beta pruning did not find correct move sequence, likely due to not taking the first expanded optimal path for MIN player' \n",
|
|
"assert(game.get_move_sequence_hash(outcome) == 'f472a4657161dd8c11256060a52fca0cd9ef075cb0bdd6f809319162cd26bba2'), 'Alpha-beta did not find correct move sequence to terminal state for private problem instance 2'\n",
|
|
"print('Alpha-beta pruning found correct terminal node and move sequence leading to it for private problem instance 2')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "903231436c1055baefa17ae5ac4a15c1",
|
|
"grade": true,
|
|
"grade_id": "cell-20ca8862d1e2a714",
|
|
"locked": true,
|
|
"points": 0.5,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"game = ProblemFactory().create_problem_from_json(json_path='boards/tictactoe_win_5x5.json')\n",
|
|
"outcome = AlphaBeta().play(game)\n",
|
|
"assert(outcome is not None), 'Alpha-beta pruning returned None, this can have various reasons; e.g., you might have not or incorrectly updated the best value / node, or the update check was wrong (val > best_value for MAX, val < best_value for MIN)'\n",
|
|
"assert(game.get_move_sequence_hash(outcome) == 'da4f6dd7b0fa7fbd53f6191111e7fdbf1082d6c8a41704b981d36bbb016f039d' or\n",
|
|
" game.get_move_sequence_hash(outcome) == '491ace3d81d89d79836ef6919b3d77c9c9a081cdf51de6941403a2ae07ca0dd3'), 'Alpha-beta did not find correct move sequence to terminal state for private problem instance 3'\n",
|
|
"print('Alpha-beta pruning found correct terminal node and move sequence leading to it for private problem instance 3')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "b769d093d7e11e54009a364a4900089f",
|
|
"grade": true,
|
|
"grade_id": "cell-a5135148f16e856c",
|
|
"locked": true,
|
|
"points": 1,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# check expanded nodes here (whether we actually save something compared to minimax)\n",
|
|
"assert(alphabeta_nodes <= minimax_nodes), 'Alpha-beta pruning took more node expansions than minimax - something must be off here...'\n",
|
|
"game_ab = ProblemFactory().create_problem_from_json(json_path='boards/game.json')\n",
|
|
"game_mm = ProblemFactory().create_problem_from_json(json_path='boards/game.json')\n",
|
|
"out_ab = AlphaBeta().play(game_ab)\n",
|
|
"out_mm = Minimax().play(game_mm)\n",
|
|
"assert(not(game_ab.get_number_of_expanded_nodes() == game_mm.get_number_of_expanded_nodes() and game_ab.get_move_sequence_hash(out_ab) == game_mm.get_move_sequence_hash(out_mm))), 'Seems like alpha-beta pruning behaves as minimax, nothing was pruned for the provided test instance'\n",
|
|
"assert(not(game_ab.get_number_of_expanded_nodes() == 354 or game_ab.get_number_of_expanded_nodes() == 360 or game_ab.get_number_of_expanded_nodes() == 640 or\n",
|
|
" game_ab.get_number_of_expanded_nodes() == 839 or game_ab.get_number_of_expanded_nodes() == 3079)), 'Did not expand correct number of nodes for alpha-beta pruning, likely either used pruning for MIN/MAX only, or used </> instead of <=/=>'\n",
|
|
"assert(game_ab.get_number_of_expanded_nodes() == 266), 'Did not expand correct number of nodes for alpha-beta pruning, likely due to wrong pruning condition (should be alpha >= beta for MAX, beta <= alpha for MIN)'\n",
|
|
"print('Pruning seems okay')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "markdown",
|
|
"checksum": "7c02e622c196e587436f0154ece2cb7e",
|
|
"grade": false,
|
|
"grade_id": "cell-568544c7786aea22",
|
|
"locked": true,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"source": [
|
|
"## Small Intro into the Gridworld\n",
|
|
"\n",
|
|
"<div class=\"alert alert-info\">\n",
|
|
"For Q-Learing, we require another new problem type - we here look at a stochastic gridworld.\n",
|
|
"</div>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# you can generate a new gridworld as follows\n",
|
|
"rng = np.random.RandomState(seed=123)\n",
|
|
"env = ProblemFactory().generate_problem('gridworld', problem_size=3, rng=rng)\n",
|
|
"\n",
|
|
"# or, you can load an existing one from a .json file like so:\n",
|
|
"env_json = ProblemFactory().create_problem_from_json(json_path='boards/environment.json')\n",
|
|
"\n",
|
|
"# if we use Q-Learning to learn the Q-function, we can visualise its results as follows:\n",
|
|
"rand_policy = np.array([[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [1, 0, 0, 0]])\n",
|
|
"outcome = Outcome(n_episodes=1, policy=rand_policy, V=np.random.randn(env.get_n_states()), # arbitrary outcome for demonstration purposes\n",
|
|
" Q=np.random.randn(env.get_n_states(), env.get_n_actions())) \n",
|
|
"env.visualize(outcome)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "markdown",
|
|
"checksum": "1135727bbef232f40df175d667718c44",
|
|
"grade": false,
|
|
"grade_id": "cell-233204deb4eb05b0",
|
|
"locked": true,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"source": [
|
|
"## Q-Learning\n",
|
|
" \n",
|
|
"<strong>Remember: To interact with the (Q-Learning) enviroment, you need</strong>\n",
|
|
"<ul>\n",
|
|
"<li><code>state = env.reset()</code> to reset the environment at the start of an episode</li>\n",
|
|
"<li><code>state, reward, done = env.step(action)</code> to tell the environment that your agent decided to take `action`. The environment then tells you in which state you actually ended up in (<code>state</code>), what the immediate reward was (<code>reward</code>), and whether or not the episode ended (<code>done</code>).</li>\n",
|
|
"<li>If the test takes significantly longer than ~20 seconds you probably have an error in your update step of the state-action function</li>\n",
|
|
" \n",
|
|
"</ul>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "f37386df9aa6102c0348b1ffe1f5822c",
|
|
"grade": false,
|
|
"grade_id": "cell-1bbd4598c8d0500f",
|
|
"locked": false,
|
|
"schema_version": 3,
|
|
"solution": true,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def eps_greedy(rng, qs, epsilon):\n",
|
|
" \"\"\" Makes an epsilon greedy decision between exploration (trying out a new option)\n",
|
|
" and exploitation (choosing best option so far). \"\"\"\n",
|
|
" if rng.uniform(0, 1) < epsilon:\n",
|
|
" # with probability p == epsilon, an action is chosen uniformly at random\n",
|
|
" # YOUR CODE HERE\n",
|
|
" raise NotImplementedError()\n",
|
|
" else:\n",
|
|
" # with probability p == 1 - epsilon, the action having the currently largest q-value estimate is chosen\n",
|
|
" # YOUR CODE HERE\n",
|
|
" raise NotImplementedError()\n",
|
|
" \n",
|
|
" # this is to avoid errors if there is no implementation yet - you can remove it if you want\n",
|
|
" return -1\n",
|
|
"\n",
|
|
"class QLearning():\n",
|
|
" def train(self, env: Environment, n_episodes=10000, alpha=0.2):\n",
|
|
" \"\"\" Performs Q-Learning for given environment. \"\"\"\n",
|
|
" # leave untouched for the sake of reproducibility (tests below rely on these fixed values)\n",
|
|
" self.rng = np.random.RandomState(1234)\n",
|
|
" self.epsilon = 0.3\n",
|
|
" self.gamma = env.get_gamma()\n",
|
|
"\n",
|
|
" # initialize the Q-'table'\n",
|
|
" Q = np.zeros((env.get_n_states(), env.get_n_actions()))\n",
|
|
"\n",
|
|
" for episode in range(1, n_episodes + 1):\n",
|
|
" # implement q-learning update here: generate an episode, interact with environment with env.reset() and env.step(action)\n",
|
|
" # YOUR CODE HERE\n",
|
|
" raise NotImplementedError()\n",
|
|
"\n",
|
|
" # compute a deterministic policy from the Q value function\n",
|
|
" policy = np.zeros((env.get_n_states(), env.get_n_actions()), dtype=np.int64)\n",
|
|
" policy[np.arange(len(policy)), np.argmax(Q, axis=1)] = 1\n",
|
|
" # finally, compute the state value function V here\n",
|
|
" # it can be computed easily from Q by taking the action that leads to the max future reward\n",
|
|
" V = None\n",
|
|
" # YOUR CODE HERE\n",
|
|
" raise NotImplementedError()\n",
|
|
"\n",
|
|
" return Outcome(n_episodes, policy, V=V, Q=Q)\n",
|
|
"\n",
|
|
"environment = ProblemFactory().create_problem_from_json(json_path='boards/environment.json')\n",
|
|
"qlearn = QLearning()\n",
|
|
"outcome = qlearn.train(environment)\n",
|
|
"\n",
|
|
"if outcome is not None:\n",
|
|
" environment.visualize(outcome)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "markdown",
|
|
"checksum": "9b840bfe6c132101d4d814e0dc258945",
|
|
"grade": false,
|
|
"grade_id": "cell-ac953a8ab55a2747",
|
|
"locked": true,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"source": [
|
|
"### Q-Learning Checks"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "37f495bfb5d0eb45b136f5fd547f0306",
|
|
"grade": true,
|
|
"grade_id": "cell-724a9b7bd5e49b48",
|
|
"locked": true,
|
|
"points": 0.5,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# here we check whether default variables were modified\n",
|
|
"assert(qlearn.epsilon == 0.3), 'Epsilon was changed for Q-Learning'\n",
|
|
"assert(qlearn.gamma == environment.get_gamma()), 'Gamma was changed for Q-Learning'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "08335d2c6e8f0d1aa5877a81f1c27227",
|
|
"grade": true,
|
|
"grade_id": "cell-321985c0232e9f93",
|
|
"locked": true,
|
|
"points": 1,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# here we check a few test instances for their resulting policy (encoded in a hash-value)\n",
|
|
"# first two hashes are obtained by using q-learning update with alpha, second two hashed w/o alpha\n",
|
|
"assert(environment.get_policy_hash(outcome) != 'f35de3a82ae52dc2c67cacc7942b85785a35ba188a20454ccbdf77627b6695c8'), 'Make sure to use both the current and the next state in your Q function update'\n",
|
|
"assert(environment.get_policy_hash(outcome) == 'a138e26bebdd61e38fc045f03a37ee77bc3343dc36cb3f1cf415707a9b5e08ad' or\n",
|
|
" environment.get_policy_hash(outcome) == '6c8ec07e309222af5c0839f8a6fb58597135356f451dc61c624a1ebea86735fe' or\n",
|
|
" environment.get_policy_hash(outcome) == '7e730a7445950f7a9c8125c96d0d45066f995833f08e4ea585b38ce93a6313e0' or\n",
|
|
" environment.get_policy_hash(outcome) == 'af0dc1e412f6985a939691b3f589f89a3fb5d696a3cde67d33438024d14223e9'), 'algorithm did not find the same optimal policy as ours, so there is probably something off'\n",
|
|
"\n",
|
|
"env = ProblemFactory().create_problem_from_json(json_path='boards/gridworld1.json')\n",
|
|
"out = QLearning().train(env)\n",
|
|
"assert(env.get_policy_hash(out) == 'ca2e78d44dacc2645876957f1e10393b77c933a94c562f845025d31b3dc062c4' or\n",
|
|
" env.get_policy_hash(out) == 'c14762db6d5f332ae89cb4209cf460fa9eecffdee197383fa043b2c6ef00b76c' or\n",
|
|
" env.get_policy_hash(out) == '0ceffffae674a851c603e971597718b70d2dce2ed6e38ffc925fbc5a81d8aebf'), 'Q-Learning did not find the same policy as ours; this could be due to incorrect update function (or other aspects, see remaining comments regarding tests)'\n",
|
|
"print('The Q-learning implementation found the same policy as ours for a new test instance!')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "bb39b0cc7f75b224d0a9f9ac75a4ba36",
|
|
"grade": true,
|
|
"grade_id": "cell-096b0e848bc26030",
|
|
"locked": true,
|
|
"points": 2,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# here we check whether eps_greedy was implemented (correctly)\n",
|
|
"assert(eps_greedy(qlearn.rng, outcome.Q[environment.reset()], qlearn.epsilon) != -1), 'eps_greedy does not appear to be implemented (correctly) yet'\n",
|
|
"\n",
|
|
"\n",
|
|
"rng = np.random.RandomState(12)\n",
|
|
"print('Test of eps_greedy sampling, random sampling with epsilon = 0.5 and random state = 12:')\n",
|
|
"assert(eps_greedy(rng, [59.23768459, 59.86590956, 59.15508216, 60.56009166], 0.5) == 2), 'first step should be random action, eps_greedy does not work as expected, make sure to use the passed rng to select randomly'\n",
|
|
"assert(eps_greedy(rng, [59.23768459, 59.86590956, 59.15508216, 60.56009166], 0.5) == 3), 'second step should be argmax of Q, eps_greedy does not work as expected'\n",
|
|
"assert(eps_greedy(rng, [59.23768459, 59.86590956, 59.15508216, 60.56009166], 0.5) == 3), 'third step should be random action, eps_greedy does not work as expected, make sure to use the passed rng to select randomly'\n",
|
|
"assert(eps_greedy(rng, [59.23768459, 59.86590956, 59.15508216, 60.56009166], 0.5) == 2), 'fourth step should be random action, eps_greedy does not work as expected, make sure to use the passed rng to select randomly'\n",
|
|
"assert(eps_greedy(rng, [59.23768459, 59.86590956, 59.15508216, 60.56009166], 0.5) == 1), 'fifth step should be random action, eps_greedy does not work as expected, make sure to use the passed rng to select randomly'\n",
|
|
"assert(eps_greedy(rng, [59.23768459, 59.86590956, 59.15508216, 60.56009166], 0.5) == 2), 'sixth step should be random action, eps_greedy does not work as expected, make sure to use the passed rng to select randomly'\n",
|
|
"assert(eps_greedy(rng, [59.23768459, 59.86590956, 59.15508216, 60.56009166], 0.5) == 3), 'seventh step should be argmax of Q, eps_greedy does not work as expected'\n",
|
|
"print('Seems to work fine!')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "9f46d2c303ee7af69eb42353c81ad2cf",
|
|
"grade": true,
|
|
"grade_id": "cell-3c0c3a8168ade782",
|
|
"locked": true,
|
|
"points": 0.5,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# here we check whether eps_greedy was called in your QLearning implementation\n",
|
|
"import ast\n",
|
|
"import inspect\n",
|
|
"call_names = [c.func.id for c in ast.walk(ast.parse(inspect.getsource(QLearning.train).lstrip()))\n",
|
|
" if isinstance(c, ast.Call) and not isinstance(c.func, ast.Attribute)] \n",
|
|
"assert('eps_greedy' in call_names), 'eps_greedy function was not used during training to sample next action'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "674a5cb9e18b0339310ed3fe125672a5",
|
|
"grade": true,
|
|
"grade_id": "cell-7e0012c0d1af9a1c",
|
|
"locked": true,
|
|
"points": 0.5,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# here we check whether V was computed (correctly)\n",
|
|
"assert(outcome.V is not None), 'V was not computed (correctly)'\n",
|
|
"assert(np.allclose(outcome.V, [60.56009166229519, 62.0862976417879, 59.69673685231367, 57.2998914308495, 57.16247167797613, 63.72475795338036, 58.681549989922765, 36.160081000906104, -0.7346191191046638, -7.957286438234731, \n",
|
|
" 62.429613717496785, 64.27021333475943, 61.8489623068024, 56.68931070289786, 64.7086260941021, 68.96740198280624, 64.72192080370876, 45.42333812841379, -7.70660099982711, -6.3540660505033255, \n",
|
|
" 63.78874659772162, 66.28022856039745, 64.26691771606903, 0.0, 69.00293688866753, 72.11660041265043, 59.602399420640545, 38.030496836206275, 14.299440372492398, -1.636243819995724, \n",
|
|
" 65.60578892848311, 68.06355849549294, 70.76043474466302, 73.52559642373782, 74.829413699787, 74.7244884571062, 0.0, 0.0, -14.372316463840185, 30.32359198821047, \n",
|
|
" 67.45442443720407, 69.98858830119075, 73.64652924428802, 76.45223284842137, 78.6024777398158, 80.19432277767568, 81.97087573052612, 81.87190792575387, 0.0, 47.6790177531604, \n",
|
|
" 65.25444937464933, 67.06361233729368, 70.09899795480531, 74.8512985464834, 80.99319093176972, 82.70891686090964, 84.1063933238421, 87.10476457339571, 89.55648769695577, 92.02018276648413, \n",
|
|
" 63.98216126752015, 64.78481366263227, 63.37994143735609, 0.0, 83.06026652598523, 84.14552727908097, 86.80838516636535, 89.91640697830174, 92.07232887476961, 94.00906431697199, \n",
|
|
" 61.66891093157388, 63.80128024721642, 59.39816181823294, 45.550034604437634, 84.31066865743784, 86.72480155341584, 89.00776722138832, 91.72422981845112, 94.69912447457426, 97.00665363583768, \n",
|
|
" 17.818059604934152, 53.73008677587408, 38.607997303469254, 0.0, 86.4088656035926, 88.63000118406423, 90.8302457235381, 93.48189534830975, 96.7290118179945, 99.76915525815585, \n",
|
|
" 2.7869296811856383, 1.9107575717432566, 28.481804693589865, 79.17530638335536, 88.17997020406862, 89.31582272957053, 92.06318449933205, 95.62679684531966, 98.34285013960046, 0.0]) or \n",
|
|
" np.allclose(outcome.V, [65.23372476711732, 54.00862916103103, 55.56427187982933, 57.13562816144377, 57.13562816144377, 60.32611790780917, 58.72285672873108, 60.32611790780917, 60.32611790780917, 73.74916255379564, \n",
|
|
" 66.9027522900175, 68.58863867678535, 54.00862916103103, 58.72285672873108, 63.58138751944615, 58.72285672873108, 70.29155421897511, 65.23372476711732, 63.58138751944615, 86.41306958139799, \n",
|
|
" 66.9027522900175, 70.29155421897511, 54.00862916103103, 0.0, 65.23372476711732, 77.27697434322585, 72.01167092825769, 66.9027522900175, 63.58138751944615, 84.54893888558401, \n",
|
|
" 70.29155421897511, 68.58863867678535, 77.27697434322585, 75.50420459979358, 80.87641500176089, 79.06765085174328, 0.0, 0.0, -100.0, 86.41306958139799, \n",
|
|
" 72.01167092825769, 73.74916255379564, 75.50420459979358, 77.27697434322585, 82.70344949672817, 80.87641500176089, 82.70344949672817, 84.54893888558401, 0.0, 86.41306958139799, \n",
|
|
" 73.74916255379564, 73.74916255379564, 73.74916255379564, 75.50420459979358, 84.54893888558401, 86.41306958139799, 80.87641500176089, 90.19800998, 92.119202, 94.0598, \n",
|
|
" 68.58863867678535, 70.29155421897511, 75.50420459979358, 0.0, 80.87641500176089, 86.41306958139799, 82.70344949672817, 88.2960298802, 94.0598, 96.02, \n",
|
|
" 66.9027522900175, 58.72285672873108, 60.32611790780917, 61.94557364425169, 84.54893888558401, 86.41306958139799, 92.119202, 90.19800998, 94.0598, 98.0, \n",
|
|
" 60.32611790780917, 58.72285672873108, 60.32611790780917, 0.0, 82.70344949672817, 88.2960298802, 92.119202, 96.02, 94.0598, 100.0, \n",
|
|
" 57.13562816144377, 58.72285672873108, 58.72285672873108, 58.72285672873108, 92.119202, 82.70344949672817, 96.02, 98.0, 100.0, 0.0])), 'Your computed V does not match our solution'\n",
|
|
"print('V seems to be computed correctly!')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"deletable": false,
|
|
"editable": false,
|
|
"nbgrader": {
|
|
"cell_type": "code",
|
|
"checksum": "a66094a9e41dab1ccd79ec4ca4ded973",
|
|
"grade": true,
|
|
"grade_id": "cell-aecdeb8567515f69",
|
|
"locked": true,
|
|
"points": 2,
|
|
"schema_version": 3,
|
|
"solution": false,
|
|
"task": false
|
|
},
|
|
"slideshow": {
|
|
"slide_type": ""
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# here we check whether the update function is correctly implemented\n",
|
|
"rng = np.random.RandomState(12)\n",
|
|
"env = ProblemFactory().generate_problem('gridworld', problem_size=3, rng=rng)\n",
|
|
"q_approx = QLearning().train(env, 1).Q\n",
|
|
"assert(np.allclose(q_approx, np.array([[-0.36, -0.2, -0.3996, -0.2], [-0.2, 0., 0., 0.], [0., 0., 0., 0.,], [-0.2, -20., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.,], [0., 0., 0., 0.], [0., 0., 0., 0.]])) or\n",
|
|
" np.allclose(q_approx, np.array([[-1., -1., -1.99, -1.], [-1., 0., 0., 0.], [0., 0., 0., 0.,], [-1., -100., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.,], [0., 0., 0., 0.], [0., 0., 0., 0.]]))), 'Q-values are wrong for this instance, probably the update function is wrongly implemented'\n",
|
|
"env = ProblemFactory().generate_problem('gridworld', problem_size=3, rng=rng)\n",
|
|
"q_approx = QLearning().train(env, 2).Q\n",
|
|
"assert(np.allclose(q_approx, np.array([[-0.3996, -0.488, -0.3996, -0.488], [-0.2, -0.2, -0.39168, -0.27128], [-0.2, 0., 0., 0.], [-0.3996, -20., -0.2396, -0.3996], [0., 0., 0., 0.], [0., 0., 0., 0.], [-0.3996, -0.36, -0.2396, -0.2], [-0.2, 20., 0., 0.], [0., 0., 0., 0.]])) or\n",
|
|
" np.allclose(q_approx, np.array([[-1.99, -1.99, -1., -1.], [-1.99, -1., -1.99, -100.], [-1., -1., -1., -1.], [0., 0., -1.99, 0.], [0., 0., 0., 0.], [-100., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]]))), 'Q-values are wrong for this instance, probably the update function is wrongly implemented'\n",
|
|
"print('Q-update looks correct as well')"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|