Spaces:
Runtime error
Runtime error
| """ | |
| Inference CTC class derived from HubertForCTC. | |
| Author: Marcely Zanon Boito, 2024 | |
| """ | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| from torch import nn | |
| from transformers import HubertPreTrainedModel, HubertModel | |
| from transformers.modeling_outputs import CausalLMOutput, SequenceClassifierOutput | |
| class VanillaNN(nn.Module): | |
| def __init__(self, input_dim, output_dim): | |
| """ | |
| simple NN with ReLU activation (no norm) | |
| """ | |
| super().__init__() | |
| self.linear = nn.Linear(input_dim, output_dim) | |
| self.act_fn = nn.ReLU() | |
| def forward(self, hidden_states: torch.FloatTensor): | |
| hidden_states = self.linear(hidden_states) | |
| hidden_states = self.act_fn(hidden_states) | |
| return hidden_states | |
| class mHubertForCTC(HubertPreTrainedModel): | |
| def __init__(self, config, target_lang: Optional[str] = None): | |
| super().__init__(config) | |
| self.hubert = HubertModel(config) | |
| self.dropout = nn.Dropout(config.final_dropout) | |
| output_hidden_size = config.hidden_size | |
| self.has_interface = config.add_interface | |
| # NN layers on top of the trainable stack | |
| if config.add_interface: | |
| self.interface = nn.ModuleList([VanillaNN(output_hidden_size,output_hidden_size) for i in range(config.num_interface_layers)]) | |
| self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_values: Optional[torch.Tensor], | |
| attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| ) -> Union[Tuple, SequenceClassifierOutput]: | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| output_hidden_states = self.config.output_hidden_states | |
| outputs = self.hubert( | |
| input_values, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = outputs[0] | |
| hidden_states = self.dropout(hidden_states) | |
| if self.has_interface: | |
| for layer in self.interface: | |
| hidden_states = layer(hidden_states) | |
| logits = self.lm_head(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| if labels.max() >= self.config.vocab_size: | |
| raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") | |
| # retrieve loss input_lengths from attention_mask | |
| attention_mask = ( | |
| attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) | |
| ) | |
| input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) | |
| # assuming that padded tokens are filled with -100 | |
| # when not being attended to | |
| labels_mask = labels >= 0 | |
| target_lengths = labels_mask.sum(-1) | |
| flattened_targets = labels.masked_select(labels_mask) | |
| # ctc_loss doesn't support fp16 | |
| log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) | |
| with torch.backends.cudnn.flags(enabled=False): | |
| loss = nn.functional.ctc_loss( | |
| log_probs, | |
| flattened_targets, | |
| input_lengths, | |
| target_lengths, | |
| blank=self.config.ctc_token_id, | |
| reduction=self.config.ctc_loss_reduction, | |
| zero_infinity=self.config.ctc_zero_infinity, | |
| ) | |
| return CausalLMOutput( | |
| loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions | |
| ) |