modifications to gym_manipulator and buffer

This commit is contained in:
Michel Aractingi
2025-04-07 08:45:53 +02:00
parent ab2c2d39fb
commit f3cea2a3e5
8 changed files with 76 additions and 78 deletions

View File

@@ -380,6 +380,7 @@ def add_actor_information_and_train(
for _ in range(utd_ratio - 1):
# Sample from the iterators
batch = next(online_iterator)
# batch = replay_buffer.sample(batch_size)
if dataset_repo_id is not None:
batch_offline = next(offline_iterator)
@@ -437,9 +438,11 @@ def add_actor_information_and_train(
# Sample for the last update in the UTD ratio
batch = next(online_iterator)
# batch = replay_buffer.sample(batch_size)
if dataset_repo_id is not None:
batch_offline = next(offline_iterator)
# batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
@@ -775,9 +778,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
params=policy.actor.parameters_to_optimize,
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters_to_optimize, lr=cfg.policy.critic_lr)
if cfg.policy.num_discrete_actions is not None:
optimizer_grasp_critic = torch.optim.Adam(
@@ -992,7 +993,6 @@ def initialize_offline_replay_buffer(
device=device,
state_keys=cfg.policy.input_features.keys(),
action_mask=active_action_dims,
action_delta=cfg.env.wrapper.delta_action,
storage_device=storage_device,
optimize_memory=True,
capacity=cfg.policy.offline_buffer_capacity,