annotate.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import argparse
  2. import json
  3. from configparser import ConfigParser
  4. from apps.task.annote.datatype.extract import extract_data_type
  5. from apps.task.annote.purpose.predict import predict
  6. from apps.task.annote.utils import load_json
  7. parser = argparse.ArgumentParser()
  8. # Required parameters
  9. parser.add_argument("--lang", default="python", type=str,
  10. help="language type, default is python")
  11. parser.add_argument("--tokenizer_name", default="microsoft/graphcodebert-base", type=str,
  12. help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
  13. parser.add_argument("--eval_batch_size", default=8, type=int,
  14. help="Batch size per GPU/CPU for predict.")
  15. parser.add_argument("--output_dir", default="purpose/saved_models", type=str,
  16. help="The output directory where the model predictions and checkpoints will be written.")
  17. parser.add_argument("--checkpoint_prefix", default="model.bin", type=str,
  18. help="The output directory where the model predictions and checkpoints will be written.")
  19. parser.add_argument("--config_name", default="microsoft/graphcodebert-base", type=str,
  20. help="Optional pretrained config name or path if not the same as model_name_or_path")
  21. parser.add_argument("--code_length", default=256, type=int,
  22. help="Optional Code input sequence length after tokenization.")
  23. parser.add_argument("--data_flow_length", default=64, type=int,
  24. help="Optional Data Flow input sequence length after tokenization.")
  25. parser.add_argument('--seed', type=int, default=42,
  26. help="random seed for initialization")
  27. parser.add_argument('--n_classes', type=int, default=10,
  28. help="random seed for initialization")
  29. args, unknown = parser.parse_known_args()
  30. def reload_params(train_params):
  31. args.n_classes = train_params['n_classes']
  32. args.do_train = True if train_params['do_train'] == 'True' else False
  33. args.do_test = True if train_params['do_test'] == 'True' else False
  34. args.train_batch_size = train_params['train_batch_size']
  35. args.eval_batch_size = train_params['eval_batch_size']
  36. args.epochs = train_params['epochs']
  37. args.lang = train_params['lang']
  38. args.output_dir = train_params['output_dir']
  39. args.code_length = train_params['code_length']
  40. args.data_flow_length = train_params['data_flow_length']
  41. args.seed = train_params['seed']
  42. args.train_data_file = train_params['train_data_file']
  43. args.test_data_file = train_params['test_data_file']
  44. args.model_path = train_params['model_path']
  45. args.gradient_accumulation_steps = train_params['gradient_accumulation_steps']
  46. args.learning_rate = train_params['learning_rate']
  47. args.weight_decay = train_params['weight_decay']
  48. args.adam_epsilon = train_params['adam_epsilon']
  49. args.max_grad_norm = train_params['max_grad_norm']
  50. args.max_steps = train_params['max_steps']
  51. args.warmup_steps = train_params['warmup_steps']
  52. def load_params():
  53. cp = ConfigParser()
  54. cp.read('params.cfg', encoding='utf-8')
  55. args.n_classes = int(cp.get('params', 'n_classes'))
  56. args.eval_batch_size = int(cp.get('params', 'batch_size'))
  57. args.output_dir = cp.get('params', 'output_dir')
  58. args.code_length = int(cp.get('params', 'code_length'))
  59. args.data_flow_length = int(cp.get('params', 'data_flow_length'))
  60. args.seed = int(cp.get('params', 'seed'))
  61. def annotate(source, lattices, entire=False):
  62. """
  63. :param source: 文件路径
  64. :param lattices: data_type的标注词典
  65. :param entire:
  66. :return:
  67. """
  68. params = load_json('apps/task/annote/params.json')
  69. reload_params(params)
  70. data_type_list = extract_data_type(source, lattices, args)
  71. purpose_list = predict(source, args)
  72. if entire:
  73. methods = dict()
  74. for data_type_single in data_type_list:
  75. func_key = data_type_single['file_path'] + "-" + data_type_single['func_name']
  76. if func_key in methods.keys():
  77. methods[func_key]['data_type'].append(data_type_single)
  78. else:
  79. methods[func_key] = dict()
  80. methods[func_key]['data_type'] = [data_type_single]
  81. for purpose in purpose_list:
  82. func_key = purpose['file_path'] + "-" + purpose['func_name']
  83. if func_key in methods.keys():
  84. methods[func_key]['purpose'] = purpose
  85. else:
  86. methods[func_key] = dict()
  87. methods[func_key]['purpose'] = purpose
  88. return methods
  89. else:
  90. return data_type_list, purpose_list
  91. if __name__ == '__main__':
  92. with open("datatype_dictionary.json", 'r', encoding='utf-8') as file:
  93. data_type = json.load(file)
  94. path = "/Users/liufan/Documents/实验室/隐私扫描项目/SAP检测项目/mini/Instagram_profile"
  95. methods = annotate(path, data_type, entire=False)
  96. a = 5