mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
refactor(processors): update transition handling in RewardClassifierProcessor and InverseKinematicsEEToJoints (#1844)
This commit is contained in:
@@ -282,15 +282,16 @@ class RewardClassifierProcessor(ProcessorStep):
|
||||
self.reward_classifier.eval()
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
new_transition = transition.copy()
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None or self.reward_classifier is None:
|
||||
return transition
|
||||
return new_transition
|
||||
|
||||
# Extract images from observation
|
||||
images = {key: value for key, value in observation.items() if "image" in key}
|
||||
|
||||
if not images:
|
||||
return transition
|
||||
return new_transition
|
||||
|
||||
# Run reward classifier
|
||||
start_time = time.perf_counter()
|
||||
@@ -300,8 +301,8 @@ class RewardClassifierProcessor(ProcessorStep):
|
||||
classifier_frequency = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
# Calculate reward and termination
|
||||
reward = transition.get(TransitionKey.REWARD, 0.0)
|
||||
terminated = transition.get(TransitionKey.DONE, False)
|
||||
reward = new_transition.get(TransitionKey.REWARD, 0.0)
|
||||
terminated = new_transition.get(TransitionKey.DONE, False)
|
||||
|
||||
if math.isclose(success, 1, abs_tol=1e-2):
|
||||
reward = self.success_reward
|
||||
@@ -309,7 +310,6 @@ class RewardClassifierProcessor(ProcessorStep):
|
||||
terminated = True
|
||||
|
||||
# Update transition
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.REWARD] = reward
|
||||
new_transition[TransitionKey.DONE] = terminated
|
||||
|
||||
|
||||
Reference in New Issue
Block a user