mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
add
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_STATE, OBS_IMAGE_4, TASK, ROBOT
|
||||
from lerobot.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_STATE, OBS_IMAGE_4, TASK, ROBOT_TYPE
|
||||
|
||||
IMAGES_ORDER = {
|
||||
OBS_IMAGE: 0,
|
||||
@@ -16,9 +16,9 @@ ROBOT_TYPE_KEYS_MAPPING = {
|
||||
"lerobot/taco_play": "static_single_arm_7statedim",
|
||||
}
|
||||
TRAINING_FEATURES = {
|
||||
0: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE],
|
||||
1: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2],
|
||||
2: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3],
|
||||
0: [ACTION, OBS_STATE, TASK, ROBOT_TYPE, OBS_IMAGE],
|
||||
1: [ACTION, OBS_STATE, TASK, ROBOT_TYPE, OBS_IMAGE, OBS_IMAGE_2],
|
||||
2: [ACTION, OBS_STATE, TASK, ROBOT_TYPE, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3],
|
||||
}
|
||||
# Map to "observation.state", "action", "observation.image", etc.
|
||||
FEATURE_KEYS_MAPPING = {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -32,6 +32,7 @@ from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
|
||||
@@ -74,6 +75,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "smolvla2":
|
||||
from lerobot.policies.smolvla2.modeling_smolvla2 import SmolVLA2Policy
|
||||
|
||||
return SmolVLA2Policy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -95,6 +100,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "smolvla2":
|
||||
return SmolVLA2Config(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
else:
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
@@ -64,18 +64,18 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_STATE
|
||||
from lerobot.common.policies.normalize import (
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import (
|
||||
Normalize,
|
||||
Unnormalize,
|
||||
)
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config
|
||||
from lerobot.common.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.common.policies.utils import (
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.policies.utils import (
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.common.utils.utils import get_safe_dtype
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
from lerobot.datasets import IMAGES_ORDER
|
||||
|
||||
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
||||
Reference in New Issue
Block a user