mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Merge branch 'feat/add_relative_action_pi_models' into feat/mirror
This commit is contained in:
@@ -254,16 +254,20 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
from lerobot.processor.delta_action_processor import to_delta_actions
|
||||
|
||||
max_samples = min(100_000, len(dataset))
|
||||
indices = np.random.choice(len(dataset), max_samples, replace=False)
|
||||
indices = np.random.choice(len(dataset), max_samples, replace=False).tolist()
|
||||
logging.info(
|
||||
f"use_delta_actions is enabled — computing delta action stats from {max_samples} dataset chunks"
|
||||
)
|
||||
|
||||
# Read only action and state from parquet (no video decoding)
|
||||
hf = dataset.hf_dataset
|
||||
actions_raw = hf.select(indices)["action"]
|
||||
states_raw = hf.select(indices)["observation.state"]
|
||||
|
||||
all_delta_actions = []
|
||||
for i in indices:
|
||||
item = dataset[int(i)]
|
||||
action = item["action"]
|
||||
state = item["observation.state"]
|
||||
for action, state in zip(actions_raw, states_raw):
|
||||
action = torch.as_tensor(action).float()
|
||||
state = torch.as_tensor(state).float()
|
||||
if action.ndim == 1:
|
||||
action = action.unsqueeze(0)
|
||||
mask = [True] * action.shape[-1]
|
||||
|
||||
Reference in New Issue
Block a user