mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
bb69cb3c8c
commit
85fe8a3f4e
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import random
|
||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import io
|
||||
@@ -737,7 +736,6 @@ def concatenate_batch_transitions(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import numpy as np
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
# ===== Test 1: Create and use a synthetic ReplayBuffer =====
|
||||
@@ -1139,7 +1137,7 @@ if __name__ == "__main__":
|
||||
|
||||
savings_percent = (std_mem - opt_mem) / std_mem * 100
|
||||
|
||||
print(f"\nMemory optimization result:")
|
||||
print("\nMemory optimization result:")
|
||||
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
|
||||
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
|
||||
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")
|
||||
|
||||
@@ -225,7 +225,9 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Crop rectangular ROIs from a LeRobot dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
@@ -247,7 +249,9 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
local_files_only = args.root is not None
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=args.repo_id, root=args.root, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
@@ -256,7 +260,7 @@ if __name__ == "__main__":
|
||||
if args.crop_params_path is None:
|
||||
rois = select_square_roi_for_images(images)
|
||||
else:
|
||||
with open(args.crop_params_path, "r") as f:
|
||||
with open(args.crop_params_path) as f:
|
||||
rois = json.load(f)
|
||||
|
||||
# rois = {
|
||||
|
||||
@@ -31,7 +31,9 @@ def find_joint_bounds(
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.waitKey(1)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
@@ -57,7 +59,12 @@ if __name__ == "__main__":
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
||||
parser.add_argument(
|
||||
"--control-time-s",
|
||||
type=float,
|
||||
default=20,
|
||||
help="Maximum episode length in seconds",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
|
||||
|
||||
def initialize_replay_buffer(
|
||||
cfg: DictConfig, logger: Logger, device: str, storage_device:str
|
||||
cfg: DictConfig, logger: Logger, device: str, storage_device: str
|
||||
) -> ReplayBuffer:
|
||||
if not cfg.resume:
|
||||
return ReplayBuffer(
|
||||
|
||||
@@ -10,7 +10,9 @@ from typing import Any
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
|
||||
def preprocess_maniskill_observation(
|
||||
observations: dict[str, np.ndarray],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
@@ -62,7 +64,9 @@ class ManiSkillCompat(gym.Wrapper):
|
||||
new_action_space_shape = env.action_space.shape[-1]
|
||||
new_low = np.squeeze(env.action_space.low, axis=0)
|
||||
new_high = np.squeeze(env.action_space.high, axis=0)
|
||||
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=new_low, high=new_high, shape=(new_action_space_shape,)
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
@@ -81,7 +85,9 @@ class ManiSkillCompat(gym.Wrapper):
|
||||
class ManiSkillActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
|
||||
self.action_space = gym.spaces.Tuple(
|
||||
spaces=(env.action_space, gym.spaces.Discrete(2))
|
||||
)
|
||||
|
||||
def action(self, action):
|
||||
action, telop = action
|
||||
@@ -95,7 +101,9 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||
action_space_agent: gym.spaces.Box = env.action_space[0]
|
||||
action_space_agent.low = action_space_agent.low * multiply_factor
|
||||
action_space_agent.high = action_space_agent.high * multiply_factor
|
||||
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
|
||||
self.action_space = gym.spaces.Tuple(
|
||||
spaces=(action_space_agent, gym.spaces.Discrete(2))
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
if isinstance(action, tuple):
|
||||
@@ -137,7 +145,9 @@ def make_maniskill(
|
||||
|
||||
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env._max_episode_steps = env.max_episode_steps = (
|
||||
50 # gym_utils.find_max_episode_steps_value(env)
|
||||
)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
@@ -149,10 +159,11 @@ def make_maniskill(
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import hydra
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
|
||||
parser.add_argument(
|
||||
"--config", type=str, default="lerobot/configs/env/maniskill_example.yaml"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize config
|
||||
|
||||
Reference in New Issue
Block a user