Abdullah-Nazhat commited on
Commit
8239141
·
verified ·
1 Parent(s): 6b7a8ed

Update litemixer.py

Browse files
Files changed (1) hide show
  1. litemixer.py +6 -6
litemixer.py CHANGED
@@ -24,14 +24,14 @@ class GatingUnit(nn.Module):
24
  self.proj_1 = nn.Linear(dim,dim,bias=False)
25
  self.proj_2 = nn.Linear(dim,dim,bias=False)
26
 
27
- self.relu = nn.SiLU()
28
 
29
 
30
  def forward(self, x):
31
 
32
  u, v = x, x
33
  u = self.proj_1(u)
34
- u = self.relu(u)
35
  v = self.proj_2(v)
36
  g = u * v
37
 
@@ -45,14 +45,14 @@ class LiteMixerBlock(nn.Module):
45
  self.norm = VectorDynamicTanh(dim)
46
 
47
 
48
- self.context_interaction = nn.Sequential(
49
 
50
  Rearrange('b n d -> b d n'),
51
  GatingUnit(num_patch),
52
  Rearrange('b d n -> b n d')
53
  )
54
 
55
- self.token_interaction = GatingUnit(dim)
56
 
57
 
58
  def forward(self, x):
@@ -61,7 +61,7 @@ class LiteMixerBlock(nn.Module):
61
 
62
  x = self.norm(x)
63
 
64
- x = self.context_interaction(x)
65
 
66
  x = x + residual
67
 
@@ -69,7 +69,7 @@ class LiteMixerBlock(nn.Module):
69
 
70
  x = self.norm(x)
71
 
72
- x = self.token_interaction(x)
73
 
74
  x = x + residual
75
 
 
24
  self.proj_1 = nn.Linear(dim,dim,bias=False)
25
  self.proj_2 = nn.Linear(dim,dim,bias=False)
26
 
27
+ self.silu = nn.SiLU()
28
 
29
 
30
  def forward(self, x):
31
 
32
  u, v = x, x
33
  u = self.proj_1(u)
34
+ u = self.silu(u)
35
  v = self.proj_2(v)
36
  g = u * v
37
 
 
45
  self.norm = VectorDynamicTanh(dim)
46
 
47
 
48
+ self.context_process = nn.Sequential(
49
 
50
  Rearrange('b n d -> b d n'),
51
  GatingUnit(num_patch),
52
  Rearrange('b d n -> b n d')
53
  )
54
 
55
+ self.token_process = GatingUnit(dim)
56
 
57
 
58
  def forward(self, x):
 
61
 
62
  x = self.norm(x)
63
 
64
+ x = self.context_process(x)
65
 
66
  x = x + residual
67
 
 
69
 
70
  x = self.norm(x)
71
 
72
+ x = self.token_process(x)
73
 
74
  x = x + residual
75