cleaned up envs
This commit is contained in:
188
envs/wrappers.py
Normal file
188
envs/wrappers.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CollectDataset:
|
||||
|
||||
def __init__(self, env, callbacks=None, precision=32):
|
||||
self._env = env
|
||||
self._callbacks = callbacks or ()
|
||||
self._precision = precision
|
||||
self._episode = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
obs = {k: self._convert(v) for k, v in obs.items()}
|
||||
transition = obs.copy()
|
||||
if isinstance(action, dict):
|
||||
transition.update(action)
|
||||
else:
|
||||
transition['action'] = action
|
||||
transition['reward'] = reward
|
||||
transition['discount'] = info.get('discount', np.array(1 - float(done)))
|
||||
self._episode.append(transition)
|
||||
if done:
|
||||
for key, value in self._episode[1].items():
|
||||
if key not in self._episode[0]:
|
||||
self._episode[0][key] = 0 * value
|
||||
episode = {k: [t[k] for t in self._episode] for k in self._episode[0]}
|
||||
episode = {k: self._convert(v) for k, v in episode.items()}
|
||||
info['episode'] = episode
|
||||
for callback in self._callbacks:
|
||||
callback(episode)
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
obs = self._env.reset()
|
||||
transition = obs.copy()
|
||||
# Missing keys will be filled with a zeroed out version of the first
|
||||
# transition, because we do not know what action information the agent will
|
||||
# pass yet.
|
||||
transition['reward'] = 0.0
|
||||
transition['discount'] = 1.0
|
||||
self._episode = [transition]
|
||||
return obs
|
||||
|
||||
def _convert(self, value):
|
||||
value = np.array(value)
|
||||
if np.issubdtype(value.dtype, np.floating):
|
||||
dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision]
|
||||
elif np.issubdtype(value.dtype, np.signedinteger):
|
||||
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
|
||||
elif np.issubdtype(value.dtype, np.uint8):
|
||||
dtype = np.uint8
|
||||
elif np.issubdtype(value.dtype, np.bool):
|
||||
dtype = np.bool
|
||||
else:
|
||||
raise NotImplementedError(value.dtype)
|
||||
return value.astype(dtype)
|
||||
|
||||
|
||||
class TimeLimit:
|
||||
|
||||
def __init__(self, env, duration):
|
||||
self._env = env
|
||||
self._duration = duration
|
||||
self._step = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
def step(self, action):
|
||||
assert self._step is not None, 'Must reset environment.'
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
self._step += 1
|
||||
if self._step >= self._duration:
|
||||
done = True
|
||||
if 'discount' not in info:
|
||||
info['discount'] = np.array(1.0).astype(np.float32)
|
||||
self._step = None
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
self._step = 0
|
||||
return self._env.reset()
|
||||
|
||||
|
||||
class NormalizeActions:
|
||||
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
self._mask = np.logical_and(
|
||||
np.isfinite(env.action_space.low),
|
||||
np.isfinite(env.action_space.high))
|
||||
self._low = np.where(self._mask, env.action_space.low, -1)
|
||||
self._high = np.where(self._mask, env.action_space.high, 1)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
low = np.where(self._mask, -np.ones_like(self._low), self._low)
|
||||
high = np.where(self._mask, np.ones_like(self._low), self._high)
|
||||
return gym.spaces.Box(low, high, dtype=np.float32)
|
||||
|
||||
def step(self, action):
|
||||
original = (action + 1) / 2 * (self._high - self._low) + self._low
|
||||
original = np.where(self._mask, original, action)
|
||||
return self._env.step(original)
|
||||
|
||||
|
||||
class OneHotAction:
|
||||
|
||||
def __init__(self, env):
|
||||
assert isinstance(env.action_space, gym.spaces.Discrete)
|
||||
self._env = env
|
||||
self._random = np.random.RandomState()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
shape = (self._env.action_space.n,)
|
||||
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||
space.sample = self._sample_action
|
||||
space.discrete = True
|
||||
return space
|
||||
|
||||
def step(self, action):
|
||||
index = np.argmax(action).astype(int)
|
||||
reference = np.zeros_like(action)
|
||||
reference[index] = 1
|
||||
if not np.allclose(reference, action):
|
||||
raise ValueError(f'Invalid one-hot action:\n{action}')
|
||||
return self._env.step(index)
|
||||
|
||||
def reset(self):
|
||||
return self._env.reset()
|
||||
|
||||
def _sample_action(self):
|
||||
actions = self._env.action_space.n
|
||||
index = self._random.randint(0, actions)
|
||||
reference = np.zeros(actions, dtype=np.float32)
|
||||
reference[index] = 1.0
|
||||
return reference
|
||||
|
||||
|
||||
class RewardObs:
|
||||
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
spaces = self._env.observation_space.spaces
|
||||
assert 'reward' not in spaces
|
||||
spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32)
|
||||
return gym.spaces.Dict(spaces)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
obs['reward'] = reward
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
obs = self._env.reset()
|
||||
obs['reward'] = 0.0
|
||||
return obs
|
||||
|
||||
|
||||
class SelectAction:
|
||||
|
||||
def __init__(self, env, key):
|
||||
self._env = env
|
||||
self._key = key
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
def step(self, action):
|
||||
return self._env.step(action[self._key])
|
||||
Reference in New Issue
Block a user