modified based on author's implementation

This commit is contained in:
NM512
2023-03-18 08:38:23 +09:00
parent a678a509b9
commit 6273444394
6 changed files with 371 additions and 229 deletions

View File

@@ -8,9 +8,9 @@ defaults:
seed: 0
steps: 1e6
eval_every: 1e4
eval_episode_num: 10
log_every: 1e4
reset_every: 0
#gpu_growth: True
device: 'cuda:0'
precision: 16
debug: False
@@ -25,9 +25,6 @@ defaults:
grayscale: False
prefill: 2500
eval_noise: 0.0
reward_trans: 'symlog'
obs_trans: 'normalize'
critic_trans: 'symlog'
reward_EMA: True
# Model
@@ -36,8 +33,8 @@ defaults:
dyn_deter: 512
dyn_stoch: 32
dyn_discrete: 32
dyn_input_layers: 2
dyn_output_layers: 2
dyn_input_layers: 1
dyn_output_layers: 1
dyn_rec_depth: 1
dyn_shared: False
dyn_mean_act: 'none'
@@ -53,11 +50,10 @@ defaults:
act: 'SiLU'
norm: 'LayerNorm'
cnn_depth: 32
encoder_kernels: [3, 3, 3, 3]
decoder_kernels: [3, 3, 3, 3]
# changed here
value_head: 'twohot'
reward_head: 'twohot'
encoder_kernels: [4, 4, 4, 4]
decoder_kernels: [4, 4, 4, 4]
value_head: 'twohot_symlog'
reward_head: 'twohot_symlog'
kl_lscale: '0.1'
kl_rscale: '0.5'
kl_free: '1.0'
@@ -71,7 +67,7 @@ defaults:
# Training
batch_size: 16
batch_length: 64
train_every: 5
train_ratio: 512
train_steps: 1
pretrain: 100
model_lr: 1e-4
@@ -85,9 +81,8 @@ defaults:
dataset_size: 0
oversample_ends: False
slow_value_target: True
slow_actor_target: True
slow_target_update: 100
slow_target_fraction: 0.01
slow_target_update: 1
slow_target_fraction: 0.02
opt: 'adam'
# Behavior.
@@ -95,16 +90,15 @@ defaults:
discount_lambda: 0.95
imag_horizon: 15
imag_gradient: 'dynamics'
imag_gradient_mix: '0.1'
imag_gradient_mix: '0.0'
imag_sample: True
actor_dist: 'trunc_normal'
actor_dist: 'normal'
actor_entropy: '3e-4'
actor_state_entropy: 0.0
actor_init_std: 1.0
actor_min_std: 0.1
actor_disc: 5
actor_max_std: 1.0
actor_temp: 0.1
actor_outscale: 0.0
expl_amount: 0.0
eval_state_mean: False
collect_dyn_sample: True
@@ -134,3 +128,14 @@ debug:
batch_size: 10
batch_length: 20
cheetah:
task: 'dmc_cheetah_run'
pendulum:
task: 'dmc_pendulum_swingup'
cup:
task: 'dmc_cup_catch'
acrobot:
task: 'dmc_acrobot_swingup'