xiangan commited on
Commit
bf5878a
·
verified ·
1 Parent(s): 4a1d1db

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_onevision_encoder.py +9 -8
modeling_onevision_encoder.py CHANGED
@@ -165,23 +165,24 @@ class VideoRotaryEmbeddingSplit466(nn.Module):
165
  Compute rotary position embeddings from explicit patch positions.
166
 
167
  Args:
168
- patch_positions: [seq_len, 3] tensor with [t, h, w] positions for each patch
169
 
170
  Returns:
171
- freqs: [seq_len, half] tensor of position frequencies
172
  """
173
  device = patch_positions.device
174
  inv_t = self.inv_freq_t.to(device=device)
175
  inv_h = self.inv_freq_h.to(device=device)
176
  inv_w = self.inv_freq_w.to(device=device)
177
 
178
- t_pos = patch_positions[:, 0].float()
179
- h_pos = patch_positions[:, 1].float()
180
- w_pos = patch_positions[:, 2].float()
181
 
182
- ft = torch.outer(t_pos, inv_t)
183
- fh = torch.outer(h_pos, inv_h)
184
- fw = torch.outer(w_pos, inv_w)
 
185
 
186
  return torch.cat([ft, fh, fw], dim=-1)
187
 
 
165
  Compute rotary position embeddings from explicit patch positions.
166
 
167
  Args:
168
+ patch_positions: [batch_size, seq_len, 3] tensor with [t, h, w] positions for each patch
169
 
170
  Returns:
171
+ freqs: [batch_size, seq_len, half] tensor of position frequencies
172
  """
173
  device = patch_positions.device
174
  inv_t = self.inv_freq_t.to(device=device)
175
  inv_h = self.inv_freq_h.to(device=device)
176
  inv_w = self.inv_freq_w.to(device=device)
177
 
178
+ t_pos = patch_positions[..., 0].float() # [batch_size, seq_len]
179
+ h_pos = patch_positions[..., 1].float() # [batch_size, seq_len]
180
+ w_pos = patch_positions[..., 2].float() # [batch_size, seq_len]
181
 
182
+ # Use einsum for batched outer product: [batch_size, seq_len] x [dim] -> [batch_size, seq_len, dim]
183
+ ft = torch.einsum("bs,d->bsd", t_pos, inv_t)
184
+ fh = torch.einsum("bs,d->bsd", h_pos, inv_h)
185
+ fw = torch.einsum("bs,d->bsd", w_pos, inv_w)
186
 
187
  return torch.cat([ft, fh, fw], dim=-1)
188