remove the base classes since we don't need to be able to extend

This commit is contained in:
Bryson Jones
2025-12-11 09:20:25 -08:00
parent c398a146b3
commit f3823e8bcd

View File

@@ -26,7 +26,6 @@ References:
"""
import math
from abc import ABC, abstractmethod
from collections import deque
import einops
@@ -607,28 +606,14 @@ class DiffusionTransformer(nn.Module):
# -- Objectives --
class BaseObjective(ABC):
"""Base class for objectives used in Multi-Task DiT policy."""
def __init__(self, config, action_dim: int, horizon: int):
self.config = config
self.action_dim = action_dim
self.horizon = horizon
@abstractmethod
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
pass
@abstractmethod
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
pass
class DiffusionObjective(BaseObjective):
class DiffusionObjective(nn.Module):
"""Standard diffusion (DDPM/DDIM) objective implementation."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__(config, action_dim, horizon)
super().__init__()
self.config = config
self.action_dim = action_dim
self.horizon = horizon
self.do_mask_loss_for_padding = do_mask_loss_for_padding
scheduler_kwargs = {
@@ -704,11 +689,14 @@ class DiffusionObjective(BaseObjective):
return sample
class FlowMatchingObjective(BaseObjective):
class FlowMatchingObjective(nn.Module):
"""Flow matching objective: trains a model to predict velocity fields."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__(config, action_dim, horizon)
super().__init__()
self.config = config
self.action_dim = action_dim
self.horizon = horizon
self.do_mask_loss_for_padding = do_mask_loss_for_padding
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor: