minor QoL improvements in offline pipeline

This commit is contained in:
Nicklas Hansen
2024-10-27 14:24:19 -07:00
parent 836547d76f
commit c1dd0c0338
5 changed files with 9 additions and 7 deletions

View File

@@ -0,0 +1 @@
for i in {0..3}; do wget https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt30/chunk_${i}.pt?download=true && mv chunk_${i}.pt'?download=true' chunk_${i}.pt; done

View File

@@ -0,0 +1 @@
for i in {0..19}; do wget https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt80/chunk_${i}.pt?download=true && mv chunk_${i}.pt'?download=true' chunk_${i}.pt; done

View File

@@ -1,4 +1,4 @@
name: graph
name: tdmpc2
channels:
- pytorch-nightly
- nvidia

View File

@@ -53,7 +53,7 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
if cfg.multitask:
cfg.task_title = cfg.task.upper()
# Account for slight inconsistency in task_dim for the mt30 experiments
cfg.task_dim = 96 if cfg.task == 'mt80' or cfg.model_size in {1, 317} else 64
cfg.task_dim = 96 if cfg.task == 'mt80' or cfg.get('model_size', 5) in {1, 317} else 64
else:
cfg.task_dim = 0
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])

View File

@@ -44,13 +44,12 @@ class OfflineTrainer(Trainer):
'Offline training only supports multitask training with mt30 or mt80 task sets.'
# Load data
assert self.cfg.task in self.cfg.data_dir, \
f'Expected data directory {self.cfg.data_dir} to contain {self.cfg.task}, ' \
f'please double-check your config.'
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
fps = sorted(glob(str(fp)))
assert len(fps) > 0, f'No data found at {fp}'
print(f'Found {len(fps)} files in {fp}')
assert len(fps) == (20 if self.cfg.task == 'mt80' else 4), \
f'Expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.'
# Create buffer for sampling
_cfg = deepcopy(self.cfg)
@@ -65,8 +64,9 @@ class OfflineTrainer(Trainer):
f'please double-check your config.'
for i in range(len(td)):
self.buffer.add(td[i])
assert self.buffer.num_eps == self.buffer.capacity, \
f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.'
expected_episodes = _cfg.buffer_size // _cfg.episode_length
assert self.buffer.num_eps == expected_episodes, \
f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.'
print(f'Training agent for {self.cfg.steps} iterations...')
metrics = {}