123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- import torch
- import torch.nn as nn
- import torch
- from torch.autograd import Variable
- import copy
- import torch.nn.functional as F
- from torch.nn import CrossEntropyLoss, MSELoss
- class RobertaClassificationHead(nn.Module):
- """Head for sentence-level classification tasks."""
- def __init__(self, config, args):
- super().__init__()
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.out_proj = nn.Linear(config.hidden_size, args.n_classes)
- def forward(self, features, **kwargs):
- x = features[:, 0, :] # take <s> token (equiv. to [CLS])
- x = self.dropout(x)
- x = self.out_proj(x)
- return x
- class Model(nn.Module):
- def __init__(self, encoder, config, tokenizer, args):
- super(Model, self).__init__()
- self.encoder = encoder
- self.config = config
- self.tokenizer = tokenizer
- self.classifier = RobertaClassificationHead(config, args)
- self.args = args
- def forward(self, code_ids, attn_mask, position_idx, label=None):
- # embedding
- nodes_mask = position_idx.eq(0)
- token_mask = position_idx.ge(2)
- inputs_embeddings = self.encoder.roberta.embeddings.word_embeddings(code_ids)
- nodes_to_token_mask = nodes_mask[:, :, None] & token_mask[:, None, :] & attn_mask
- nodes_to_token_mask = nodes_to_token_mask / (nodes_to_token_mask.sum(-1) + 1e-10)[:, :, None]
- avg_embeddings = torch.einsum("abc,acd->abd", nodes_to_token_mask, inputs_embeddings)
- inputs_embeddings = inputs_embeddings * (~nodes_mask)[:, :, None] + avg_embeddings * nodes_mask[:, :, None]
- outputs = \
- self.encoder.roberta(inputs_embeds=inputs_embeddings, attention_mask=attn_mask, position_ids=position_idx,
- token_type_ids=position_idx.eq(-1).long())[0]
- logit = self.classifier(outputs)
- prob = F.softmax(logit, dim=-1)
- if label is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logit, label)
- return loss, prob
- else:
- return prob
|