added state input capability

This commit is contained in:
NM512
2023-05-14 23:38:46 +09:00
parent 3ebb8ad617
commit b984e69b6e
8 changed files with 369 additions and 142 deletions

View File

@@ -27,7 +27,7 @@ to_np = lambda x: x.detach().cpu().numpy()
class Dreamer(nn.Module):
def __init__(self, config, logger, dataset):
def __init__(self, obs_space, act_space, config, logger, dataset):
super(Dreamer, self).__init__()
self._config = config
self._logger = logger
@@ -51,7 +51,7 @@ class Dreamer(nn.Module):
x, self._step
)
self._dataset = dataset
self._wm = models.WorldModel(self._step, config)
self._wm = models.WorldModel(obs_space, act_space, self._step, config)
self._task_behavior = models.ImagBehavior(
config, self._wm, config.behavior_stop_grad
)
@@ -90,8 +90,9 @@ class Dreamer(nn.Module):
for name, values in self._metrics.items():
self._logger.scalar(name, float(np.mean(values)))
self._metrics[name] = []
openl = self._wm.video_pred(next(self._dataset))
self._logger.video("train_openl", to_np(openl))
if self._config.video_pred_log:
openl = self._wm.video_pred(next(self._dataset))
self._logger.video("train_openl", to_np(openl))
self._logger.write(fps=True)
policy_output, state = self._policy(obs, state, training)
@@ -296,8 +297,6 @@ def main(config):
config.eval_every //= config.action_repeat
config.log_every //= config.action_repeat
config.time_limit //= config.action_repeat
config.act = getattr(torch.nn, config.act)
config.norm = getattr(torch.nn, config.norm)
print("Logdir", logdir)
logdir.mkdir(parents=True, exist_ok=True)
@@ -350,7 +349,13 @@ def main(config):
print("Simulate agent.")
train_dataset = make_dataset(train_eps, config)
eval_dataset = make_dataset(eval_eps, config)
agent = Dreamer(config, logger, train_dataset).to(config.device)
agent = Dreamer(
train_envs[0].observation_space,
train_envs[0].action_space,
config,
logger,
train_dataset,
).to(config.device)
agent.requires_grad_(requires_grad=False)
if (logdir / "latest_model.pt").exists():
agent.load_state_dict(torch.load(logdir / "latest_model.pt"))
@@ -362,8 +367,9 @@ def main(config):
print("Start evaluation.")
eval_policy = functools.partial(agent, training=False)
tools.simulate(eval_policy, eval_envs, episodes=config.eval_episode_num)
video_pred = agent._wm.video_pred(next(eval_dataset))
logger.video("eval_openl", to_np(video_pred))
if config.video_pred_log:
video_pred = agent._wm.video_pred(next(eval_dataset))
logger.video("eval_openl", to_np(video_pred))
print("Start training.")
state = tools.simulate(agent, train_envs, config.eval_every, state=state)
torch.save(agent.state_dict(), logdir / "latest_model.pt")
@@ -376,14 +382,23 @@ def main(config):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+", required=True)
parser.add_argument("--configs", nargs="+")
args, remaining = parser.parse_known_args()
configs = yaml.safe_load(
(pathlib.Path(sys.argv[0]).parent / "configs.yaml").read_text()
)
def recursive_update(base, update):
for key, value in update.items():
if isinstance(value, dict) and key in base:
recursive_update(base[key], value)
else:
base[key] = value
name_list = ["defaults", *args.configs] if args.configs else ["defaults"]
defaults = {}
for name in args.configs:
defaults.update(configs[name])
for name in name_list:
recursive_update(defaults, configs[name])
parser = argparse.ArgumentParser()
for key, value in sorted(defaults.items(), key=lambda x: x[0]):
arg_type = tools.args_type(value)