minor QoL improvements in offline pipeline
This commit is contained in:
1
datasets/download_mt30.sh
Normal file
1
datasets/download_mt30.sh
Normal file
@@ -0,0 +1 @@
|
|||||||
|
for i in {0..3}; do wget https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt30/chunk_${i}.pt?download=true && mv chunk_${i}.pt'?download=true' chunk_${i}.pt; done
|
||||||
1
datasets/download_mt80.sh
Normal file
1
datasets/download_mt80.sh
Normal file
@@ -0,0 +1 @@
|
|||||||
|
for i in {0..19}; do wget https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt80/chunk_${i}.pt?download=true && mv chunk_${i}.pt'?download=true' chunk_${i}.pt; done
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
name: graph
|
name: tdmpc2
|
||||||
channels:
|
channels:
|
||||||
- pytorch-nightly
|
- pytorch-nightly
|
||||||
- nvidia
|
- nvidia
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
|||||||
if cfg.multitask:
|
if cfg.multitask:
|
||||||
cfg.task_title = cfg.task.upper()
|
cfg.task_title = cfg.task.upper()
|
||||||
# Account for slight inconsistency in task_dim for the mt30 experiments
|
# Account for slight inconsistency in task_dim for the mt30 experiments
|
||||||
cfg.task_dim = 96 if cfg.task == 'mt80' or cfg.model_size in {1, 317} else 64
|
cfg.task_dim = 96 if cfg.task == 'mt80' or cfg.get('model_size', 5) in {1, 317} else 64
|
||||||
else:
|
else:
|
||||||
cfg.task_dim = 0
|
cfg.task_dim = 0
|
||||||
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
||||||
|
|||||||
@@ -44,13 +44,12 @@ class OfflineTrainer(Trainer):
|
|||||||
'Offline training only supports multitask training with mt30 or mt80 task sets.'
|
'Offline training only supports multitask training with mt30 or mt80 task sets.'
|
||||||
|
|
||||||
# Load data
|
# Load data
|
||||||
assert self.cfg.task in self.cfg.data_dir, \
|
|
||||||
f'Expected data directory {self.cfg.data_dir} to contain {self.cfg.task}, ' \
|
|
||||||
f'please double-check your config.'
|
|
||||||
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
|
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
|
||||||
fps = sorted(glob(str(fp)))
|
fps = sorted(glob(str(fp)))
|
||||||
assert len(fps) > 0, f'No data found at {fp}'
|
assert len(fps) > 0, f'No data found at {fp}'
|
||||||
print(f'Found {len(fps)} files in {fp}')
|
print(f'Found {len(fps)} files in {fp}')
|
||||||
|
assert len(fps) == (20 if self.cfg.task == 'mt80' else 4), \
|
||||||
|
f'Expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.'
|
||||||
|
|
||||||
# Create buffer for sampling
|
# Create buffer for sampling
|
||||||
_cfg = deepcopy(self.cfg)
|
_cfg = deepcopy(self.cfg)
|
||||||
@@ -65,8 +64,9 @@ class OfflineTrainer(Trainer):
|
|||||||
f'please double-check your config.'
|
f'please double-check your config.'
|
||||||
for i in range(len(td)):
|
for i in range(len(td)):
|
||||||
self.buffer.add(td[i])
|
self.buffer.add(td[i])
|
||||||
assert self.buffer.num_eps == self.buffer.capacity, \
|
expected_episodes = _cfg.buffer_size // _cfg.episode_length
|
||||||
f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.'
|
assert self.buffer.num_eps == expected_episodes, \
|
||||||
|
f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.'
|
||||||
|
|
||||||
print(f'Training agent for {self.cfg.steps} iterations...')
|
print(f'Training agent for {self.cfg.steps} iterations...')
|
||||||
metrics = {}
|
metrics = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user