finish locomotion loading code

This commit is contained in:
Martino Russi
2025-11-26 16:12:53 +01:00
parent d07c65eb9a
commit 1bd91a04ce
2 changed files with 13 additions and 39 deletions

View File

@@ -25,7 +25,7 @@ from ..config import RobotConfig
@dataclass
class UnitreeG1Config(RobotConfig):
# id: str = "unitree_g1"
simulation_mode: bool = True
simulation_mode: bool = False
kp_high = 40.0
kd_high = 3.0
kp_low = 80.0
@@ -56,7 +56,7 @@ class UnitreeG1Config(RobotConfig):
socket_port: int | None = None
# Locomotion control
locomotion_control: bool = False
locomotion_control: bool = True
#policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/motion.pt"
policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/GR00T-WholeBodyControl-Walk.onnx"

View File

@@ -128,37 +128,10 @@ class UnitreeG1(Robot):
self.calibrated = False
self.calibrate()
if self.config.socket_host is not None:
from lerobot.robots.unitree_g1.unitree_sdk2_socket import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
else:
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
if not self.config.simulation_mode:
self.msc = MotionSwitcherClient()
self.msc.SetTimeout(5.0)
self.msc.Init()
status, result = self.msc.CheckMode()
print(status, result)
#check if result name first
if result is not None and "name" in result:
while result["name"]:
self.msc.ReleaseMode()
status, result = self.msc.CheckMode()
print(status, result)
time.sleep(1)
# initialize lowcmd nd lowstate subscriber
if self.simulation_mode:
ChannelFactoryInitialize(0, "lo")
logger.info("Launching MuJoCo simulation environment...")
self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
logger.info("MuJoCo environment launched successfully!")
else:
ChannelFactoryInitialize(0)
from lerobot.robots.unitree_g1.unitree_sdk2_socket import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
ChannelFactoryInitialize(0)
# Always use debug mode (direct motor control)
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
@@ -175,7 +148,7 @@ class UnitreeG1(Robot):
while not self.lowstate_buffer.GetData():
time.sleep(0.1)
logger.warning("[UnitreeG1] Waiting to subscribe dds...")
logger.info("[UnitreeG1] Subscribe dds ok.")
logger.warning("[UnitreeG1] Subscribe dds ok.")
# initialize hg's lowcmd msg
self.crc = CRC()
@@ -185,6 +158,7 @@ class UnitreeG1(Robot):
print(self.msg)
self.all_motor_q = self.get_current_motor_q()
print(self.all_motor_q)
logger.info(f"Current all body motor state q:\n{self.all_motor_q} \n")
logger.info(f"Current two arms motor state q:\n{self.get_current_dual_arm_q()}\n")
logger.info("Lock all joints except two arms...\n")
@@ -209,7 +183,7 @@ class UnitreeG1(Robot):
self.msg.motor_cmd[id].q = self.all_motor_q[id]
#print current motor q, kp, kd
logger.info("Lock OK!\n") #motors are not locked x
logger.warning("Lock OK!\n") #motors are not locked x
# for i in range(10000):
# print(self.get_current_motor_q())
# time.sleep(0.05)
@@ -230,26 +204,26 @@ class UnitreeG1(Robot):
self.publish_thread = threading.Thread(target=self._ctrl_motor_state)
self.publish_thread.daemon = True
self.publish_thread.start()
logger.info("Arm control publish thread started")
logger.warning("Arm control publish thread started")
# Load locomotion policy if enabled
self.policy = None
self.policy_type = None # 'torchscript' or 'onnx'
print(config)
if config.locomotion_control:
if config.policy_path is None:
raise ValueError("locomotion_control is True but policy_path is not set")
logger.info(f"Loading locomotion policy from {config.policy_path}")
logger.warning(f"Loading locomotion policy from {config.policy_path}")
# Check file extension and load accordingly
if config.policy_path.endswith('.pt'):
logger.info("Detected TorchScript (.pt) policy")
logger.warning("Detected TorchScript (.pt) policy")
self.policy = torch.jit.load(config.policy_path)
self.policy_type = 'torchscript'
logger.info("TorchScript policy loaded successfully")
elif config.policy_path.endswith('.onnx'):
logger.info("Detected ONNX (.onnx) policy")
logger.warning("Detected ONNX (.onnx) policy")
# For GR00T-style policies, load both Balance and Walk policies
# Balance policy for standing (low velocity commands)