model.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import torch
  2. import torch.nn as nn
  3. import torch
  4. from torch.autograd import Variable
  5. import copy
  6. import torch.nn.functional as F
  7. from torch.nn import CrossEntropyLoss, MSELoss
  8. class RobertaClassificationHead(nn.Module):
  9. """Head for sentence-level classification tasks."""
  10. def __init__(self, config, args):
  11. super().__init__()
  12. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  13. self.out_proj = nn.Linear(config.hidden_size, args.n_classes)
  14. def forward(self, features, **kwargs):
  15. x = features[:, 0, :] # take <s> token (equiv. to [CLS])
  16. x = self.dropout(x)
  17. x = self.out_proj(x)
  18. return x
  19. class Model(nn.Module):
  20. def __init__(self, encoder, config, tokenizer, args):
  21. super(Model, self).__init__()
  22. self.encoder = encoder
  23. self.config = config
  24. self.tokenizer = tokenizer
  25. self.classifier = RobertaClassificationHead(config, args)
  26. self.args = args
  27. def forward(self, code_ids, attn_mask, position_idx, label=None):
  28. # embedding
  29. nodes_mask = position_idx.eq(0)
  30. token_mask = position_idx.ge(2)
  31. inputs_embeddings = self.encoder.roberta.embeddings.word_embeddings(code_ids)
  32. nodes_to_token_mask = nodes_mask[:, :, None] & token_mask[:, None, :] & attn_mask
  33. nodes_to_token_mask = nodes_to_token_mask / (nodes_to_token_mask.sum(-1) + 1e-10)[:, :, None]
  34. avg_embeddings = torch.einsum("abc,acd->abd", nodes_to_token_mask, inputs_embeddings)
  35. inputs_embeddings = inputs_embeddings * (~nodes_mask)[:, :, None] + avg_embeddings * nodes_mask[:, :, None]
  36. outputs = \
  37. self.encoder.roberta(inputs_embeds=inputs_embeddings, attention_mask=attn_mask, position_ids=position_idx,
  38. token_type_ids=position_idx.eq(-1).long())[0]
  39. logit = self.classifier(outputs)
  40. prob = F.softmax(logit, dim=-1)
  41. if label is not None:
  42. loss_fct = CrossEntropyLoss()
  43. loss = loss_fct(logit, label)
  44. return loss, prob
  45. else:
  46. return prob