import torch import torch.nn as nn from transformers import AutoModel class BERTMultiLabel(nn.Module): def __init__(self, model_name="microsoft/deberta-v3-base", num_labels=5): super().__init__() self.bert = AutoModel.from_pretrained(model_name) hidden = self.bert.config.hidden_size self.dropout = nn.Dropout(0.2) self.classifier = nn.Linear(hidden, num_labels) def forward(self, input_ids, attention_mask): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask ) cls = outputs.last_hidden_state[:, 0] # CLS token cls = self.dropout(cls) logits = self.classifier(cls) return logits