lightupbutton.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import gymnasium as gym
  2. import numpy as np
  3. from gymnasium.envs.registration import register
  4. # a very simple game
  5. # one of <size> lights will light up
  6. # take the action of the lit up light
  7. # in <hard_mode>, you act differently based on the step number and need to track this
  8. class PressTheLightUpButton(gym.Env):
  9. metadata = {"render_modes": []}
  10. def __init__(self, render_mode=None, size=2, game_length=10, hard_mode=False):
  11. self.size, self.game_length = size, game_length
  12. self.observation_space = gym.spaces.Box(0, 1, shape=(self.size,), dtype=np.float32)
  13. self.action_space = gym.spaces.Discrete(self.size)
  14. self.step_num = 0
  15. self.done = True
  16. self.hard_mode = hard_mode
  17. def _get_obs(self):
  18. obs = [0]*self.size
  19. if self.step_num < len(self.state):
  20. obs[self.state[self.step_num]] = 1
  21. return np.array(obs, dtype=np.float32)
  22. def reset(self, seed=None, options=None):
  23. super().reset(seed=seed)
  24. self.state = np.random.randint(0, self.size, size=self.game_length)
  25. self.step_num = 0
  26. self.done = False
  27. return self._get_obs(), {}
  28. def step(self, action):
  29. target = ((action + self.step_num) % self.size) if self.hard_mode else action
  30. reward = int(target == self.state[self.step_num])
  31. self.step_num += 1
  32. if not reward:
  33. self.done = True
  34. return self._get_obs(), reward, self.done, self.step_num >= self.game_length, {}
  35. register(
  36. id="PressTheLightUpButton-v0",
  37. entry_point="examples.rl.lightupbutton:PressTheLightUpButton",
  38. max_episode_steps=None,
  39. )