refactor(pipeline): Remove model card generation and streamline processor methods

- Eliminated the _generate_model_card method from RobotProcessor, which was responsible for generating README.md files from a template.
- Updated save_pretrained method to remove model card generation, focusing on serialization of processor definitions and parameters.
- Added default implementations for get_config, state_dict, load_state_dict, reset, and feature_contract methods in various processor classes to enhance consistency and usability.
This commit is contained in:
Adil Zouitine
2025-08-05 10:31:09 +02:00
parent 5595887fd0
commit 8077456c00
6 changed files with 109 additions and 2042 deletions

View File

@@ -409,23 +409,6 @@ class RobotProcessor(ModelHubMixin):
config_filename = kwargs.pop("config_filename", None)
self.save_pretrained(destination_path, config_filename=config_filename)
def _generate_model_card(self, destination_path: str) -> None:
"""Generate README.md from the RobotProcessor model card template."""
# Read the template
template_path = Path(__file__).parent.parent / "templates" / "robotprocessor_modelcard_template.md"
if not template_path.exists():
# Fallback: if template doesn't exist, skip model card generation
return
with open(template_path) as f:
model_card_content = f.read()
# Write the README.md
readme_path = os.path.join(destination_path, "README.md")
with open(readme_path, "w") as f:
f.write(model_card_content)
def save_pretrained(self, destination_path: str, config_filename: str | None = None, **kwargs):
"""Serialize the processor definition and parameters to *destination_path*.
@@ -500,9 +483,6 @@ class RobotProcessor(ModelHubMixin):
with open(os.path.join(destination_path, config_filename), "w") as file_pointer:
json.dump(config, file_pointer, indent=2)
# Generate README.md from template
self._generate_model_card(destination_path)
@classmethod
def from_pretrained(
cls, source: str, *, config_filename: str | None = None, overrides: dict[str, Any] | None = None
@@ -910,6 +890,21 @@ class ObservationProcessor:
new_transition[TransitionKey.OBSERVATION] = processed_observation
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class ActionProcessor:
"""Base class for processors that modify only the action component of a transition.
@@ -952,6 +947,21 @@ class ActionProcessor:
new_transition[TransitionKey.ACTION] = processed_action
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class RewardProcessor:
"""Base class for processors that modify only the reward component of a transition.
@@ -993,6 +1003,21 @@ class RewardProcessor:
new_transition[TransitionKey.REWARD] = processed_reward
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class DoneProcessor:
"""Base class for processors that modify only the done flag of a transition.
@@ -1039,6 +1064,21 @@ class DoneProcessor:
new_transition[TransitionKey.DONE] = processed_done
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class TruncatedProcessor:
"""Base class for processors that modify only the truncated flag of a transition.
@@ -1081,6 +1121,21 @@ class TruncatedProcessor:
new_transition[TransitionKey.TRUNCATED] = processed_truncated
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class InfoProcessor:
"""Base class for processors that modify only the info dictionary of a transition.
@@ -1128,6 +1183,21 @@ class InfoProcessor:
new_transition[TransitionKey.INFO] = processed_info
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class ComplementaryDataProcessor:
"""Base class for processors that modify only the complementary data of a transition.
@@ -1156,6 +1226,21 @@ class ComplementaryDataProcessor:
new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
pass
def reset(self) -> None:
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
class IdentityProcessor:
"""Identity processor that does nothing."""