mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-05 05:11:25 +00:00
feat(dependencies): minimal default tag install (#3362)
This commit is contained in:
@@ -29,19 +29,18 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import (
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import (
|
||||
get_device_from_parameters,
|
||||
get_dtype_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from .configuration_diffusion import DiffusionConfig
|
||||
|
||||
|
||||
class DiffusionPolicy(PreTrainedPolicy):
|
||||
@@ -151,11 +150,17 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
return loss, None
|
||||
|
||||
|
||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||
def _make_noise_scheduler(name: str, **kwargs: dict):
|
||||
"""
|
||||
Factory for noise scheduler instances of the requested type. All kwargs are passed
|
||||
to the scheduler.
|
||||
"""
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("diffusers", extra="diffusion")
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
if name == "DDPM":
|
||||
return DDPMScheduler(**kwargs)
|
||||
elif name == "DDIM":
|
||||
|
||||
Reference in New Issue
Block a user