added state input capability
This commit is contained in:
39
dreamer.py
39
dreamer.py
@@ -27,7 +27,7 @@ to_np = lambda x: x.detach().cpu().numpy()
|
||||
|
||||
|
||||
class Dreamer(nn.Module):
|
||||
def __init__(self, config, logger, dataset):
|
||||
def __init__(self, obs_space, act_space, config, logger, dataset):
|
||||
super(Dreamer, self).__init__()
|
||||
self._config = config
|
||||
self._logger = logger
|
||||
@@ -51,7 +51,7 @@ class Dreamer(nn.Module):
|
||||
x, self._step
|
||||
)
|
||||
self._dataset = dataset
|
||||
self._wm = models.WorldModel(self._step, config)
|
||||
self._wm = models.WorldModel(obs_space, act_space, self._step, config)
|
||||
self._task_behavior = models.ImagBehavior(
|
||||
config, self._wm, config.behavior_stop_grad
|
||||
)
|
||||
@@ -90,8 +90,9 @@ class Dreamer(nn.Module):
|
||||
for name, values in self._metrics.items():
|
||||
self._logger.scalar(name, float(np.mean(values)))
|
||||
self._metrics[name] = []
|
||||
openl = self._wm.video_pred(next(self._dataset))
|
||||
self._logger.video("train_openl", to_np(openl))
|
||||
if self._config.video_pred_log:
|
||||
openl = self._wm.video_pred(next(self._dataset))
|
||||
self._logger.video("train_openl", to_np(openl))
|
||||
self._logger.write(fps=True)
|
||||
|
||||
policy_output, state = self._policy(obs, state, training)
|
||||
@@ -296,8 +297,6 @@ def main(config):
|
||||
config.eval_every //= config.action_repeat
|
||||
config.log_every //= config.action_repeat
|
||||
config.time_limit //= config.action_repeat
|
||||
config.act = getattr(torch.nn, config.act)
|
||||
config.norm = getattr(torch.nn, config.norm)
|
||||
|
||||
print("Logdir", logdir)
|
||||
logdir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -350,7 +349,13 @@ def main(config):
|
||||
print("Simulate agent.")
|
||||
train_dataset = make_dataset(train_eps, config)
|
||||
eval_dataset = make_dataset(eval_eps, config)
|
||||
agent = Dreamer(config, logger, train_dataset).to(config.device)
|
||||
agent = Dreamer(
|
||||
train_envs[0].observation_space,
|
||||
train_envs[0].action_space,
|
||||
config,
|
||||
logger,
|
||||
train_dataset,
|
||||
).to(config.device)
|
||||
agent.requires_grad_(requires_grad=False)
|
||||
if (logdir / "latest_model.pt").exists():
|
||||
agent.load_state_dict(torch.load(logdir / "latest_model.pt"))
|
||||
@@ -362,8 +367,9 @@ def main(config):
|
||||
print("Start evaluation.")
|
||||
eval_policy = functools.partial(agent, training=False)
|
||||
tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num)
|
||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||
logger.video("eval_openl", to_np(video_pred))
|
||||
if config.video_pred_log:
|
||||
video_pred = agent._wm.video_pred(next(eval_dataset))
|
||||
logger.video("eval_openl", to_np(video_pred))
|
||||
print("Start training.")
|
||||
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
|
||||
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
||||
@@ -376,14 +382,23 @@ def main(config):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--configs", nargs="+", required=True)
|
||||
parser.add_argument("--configs", nargs="+")
|
||||
args, remaining = parser.parse_known_args()
|
||||
configs = yaml.safe_load(
|
||||
(pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text()
|
||||
)
|
||||
|
||||
def recursive_update(base, update):
|
||||
for key, value in update.items():
|
||||
if isinstance(value, dict) and key in base:
|
||||
recursive_update(base[key], value)
|
||||
else:
|
||||
base[key] = value
|
||||
|
||||
name_list = ["defaults", *args.configs] if args.configs else ["defaults"]
|
||||
defaults = {}
|
||||
for name in args.configs:
|
||||
defaults.update(configs[name])
|
||||
for name in name_list:
|
||||
recursive_update(defaults, configs[name])
|
||||
parser = argparse.ArgumentParser()
|
||||
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
|
||||
arg_type = tools.args_type(value)
|
||||
|
||||
Reference in New Issue
Block a user