mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
refactor(pipeline): Remove model card generation and streamline processor methods
- Eliminated the _generate_model_card method from RobotProcessor, which was responsible for generating README.md files from a template. - Updated save_pretrained method to remove model card generation, focusing on serialization of processor definitions and parameters. - Added default implementations for get_config, state_dict, load_state_dict, reset, and feature_contract methods in various processor classes to enhance consistency and usability.
This commit is contained in:
@@ -409,23 +409,6 @@ class RobotProcessor(ModelHubMixin):
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
self.save_pretrained(destination_path, config_filename=config_filename)
|
||||
|
||||
def _generate_model_card(self, destination_path: str) -> None:
|
||||
"""Generate README.md from the RobotProcessor model card template."""
|
||||
# Read the template
|
||||
template_path = Path(__file__).parent.parent / "templates" / "robotprocessor_modelcard_template.md"
|
||||
|
||||
if not template_path.exists():
|
||||
# Fallback: if template doesn't exist, skip model card generation
|
||||
return
|
||||
|
||||
with open(template_path) as f:
|
||||
model_card_content = f.read()
|
||||
|
||||
# Write the README.md
|
||||
readme_path = os.path.join(destination_path, "README.md")
|
||||
with open(readme_path, "w") as f:
|
||||
f.write(model_card_content)
|
||||
|
||||
def save_pretrained(self, destination_path: str, config_filename: str | None = None, **kwargs):
|
||||
"""Serialize the processor definition and parameters to *destination_path*.
|
||||
|
||||
@@ -500,9 +483,6 @@ class RobotProcessor(ModelHubMixin):
|
||||
with open(os.path.join(destination_path, config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
|
||||
# Generate README.md from template
|
||||
self._generate_model_card(destination_path)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, source: str, *, config_filename: str | None = None, overrides: dict[str, Any] | None = None
|
||||
@@ -910,6 +890,21 @@ class ObservationProcessor:
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_observation
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class ActionProcessor:
|
||||
"""Base class for processors that modify only the action component of a transition.
|
||||
@@ -952,6 +947,21 @@ class ActionProcessor:
|
||||
new_transition[TransitionKey.ACTION] = processed_action
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class RewardProcessor:
|
||||
"""Base class for processors that modify only the reward component of a transition.
|
||||
@@ -993,6 +1003,21 @@ class RewardProcessor:
|
||||
new_transition[TransitionKey.REWARD] = processed_reward
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class DoneProcessor:
|
||||
"""Base class for processors that modify only the done flag of a transition.
|
||||
@@ -1039,6 +1064,21 @@ class DoneProcessor:
|
||||
new_transition[TransitionKey.DONE] = processed_done
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class TruncatedProcessor:
|
||||
"""Base class for processors that modify only the truncated flag of a transition.
|
||||
@@ -1081,6 +1121,21 @@ class TruncatedProcessor:
|
||||
new_transition[TransitionKey.TRUNCATED] = processed_truncated
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class InfoProcessor:
|
||||
"""Base class for processors that modify only the info dictionary of a transition.
|
||||
@@ -1128,6 +1183,21 @@ class InfoProcessor:
|
||||
new_transition[TransitionKey.INFO] = processed_info
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class ComplementaryDataProcessor:
|
||||
"""Base class for processors that modify only the complementary data of a transition.
|
||||
@@ -1156,6 +1226,21 @@ class ComplementaryDataProcessor:
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
class IdentityProcessor:
|
||||
"""Identity processor that does nothing."""
|
||||
|
||||
Reference in New Issue
Block a user