mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
remove the base classes since we don't need to be able to extend
This commit is contained in:
@@ -26,7 +26,6 @@ References:
|
||||
"""
|
||||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
|
||||
import einops
|
||||
@@ -607,28 +606,14 @@ class DiffusionTransformer(nn.Module):
|
||||
# -- Objectives --
|
||||
|
||||
|
||||
class BaseObjective(ABC):
|
||||
"""Base class for objectives used in Multi-Task DiT policy."""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int):
|
||||
self.config = config
|
||||
self.action_dim = action_dim
|
||||
self.horizon = horizon
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||
pass
|
||||
|
||||
|
||||
class DiffusionObjective(BaseObjective):
|
||||
class DiffusionObjective(nn.Module):
|
||||
"""Standard diffusion (DDPM/DDIM) objective implementation."""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||
super().__init__(config, action_dim, horizon)
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.action_dim = action_dim
|
||||
self.horizon = horizon
|
||||
self.do_mask_loss_for_padding = do_mask_loss_for_padding
|
||||
|
||||
scheduler_kwargs = {
|
||||
@@ -704,11 +689,14 @@ class DiffusionObjective(BaseObjective):
|
||||
return sample
|
||||
|
||||
|
||||
class FlowMatchingObjective(BaseObjective):
|
||||
class FlowMatchingObjective(nn.Module):
|
||||
"""Flow matching objective: trains a model to predict velocity fields."""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||
super().__init__(config, action_dim, horizon)
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.action_dim = action_dim
|
||||
self.horizon = horizon
|
||||
self.do_mask_loss_for_padding = do_mask_loss_for_padding
|
||||
|
||||
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user