diff --git a/src/lerobot/policies/vla_jepa/action_head.py b/src/lerobot/policies/vla_jepa/action_head.py index 16d7eadb8..d62953abf 100644 --- a/src/lerobot/policies/vla_jepa/action_head.py +++ b/src/lerobot/policies/vla_jepa/action_head.py @@ -44,10 +44,6 @@ else: from .configuration_vla_jepa import VLAJEPAConfig -def swish(x: torch.Tensor) -> torch.Tensor: - return x * torch.sigmoid(x) - - class SinusoidalPositionalEncoding(nn.Module): def __init__(self, embedding_dim: int): super().__init__() @@ -78,7 +74,7 @@ class ActionEncoder(nn.Module): timesteps = timesteps.unsqueeze(1).expand(-1, seq_len) action_emb = self.layer1(actions) time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype) - return self.layer3(swish(self.layer2(torch.cat([action_emb, time_emb], dim=-1)))) + return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1)))) class TimestepEncoder(nn.Module):