refactor(processors): update transition handling in RewardClassifierProcessor and InverseKinematicsEEToJoints (#1844)

This commit is contained in:
Steven Palma
2025-09-02 17:57:49 +02:00
committed by GitHub
parent 2914ae2a96
commit ebb464c255
2 changed files with 23 additions and 21 deletions

View File

@@ -282,15 +282,16 @@ class RewardClassifierProcessor(ProcessorStep):
self.reward_classifier.eval()
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION)
new_transition = transition.copy()
observation = new_transition.get(TransitionKey.OBSERVATION)
if observation is None or self.reward_classifier is None:
return transition
return new_transition
# Extract images from observation
images = {key: value for key, value in observation.items() if "image" in key}
if not images:
return transition
return new_transition
# Run reward classifier
start_time = time.perf_counter()
@@ -300,8 +301,8 @@ class RewardClassifierProcessor(ProcessorStep):
classifier_frequency = 1 / (time.perf_counter() - start_time)
# Calculate reward and termination
reward = transition.get(TransitionKey.REWARD, 0.0)
terminated = transition.get(TransitionKey.DONE, False)
reward = new_transition.get(TransitionKey.REWARD, 0.0)
terminated = new_transition.get(TransitionKey.DONE, False)
if math.isclose(success, 1, abs_tol=1e-2):
reward = self.success_reward
@@ -309,7 +310,6 @@ class RewardClassifierProcessor(ProcessorStep):
terminated = True
# Update transition
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = reward
new_transition[TransitionKey.DONE] = terminated