perf: remove redundant CPU→GPU→CPU transition move in learner

This commit is contained in:
Khalil Meftah
2026-04-13 19:06:28 +02:00
parent ee0814ef60
commit a8838c081b

View File

@@ -97,7 +97,6 @@ from lerobot.utils.train_utils import (
save_checkpoint,
update_last_checkpoint,
)
from lerobot.utils.transition import move_transition_to_device
from lerobot.utils.utils import (
format_big_number,
init_logging,
@@ -382,7 +381,6 @@ def add_actor_information_and_train(
transition_queue=transition_queue,
replay_buffer=replay_buffer,
offline_replay_buffer=offline_replay_buffer,
device=device,
dataset_repo_id=dataset_repo_id,
shutdown_event=shutdown_event,
)
@@ -906,7 +904,6 @@ def process_transitions(
transition_queue: Queue,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
device: str,
dataset_repo_id: str | None,
shutdown_event: any,
):
@@ -916,7 +913,6 @@ def process_transitions(
transition_queue: Queue for receiving transitions from the actor
replay_buffer: Replay buffer to add transitions to
offline_replay_buffer: Offline replay buffer to add transitions to
device: Device to move transitions to
dataset_repo_id: Repository ID for dataset
shutdown_event: Event to signal shutdown
"""
@@ -925,8 +921,6 @@ def process_transitions(
transition_list = bytes_to_transitions(buffer=transition_list)
for transition in transition_list:
transition = move_transition_to_device(transition=transition, device=device)
# Skip transitions with NaN values
if check_nan_in_transition(
observations=transition["state"],