removing swish in favor of silu

This commit is contained in:
Maximellerbach
2026-05-29 13:10:21 +02:00
parent 79f6756505
commit d8ed30d58c

View File

@@ -44,10 +44,6 @@ else:
from .configuration_vla_jepa import VLAJEPAConfig
def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
@@ -78,7 +74,7 @@ class ActionEncoder(nn.Module):
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
action_emb = self.layer1(actions)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
return self.layer3(swish(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
class TimestepEncoder(nn.Module):