fix bugs in flow

This commit is contained in:
Geoffrey19
2025-11-20 17:17:25 +08:00
parent 73a6f20e58
commit b185fa0f87

View File

@@ -119,7 +119,7 @@ class ActionHead(nn.Module):
torch.tensor(self.beta_beta, dtype=dtype, device=device)
)
sample = beta_dist.sample([batch_size])
time = (1 - sample) / self.s
time = (1 - sample) * self.s
return time
def forward(self, action_chunk, dof_mask=None):
@@ -420,7 +420,7 @@ class WallXPolicy(PreTrainedPolicy):
num_steps = self.config.num_inference_timesteps
dt = 1.0 / num_steps
for step_idx in range(num_steps):
for step_idx in range(num_steps + 1):
t = torch.tensor(step_idx * dt, device=device, dtype=dtype)
timestep = t.unsqueeze(0).repeat(batch_size)