diff --git a/configs.yaml b/configs.yaml index 8168060..1ab9283 100644 --- a/configs.yaml +++ b/configs.yaml @@ -63,6 +63,7 @@ defaults: weight_decay: 0.0 unimix_ratio: 0.01 action_unimix_ratio: 0.01 + initial: 'learned' # Training batch_size: 16 diff --git a/dreamer.py b/dreamer.py index a41854d..ce6152c 100644 --- a/dreamer.py +++ b/dreamer.py @@ -110,9 +110,10 @@ class Dreamer(nn.Module): ) else: latent, action = state - embed = self._wm.encoder(self._wm.preprocess(obs)) + obs = self._wm.preprocess(obs) + embed = self._wm.encoder(obs) latent, _ = self._wm.dynamics.obs_step( - latent, action, embed, self._config.collect_dyn_sample + latent, action, embed, obs["is_first"], self._config.collect_dyn_sample ) if self._config.eval_state_mean: latent["stoch"] = latent["mean"] diff --git a/models.py b/models.py index a2448f1..7120234 100644 --- a/models.py +++ b/models.py @@ -66,6 +66,7 @@ class WorldModel(nn.Module): config.dyn_min_std, config.dyn_cell, config.unimix_ratio, + config.initial, config.num_actions, embed_size, config.device, @@ -95,6 +96,7 @@ class WorldModel(nn.Module): config.norm, dist=config.reward_head, outscale=0.0, + device=config.device, ) else: self.heads["reward"] = networks.DenseHead( @@ -106,6 +108,7 @@ class WorldModel(nn.Module): config.norm, dist=config.reward_head, outscale=0.0, + device=config.device, ) self.heads["cont"] = networks.DenseHead( feat_size, # pytorch version @@ -115,6 +118,7 @@ class WorldModel(nn.Module): config.act, config.norm, dist="binary", + device=config.device, ) for name in config.grad_heads: assert name in self.heads, name @@ -140,7 +144,9 @@ class WorldModel(nn.Module): with tools.RequiresGrad(self): with torch.cuda.amp.autocast(self._use_amp): embed = self.encoder(data) - post, prior = self.dynamics.observe(embed, data["action"]) + post, prior = self.dynamics.observe( + embed, data["action"], data["is_first"] + ) kl_free = tools.schedule(self._config.kl_free, self._step) dyn_scale = tools.schedule(self._config.dyn_scale, self._step) rep_scale = tools.schedule(self._config.rep_scale, self._step) @@ -204,7 +210,9 @@ class WorldModel(nn.Module): data = self.preprocess(data) embed = self.encoder(data) - states, _ = self.dynamics.observe(embed[:6, :5], data["action"][:6, :5]) + states, _ = self.dynamics.observe( + embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5] + ) recon = self.heads["image"](self.dynamics.get_feat(states)).mode()[:6] reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6] init = {k: v[:, -1] for k, v in states.items()} @@ -257,6 +265,7 @@ class ImagBehavior(nn.Module): config.norm, config.value_head, outscale=0.0, + device=config.device, ) else: self.value = networks.DenseHead( @@ -268,6 +277,7 @@ class ImagBehavior(nn.Module): config.norm, config.value_head, outscale=0.0, + device=config.device, ) if config.slow_value_target: self._slow_value = copy.deepcopy(self.value) diff --git a/networks.py b/networks.py index 3171afa..f52d1a8 100644 --- a/networks.py +++ b/networks.py @@ -28,6 +28,7 @@ class RSSM(nn.Module): min_std=0.1, cell="gru", unimix_ratio=0.01, + initial="learned", num_actions=None, embed=None, device=None, @@ -48,6 +49,7 @@ class RSSM(nn.Module): self._std_act = std_act self._temp_post = temp_post self._unimix_ratio = unimix_ratio + self._initial = initial self._embed = embed self._device = device @@ -112,6 +114,12 @@ class RSSM(nn.Module): self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch) self._obs_stat_layer.apply(tools.weight_init) + if self._initial == "learned": + self.W = torch.nn.Parameter( + torch.zeros((1, self._deter), device=torch.device(self._device)), + requires_grad=True, + ) + def initial(self, batch_size): deter = torch.zeros(batch_size, self._deter).to(self._device) if self._discrete: @@ -131,19 +139,27 @@ class RSSM(nn.Module): stoch=torch.zeros([batch_size, self._stoch]).to(self._device), deter=deter, ) - return state + if self._initial == "zeros": + return state + elif self._initial == "learned": + state["deter"] = torch.tanh(self.W).repeat(batch_size, 1) + state["stoch"] = self.get_stoch(state["deter"]) + return state + else: + raise NotImplementedError(self._initial) - def observe(self, embed, action, state=None): + def observe(self, embed, action, is_first, state=None): swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) if state is None: state = self.initial(action.shape[0]) # (batch, time, ch) -> (time, batch, ch) - embed, action = swap(embed), swap(action) + embed, action, is_first = swap(embed), swap(action), swap(is_first) + # prev_state[0] means selecting posterior of return(posterior, prior) from obs_step post, prior = tools.static_scan( - lambda prev_state, prev_act, embed: self.obs_step( - prev_state[0], prev_act, embed + lambda prev_state, prev_act, embed, is_first: self.obs_step( + prev_state[0], prev_act, embed, is_first ), - (action, embed), + (action, embed, is_first), (state, state), ) @@ -184,10 +200,22 @@ class RSSM(nn.Module): ) return dist - def obs_step(self, prev_state, prev_action, embed, sample=True): + def obs_step(self, prev_state, prev_action, embed, is_first, sample=True): # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer) # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach() + + if torch.sum(is_first) > 0: + is_first = is_first[:, None] + prev_action *= 1.0 - is_first + init_state = self.initial(len(is_first)) + for key, val in prev_state.items(): + is_first_r = torch.reshape( + is_first, + is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)), + ) + val = val * (1.0 - is_first_r) + init_state[key] * is_first_r + prior = self.img_step(prev_state, prev_action, None, sample) if self._shared: post = self.img_step(prev_state, prev_action, embed, sample) @@ -242,6 +270,12 @@ class RSSM(nn.Module): prior = {"stoch": stoch, "deter": deter, **stats} return prior + def get_stoch(self, deter): + x = self._img_out_layers(deter) + stats = self._suff_stats_layer("ims", x) + dist = self.get_dist(stats) + return dist.mode() + def _suff_stats_layer(self, name, x): if self._discrete: if name == "ims": @@ -435,6 +469,7 @@ class DenseHead(nn.Module): dist="normal", std=1.0, outscale=1.0, + device="cuda", ): super(DenseHead, self).__init__() self._shape = (shape,) if isinstance(shape, int) else shape @@ -446,6 +481,7 @@ class DenseHead(nn.Module): self._norm = norm self._dist = dist self._std = std + self._device = device layers = [] for index in range(self._layers): @@ -491,7 +527,7 @@ class DenseHead(nn.Module): ) ) if self._dist == "twohot_symlog": - return tools.TwoHotDistSymlog(logits=mean) + return tools.TwoHotDistSymlog(logits=mean, device=self._device) raise NotImplementedError(self._dist)