diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py index 757da7c17..39a17aa07 100644 --- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -26,7 +26,6 @@ References: """ import math -from abc import ABC, abstractmethod from collections import deque import einops @@ -607,28 +606,14 @@ class DiffusionTransformer(nn.Module): # -- Objectives -- -class BaseObjective(ABC): - """Base class for objectives used in Multi-Task DiT policy.""" - - def __init__(self, config, action_dim: int, horizon: int): - self.config = config - self.action_dim = action_dim - self.horizon = horizon - - @abstractmethod - def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor: - pass - - @abstractmethod - def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor: - pass - - -class DiffusionObjective(BaseObjective): +class DiffusionObjective(nn.Module): """Standard diffusion (DDPM/DDIM) objective implementation.""" def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False): - super().__init__(config, action_dim, horizon) + super().__init__() + self.config = config + self.action_dim = action_dim + self.horizon = horizon self.do_mask_loss_for_padding = do_mask_loss_for_padding scheduler_kwargs = { @@ -704,11 +689,14 @@ class DiffusionObjective(BaseObjective): return sample -class FlowMatchingObjective(BaseObjective): +class FlowMatchingObjective(nn.Module): """Flow matching objective: trains a model to predict velocity fields.""" def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False): - super().__init__(config, action_dim, horizon) + super().__init__() + self.config = config + self.action_dim = action_dim + self.horizon = horizon self.do_mask_loss_for_padding = do_mask_loss_for_padding def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor: