Update litemixer.py
Browse files- litemixer.py +6 -6
litemixer.py
CHANGED
|
@@ -24,14 +24,14 @@ class GatingUnit(nn.Module):
|
|
| 24 |
self.proj_1 = nn.Linear(dim,dim,bias=False)
|
| 25 |
self.proj_2 = nn.Linear(dim,dim,bias=False)
|
| 26 |
|
| 27 |
-
self.
|
| 28 |
|
| 29 |
|
| 30 |
def forward(self, x):
|
| 31 |
|
| 32 |
u, v = x, x
|
| 33 |
u = self.proj_1(u)
|
| 34 |
-
u = self.
|
| 35 |
v = self.proj_2(v)
|
| 36 |
g = u * v
|
| 37 |
|
|
@@ -45,14 +45,14 @@ class LiteMixerBlock(nn.Module):
|
|
| 45 |
self.norm = VectorDynamicTanh(dim)
|
| 46 |
|
| 47 |
|
| 48 |
-
self.
|
| 49 |
|
| 50 |
Rearrange('b n d -> b d n'),
|
| 51 |
GatingUnit(num_patch),
|
| 52 |
Rearrange('b d n -> b n d')
|
| 53 |
)
|
| 54 |
|
| 55 |
-
self.
|
| 56 |
|
| 57 |
|
| 58 |
def forward(self, x):
|
|
@@ -61,7 +61,7 @@ class LiteMixerBlock(nn.Module):
|
|
| 61 |
|
| 62 |
x = self.norm(x)
|
| 63 |
|
| 64 |
-
x = self.
|
| 65 |
|
| 66 |
x = x + residual
|
| 67 |
|
|
@@ -69,7 +69,7 @@ class LiteMixerBlock(nn.Module):
|
|
| 69 |
|
| 70 |
x = self.norm(x)
|
| 71 |
|
| 72 |
-
x = self.
|
| 73 |
|
| 74 |
x = x + residual
|
| 75 |
|
|
|
|
| 24 |
self.proj_1 = nn.Linear(dim,dim,bias=False)
|
| 25 |
self.proj_2 = nn.Linear(dim,dim,bias=False)
|
| 26 |
|
| 27 |
+
self.silu = nn.SiLU()
|
| 28 |
|
| 29 |
|
| 30 |
def forward(self, x):
|
| 31 |
|
| 32 |
u, v = x, x
|
| 33 |
u = self.proj_1(u)
|
| 34 |
+
u = self.silu(u)
|
| 35 |
v = self.proj_2(v)
|
| 36 |
g = u * v
|
| 37 |
|
|
|
|
| 45 |
self.norm = VectorDynamicTanh(dim)
|
| 46 |
|
| 47 |
|
| 48 |
+
self.context_process = nn.Sequential(
|
| 49 |
|
| 50 |
Rearrange('b n d -> b d n'),
|
| 51 |
GatingUnit(num_patch),
|
| 52 |
Rearrange('b d n -> b n d')
|
| 53 |
)
|
| 54 |
|
| 55 |
+
self.token_process = GatingUnit(dim)
|
| 56 |
|
| 57 |
|
| 58 |
def forward(self, x):
|
|
|
|
| 61 |
|
| 62 |
x = self.norm(x)
|
| 63 |
|
| 64 |
+
x = self.context_process(x)
|
| 65 |
|
| 66 |
x = x + residual
|
| 67 |
|
|
|
|
| 69 |
|
| 70 |
x = self.norm(x)
|
| 71 |
|
| 72 |
+
x = self.token_process(x)
|
| 73 |
|
| 74 |
x = x + residual
|
| 75 |
|