This commit is contained in:
danaaubakirova
2025-07-09 14:22:34 +02:00
parent c8b51ef205
commit 67c8d27e9c
6 changed files with 1353 additions and 13 deletions

View File

@@ -1,4 +1,4 @@
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_STATE, OBS_IMAGE_4, TASK, ROBOT
from lerobot.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_STATE, OBS_IMAGE_4, TASK, ROBOT_TYPE
IMAGES_ORDER = {
OBS_IMAGE: 0,
@@ -16,9 +16,9 @@ ROBOT_TYPE_KEYS_MAPPING = {
"lerobot/taco_play": "static_single_arm_7statedim",
}
TRAINING_FEATURES = {
0: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE],
1: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2],
2: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3],
0: [ACTION, OBS_STATE, TASK, ROBOT_TYPE, OBS_IMAGE],
1: [ACTION, OBS_STATE, TASK, ROBOT_TYPE, OBS_IMAGE, OBS_IMAGE_2],
2: [ACTION, OBS_STATE, TASK, ROBOT_TYPE, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3],
}
# Map to "observation.state", "action", "observation.image", etc.
FEATURE_KEYS_MAPPING = {

File diff suppressed because it is too large Load Diff

View File

@@ -32,6 +32,7 @@ from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
@@ -74,6 +75,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "smolvla2":
from lerobot.policies.smolvla2.modeling_smolvla2 import SmolVLA2Policy
return SmolVLA2Policy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -95,6 +100,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SACConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "smolvla2":
return SmolVLA2Config(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
else:

View File

@@ -14,8 +14,8 @@
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.configs.policies import PreTrainedConfig

View File

@@ -64,18 +64,18 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoProcessor
from lerobot.common.constants import ACTION, OBS_STATE
from lerobot.common.policies.normalize import (
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import (
Normalize,
Unnormalize,
)
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config
from lerobot.common.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.common.policies.utils import (
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.utils import (
populate_queues,
)
from lerobot.common.utils.utils import get_safe_dtype
from lerobot.utils.utils import get_safe_dtype
from lerobot.datasets import IMAGES_ORDER
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker