Enhance processing architecture with new components

- Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility.
- Updated `__init__.py` to include `RenameProcessor` in module exports.
- Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling.
- Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness.
This commit is contained in:
Adil Zouitine
2025-07-04 10:53:40 +02:00
parent 8ebf79c494
commit 453e0a995f
5 changed files with 966 additions and 2 deletions

View File

@@ -151,7 +151,7 @@ class ObservationNormalizer:
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.split(".", 1)
key, stat_name = flat_key.rsplit(".", 1)
if key not in self._tensor_stats:
self._tensor_stats[key] = {}
self._tensor_stats[key][stat_name] = tensor
@@ -382,7 +382,7 @@ class NormalizationProcessor:
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.split(".", 1)
key, stat_name = flat_key.rsplit(".", 1)
if key not in self._tensor_stats:
self._tensor_stats[key] = {}
self._tensor_stats[key][stat_name] = tensor