mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
make fast work
This commit is contained in:
@@ -29,7 +29,9 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.fft import idct
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import (
|
||||
@@ -223,7 +225,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
task = self.get_task(self.transition)
|
||||
if task is None:
|
||||
raise ValueError("Task cannot be None")
|
||||
|
||||
# Tokenize the task (this will create CPU tensors)
|
||||
tokenized_prompt = self._tokenize_text(task)
|
||||
|
||||
@@ -534,7 +535,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
Tokenizes the action tensor and creates a mask.
|
||||
|
||||
Args:
|
||||
action: The input action tensor to tokenize. Shape: (B, action_dim) or (action_dim,)
|
||||
action: The input action tensor to tokenize. Shape: (B, H, action_dim) or (H, action_dim,)
|
||||
|
||||
Returns:
|
||||
A tuple of (tokens, mask) where:
|
||||
@@ -576,7 +577,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
# Flatten to 1D if needed
|
||||
if tokens.dim() > 1:
|
||||
tokens = tokens.flatten()
|
||||
|
||||
|
||||
tokens = torch.cat([self._act_tokens_to_paligemma_tokens(tokens), torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)])
|
||||
# Truncate or pad to max_action_tokens
|
||||
if len(tokens) > self.max_action_tokens:
|
||||
@@ -674,3 +675,510 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="action_detokenizer_processor_1")
|
||||
class ActionDetokenizerProcessorStep1(ActionProcessorStep):
|
||||
"""
|
||||
Processor step to detokenize action tokens back to continuous actions.
|
||||
|
||||
This step takes tokenized actions (e.g., from model predictions), decodes them using
|
||||
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
|
||||
and returns the continuous action tensor.
|
||||
|
||||
This is the inverse operation of ActionTokenizerProcessorStep and is typically used
|
||||
during inference to convert predicted tokens back to executable actions.
|
||||
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||
action_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding for actions (allows graceful handling of partial sequences).
|
||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None
|
||||
trust_remote_code: bool = True
|
||||
action_horizon: int = 1
|
||||
action_dim: int = 7
|
||||
relaxed_decoding: bool = False
|
||||
|
||||
# Internal tokenizer instance (not part of the config)
|
||||
action_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
_fast_skip_tokens: int = field(default=128, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Initializes the action detokenizer after the dataclass is created.
|
||||
|
||||
It checks for the availability of the `transformers` library and loads the tokenizer
|
||||
either from a provided object or by name from the Hugging Face Hub.
|
||||
|
||||
Raises:
|
||||
ImportError: If the `transformers` library is not installed.
|
||||
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
|
||||
"""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionDetokenizerProcessorStep."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self.action_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoProcessor is None:
|
||||
raise ImportError("AutoProcessor is not available")
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
self.tokenizer_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||
)
|
||||
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224",
|
||||
trust_remote_code=True,
|
||||
add_eos_token=True,
|
||||
add_bos_token=False
|
||||
)
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
def decode_actions_with_fast(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
time_horizon: int,
|
||||
action_dim: int,
|
||||
relaxed_decoding: bool = False
|
||||
) -> list:
|
||||
"""
|
||||
Decodes action token IDs back to continuous action values using the FAST tokenizer.
|
||||
|
||||
Args:
|
||||
token_ids: List of token IDs to decode.
|
||||
time_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
|
||||
|
||||
Returns:
|
||||
A list representing the decoded actions.
|
||||
"""
|
||||
# Use the action tokenizer's decode method
|
||||
# The FAST tokenizer should have a decode method that converts tokens back to actions
|
||||
try:
|
||||
decoded_actions = self.action_tokenizer.decode(
|
||||
token_ids,
|
||||
time_horizon=time_horizon,
|
||||
action_dim=action_dim
|
||||
)
|
||||
return decoded_actions
|
||||
except Exception as e:
|
||||
if relaxed_decoding:
|
||||
# If relaxed decoding is enabled, try to decode as much as possible
|
||||
import logging
|
||||
logging.warning(f"Relaxed decoding: {e}. Returning partial decode.")
|
||||
try:
|
||||
# Try to decode with whatever tokens we have
|
||||
partial_decoded = self.action_tokenizer.decode(
|
||||
token_ids[:len(token_ids)],
|
||||
time_horizon=time_horizon,
|
||||
action_dim=action_dim
|
||||
)
|
||||
return partial_decoded
|
||||
except:
|
||||
# Return zeros if decoding completely fails
|
||||
return [[0.0] * action_dim for _ in range(time_horizon)]
|
||||
else:
|
||||
raise e
|
||||
|
||||
def extract_actions(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extracts actions from predicted output tokens using the FAST model.
|
||||
|
||||
Args:
|
||||
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
|
||||
|
||||
Returns:
|
||||
The extracted actions as a tensor of shape (B, action_horizon, action_dim) or (action_horizon, action_dim).
|
||||
"""
|
||||
# Handle single sample (add batch dimension)
|
||||
single_sample = tokens.dim() == 1
|
||||
if single_sample:
|
||||
tokens = tokens.unsqueeze(0)
|
||||
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self._paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
||||
|
||||
# Clean the decoded tokens by removing "Action:" prefix and extracting the relevant part
|
||||
cleaned_tokens = [
|
||||
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
|
||||
for tokens_sequence in decoded_tokens
|
||||
]
|
||||
|
||||
# Re-encode the cleaned text to get raw action tokens
|
||||
raw_action_tokens = [
|
||||
self._paligemma_tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
|
||||
for sample_tokens in cleaned_tokens
|
||||
]
|
||||
|
||||
# Convert PaliGemma tokens back to action tokens
|
||||
action_tokens = [
|
||||
self._paligemma_tokens_to_act_tokens(raw_action_token)
|
||||
for raw_action_token in raw_action_tokens
|
||||
]
|
||||
tokens = [t.flatten().tolist() for t in action_tokens]
|
||||
breakpoint()
|
||||
# Decode each sample's tokens to continuous actions
|
||||
decoded_actions = [
|
||||
torch.tensor(
|
||||
self.decode_actions_with_fast(
|
||||
tok.tolist(),
|
||||
time_horizon=self.action_horizon,
|
||||
action_dim=self.action_dim,
|
||||
relaxed_decoding=self.relaxed_decoding,
|
||||
),
|
||||
device=tokens.device,
|
||||
).squeeze(0)
|
||||
for tok in action_tokens
|
||||
]
|
||||
breakpoint()
|
||||
# Stack into a batch
|
||||
result = torch.stack(decoded_actions, dim=0)
|
||||
|
||||
# Remove batch dimension if input was single sample
|
||||
if single_sample:
|
||||
result = result.squeeze(0)
|
||||
|
||||
return result
|
||||
|
||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Detokenizes action tokens back to continuous actions.
|
||||
|
||||
Args:
|
||||
action: The tokenized action tensor. Shape: (B, max_action_tokens) or (max_action_tokens,)
|
||||
|
||||
Returns:
|
||||
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
|
||||
"""
|
||||
return self.extract_actions(action)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the serializable configuration of the processor.
|
||||
|
||||
Note: The tokenizer object itself is not serialized. If the processor was initialized
|
||||
with a tokenizer name, that name will be included in the config.
|
||||
|
||||
Returns:
|
||||
A dictionary with the processor's configuration parameters.
|
||||
"""
|
||||
config = {
|
||||
"trust_remote_code": self.trust_remote_code,
|
||||
"action_horizon": self.action_horizon,
|
||||
"action_dim": self.action_dim,
|
||||
"relaxed_decoding": self.relaxed_decoding,
|
||||
}
|
||||
|
||||
# Only save tokenizer_name if it was used to create the tokenizer
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Updates feature definitions to reflect detokenized actions.
|
||||
|
||||
This updates the policy features dictionary to indicate that the action
|
||||
has been detokenized from token IDs back to continuous values.
|
||||
|
||||
Args:
|
||||
features: The dictionary of existing policy features.
|
||||
|
||||
Returns:
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
# Update the action feature to reflect the continuous action shape
|
||||
if PipelineFeatureType.ACTION in features:
|
||||
# Replace the action feature with the detokenized version
|
||||
features[PipelineFeatureType.ACTION] = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.STATE, # Continuous action
|
||||
shape=(self.action_horizon, self.action_dim)
|
||||
)
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="action_detokenizer_processor")
|
||||
class ActionDetokenizerProcessorStep(ActionProcessorStep):
|
||||
"""
|
||||
Processor step to detokenize action tokens back to continuous actions.
|
||||
|
||||
This step takes tokenized actions (e.g., from model predictions), decodes them using
|
||||
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
|
||||
and returns the continuous action tensor.
|
||||
|
||||
This is the inverse operation of ActionTokenizerProcessorStep and is typically used
|
||||
during inference to convert predicted tokens back to executable actions.
|
||||
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||
action_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding for actions (allows graceful handling of partial sequences).
|
||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None
|
||||
trust_remote_code: bool = True
|
||||
action_horizon: int = 1
|
||||
action_dim: int = 7
|
||||
relaxed_decoding: bool = False
|
||||
|
||||
# Internal tokenizer instance (not part of the config)
|
||||
action_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
_fast_skip_tokens: int = field(default=128, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Initializes the action detokenizer after the dataclass is created.
|
||||
|
||||
It checks for the availability of the `transformers` library and loads the tokenizer
|
||||
either from a provided object or by name from the Hugging Face Hub.
|
||||
|
||||
Raises:
|
||||
ImportError: If the `transformers` library is not installed.
|
||||
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
|
||||
"""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionDetokenizerProcessorStep."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self.action_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoProcessor is None:
|
||||
raise ImportError("AutoProcessor is not available")
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
self.tokenizer_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||
)
|
||||
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224",
|
||||
trust_remote_code=True,
|
||||
add_eos_token=True,
|
||||
add_bos_token=False
|
||||
)
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
|
||||
def decode_actions_with_fast(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
time_horizon: int,
|
||||
action_dim: int,
|
||||
relaxed_decoding: bool = True
|
||||
) -> list:
|
||||
"""
|
||||
Decodes action token IDs back to continuous action values using the FAST tokenizer.
|
||||
|
||||
Args:
|
||||
token_ids: List of token IDs to decode.
|
||||
time_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
|
||||
|
||||
Returns:
|
||||
A list representing the decoded actions.
|
||||
"""
|
||||
decoded_actions = []
|
||||
|
||||
for token in token_ids:
|
||||
try:
|
||||
decoded_tokens = self.action_tokenizer.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.action_tokenizer.min_token
|
||||
|
||||
if relaxed_decoding:
|
||||
# expected sequence length
|
||||
expected_seq_len = time_horizon * action_dim
|
||||
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
||||
|
||||
# apply truncation if too long
|
||||
if diff < 0:
|
||||
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # tsruncate on the right
|
||||
|
||||
# apply padding if too short
|
||||
elif diff > 0:
|
||||
decoded_dct_coeff = np.pad(
|
||||
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
|
||||
)
|
||||
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, action_dim)
|
||||
assert decoded_dct_coeff.shape == (
|
||||
time_horizon,
|
||||
action_dim,
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({time_horizon}, {action_dim})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error decoding tokens: {e}")
|
||||
print(f"Tokens: {token}")
|
||||
decoded_dct_coeff = np.zeros((time_horizon, action_dim))
|
||||
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.action_tokenizer.scale, axis=0, norm="ortho"))
|
||||
|
||||
return np.stack(decoded_actions)
|
||||
|
||||
def extract_actions(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extracts actions from predicted output tokens using the FAST model.
|
||||
|
||||
Args:
|
||||
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
|
||||
|
||||
Returns:
|
||||
The extracted actions as a tensor of shape (B, action_horizon, action_dim) or (action_horizon, action_dim).
|
||||
"""
|
||||
# Handle single sample (add batch dimension)
|
||||
single_sample = tokens.dim() == 1
|
||||
if single_sample:
|
||||
tokens = tokens.unsqueeze(0)
|
||||
|
||||
# valid = tokens <= (self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens)
|
||||
# fast_region = tokens.masked_fill(~valid, 0)
|
||||
# fast_tokens = self._paligemma_tokens_to_act_tokens(fast_region)
|
||||
# actions = self.decode_actions_with_fast(fast_tokens.tolist(), time_horizon=self.action_horizon, action_dim=self.action_dim, relaxed_decoding=self.relaxed_decoding)[0]
|
||||
decoded_tokens = [
|
||||
self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
||||
for seq in tokens
|
||||
]
|
||||
cleaned_tokens = []
|
||||
for token_seq in decoded_tokens:
|
||||
if "|" in token_seq:
|
||||
token_seq = token_seq[:token_seq.index("|")]
|
||||
cleaned_tokens.append(token_seq)
|
||||
raw_action_tokens = [
|
||||
torch.tensor(
|
||||
self._paligemma_tokenizer.convert_tokens_to_ids(token_seq),
|
||||
dtype=torch.long,
|
||||
device=tokens.device,
|
||||
)
|
||||
for token_seq in cleaned_tokens
|
||||
]
|
||||
|
||||
action_tokens = [
|
||||
self._paligemma_tokens_to_act_tokens(raw_action_token)
|
||||
for raw_action_token in raw_action_tokens
|
||||
]
|
||||
actions = self.decode_actions_with_fast(
|
||||
action_tokens,
|
||||
time_horizon=self.action_horizon,
|
||||
action_dim=self.action_dim
|
||||
)
|
||||
|
||||
return torch.tensor(actions, device=tokens.device)
|
||||
|
||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Detokenizes action tokens back to continuous actions.
|
||||
|
||||
Args:
|
||||
action: The tokenized action tensor. Shape: (B, max_action_tokens) or (max_action_tokens,)
|
||||
|
||||
Returns:
|
||||
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
|
||||
"""
|
||||
return self.extract_actions(action)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the serializable configuration of the processor.
|
||||
|
||||
Note: The tokenizer object itself is not serialized. If the processor was initialized
|
||||
with a tokenizer name, that name will be included in the config.
|
||||
|
||||
Returns:
|
||||
A dictionary with the processor's configuration parameters.
|
||||
"""
|
||||
config = {
|
||||
"trust_remote_code": self.trust_remote_code,
|
||||
"action_horizon": self.action_horizon,
|
||||
"action_dim": self.action_dim,
|
||||
"relaxed_decoding": self.relaxed_decoding,
|
||||
}
|
||||
|
||||
# Only save tokenizer_name if it was used to create the tokenizer
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Updates feature definitions to reflect detokenized actions.
|
||||
|
||||
This updates the policy features dictionary to indicate that the action
|
||||
has been detokenized from token IDs back to continuous values.
|
||||
|
||||
Args:
|
||||
features: The dictionary of existing policy features.
|
||||
|
||||
Returns:
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
# Update the action feature to reflect the continuous action shape
|
||||
if PipelineFeatureType.ACTION in features:
|
||||
# Replace the action feature with the detokenized version
|
||||
features[PipelineFeatureType.ACTION] = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.STATE, # Continuous action
|
||||
shape=(self.action_horizon, self.action_dim)
|
||||
)
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
Reference in New Issue
Block a user