From c1dd0c0338dbb2acc95ee506cf1907087311ca44 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 27 Oct 2024 14:24:19 -0700 Subject: [PATCH] minor QoL improvements in offline pipeline --- datasets/download_mt30.sh | 1 + datasets/download_mt80.sh | 1 + docker/environment.yaml | 2 +- tdmpc2/common/parser.py | 2 +- tdmpc2/trainer/offline_trainer.py | 10 +++++----- 5 files changed, 9 insertions(+), 7 deletions(-) create mode 100644 datasets/download_mt30.sh create mode 100644 datasets/download_mt80.sh diff --git a/datasets/download_mt30.sh b/datasets/download_mt30.sh new file mode 100644 index 0000000..2073bcb --- /dev/null +++ b/datasets/download_mt30.sh @@ -0,0 +1 @@ +for i in {0..3}; do wget https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt30/chunk_${i}.pt?download=true && mv chunk_${i}.pt'?download=true' chunk_${i}.pt; done \ No newline at end of file diff --git a/datasets/download_mt80.sh b/datasets/download_mt80.sh new file mode 100644 index 0000000..01a7c46 --- /dev/null +++ b/datasets/download_mt80.sh @@ -0,0 +1 @@ +for i in {0..19}; do wget https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt80/chunk_${i}.pt?download=true && mv chunk_${i}.pt'?download=true' chunk_${i}.pt; done \ No newline at end of file diff --git a/docker/environment.yaml b/docker/environment.yaml index 857c81a..9425459 100644 --- a/docker/environment.yaml +++ b/docker/environment.yaml @@ -1,4 +1,4 @@ -name: graph +name: tdmpc2 channels: - pytorch-nightly - nvidia diff --git a/tdmpc2/common/parser.py b/tdmpc2/common/parser.py index ddce2b4..378ba4a 100755 --- a/tdmpc2/common/parser.py +++ b/tdmpc2/common/parser.py @@ -53,7 +53,7 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf: if cfg.multitask: cfg.task_title = cfg.task.upper() # Account for slight inconsistency in task_dim for the mt30 experiments - cfg.task_dim = 96 if cfg.task == 'mt80' or cfg.model_size in {1, 317} else 64 + cfg.task_dim = 96 if cfg.task == 'mt80' or cfg.get('model_size', 5) in {1, 317} else 64 else: cfg.task_dim = 0 cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) diff --git a/tdmpc2/trainer/offline_trainer.py b/tdmpc2/trainer/offline_trainer.py index 1bace8e..a4289d9 100755 --- a/tdmpc2/trainer/offline_trainer.py +++ b/tdmpc2/trainer/offline_trainer.py @@ -44,13 +44,12 @@ class OfflineTrainer(Trainer): 'Offline training only supports multitask training with mt30 or mt80 task sets.' # Load data - assert self.cfg.task in self.cfg.data_dir, \ - f'Expected data directory {self.cfg.data_dir} to contain {self.cfg.task}, ' \ - f'please double-check your config.' fp = Path(os.path.join(self.cfg.data_dir, '*.pt')) fps = sorted(glob(str(fp))) assert len(fps) > 0, f'No data found at {fp}' print(f'Found {len(fps)} files in {fp}') + assert len(fps) == (20 if self.cfg.task == 'mt80' else 4), \ + f'Expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.' # Create buffer for sampling _cfg = deepcopy(self.cfg) @@ -65,8 +64,9 @@ class OfflineTrainer(Trainer): f'please double-check your config.' for i in range(len(td)): self.buffer.add(td[i]) - assert self.buffer.num_eps == self.buffer.capacity, \ - f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.' + expected_episodes = _cfg.buffer_size // _cfg.episode_length + assert self.buffer.num_eps == expected_episodes, \ + f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.' print(f'Training agent for {self.cfg.steps} iterations...') metrics = {}