Files
exercise-00/pig_lite/instance_generation/problem_factory.py
2025-10-07 18:22:35 +02:00

97 lines
4.0 KiB
Python

import json
from pig_lite.problem.simple_2d import Simple2DProblem, MazeLevel, TerrainLevel, RoomLevel
from pig_lite.environment.gridworld import Gridworld
from pig_lite.game.tictactoe import TicTacToe
from pig_lite.decision_tree.training_set import TrainingSet
# this is the common encoding for different level tiles
encoding = {
'WALL': 1,
'SPACE': 0,
'EXPOSED': -1,
'UNDETERMINED': -2
}
class ProblemFactory():
def __init__(self) -> None:
pass
@staticmethod
def generate_problem(problem_type, problem_size, rng):
if problem_type == 'maze':
level = MazeLevel(rng, size=problem_size)
return Simple2DProblem(level.get_field(),
level.get_costs(),
level.get_start(),
level.get_end())
elif problem_type == 'terrain':
level = TerrainLevel(rng, size=problem_size)
return Simple2DProblem(level.get_field(),
level.get_costs(),
level.get_start(),
level.get_end())
elif problem_type == 'rooms':
level = RoomLevel(rng, size=problem_size)
return Simple2DProblem(level.get_field(),
level.get_costs(),
level.get_start(),
level.get_end())
elif problem_type == 'tictactoe':
return TicTacToe(rng, depth=problem_size)
elif problem_type == 'gridworld':
return Gridworld.get_random_instance(rng, size=problem_size)
elif problem_type =='trainset':
raise NotImplementedError(f'problem_type {problem_type} is not implemented yet')
else:
raise ValueError(f'unknown problem_type {problem_type}')
@staticmethod
def create_problem_from_json(json_path):
with open(json_path, 'r') as file:
data = json.load(file)
problem_type = data['type']
if problem_type == 'Simple2DProblem':
problem = Simple2DProblem.from_dict(data)
return problem
elif problem_type == 'TicTacToe':
problem = TicTacToe.from_dict(data)
return problem
elif problem_type == 'Gridworld':
problem = Gridworld.from_dict(data)
return problem
elif problem_type == 'TrainingSet':
problem = TrainingSet.from_dict(data)
return problem
else:
raise ValueError(f"Unknown problem type: {problem_type}")
@staticmethod
def create_problem_from_dict(data, problem_type='Simple2DProblem'):
import numpy as np
if problem_type == 'Simple2DProblem':
if not ('board' in data.keys() and 'costs' in data.keys()
and 'start_state' in data.keys() and 'end_state' in data.keys()):
raise ValueError('data dict must contain: "board", "costs", "start_state" and "end_state"')
if np.array(data['board']).shape != np.array(data['costs']).shape:
raise ValueError('data["board"] and data["costs"] must have same shape')
problem = Simple2DProblem.from_dict(data)
return problem
if problem_type == 'TicTacToe':
if not ('start_state' in data.keys() and 'start_player' in data.keys()):
raise ValueError('data dict must contain: "start_state", "start_player"')
problem = TicTacToe.from_dict(data)
return problem
if problem_type == 'Gridworld':
if not ('seed' in data.keys() and 'dones' in data.keys()
and 'rewards' in data.keys() and 'starts' in data.keys()):
raise ValueError('data dict must contain: "seed", "dones", "rewards", "starts"')
problem = Gridworld.from_dict(data)
return problem
else:
raise NotImplementedError(f'problem_type {problem_type} is not implemented yet')