mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
fix bugs in flow
This commit is contained in:
@@ -119,7 +119,7 @@ class ActionHead(nn.Module):
|
||||
torch.tensor(self.beta_beta, dtype=dtype, device=device)
|
||||
)
|
||||
sample = beta_dist.sample([batch_size])
|
||||
time = (1 - sample) / self.s
|
||||
time = (1 - sample) * self.s
|
||||
return time
|
||||
|
||||
def forward(self, action_chunk, dof_mask=None):
|
||||
@@ -420,7 +420,7 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
num_steps = self.config.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
for step_idx in range(num_steps):
|
||||
for step_idx in range(num_steps + 1):
|
||||
t = torch.tensor(step_idx * dt, device=device, dtype=dtype)
|
||||
timestep = t.unsqueeze(0).repeat(batch_size)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user