97 lines
4.0 KiB
Python
97 lines
4.0 KiB
Python
import json
|
|
|
|
from pig_lite.problem.simple_2d import Simple2DProblem, MazeLevel, TerrainLevel, RoomLevel
|
|
from pig_lite.environment.gridworld import Gridworld
|
|
from pig_lite.game.tictactoe import TicTacToe
|
|
from pig_lite.decision_tree.training_set import TrainingSet
|
|
|
|
# this is the common encoding for different level tiles
|
|
encoding = {
|
|
'WALL': 1,
|
|
'SPACE': 0,
|
|
'EXPOSED': -1,
|
|
'UNDETERMINED': -2
|
|
}
|
|
|
|
class ProblemFactory():
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
@staticmethod
|
|
def generate_problem(problem_type, problem_size, rng):
|
|
if problem_type == 'maze':
|
|
level = MazeLevel(rng, size=problem_size)
|
|
return Simple2DProblem(level.get_field(),
|
|
level.get_costs(),
|
|
level.get_start(),
|
|
level.get_end())
|
|
elif problem_type == 'terrain':
|
|
level = TerrainLevel(rng, size=problem_size)
|
|
return Simple2DProblem(level.get_field(),
|
|
level.get_costs(),
|
|
level.get_start(),
|
|
level.get_end())
|
|
elif problem_type == 'rooms':
|
|
level = RoomLevel(rng, size=problem_size)
|
|
return Simple2DProblem(level.get_field(),
|
|
level.get_costs(),
|
|
level.get_start(),
|
|
level.get_end())
|
|
elif problem_type == 'tictactoe':
|
|
return TicTacToe(rng, depth=problem_size)
|
|
elif problem_type == 'gridworld':
|
|
return Gridworld.get_random_instance(rng, size=problem_size)
|
|
elif problem_type =='trainset':
|
|
raise NotImplementedError(f'problem_type {problem_type} is not implemented yet')
|
|
else:
|
|
raise ValueError(f'unknown problem_type {problem_type}')
|
|
|
|
|
|
@staticmethod
|
|
def create_problem_from_json(json_path):
|
|
with open(json_path, 'r') as file:
|
|
data = json.load(file)
|
|
problem_type = data['type']
|
|
|
|
if problem_type == 'Simple2DProblem':
|
|
problem = Simple2DProblem.from_dict(data)
|
|
return problem
|
|
elif problem_type == 'TicTacToe':
|
|
problem = TicTacToe.from_dict(data)
|
|
return problem
|
|
elif problem_type == 'Gridworld':
|
|
problem = Gridworld.from_dict(data)
|
|
return problem
|
|
elif problem_type == 'TrainingSet':
|
|
problem = TrainingSet.from_dict(data)
|
|
return problem
|
|
else:
|
|
raise ValueError(f"Unknown problem type: {problem_type}")
|
|
|
|
|
|
@staticmethod
|
|
def create_problem_from_dict(data, problem_type='Simple2DProblem'):
|
|
import numpy as np
|
|
if problem_type == 'Simple2DProblem':
|
|
if not ('board' in data.keys() and 'costs' in data.keys()
|
|
and 'start_state' in data.keys() and 'end_state' in data.keys()):
|
|
raise ValueError('data dict must contain: "board", "costs", "start_state" and "end_state"')
|
|
if np.array(data['board']).shape != np.array(data['costs']).shape:
|
|
raise ValueError('data["board"] and data["costs"] must have same shape')
|
|
problem = Simple2DProblem.from_dict(data)
|
|
return problem
|
|
if problem_type == 'TicTacToe':
|
|
if not ('start_state' in data.keys() and 'start_player' in data.keys()):
|
|
raise ValueError('data dict must contain: "start_state", "start_player"')
|
|
problem = TicTacToe.from_dict(data)
|
|
return problem
|
|
if problem_type == 'Gridworld':
|
|
if not ('seed' in data.keys() and 'dones' in data.keys()
|
|
and 'rewards' in data.keys() and 'starts' in data.keys()):
|
|
raise ValueError('data dict must contain: "seed", "dones", "rewards", "starts"')
|
|
problem = Gridworld.from_dict(data)
|
|
return problem
|
|
else:
|
|
raise NotImplementedError(f'problem_type {problem_type} is not implemented yet')
|
|
|