mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
finish locomotion loading code
This commit is contained in:
@@ -25,7 +25,7 @@ from ..config import RobotConfig
|
||||
@dataclass
|
||||
class UnitreeG1Config(RobotConfig):
|
||||
# id: str = "unitree_g1"
|
||||
simulation_mode: bool = True
|
||||
simulation_mode: bool = False
|
||||
kp_high = 40.0
|
||||
kd_high = 3.0
|
||||
kp_low = 80.0
|
||||
@@ -56,7 +56,7 @@ class UnitreeG1Config(RobotConfig):
|
||||
socket_port: int | None = None
|
||||
|
||||
# Locomotion control
|
||||
locomotion_control: bool = False
|
||||
locomotion_control: bool = True
|
||||
#policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/motion.pt"
|
||||
policy_path: str = "src/lerobot/robots/unitree_g1/assets/g1/locomotion/GR00T-WholeBodyControl-Walk.onnx"
|
||||
|
||||
|
||||
@@ -128,37 +128,10 @@ class UnitreeG1(Robot):
|
||||
|
||||
self.calibrated = False
|
||||
|
||||
self.calibrate()
|
||||
|
||||
if self.config.socket_host is not None:
|
||||
from lerobot.robots.unitree_g1.unitree_sdk2_socket import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
|
||||
else:
|
||||
from unitree_sdk2py.core.channel import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
|
||||
|
||||
if not self.config.simulation_mode:
|
||||
self.msc = MotionSwitcherClient()
|
||||
self.msc.SetTimeout(5.0)
|
||||
self.msc.Init()
|
||||
|
||||
status, result = self.msc.CheckMode()
|
||||
print(status, result)
|
||||
#check if result name first
|
||||
if result is not None and "name" in result:
|
||||
while result["name"]:
|
||||
self.msc.ReleaseMode()
|
||||
status, result = self.msc.CheckMode()
|
||||
print(status, result)
|
||||
time.sleep(1)
|
||||
|
||||
# initialize lowcmd nd lowstate subscriber
|
||||
if self.simulation_mode:
|
||||
ChannelFactoryInitialize(0, "lo")
|
||||
|
||||
logger.info("Launching MuJoCo simulation environment...")
|
||||
self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
|
||||
logger.info("MuJoCo environment launched successfully!")
|
||||
else:
|
||||
ChannelFactoryInitialize(0)
|
||||
from lerobot.robots.unitree_g1.unitree_sdk2_socket import ChannelPublisher, ChannelSubscriber, ChannelFactoryInitialize # dds
|
||||
|
||||
ChannelFactoryInitialize(0)
|
||||
|
||||
# Always use debug mode (direct motor control)
|
||||
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
@@ -175,7 +148,7 @@ class UnitreeG1(Robot):
|
||||
while not self.lowstate_buffer.GetData():
|
||||
time.sleep(0.1)
|
||||
logger.warning("[UnitreeG1] Waiting to subscribe dds...")
|
||||
logger.info("[UnitreeG1] Subscribe dds ok.")
|
||||
logger.warning("[UnitreeG1] Subscribe dds ok.")
|
||||
|
||||
# initialize hg's lowcmd msg
|
||||
self.crc = CRC()
|
||||
@@ -185,6 +158,7 @@ class UnitreeG1(Robot):
|
||||
print(self.msg)
|
||||
|
||||
self.all_motor_q = self.get_current_motor_q()
|
||||
print(self.all_motor_q)
|
||||
logger.info(f"Current all body motor state q:\n{self.all_motor_q} \n")
|
||||
logger.info(f"Current two arms motor state q:\n{self.get_current_dual_arm_q()}\n")
|
||||
logger.info("Lock all joints except two arms...\n")
|
||||
@@ -209,7 +183,7 @@ class UnitreeG1(Robot):
|
||||
self.msg.motor_cmd[id].q = self.all_motor_q[id]
|
||||
#print current motor q, kp, kd
|
||||
|
||||
logger.info("Lock OK!\n") #motors are not locked x
|
||||
logger.warning("Lock OK!\n") #motors are not locked x
|
||||
# for i in range(10000):
|
||||
# print(self.get_current_motor_q())
|
||||
# time.sleep(0.05)
|
||||
@@ -230,26 +204,26 @@ class UnitreeG1(Robot):
|
||||
self.publish_thread = threading.Thread(target=self._ctrl_motor_state)
|
||||
self.publish_thread.daemon = True
|
||||
self.publish_thread.start()
|
||||
logger.info("Arm control publish thread started")
|
||||
logger.warning("Arm control publish thread started")
|
||||
|
||||
# Load locomotion policy if enabled
|
||||
self.policy = None
|
||||
self.policy_type = None # 'torchscript' or 'onnx'
|
||||
|
||||
print(config)
|
||||
if config.locomotion_control:
|
||||
if config.policy_path is None:
|
||||
raise ValueError("locomotion_control is True but policy_path is not set")
|
||||
|
||||
logger.info(f"Loading locomotion policy from {config.policy_path}")
|
||||
logger.warning(f"Loading locomotion policy from {config.policy_path}")
|
||||
|
||||
# Check file extension and load accordingly
|
||||
if config.policy_path.endswith('.pt'):
|
||||
logger.info("Detected TorchScript (.pt) policy")
|
||||
logger.warning("Detected TorchScript (.pt) policy")
|
||||
self.policy = torch.jit.load(config.policy_path)
|
||||
self.policy_type = 'torchscript'
|
||||
logger.info("TorchScript policy loaded successfully")
|
||||
elif config.policy_path.endswith('.onnx'):
|
||||
logger.info("Detected ONNX (.onnx) policy")
|
||||
logger.warning("Detected ONNX (.onnx) policy")
|
||||
|
||||
# For GR00T-style policies, load both Balance and Walk policies
|
||||
# Balance policy for standing (low velocity commands)
|
||||
|
||||
Reference in New Issue
Block a user