predict.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import argparse
  2. import json
  3. import logging
  4. import multiprocessing
  5. import os
  6. import pickle
  7. import random
  8. from xml import parsers
  9. import numpy as np
  10. import torch
  11. from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler, TensorDataset
  12. from tqdm import tqdm
  13. from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
  14. RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer)
  15. from tree_sitter import Language, Parser
  16. from apps.task.annote.Parser import DFG_python, DFG_go, DFG_javascript, remove_comments_and_docstrings, tree_to_token_index, \
  17. index_to_code_token
  18. from apps.task.annote.Parser.utils import extract_dataflow
  19. from apps.task.annote.datatype.extract import split_file_by_func
  20. from apps.task.annote.purpose.model import Model
  21. from apps.task.annote.utils import walk_files
  22. cpu_cont = 1
  23. logger = logging.getLogger(__name__)
  24. dfg_function = {
  25. 'python': DFG_python,
  26. }
  27. Parsers = {}
  28. for lang in dfg_function:
  29. LANGUAGE = Language('apps/task/annote/Parser/my-languages.so', lang)
  30. parser = Parser()
  31. parser.set_language(LANGUAGE)
  32. parser = [parser, dfg_function[lang]]
  33. Parsers[lang] = parser
  34. def set_seed(args):
  35. random.seed(args.seed)
  36. np.random.seed(args.seed)
  37. torch.manual_seed(args.seed)
  38. if args.n_gpu > 0:
  39. torch.cuda.manual_seed_all(args.seed)
  40. class InputFeatures(object):
  41. """A single training/test features for a example."""
  42. def __init__(self,
  43. code_tokens,
  44. code_ids,
  45. position_idx,
  46. dfg_to_code,
  47. dfg_to_dfg,
  48. file_path,
  49. func_name,
  50. label=None
  51. ):
  52. self.label_list = ['Archive', 'Azure', 'File', 'Hash', 'Kafka', 'Other', 'Pseudonym', 'S3', 'Truncate',
  53. 'Visualize']
  54. self.code_tokens = code_tokens
  55. self.code_ids = code_ids
  56. self.position_idx = position_idx
  57. self.dfg_to_code = dfg_to_code
  58. self.dfg_to_dfg = dfg_to_dfg
  59. self.file_path = file_path
  60. self.func_name = func_name
  61. self.label = label
  62. def __str__(self):
  63. return self.file_path + "-" + self.func_name + " " + self.label
  64. def convert_examples_to_features(item):
  65. """
  66. :param item: func_content
  67. :return: InputFeatures
  68. """
  69. file_path, func_name, content, tokenizer, arg = item
  70. # code_type
  71. par = Parsers[arg.lang]
  72. code_tokens, dfg = extract_dataflow(content, par, arg.lang)
  73. code_tokens = [tokenizer.tokenize('@ ' + x)[1:] if idx != 0 else tokenizer.tokenize(x) for idx, x in
  74. enumerate(code_tokens)]
  75. ori2cur_pos = {-1: (0, 0)}
  76. for i in range(len(code_tokens)):
  77. ori2cur_pos[i] = (ori2cur_pos[i - 1][1], ori2cur_pos[i - 1][1] + len(code_tokens[i]))
  78. code_tokens = [y for x in code_tokens for y in x]
  79. # truncating
  80. code_tokens = code_tokens[:arg.code_length + arg.data_flow_length - 2 - min(len(dfg), arg.data_flow_length)]
  81. code_tokens = [tokenizer.cls_token] + code_tokens + [tokenizer.sep_token]
  82. code_ids = tokenizer.convert_tokens_to_ids(code_tokens)
  83. position_idx = [i + tokenizer.pad_token_id + 1 for i in range(len(code_tokens))]
  84. dfg = dfg[:arg.code_length + arg.data_flow_length - len(code_tokens)]
  85. code_tokens += [x[0] for x in dfg]
  86. position_idx += [0 for x in dfg]
  87. code_ids += [tokenizer.unk_token_id for x in dfg]
  88. padding_length = arg.code_length + arg.data_flow_length - len(code_ids)
  89. position_idx += [tokenizer.pad_token_id] * padding_length
  90. code_ids += [tokenizer.pad_token_id] * padding_length
  91. # reindex
  92. reverse_index = {}
  93. for idx, x in enumerate(dfg):
  94. reverse_index[x[1]] = idx
  95. for idx, x in enumerate(dfg):
  96. dfg[idx] = x[:-1] + ([reverse_index[i] for i in x[-1] if i in reverse_index],)
  97. dfg_to_dfg = [x[-1] for x in dfg]
  98. dfg_to_code = [ori2cur_pos[x[1]] for x in dfg]
  99. length = len([tokenizer.cls_token])
  100. dfg_to_code = [(x[0] + length, x[1] + length) for x in dfg_to_code]
  101. return InputFeatures(code_tokens, code_ids, position_idx, dfg_to_code, dfg_to_dfg, file_path, func_name)
  102. class TextDataset(Dataset):
  103. def __init__(self, input_path, tokenizer, args, pool=None):
  104. self.examples = []
  105. self.args = args
  106. self.examples = []
  107. data = []
  108. for file_name in walk_files(input_path):
  109. func_content_dict = split_file_by_func(file_name)
  110. for func_name, func_content in func_content_dict.items():
  111. data.append((file_name, func_name, func_content[0], tokenizer, args))
  112. for d in data:
  113. self.examples.append(convert_examples_to_features(d))
  114. a = 5
  115. def __len__(self):
  116. return len(self.examples)
  117. def __getitem__(self, item):
  118. # calculate graph-guided masked function
  119. attn_mask = np.zeros((self.args.code_length + self.args.data_flow_length,
  120. self.args.code_length + self.args.data_flow_length), dtype=bool)
  121. # calculate begin index of node and max length of input
  122. node_index = sum([i > 1 for i in self.examples[item].position_idx])
  123. max_length = sum([i != 1 for i in self.examples[item].position_idx])
  124. # sequence can attend to sequence
  125. attn_mask[:node_index, :node_index] = True
  126. # special tokens attend to all tokens
  127. for idx, i in enumerate(self.examples[item].code_ids):
  128. if i in [0, 2]:
  129. attn_mask[idx, :max_length] = True
  130. # nodes attend to code tokens that are identified from
  131. for idx, (a, b) in enumerate(self.examples[item].dfg_to_code):
  132. if a < node_index and b < node_index:
  133. attn_mask[idx + node_index, a:b] = True
  134. attn_mask[a:b, idx + node_index] = True
  135. # nodes attend to adjacent nodes
  136. for idx, nodes in enumerate(self.examples[item].dfg_to_dfg):
  137. for a in nodes:
  138. if a + node_index < len(self.examples[item].position_idx):
  139. attn_mask[idx + node_index, a + node_index] = True
  140. return (torch.tensor(self.examples[item].code_ids),
  141. torch.tensor(attn_mask),
  142. torch.tensor(self.examples[item].position_idx))
  143. def predict(input_path, args):
  144. pool = multiprocessing.Pool(cpu_cont)
  145. # Setup CUDA, GPU
  146. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  147. # device = torch.device("cpu")
  148. args.n_gpu = 0
  149. args.device = device
  150. set_seed(args)
  151. config = RobertaConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
  152. config.num_labels = 1
  153. tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
  154. model = RobertaForSequenceClassification.from_pretrained(args.config_name, config=config)
  155. model = Model(model, config, tokenizer, args)
  156. output_dir = os.path.join(args.output_dir, '{}'.format(args.checkpoint_prefix))
  157. model.load_state_dict(torch.load(output_dir, map_location=device))
  158. model.to(args.device)
  159. # build dataloader
  160. eval_dataset = TextDataset(input_path, tokenizer, args, pool=pool)
  161. eval_sampler = SequentialSampler(eval_dataset)
  162. eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=4)
  163. # multi-gpu evaluate
  164. if args.n_gpu > 1:
  165. model = torch.nn.DataParallel(model)
  166. logit = []
  167. for batch in eval_dataloader:
  168. (code_ids, attn_mask, position_idx) = [x.to(args.device) for x in batch]
  169. with torch.no_grad():
  170. log = model(code_ids, attn_mask, position_idx)
  171. logit.append(log.cpu().numpy())
  172. logit = np.concatenate(logit, 0)
  173. y_pre = np.argmax(logit, 1).tolist()
  174. result = []
  175. for i in range(len(eval_dataset.examples)):
  176. if y_pre[i] != 5:
  177. eval_dataset.examples[i].label = eval_dataset.examples[i].label_list[y_pre[i]]
  178. result.append({
  179. 'file_path': eval_dataset.examples[i].file_path.replace(input_path + '/', ''),
  180. 'func_name': eval_dataset.examples[i].func_name,
  181. 'purpose': eval_dataset.examples[i].label
  182. })
  183. return result