[Port HIL_SERL] Final fixes for the Reward Classifier (#598)

This commit is contained in:
Eugene Mironov
2025-01-06 17:34:00 +07:00
committed by Michel Aractingi
parent e5801f467f
commit d1d6ffd23c
10 changed files with 7780 additions and 15 deletions

View File

@@ -197,8 +197,14 @@ def record(
resume: bool = False,
local_files_only: bool = False,
run_compute_stats: bool = True,
assign_rewards: bool = False,
) -> LeRobotDataset:
# Load pretrained policy
extra_features = (
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
)
policy = None
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
@@ -211,7 +217,7 @@ def record(
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
# initialize listener before sim env
listener, events = init_keyboard_listener()
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
# create sim env
env = env()
@@ -251,6 +257,7 @@ def record(
}
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
features = {**features, **extra_features}
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
@@ -302,6 +309,13 @@ def record(
"timestamp": env_timestamp,
}
# Overwrite environment reward with manually assigned reward
if assign_rewards:
frame["next.reward"] = events["next.reward"]
# Should success always be false to match what we do in control_utils?
frame["next.success"] = False
for key in image_keys:
if not key.startswith("observation.image"):
frame["observation.image." + key] = observation[key]
@@ -486,6 +500,13 @@ if __name__ == "__main__":
default=0,
help="Resume recording on an existing dataset.",
)
parser_record.add_argument(
"--assign-rewards",
type=int,
default=0,
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"