mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
refactor(policies): rename policies/sac → policies/gaussian_actor
This commit is contained in:
@@ -820,10 +820,10 @@ The LeRobot system uses a distributed actor-learner architecture for training. T
|
||||
|
||||
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
|
||||
|
||||
1. Configure the policy settings (`type="sac"`, `device`, etc.)
|
||||
1. Configure the policy settings (`type="gaussian_actor"`, `device`, etc.)
|
||||
2. Set `dataset` to your cropped dataset
|
||||
3. Configure environment settings with crop parameters
|
||||
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
|
||||
4. Check the other parameters related to the Gaussian Actor in [configuration_gaussian_actor.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py#L79).
|
||||
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||
|
||||
**Starting the Learner**
|
||||
|
||||
@@ -7,9 +7,9 @@ import torch
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies import GaussianActorConfig
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
@@ -28,7 +28,7 @@ def run_learner(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_learner: SACPolicy,
|
||||
policy_learner: GaussianActorPolicy,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer,
|
||||
lr: float = 3e-4,
|
||||
@@ -116,7 +116,7 @@ def run_actor(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_actor: SACPolicy,
|
||||
policy_actor: GaussianActorPolicy,
|
||||
reward_classifier: Classifier,
|
||||
env_cfg: HILSerlRobotEnvConfig,
|
||||
device: torch.device = "mps",
|
||||
@@ -264,14 +264,14 @@ def main():
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
policy_cfg = GaussianActorConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
policy_actor = GaussianActorPolicy(policy_cfg)
|
||||
policy_learner = GaussianActorPolicy(policy_cfg)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
@@ -15,6 +15,10 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .gaussian_actor.reward_model.configuration_classifier import (
|
||||
RewardClassifierConfig as RewardClassifierConfig,
|
||||
)
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
@@ -22,8 +26,6 @@ from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .rtc import ActionInterpolator as ActionInterpolator
|
||||
from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
@@ -32,21 +34,21 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
|
||||
# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here.
|
||||
# NOTE: Policy modeling classes (e.g., GaussianActorPolicy) are intentionally NOT re-exported here.
|
||||
# They have heavy optional dependencies and are loaded lazily via get_policy_class().
|
||||
# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy``
|
||||
# Import directly: ``from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy``
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"RewardClassifierConfig",
|
||||
"SACConfig",
|
||||
"SARMConfig",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
|
||||
@@ -46,13 +46,13 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
|
||||
from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .gaussian_actor.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config
|
||||
from .pretrained import PreTrainedPolicy
|
||||
from .sac.configuration_sac import SACConfig
|
||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
@@ -89,7 +89,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "reward_classifier", "smolvla", "wall_x".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -128,12 +128,12 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "sac":
|
||||
from .sac.modeling_sac import SACPolicy
|
||||
elif name == "gaussian_actor":
|
||||
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
|
||||
return SACPolicy
|
||||
return GaussianActorPolicy
|
||||
elif name == "reward_classifier":
|
||||
from .sac.reward_model.modeling_classifier import Classifier
|
||||
from .gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
elif name == "smolvla":
|
||||
@@ -172,7 +172,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "reward_classifier", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
@@ -196,8 +196,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "gaussian_actor":
|
||||
return GaussianActorConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
@@ -370,16 +370,16 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from .sac.processor_sac import make_sac_pre_post_processors
|
||||
elif isinstance(policy_cfg, GaussianActorConfig):
|
||||
from .gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
|
||||
processors = make_sac_pre_post_processors(
|
||||
processors = make_gaussian_actor_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, RewardClassifierConfig):
|
||||
from .sac.reward_model.processor_classifier import make_classifier_processor
|
||||
from .gaussian_actor.reward_model.processor_classifier import make_classifier_processor
|
||||
|
||||
processors = make_classifier_processor(
|
||||
config=policy_cfg,
|
||||
|
||||
@@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_sac import SACConfig
|
||||
from .modeling_sac import SACPolicy
|
||||
from .processor_sac import make_sac_pre_post_processors
|
||||
from .configuration_gaussian_actor import GaussianActorConfig
|
||||
from .modeling_gaussian_actor import GaussianActorPolicy
|
||||
from .processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
|
||||
__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"]
|
||||
__all__ = ["GaussianActorConfig", "GaussianActorPolicy", "make_gaussian_actor_pre_post_processors"]
|
||||
@@ -75,18 +75,19 @@ class PolicyConfig:
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sac")
|
||||
@PreTrainedConfig.register_subclass("gaussian_actor")
|
||||
@dataclass
|
||||
class SACConfig(PreTrainedConfig):
|
||||
"""Soft Actor-Critic (SAC) configuration.
|
||||
class GaussianActorConfig(PreTrainedConfig):
|
||||
"""Gaussian actor configuration.
|
||||
|
||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
||||
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
||||
using experience collected from the environment.
|
||||
This configures the policy-side (actor + observation encoder) of a Gaussian
|
||||
policy, as used by SAC and related maximum-entropy continuous-control algorithms.
|
||||
By default the actor output is a tanh-squashed diagonal Gaussian
|
||||
(``TanhMultivariateNormalDiag``); the tanh squashing can be disabled via
|
||||
``policy_kwargs.use_tanh_squash``. The critics, temperature, and Bellman-update
|
||||
logic live on the algorithm side (see ``lerobot.rl.algorithms.sac``).
|
||||
|
||||
This configuration class contains all the parameters needed to define a SAC agent,
|
||||
including network architectures, optimization settings, and algorithm-specific
|
||||
hyperparameters.
|
||||
CLI: ``--policy.type=gaussian_actor``.
|
||||
"""
|
||||
|
||||
# Mapping of feature types to normalization modes
|
||||
@@ -29,22 +29,29 @@ from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import get_device_from_parameters
|
||||
from .configuration_sac import SACConfig, is_image_feature
|
||||
from .configuration_gaussian_actor import GaussianActorConfig, is_image_feature
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
class GaussianActorPolicy(
|
||||
PreTrainedPolicy,
|
||||
):
|
||||
"""SAC policy."""
|
||||
"""Gaussian actor + observation encoder.
|
||||
|
||||
config_class = SACConfig
|
||||
name = "sac"
|
||||
Policy-side ``nn.Module`` used by SAC and related maximum-entropy continuous
|
||||
control algorithms. It owns the actor network (``Policy``) and the observation
|
||||
encoder (``GaussianActorObservationEncoder``); the critics, temperature, and
|
||||
Bellman-update logic live on the algorithm side
|
||||
(see ``lerobot.rl.algorithms.sac``).
|
||||
"""
|
||||
|
||||
config_class = GaussianActorConfig
|
||||
name = "gaussian_actor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
config: GaussianActorConfig | None = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
@@ -73,7 +80,9 @@ class SACPolicy(
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
|
||||
raise NotImplementedError(
|
||||
"GaussianActorPolicy does not support action chunking. It returns single actions!"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -133,9 +142,9 @@ class SACPolicy(
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_critic = GaussianActorObservationEncoder(self.config)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
self.encoder_critic if self.shared_encoder else GaussianActorObservationEncoder(self.config)
|
||||
)
|
||||
|
||||
def _init_actor(self, continuous_action_dim):
|
||||
@@ -155,10 +164,10 @@ class SACPolicy(
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
class GaussianActorObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig) -> None:
|
||||
def __init__(self, config: GaussianActorConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._init_image_layers()
|
||||
@@ -411,7 +420,7 @@ class DiscreteCritic(nn.Module):
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
encoder: GaussianActorObservationEncoder,
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
std_min: float = -5,
|
||||
@@ -422,7 +431,7 @@ class Policy(nn.Module):
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder: SACObservationEncoder = encoder
|
||||
self.encoder: GaussianActorObservationEncoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.std_min = std_min
|
||||
@@ -496,7 +505,7 @@ class Policy(nn.Module):
|
||||
|
||||
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
def __init__(self, config: GaussianActorConfig):
|
||||
super().__init__()
|
||||
image_key = next(key for key in config.input_features if is_image_feature(key))
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
@@ -542,12 +551,12 @@ def freeze_image_encoder(image_encoder: nn.Module):
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
def __init__(self, config: GaussianActorConfig):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
def _load_pretrained_vision_encoder(self, config: GaussianActorConfig):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
@@ -32,18 +32,18 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from .configuration_sac import SACConfig
|
||||
from .configuration_gaussian_actor import GaussianActorConfig
|
||||
|
||||
|
||||
def make_sac_pre_post_processors(
|
||||
config: SACConfig,
|
||||
def make_gaussian_actor_pre_post_processors(
|
||||
config: GaussianActorConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the SAC policy.
|
||||
Constructs pre-processor and post-processor pipelines for the Gaussian actor policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
@@ -56,7 +56,7 @@ def make_sac_pre_post_processors(
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the SAC policy.
|
||||
config: The configuration object for the tanh-Gaussian policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
|
||||
Returns:
|
||||
@@ -557,7 +557,7 @@ class RewardClassifierProcessorStep(ProcessorStep):
|
||||
def __post_init__(self):
|
||||
"""Initializes the reward classifier model after the dataclass is created."""
|
||||
if self.pretrained_path is not None:
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
|
||||
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
|
||||
self.reward_classifier.to(self.device)
|
||||
|
||||
@@ -251,7 +251,7 @@ def act_with_policy(
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### To avoid sending a policy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
|
||||
@@ -19,7 +19,10 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sac.configuration_sac import CriticNetworkConfig, SACConfig
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
|
||||
CriticNetworkConfig,
|
||||
GaussianActorConfig,
|
||||
)
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -32,7 +35,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
"""SAC algorithm hyperparameters."""
|
||||
|
||||
# Policy config
|
||||
sac_config: SACConfig
|
||||
sac_config: GaussianActorConfig
|
||||
|
||||
# Optimizer learning rates
|
||||
actor_lr: float = 3e-4
|
||||
@@ -59,7 +62,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg: SACConfig) -> SACAlgorithmConfig:
|
||||
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
|
||||
"""Build an algorithm config by copying hyperparameters from the policy config."""
|
||||
return cls(
|
||||
actor_lr=policy_cfg.actor_lr,
|
||||
|
||||
@@ -26,12 +26,12 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.policies.sac.modeling_sac import (
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import (
|
||||
DISCRETE_DIMENSION_INDEX,
|
||||
MLP,
|
||||
DiscreteCritic,
|
||||
SACObservationEncoder,
|
||||
SACPolicy,
|
||||
GaussianActorObservationEncoder,
|
||||
GaussianActorPolicy,
|
||||
orthogonal_init,
|
||||
)
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
@@ -50,7 +50,7 @@ class SACAlgorithm(RLAlgorithm):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: SACPolicy,
|
||||
policy: GaussianActorPolicy,
|
||||
config: SACAlgorithmConfig,
|
||||
):
|
||||
self.config = config
|
||||
@@ -100,7 +100,9 @@ class SACAlgorithm(RLAlgorithm):
|
||||
self.discrete_critic, self.discrete_critic_target = self._init_discrete_critics(encoder)
|
||||
self.policy.discrete_critic = self.discrete_critic
|
||||
|
||||
def _init_discrete_critics(self, encoder: SACObservationEncoder) -> tuple[DiscreteCritic, DiscreteCritic]:
|
||||
def _init_discrete_critics(
|
||||
self, encoder: GaussianActorObservationEncoder
|
||||
) -> tuple[DiscreteCritic, DiscreteCritic]:
|
||||
"""Build discrete critic ensemble and target networks."""
|
||||
discrete_critic = DiscreteCritic(
|
||||
encoder=encoder,
|
||||
@@ -557,7 +559,7 @@ class CriticEnsemble(nn.Module):
|
||||
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
|
||||
|
||||
Args:
|
||||
encoder (SACObservationEncoder): encoder for observations.
|
||||
encoder (GaussianActorObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
@@ -566,7 +568,7 @@ class CriticEnsemble(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
encoder: GaussianActorObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
init_final: float | None = None,
|
||||
):
|
||||
|
||||
@@ -39,8 +39,8 @@ For more details, see the [Physical Intelligence π₀ blog post](https://www.ph
|
||||
π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
|
||||
|
||||
For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05).
|
||||
{% elif model_name == "sac" %}
|
||||
[Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) is an entropy-regularised actor-critic algorithm offering stable, sample-efficient learning in continuous-control environments.
|
||||
{% elif model_name == "gaussian_actor" %}
|
||||
This is a Gaussian Actor policy (Gaussian policy with a tanh squash) — the policy-side component used by [Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) and related maximum-entropy continuous-control algorithms.
|
||||
{% elif model_name == "reward_classifier" %}
|
||||
A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable.
|
||||
{% else %}
|
||||
|
||||
@@ -17,8 +17,8 @@
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.policies.gaussian_actor.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
@@ -38,7 +38,7 @@ def test_classifier_output():
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
@@ -79,7 +79,7 @@ def test_binary_classifier_with_default_params():
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
|
||||
num_classes = 5
|
||||
config = RewardClassifierConfig()
|
||||
@@ -118,7 +118,7 @@ def test_multiclass_classifier():
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_default_device():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
@@ -130,7 +130,7 @@ def test_default_device():
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig(device="cpu")
|
||||
assert config.device == "cpu"
|
||||
|
||||
@@ -17,19 +17,19 @@
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import (
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
|
||||
ActorLearnerConfig,
|
||||
ActorNetworkConfig,
|
||||
ConcurrencyConfig,
|
||||
CriticNetworkConfig,
|
||||
GaussianActorConfig,
|
||||
PolicyConfig,
|
||||
SACConfig,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def test_sac_config_default_initialization():
|
||||
config = SACConfig()
|
||||
def test_gaussian_actor_config_default_initialization():
|
||||
config = GaussianActorConfig()
|
||||
|
||||
assert config.normalization_mapping == {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
@@ -175,8 +175,8 @@ def test_concurrency_config():
|
||||
assert config.learner == "threads"
|
||||
|
||||
|
||||
def test_sac_config_custom_initialization():
|
||||
config = SACConfig(
|
||||
def test_gaussian_actor_config_custom_initialization():
|
||||
config = GaussianActorConfig(
|
||||
device="cpu",
|
||||
discount=0.95,
|
||||
temperature_init=0.5,
|
||||
@@ -190,7 +190,7 @@ def test_sac_config_custom_initialization():
|
||||
|
||||
|
||||
def test_validate_features():
|
||||
config = SACConfig(
|
||||
config = GaussianActorConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
@@ -198,7 +198,7 @@ def test_validate_features():
|
||||
|
||||
|
||||
def test_validate_features_missing_observation():
|
||||
config = SACConfig(
|
||||
config = GaussianActorConfig(
|
||||
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
@@ -209,7 +209,7 @@ def test_validate_features_missing_observation():
|
||||
|
||||
|
||||
def test_validate_features_missing_action():
|
||||
config = SACConfig(
|
||||
config = GaussianActorConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
|
||||
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
|
||||
)
|
||||
@@ -22,8 +22,8 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
@@ -81,9 +81,9 @@ def test_mlp_with_custom_final_activation():
|
||||
assert (y >= -1).all() and (y <= 1).all()
|
||||
|
||||
|
||||
def test_sac_policy_with_default_args():
|
||||
def test_gaussian_actor_policy_with_default_args():
|
||||
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
|
||||
SACPolicy()
|
||||
GaussianActorPolicy()
|
||||
|
||||
|
||||
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
|
||||
@@ -142,12 +142,12 @@ def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: i
|
||||
|
||||
def create_default_config(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
) -> GaussianActorConfig:
|
||||
action_dim = continuous_action_dim
|
||||
if has_discrete_action:
|
||||
action_dim += 1
|
||||
|
||||
config = SACConfig(
|
||||
config = GaussianActorConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
|
||||
dataset_stats={
|
||||
@@ -167,7 +167,7 @@ def create_default_config(
|
||||
|
||||
def create_config_with_visual_input(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
) -> GaussianActorConfig:
|
||||
config = create_default_config(
|
||||
state_dim=state_dim,
|
||||
continuous_action_dim=continuous_action_dim,
|
||||
@@ -186,9 +186,9 @@ def create_config_with_visual_input(
|
||||
return config
|
||||
|
||||
|
||||
def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]:
|
||||
def _make_algorithm(config: GaussianActorConfig) -> tuple[SACAlgorithm, GaussianActorPolicy]:
|
||||
"""Helper to create policy + algorithm pair for tests that need critics."""
|
||||
policy = SACPolicy(config=config)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.train()
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
@@ -197,9 +197,9 @@ def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
|
||||
def test_gaussian_actor_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -209,11 +209,11 @@ def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: i
|
||||
assert selected_action.shape[-1] == action_dim
|
||||
|
||||
|
||||
def test_sac_policy_select_action_with_discrete():
|
||||
def test_gaussian_actor_policy_select_action_with_discrete():
|
||||
"""select_action should return continuous + discrete actions."""
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.num_discrete_actions = 3
|
||||
policy = SACPolicy(config=config)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -225,9 +225,9 @@ def test_sac_policy_select_action_with_discrete():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_forward(batch_size: int, state_dim: int, action_dim: int):
|
||||
def test_gaussian_actor_policy_forward(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
@@ -307,7 +307,7 @@ def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_
|
||||
[(1, 6, 6, "lerobot/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||
)
|
||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||
def test_sac_policy_with_pretrained_encoder(
|
||||
def test_gaussian_actor_policy_with_pretrained_encoder(
|
||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
||||
):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
@@ -415,7 +415,7 @@ def test_sac_algorithm_target_entropy_with_discrete_action():
|
||||
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
||||
config.num_discrete_actions = 5
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = SACPolicy(config=config)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
assert algorithm.target_entropy == -3.5
|
||||
|
||||
@@ -425,7 +425,7 @@ def test_sac_algorithm_temperature():
|
||||
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = SACPolicy(config=config)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
|
||||
assert algorithm.temperature == pytest.approx(1.0)
|
||||
@@ -437,7 +437,7 @@ def test_sac_algorithm_update_target_network():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.critic_target_update_weight = 1.0
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = SACPolicy(config=config)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
|
||||
for p in algorithm.critic_ensemble.parameters():
|
||||
@@ -472,7 +472,7 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
|
||||
def test_sac_policy_save_and_load(tmp_path):
|
||||
def test_gaussian_actor_policy_save_and_load(tmp_path):
|
||||
"""Test that the policy can be saved and loaded from pretrained."""
|
||||
root = tmp_path / "test_sac_save_and_load"
|
||||
|
||||
@@ -481,10 +481,10 @@ def test_sac_policy_save_and_load(tmp_path):
|
||||
batch_size = 2
|
||||
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.save_pretrained(root)
|
||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
@@ -503,7 +503,7 @@ def test_sac_policy_save_and_load(tmp_path):
|
||||
assert torch.allclose(actions, loaded_actions)
|
||||
|
||||
|
||||
def test_sac_policy_save_and_load_with_discrete_critic(tmp_path):
|
||||
def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path):
|
||||
"""Discrete critic should be saved/loaded as part of the policy."""
|
||||
root = tmp_path / "test_sac_save_and_load_discrete"
|
||||
|
||||
@@ -512,11 +512,11 @@ def test_sac_policy_save_and_load_with_discrete_critic(tmp_path):
|
||||
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_discrete_actions = 3
|
||||
policy = SACPolicy(config=config)
|
||||
policy = GaussianActorPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.save_pretrained(root)
|
||||
|
||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
assert loaded_policy.discrete_critic is not None
|
||||
@@ -21,8 +21,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||
from lerobot.policies.gaussian_actor.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.gaussian_actor.reward_model.processor_classifier import make_classifier_processor
|
||||
from lerobot.processor import (
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
|
||||
@@ -21,8 +21,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from lerobot.policies.gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
@@ -38,7 +38,7 @@ from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default SAC configuration for testing."""
|
||||
config = SACConfig()
|
||||
config = GaussianActorConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
}
|
||||
@@ -66,7 +66,7 @@ def test_make_sac_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -88,12 +88,12 @@ def test_make_sac_processor_basic():
|
||||
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
|
||||
|
||||
|
||||
def test_sac_processor_normalization_modes():
|
||||
def test_gaussian_actor_processor_normalization_modes():
|
||||
"""Test that SAC processor correctly handles different normalization modes."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -121,13 +121,13 @@ def test_sac_processor_normalization_modes():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_cuda():
|
||||
def test_gaussian_actor_processor_cuda():
|
||||
"""Test SAC processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -153,13 +153,13 @@ def test_sac_processor_cuda():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_accelerate_scenario():
|
||||
def test_gaussian_actor_processor_accelerate_scenario():
|
||||
"""Test SAC processor in simulated Accelerate scenario."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -180,13 +180,13 @@ def test_sac_processor_accelerate_scenario():
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_sac_processor_multi_gpu():
|
||||
def test_gaussian_actor_processor_multi_gpu():
|
||||
"""Test SAC processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -206,11 +206,11 @@ def test_sac_processor_multi_gpu():
|
||||
assert processed[TransitionKey.ACTION.value].device == device
|
||||
|
||||
|
||||
def test_sac_processor_without_stats():
|
||||
def test_gaussian_actor_processor_without_stats():
|
||||
"""Test SAC processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
@@ -226,12 +226,12 @@ def test_sac_processor_without_stats():
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_sac_processor_save_and_load():
|
||||
def test_gaussian_actor_processor_save_and_load():
|
||||
"""Test saving and loading SAC processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -257,14 +257,14 @@ def test_sac_processor_save_and_load():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_mixed_precision():
|
||||
def test_gaussian_actor_processor_mixed_precision():
|
||||
"""Test SAC processor with mixed precision."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -304,12 +304,12 @@ def test_sac_processor_mixed_precision():
|
||||
assert processed[TransitionKey.ACTION.value].dtype == torch.float16
|
||||
|
||||
|
||||
def test_sac_processor_batch_data():
|
||||
def test_gaussian_actor_processor_batch_data():
|
||||
"""Test SAC processor with batched data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -329,12 +329,12 @@ def test_sac_processor_batch_data():
|
||||
assert processed[TransitionKey.ACTION.value].shape == (batch_size, 5)
|
||||
|
||||
|
||||
def test_sac_processor_edge_cases():
|
||||
def test_gaussian_actor_processor_edge_cases():
|
||||
"""Test SAC processor with edge cases."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -358,13 +358,13 @@ def test_sac_processor_edge_cases():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_sac_processor_bfloat16_device_float32_normalizer():
|
||||
def test_gaussian_actor_processor_bfloat16_device_float32_normalizer():
|
||||
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, _ = make_sac_pre_post_processors(
|
||||
preprocessor, _ = make_gaussian_actor_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
)
|
||||
@@ -28,7 +28,7 @@ from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import skip_if_package_missing
|
||||
@@ -81,7 +81,7 @@ def cfg():
|
||||
|
||||
port = find_free_port()
|
||||
|
||||
policy_cfg = SACConfig()
|
||||
policy_cfg = GaussianActorConfig()
|
||||
policy_cfg.actor_learner_config.learner_host = "127.0.0.1"
|
||||
policy_cfg.actor_learner_config.learner_port = port
|
||||
policy_cfg.concurrency.actor = "threads"
|
||||
@@ -312,7 +312,7 @@ def test_learner_algorithm_wiring():
|
||||
"""Verify that make_algorithm constructs an SACAlgorithm from config,
|
||||
make_optimizers_and_scheduler() creates the right optimizers, update() works, and
|
||||
get_weights() output is serializable."""
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm
|
||||
from lerobot.transport.utils import state_to_bytes
|
||||
@@ -320,7 +320,7 @@ def test_learner_algorithm_wiring():
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
|
||||
sac_cfg = SACConfig(
|
||||
sac_cfg = GaussianActorConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
@@ -331,7 +331,7 @@ def test_learner_algorithm_wiring():
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
@@ -399,13 +399,13 @@ def test_learner_algorithm_wiring():
|
||||
def test_initial_and_periodic_weight_push_consistency():
|
||||
"""Both initial and periodic weight pushes should use algorithm.get_weights()
|
||||
and produce identical structures."""
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
sac_cfg = SACConfig(
|
||||
sac_cfg = GaussianActorConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
@@ -416,7 +416,7 @@ def test_initial_and_periodic_weight_push_consistency():
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
@@ -437,13 +437,13 @@ def test_initial_and_periodic_weight_push_consistency():
|
||||
|
||||
def test_actor_side_algorithm_select_action_and_load_weights():
|
||||
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
sac_cfg = SACConfig(
|
||||
sac_cfg = GaussianActorConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
@@ -455,7 +455,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
|
||||
sac_cfg.validate_features()
|
||||
|
||||
# Actor side: no optimizers
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.eval()
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
|
||||
@@ -22,8 +22,8 @@ pytest.importorskip("grpc")
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
@@ -47,8 +47,8 @@ def _make_sac_config(
|
||||
utd_ratio: int = 1,
|
||||
policy_update_freq: int = 1,
|
||||
with_images: bool = False,
|
||||
) -> SACConfig:
|
||||
config = SACConfig(
|
||||
) -> GaussianActorConfig:
|
||||
config = GaussianActorConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
@@ -79,7 +79,7 @@ def _make_algorithm(
|
||||
policy_update_freq: int = 1,
|
||||
num_discrete_actions: int | None = None,
|
||||
with_images: bool = False,
|
||||
) -> tuple[SACAlgorithm, SACPolicy]:
|
||||
) -> tuple[SACAlgorithm, GaussianActorPolicy]:
|
||||
sac_cfg = _make_sac_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
@@ -88,7 +88,7 @@ def _make_algorithm(
|
||||
num_discrete_actions=num_discrete_actions,
|
||||
with_images=with_images,
|
||||
)
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
@@ -349,7 +349,7 @@ def test_optimization_step_can_be_set_for_resume():
|
||||
|
||||
def test_make_algorithm_returns_sac_for_sac_policy():
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
@@ -358,7 +358,7 @@ def test_make_algorithm_returns_sac_for_sac_policy():
|
||||
def test_make_optimizers_creates_expected_keys():
|
||||
"""make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
optimizers = algorithm.make_optimizers_and_scheduler()
|
||||
assert "actor" in optimizers
|
||||
@@ -371,7 +371,7 @@ def test_make_optimizers_creates_expected_keys():
|
||||
def test_actor_side_no_optimizers():
|
||||
"""Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
@@ -379,7 +379,7 @@ def test_actor_side_no_optimizers():
|
||||
|
||||
def test_make_algorithm_copies_config_fields():
|
||||
sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3)
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert algorithm.config.utd_ratio == 5
|
||||
assert algorithm.config.policy_update_freq == 3
|
||||
@@ -404,7 +404,7 @@ def test_load_weights_round_trip():
|
||||
algo_src.update(_batch_iterator())
|
||||
|
||||
sac_cfg = _make_sac_config(state_dim=10, action_dim=6)
|
||||
policy_dst = SACPolicy(config=sac_cfg)
|
||||
policy_dst = GaussianActorPolicy(config=sac_cfg)
|
||||
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
|
||||
|
||||
weights = algo_src.get_weights()
|
||||
@@ -423,7 +423,7 @@ def test_load_weights_round_trip_with_discrete_critic():
|
||||
algo_src.update(_batch_iterator(action_dim=7))
|
||||
|
||||
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
|
||||
policy_dst = SACPolicy(config=sac_cfg)
|
||||
policy_dst = GaussianActorPolicy(config=sac_cfg)
|
||||
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
|
||||
|
||||
weights = algo_src.get_weights()
|
||||
@@ -470,7 +470,7 @@ def test_build_algorithm_via_config():
|
||||
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
|
||||
sac_cfg = _make_sac_config(utd_ratio=2)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
|
||||
algorithm = algo_config.build_algorithm(policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
@@ -480,6 +480,6 @@ def test_build_algorithm_via_config():
|
||||
def test_make_algorithm_uses_build_algorithm():
|
||||
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy = GaussianActorPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
|
||||
Reference in New Issue
Block a user