partial fix to loading checkpoints
This commit is contained in:
@@ -12,7 +12,7 @@ class Buffer():
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self._device = torch.get_default_device()
|
||||
self._device = torch.device('cuda:0')
|
||||
self._capacity = min(cfg.buffer_size, cfg.steps)
|
||||
self._sampler = SliceSampler(
|
||||
num_slices=self.cfg.batch_size,
|
||||
@@ -59,7 +59,7 @@ class Buffer():
|
||||
total_bytes = bytes_per_step*self._capacity
|
||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||
# Heuristic: decide whether to use CUDA or CPU memory
|
||||
storage_device = torch.get_default_device() if 2.5*total_bytes < mem_free else 'cpu'
|
||||
storage_device = 'cuda:0' if 2.5*total_bytes < mem_free else 'cpu'
|
||||
print(f'Using {storage_device.upper()} memory for storage.')
|
||||
self._storage_device = torch.device(storage_device)
|
||||
return self._reserve_buffer(
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import torch
|
||||
from torch.nn import Buffer
|
||||
|
||||
|
||||
class RunningScale(torch.nn.Module):
|
||||
"""Running trimmed scale estimator."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.get_default_device()))
|
||||
self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.get_default_device()))
|
||||
self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda:0')))
|
||||
self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda:0')))
|
||||
|
||||
def state_dict(self):
|
||||
return dict(value=self.value, percentiles=self._percentiles)
|
||||
|
||||
@@ -7,6 +7,7 @@ from common import layers, math, init
|
||||
from tensordict import TensorDict
|
||||
from tensordict.nn import TensorDictParams
|
||||
|
||||
|
||||
class WorldModel(nn.Module):
|
||||
"""
|
||||
TD-MPC2 implicit world model architecture.
|
||||
|
||||
@@ -9,11 +9,8 @@ from envs.wrappers.tensor import TensorWrapper
|
||||
def missing_dependencies(task):
|
||||
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
|
||||
|
||||
|
||||
from envs.dmcontrol import make_env as make_dm_control_env
|
||||
|
||||
try:
|
||||
pass
|
||||
from envs.dmcontrol import make_env as make_dm_control_env
|
||||
except:
|
||||
make_dm_control_env = missing_dependencies
|
||||
try:
|
||||
@@ -67,8 +64,7 @@ def make_env(cfg):
|
||||
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
||||
try:
|
||||
env = fn(cfg)
|
||||
except ValueError as err:
|
||||
print(err)
|
||||
except ValueError:
|
||||
pass
|
||||
if env is None:
|
||||
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
|
||||
|
||||
@@ -29,7 +29,7 @@ def evaluate(cfg: dict):
|
||||
`eval_episodes`: number of episodes to evaluate on per task (default: 10)
|
||||
`save_video`: whether to save a video of the evaluation (default: True)
|
||||
`seed`: random seed (default: 1)
|
||||
|
||||
|
||||
See config.yaml for a full list of args.
|
||||
|
||||
Example usage:
|
||||
@@ -39,8 +39,7 @@ def evaluate(cfg: dict):
|
||||
$ python evaluate.py task=dog-run checkpoint=/path/to/dog-1.pt save_video=true
|
||||
```
|
||||
"""
|
||||
if torch.get_default_device().type == "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
assert torch.cuda.is_available()
|
||||
assert cfg.eval_episodes > 0, 'Must evaluate at least 1 episode.'
|
||||
cfg = parse_cfg(cfg)
|
||||
set_seed(cfg.seed)
|
||||
@@ -58,7 +57,7 @@ def evaluate(cfg: dict):
|
||||
agent = TDMPC2(cfg)
|
||||
assert os.path.exists(cfg.checkpoint), f'Checkpoint {cfg.checkpoint} not found! Must be a valid filepath.'
|
||||
agent.load(cfg.checkpoint)
|
||||
|
||||
|
||||
# Evaluate
|
||||
if cfg.multitask:
|
||||
print(colored(f'Evaluating agent on {len(cfg.tasks)} tasks:', 'yellow', attrs=['bold']))
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -8,7 +6,6 @@ from common.scale import RunningScale
|
||||
from common.world_model import WorldModel
|
||||
from tensordict import TensorDict
|
||||
|
||||
torch.set_default_device(os.getenv("TDMPC2_DEFAULT_DEVICE", "cuda:0"))
|
||||
|
||||
class TDMPC2(torch.nn.Module):
|
||||
"""
|
||||
@@ -20,7 +17,7 @@ class TDMPC2(torch.nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.device = torch.get_default_device()
|
||||
self.device = torch.device('cuda:0')
|
||||
self.model = WorldModel(cfg).to(self.device)
|
||||
self.optim = torch.optim.Adam([
|
||||
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
||||
@@ -35,7 +32,7 @@ class TDMPC2(torch.nn.Module):
|
||||
self.scale = RunningScale(cfg)
|
||||
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
|
||||
self.discount = torch.tensor(
|
||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device=torch.get_default_device()
|
||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||
if cfg.compile:
|
||||
@@ -91,16 +88,16 @@ class TDMPC2(torch.nn.Module):
|
||||
name_map = [
|
||||
"weight", "bias", "ln.weight", "ln.bias",
|
||||
]
|
||||
print("Listing state dict keys (from disk)")
|
||||
for k in list(local_state_dict.keys()):
|
||||
print("\t", k)
|
||||
# print("Listing state dict keys (from disk)")
|
||||
# for k in list(local_state_dict.keys()):
|
||||
# print("\t", k)
|
||||
|
||||
sd = model.state_dict()
|
||||
print("Listing dest state dict keys")
|
||||
for k in list(sd.keys()):
|
||||
print("\t", k)
|
||||
# print("Listing dest state dict keys")
|
||||
# for k in list(sd.keys()):
|
||||
# print("\t", k)
|
||||
|
||||
print("Maps:")
|
||||
# print("Maps:")
|
||||
new_sd = dict(sd)
|
||||
for cur_prefix in (prefix, "_target"+prefix[:-1]+"_"):
|
||||
for key, val in list(local_state_dict.items()):
|
||||
@@ -109,12 +106,12 @@ class TDMPC2(torch.nn.Module):
|
||||
num = key[len(cur_prefix + "params."):]
|
||||
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
|
||||
new_total_key = cur_prefix + 'params.' + new_key
|
||||
print("\t", key, '-->', new_total_key)
|
||||
# print("\t", key, '-->', new_total_key)
|
||||
del local_state_dict[key]
|
||||
new_sd[new_total_key] = val
|
||||
if not cur_prefix.startswith("_target"):
|
||||
new_total_key = "_detach" + cur_prefix[:-1] + "_" + 'params.' + new_key
|
||||
print("\t", 'DETACH', key, '-->', new_total_key)
|
||||
# print("\t", 'DETACH', key, '-->', new_total_key)
|
||||
new_sd[new_total_key] = val
|
||||
local_state_dict.update(new_sd)
|
||||
return local_state_dict
|
||||
|
||||
@@ -43,8 +43,7 @@ def train(cfg: dict):
|
||||
$ python train.py task=dog-run steps=7000000
|
||||
```
|
||||
"""
|
||||
if torch.get_default_device().type == 'cuda':
|
||||
assert torch.cuda.is_available()
|
||||
assert torch.cuda.is_available()
|
||||
assert cfg.steps > 0, 'Must train for at least 1 step.'
|
||||
cfg = parse_cfg(cfg)
|
||||
set_seed(cfg.seed)
|
||||
|
||||
Reference in New Issue
Block a user