Upgrade transformers version

#8
by ybabakhin - opened
Files changed (2) hide show
  1. README.md +2 -2
  2. llama_bidirectional_model.py +36 -20
README.md CHANGED
@@ -67,10 +67,10 @@ We trained the model on public datasets described in the Dataset and Training se
67
 
68
  ### **Installation**
69
 
70
- The model requires transformers version 4.47.1.
71
 
72
  ```bash
73
- pip install transformers==4.47.1
74
  ```
75
 
76
  ### **Usage**
 
67
 
68
  ### **Installation**
69
 
70
+ The model requires transformers version >=4.47.1.
71
 
72
  ```bash
73
+ pip install transformers>=4.47.1
74
  ```
75
 
76
  ### **Usage**
llama_bidirectional_model.py CHANGED
@@ -5,7 +5,7 @@ import torch.nn.functional as F
5
  from torch import Tensor, nn
6
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
  from transformers.cache_utils import Cache, HybridCache
8
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
9
  from transformers.modeling_outputs import (
10
  BaseModelOutputWithPast,
11
  SequenceClassifierOutputWithPast,
@@ -21,6 +21,24 @@ from transformers.utils import logging
21
  logger = logging.get_logger(__name__)
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def pool(last_hidden_states: Tensor, attention_mask: Tensor, pool_type: str) -> Tensor:
25
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
26
 
@@ -64,31 +82,17 @@ class LlamaBidirectionalModel(LlamaModel):
64
  super().__init__(config)
65
  for layer in self.layers:
66
  layer.self_attn.is_causal = False
67
- self.config._attn_implementation = "eager"
68
-
69
- def _update_causal_mask(
70
- self,
71
- attention_mask: torch.Tensor,
72
- input_tensor: torch.Tensor,
73
- cache_position: torch.Tensor,
74
- past_key_values: Cache,
75
- output_attentions: bool,
76
- ):
77
- # Generates bi-directional attention.
78
- causal_mask = _prepare_4d_attention_mask(attention_mask, input_tensor.dtype)
79
- return causal_mask
80
 
81
 
82
  class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification):
83
  config_class = LlamaBidirectionalConfig
84
 
85
  def __init__(self, config):
86
- super().__init__(config)
87
- # Releasing the parameters of LlamaModel
88
- # created by parent LlamaForSequenceClassification
89
- del self.model
90
 
 
91
  self.model = LlamaBidirectionalModel(config)
 
92
 
93
  # Initialize weights and apply final processing
94
  self.post_init()
@@ -105,6 +109,7 @@ class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification
105
  output_attentions: Optional[bool] = None,
106
  output_hidden_states: Optional[bool] = None,
107
  return_dict: Optional[bool] = None,
 
108
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
109
  r"""
110
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -116,6 +121,16 @@ class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification
116
  return_dict if return_dict is not None else self.config.use_return_dict
117
  )
118
 
 
 
 
 
 
 
 
 
 
 
119
  transformer_outputs = self.model(
120
  input_ids,
121
  attention_mask=attention_mask,
@@ -126,12 +141,13 @@ class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification
126
  output_attentions=output_attentions,
127
  output_hidden_states=output_hidden_states,
128
  return_dict=return_dict,
 
129
  )
130
  hidden_states = transformer_outputs[0]
131
 
132
  pooled_hidden_states = pool(
133
  last_hidden_states=hidden_states,
134
- attention_mask=attention_mask,
135
  pool_type=self.config.pooling,
136
  )
137
 
@@ -140,7 +156,7 @@ class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification
140
 
141
  loss = None
142
  if labels is not None:
143
- labels = labels.to(logits.device)
144
  if self.config.problem_type is None:
145
  if self.num_labels == 1:
146
  self.config.problem_type = "regression"
 
5
  from torch import Tensor, nn
6
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
  from transformers.cache_utils import Cache, HybridCache
8
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
9
  from transformers.modeling_outputs import (
10
  BaseModelOutputWithPast,
11
  SequenceClassifierOutputWithPast,
 
21
  logger = logging.get_logger(__name__)
22
 
23
 
24
+ def create_bidirectional_attention_mask(
25
+ attn_implementation: str,
26
+ attention_mask: torch.Tensor,
27
+ dtype: torch.dtype,
28
+ ) -> torch.Tensor:
29
+
30
+ if attn_implementation == "flash_attention_2":
31
+ if attention_mask is not None and (attention_mask == 0.0).any():
32
+ return attention_mask
33
+ return None
34
+ elif attn_implementation == "eager":
35
+ return _prepare_4d_attention_mask(attention_mask, dtype=dtype)
36
+ elif attn_implementation == "sdpa":
37
+ return _prepare_4d_attention_mask_for_sdpa(attention_mask, dtype=dtype)
38
+ else:
39
+ raise ValueError(f"Unsupported attention implementation: {attn_implementation}, only support flash_attention_2, eager or sdpa")
40
+
41
+
42
  def pool(last_hidden_states: Tensor, attention_mask: Tensor, pool_type: str) -> Tensor:
43
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
44
 
 
82
  super().__init__(config)
83
  for layer in self.layers:
84
  layer.self_attn.is_causal = False
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
  class LlamaBidirectionalForSequenceClassification(LlamaForSequenceClassification):
88
  config_class = LlamaBidirectionalConfig
89
 
90
  def __init__(self, config):
91
+ LlamaPreTrainedModel.__init__(self, config)
 
 
 
92
 
93
+ self.num_labels = config.num_labels
94
  self.model = LlamaBidirectionalModel(config)
95
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
96
 
97
  # Initialize weights and apply final processing
98
  self.post_init()
 
109
  output_attentions: Optional[bool] = None,
110
  output_hidden_states: Optional[bool] = None,
111
  return_dict: Optional[bool] = None,
112
+ **kwargs,
113
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
114
  r"""
115
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
121
  return_dict if return_dict is not None else self.config.use_return_dict
122
  )
123
 
124
+ # Keep original 2D mask for pooling
125
+ attention_mask_2d = attention_mask
126
+
127
+ # Create 4D bidirectional attention mask
128
+ attention_mask = create_bidirectional_attention_mask(
129
+ attn_implementation=self.config._attn_implementation,
130
+ attention_mask=attention_mask,
131
+ dtype=self.config.torch_dtype,
132
+ )
133
+
134
  transformer_outputs = self.model(
135
  input_ids,
136
  attention_mask=attention_mask,
 
141
  output_attentions=output_attentions,
142
  output_hidden_states=output_hidden_states,
143
  return_dict=return_dict,
144
+ **kwargs,
145
  )
146
  hidden_states = transformer_outputs[0]
147
 
148
  pooled_hidden_states = pool(
149
  last_hidden_states=hidden_states,
150
+ attention_mask=attention_mask_2d,
151
  pool_type=self.config.pooling,
152
  )
153
 
 
156
 
157
  loss = None
158
  if labels is not None:
159
+ labels = labels.to(pooled_logits.device)
160
  if self.config.problem_type is None:
161
  if self.num_labels == 1:
162
  self.config.problem_type = "regression"