From 9b47c5fac978624edd9744881c986ccba675d8ff Mon Sep 17 00:00:00 2001 From: Bryson Jones Date: Thu, 11 Dec 2025 09:04:38 -0800 Subject: [PATCH] move policy to top of file --- .../multi_task_dit/modeling_multi_task_dit.py | 270 +++++++++--------- 1 file changed, 135 insertions(+), 135 deletions(-) diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py index 5ae7d8480..dbe4910b8 100644 --- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -44,6 +44,141 @@ from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import populate_queues from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE +# -- Policy -- + + +class MultiTaskDiTPolicy(PreTrainedPolicy): + config_class = MultiTaskDiTConfig + name = "multi_task_dit" + + def __init__(self, config: MultiTaskDiTConfig): + super().__init__(config) + config.validate_features() + self.config = config + + self._queues = None + + self.observation_encoder = ObservationEncoder(config) + conditioning_dim = self.observation_encoder.conditioning_dim + self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim) + + action_dim = config.action_feature.shape[0] + horizon = config.horizon + + if config.is_diffusion: + self.objective = DiffusionObjective( + config, + action_dim=action_dim, + horizon=horizon, + do_mask_loss_for_padding=config.do_mask_loss_for_padding, + ) + elif config.is_flow_matching: + self.objective = FlowMatchingObjective( + config, + action_dim=action_dim, + horizon=horizon, + do_mask_loss_for_padding=config.do_mask_loss_for_padding, + ) + else: + raise ValueError(f"Unsupported objective: {config.objective}") + + self.reset() + + def get_optim_params(self) -> list: + """Returns parameter groups with different learning rates for vision vs non-vision parameters""" + non_vision_params = [] + vision_encoder_params = [] + + for name, param in self.named_parameters(): + if not param.requires_grad: + continue + + if "observation_encoder.vision_encoder" in name: + vision_encoder_params.append(param) + else: + non_vision_params.append(param) + + return [ + {"params": non_vision_params}, + { + "params": vision_encoder_params, + "lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier, + }, + ] + + def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor: + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert n_obs_steps == self.config.n_obs_steps + + conditioning_vec = self.observation_encoder.encode(batch) + actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec) + + start_idx = n_obs_steps - 1 + end_idx = start_idx + self.config.n_action_steps + return actions[:, start_idx:end_idx] + + def reset(self): + """Clear observation and action queues. Should be called on `env.reset()`""" + self._queues = { + "observation.state": deque(maxlen=self.config.n_obs_steps), + "action": deque(maxlen=self.config.n_action_steps), + } + + if self.config.image_features: + self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) + + self._queues["task"] = deque(maxlen=self.config.n_obs_steps) + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations""" + self.eval() + + original_batch_keys = set(batch.keys()) + new_batch = {} + for k in self._queues: + if k in original_batch_keys: + if self._queues[k] and isinstance(self._queues[k][-1][0], str): + new_batch[k] = self._queues[k][-1] + else: + queue_values = list(self._queues[k]) + new_batch[k] = torch.stack(queue_values, dim=1) + batch = new_batch + + actions = self._generate_actions(batch) + return actions + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations""" + if ACTION in batch: + batch.pop(ACTION) + + if self.config.image_features: + batch = dict(batch) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + + self._queues = populate_queues(self._queues, batch) + + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) + self._queues[ACTION].extend(actions.transpose(0, 1)) + + action = self._queues[ACTION].popleft() + return action + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: + """Run the batch through the model and compute the loss for training""" + if self.config.image_features: + batch = dict(batch) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + + conditioning_vec = self.observation_encoder.encode(batch) + loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec) + + return loss, None + + # -- Observation Encoders -- @@ -670,138 +805,3 @@ class FlowMatchingObjective(BaseObjective): x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4) return x - - -# -- Policy -- - - -class MultiTaskDiTPolicy(PreTrainedPolicy): - config_class = MultiTaskDiTConfig - name = "multi_task_dit" - - def __init__(self, config: MultiTaskDiTConfig): - super().__init__(config) - config.validate_features() - self.config = config - - self._queues = None - - self.observation_encoder = ObservationEncoder(config) - conditioning_dim = self.observation_encoder.conditioning_dim - self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim) - - action_dim = config.action_feature.shape[0] - horizon = config.horizon - - if config.is_diffusion: - self.objective = DiffusionObjective( - config, - action_dim=action_dim, - horizon=horizon, - do_mask_loss_for_padding=config.do_mask_loss_for_padding, - ) - elif config.is_flow_matching: - self.objective = FlowMatchingObjective( - config, - action_dim=action_dim, - horizon=horizon, - do_mask_loss_for_padding=config.do_mask_loss_for_padding, - ) - else: - raise ValueError(f"Unsupported objective: {config.objective}") - - self.reset() - - def get_optim_params(self) -> list: - """Returns parameter groups with different learning rates for vision vs non-vision parameters""" - non_vision_params = [] - vision_encoder_params = [] - - for name, param in self.named_parameters(): - if not param.requires_grad: - continue - - if "observation_encoder.vision_encoder" in name: - vision_encoder_params.append(param) - else: - non_vision_params.append(param) - - return [ - {"params": non_vision_params}, - { - "params": vision_encoder_params, - "lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier, - }, - ] - - def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor: - batch_size, n_obs_steps = batch["observation.state"].shape[:2] - assert n_obs_steps == self.config.n_obs_steps - - conditioning_vec = self.observation_encoder.encode(batch) - actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec) - - start_idx = n_obs_steps - 1 - end_idx = start_idx + self.config.n_action_steps - return actions[:, start_idx:end_idx] - - def reset(self): - """Clear observation and action queues. Should be called on `env.reset()`""" - self._queues = { - "observation.state": deque(maxlen=self.config.n_obs_steps), - "action": deque(maxlen=self.config.n_action_steps), - } - - if self.config.image_features: - self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) - - self._queues["task"] = deque(maxlen=self.config.n_obs_steps) - - @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: - """Predict a chunk of actions given environment observations""" - self.eval() - - original_batch_keys = set(batch.keys()) - new_batch = {} - for k in self._queues: - if k in original_batch_keys: - if self._queues[k] and isinstance(self._queues[k][-1][0], str): - new_batch[k] = self._queues[k][-1] - else: - queue_values = list(self._queues[k]) - new_batch[k] = torch.stack(queue_values, dim=1) - batch = new_batch - - actions = self._generate_actions(batch) - return actions - - @torch.no_grad() - def select_action(self, batch: dict[str, Tensor]) -> Tensor: - """Select a single action given environment observations""" - if ACTION in batch: - batch.pop(ACTION) - - if self.config.image_features: - batch = dict(batch) - batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) - - self._queues = populate_queues(self._queues, batch) - - if len(self._queues[ACTION]) == 0: - actions = self.predict_action_chunk(batch) - self._queues[ACTION].extend(actions.transpose(0, 1)) - - action = self._queues[ACTION].popleft() - return action - - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: - """Run the batch through the model and compute the loss for training""" - if self.config.image_features: - batch = dict(batch) - batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) - - conditioning_vec = self.observation_encoder.encode(batch) - loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec) - - return loss, None