import torch import torch.nn as nn import torch from torch.autograd import Variable import copy import torch.nn.functional as F from torch.nn import CrossEntropyLoss, MSELoss class RobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" def __init__(self, config, args): super().__init__() self.dropout = nn.Dropout(config.hidden_dropout_prob) self.out_proj = nn.Linear(config.hidden_size, args.n_classes) def forward(self, features, **kwargs): x = features[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x) x = self.out_proj(x) return x class Model(nn.Module): def __init__(self, encoder, config, tokenizer, args): super(Model, self).__init__() self.encoder = encoder self.config = config self.tokenizer = tokenizer self.classifier = RobertaClassificationHead(config, args) self.args = args def forward(self, code_ids, attn_mask, position_idx, label=None): # embedding nodes_mask = position_idx.eq(0) token_mask = position_idx.ge(2) inputs_embeddings = self.encoder.roberta.embeddings.word_embeddings(code_ids) nodes_to_token_mask = nodes_mask[:, :, None] & token_mask[:, None, :] & attn_mask nodes_to_token_mask = nodes_to_token_mask / (nodes_to_token_mask.sum(-1) + 1e-10)[:, :, None] avg_embeddings = torch.einsum("abc,acd->abd", nodes_to_token_mask, inputs_embeddings) inputs_embeddings = inputs_embeddings * (~nodes_mask)[:, :, None] + avg_embeddings * nodes_mask[:, :, None] outputs = \ self.encoder.roberta(inputs_embeds=inputs_embeddings, attention_mask=attn_mask, position_ids=position_idx, token_type_ids=position_idx.eq(-1).long())[0] logit = self.classifier(outputs) prob = F.softmax(logit, dim=-1) if label is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logit, label) return loss, prob else: return prob