mutate_lemon.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. # -*-coding:UTF-8-*-
  2. import csv
  3. from itertools import *
  4. import keras
  5. import json
  6. import networkx as nx
  7. import sys
  8. # sys.path.append("../")
  9. import os
  10. sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
  11. from scripts.logger.lemon_logger import Logger
  12. from scripts.tools.mutator_selection_logic import Roulette, MCMC
  13. from scripts.mutation.model_mutation_generators import *
  14. import argparse
  15. import ast
  16. import numpy as np
  17. from scripts.mutation.mutation_utils import *
  18. import pickle
  19. from scripts.tools import utils
  20. from scripts.tools.utils import ModelUtils
  21. import shutil
  22. import re
  23. import datetime
  24. import configparser
  25. import warnings
  26. import math
  27. lines = 0
  28. # np.random.seed(20200501)
  29. warnings.filterwarnings("ignore")
  30. os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
  31. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  32. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  33. import psutil
  34. def partially_nan_or_inf(predictions, bk_num): # 检查是否无穷大或空
  35. """
  36. Check if there is NAN in the result
  37. """
  38. def get_nan_num(nds):
  39. _nan_num = 0
  40. for nd in nds:
  41. if np.isnan(nd).any() or np.isinf(nd).any():
  42. _nan_num += 1
  43. return _nan_num
  44. if len(predictions) == bk_num:
  45. for input_predict in zip(*predictions):
  46. nan_num = get_nan_num(input_predict)
  47. if 0 < nan_num < bk_num:
  48. return True
  49. else:
  50. continue
  51. return False
  52. else:
  53. raise Exception("wrong backend amounts")
  54. def get_selector_by_startegy_name(mutator_s, mutant_s):
  55. mutant_strategy_dict = {"ROULETTE": Roulette}
  56. mutator_strategy_dict = {"MCMC": MCMC}
  57. return mutator_strategy_dict[mutator_s], mutant_strategy_dict[mutant_s]
  58. def save_mutate_history(selector, invalid_history: dict, mutant_history: list):
  59. mutator_history_path = os.path.join(experiment_dir, "mutator_history.csv")
  60. mutant_history_path = os.path.join(experiment_dir, "mutant_history.txt")
  61. with open(mutator_history_path, "w+") as fw:
  62. fw.write("Name,Success,Invalid,Total\n")
  63. for op in invalid_history.keys():
  64. mtrs = selector.mutators[op]
  65. invalid_cnt = invalid_history[op]
  66. fw.write(
  67. "{},{},{},{}\n".format(
  68. op, mtrs.delta_bigger_than_zero, invalid_cnt, mtrs.total
  69. )
  70. )
  71. with open(mutant_history_path, "w+") as fw:
  72. for mutant in mutant_history:
  73. fw.write("{}\n".format(mutant))
  74. def is_nan_or_inf(t):
  75. if math.isnan(t) or math.isinf(t):
  76. return True
  77. else:
  78. return False
  79. def continue_checker(**run_stat): # 判断算法是否满足退出条件
  80. start_time = run_stat["start_time"]
  81. time_limitation = run_stat["time_limit"]
  82. cur_counters = run_stat["cur_counters"]
  83. counters_limit = run_stat["counters_limit"]
  84. s_mode = run_stat["stop_mode"]
  85. # if timing
  86. # 时间限制
  87. if s_mode == "TIMING":
  88. hours, minutes, seconds = utils.ToolUtils.get_HH_mm_ss(
  89. datetime.datetime.now() - start_time
  90. )
  91. total_minutes = hours * 60 + minutes
  92. mutate_logger.info(
  93. f"INFO: Mutation progress: {total_minutes}/{time_limitation} Minutes!"
  94. )
  95. if total_minutes < time_limitation:
  96. return True
  97. else:
  98. return False
  99. # if counters
  100. # 次数限制,size(models)<N
  101. elif s_mode == "COUNTER":
  102. if cur_counters < counters_limit:
  103. mutate_logger.info(
  104. "INFO: Mutation progress {}/{}".format(cur_counters + 1, counters_limit)
  105. )
  106. return True
  107. else:
  108. return False
  109. else:
  110. raise Exception(f"Error! Stop Mode {s_mode} not Found!")
  111. def calc_inner_div(model):
  112. graph = nx.DiGraph()
  113. for layer in model.layers:
  114. graph.add_node(layer.name)
  115. for inbound_node in layer._inbound_nodes:
  116. if inbound_node.inbound_layers:
  117. for parent_layer in inbound_node.inbound_layers:
  118. graph.add_edge(parent_layer.name, layer.name)
  119. longest_path = nx.dag_longest_path(graph)
  120. return len(longest_path) / len(graph)
  121. def _generate_and_predict(
  122. res_dict, filename, mutate_num, mutate_ops, test_size, exp, backends
  123. ):
  124. # 主算法函数
  125. """
  126. Generate models using mutate operators and store them
  127. """
  128. mutate_op_history = {k: 0 for k in mutate_ops}
  129. mutate_op_invalid_history = {k: 0 for k in mutate_ops}
  130. mutant_history = []
  131. # get mutator selection strategy
  132. if "svhn" in exp or "fashion2" in exp:
  133. origin_model_name = "{}_origin0.hdf5".format(exp)
  134. else:
  135. origin_model_name = "{}_origin0.h5".format(exp)
  136. # 初始种子模型列表Ms初始时只有这一个模型
  137. root_dir = os.path.dirname(os.getcwd())
  138. origin_save_path = os.path.join(mut_dir, origin_model_name)
  139. mutator_selector_func, mutant_selector_func = get_selector_by_startegy_name(
  140. mutator_strategy, mutant_strategy
  141. )
  142. # [origin_model_name] means seed pool only contains initial model at beginning.
  143. mutator_selector, mutant_selector = mutator_selector_func(
  144. mutate_ops
  145. ), mutant_selector_func([origin_model_name], capacity=mutate_num + 1)
  146. # MCMC,Roulette
  147. shutil.copy(src=filename, dst=origin_save_path)
  148. origin_model_status, res_dict, accumulative_inconsistency, _ = get_model_prediction(
  149. res_dict, origin_save_path, origin_model_name, exp, test_size, backends
  150. )
  151. if not origin_model_status:
  152. mutate_logger.error(
  153. f"Origin model {exp} crashed on some backends! LEMON would skip it"
  154. )
  155. sys.exit(-1)
  156. last_used_mutator = None
  157. last_inconsistency = accumulative_inconsistency # ACC
  158. mutant_counter = 0
  159. start_time = datetime.datetime.now()
  160. order_inconsistency_dict = {}
  161. run_stat = {
  162. "start_time": start_time,
  163. "time_limit": time_limit,
  164. "cur_counters": mutant_counter,
  165. "counters_limit": mutate_num,
  166. "stop_mode": stop_mode,
  167. }
  168. # 满足限制条件就继续循环
  169. while continue_checker(**run_stat):
  170. global model_num
  171. if model_num == mutate_num:
  172. break
  173. picked_seed = utils.ToolUtils.select_mutant(
  174. mutant_selector
  175. ) # 轮盘赌选择种子模型(伪代码3-14行)
  176. selected_op = utils.ToolUtils.select_mutator(
  177. mutator_selector, last_used_mutator=last_used_mutator
  178. ) # 蒙特卡洛选择变异算子(伪代码15-20行)
  179. mutate_op_history[selected_op] += 1
  180. last_used_mutator = selected_op
  181. mutator = mutator_selector.mutators[selected_op] # 变异算子对象
  182. mutant = mutant_selector.mutants[picked_seed] # 种子模型对象
  183. if "svhn" in picked_seed or "fashion2" in picked_seed:
  184. new_seed_name = "{}-{}{}.hdf5".format(
  185. picked_seed[:-5], selected_op, mutate_op_history[selected_op]
  186. )
  187. else:
  188. new_seed_name = "{}-{}{}.h5".format(
  189. picked_seed[:-3], selected_op, mutate_op_history[selected_op]
  190. ) # 生成新模型
  191. # seed name would not be duplicate
  192. if new_seed_name not in mutant_selector.mutants.keys():
  193. # 对应伪代码22行,因为种子模型是以当前选择的种子模型和变异算子命名的,所以重名就表示这个模型已经存在了
  194. new_seed_path = os.path.join(mut_dir, new_seed_name)
  195. picked_seed_path = os.path.join(mut_dir, picked_seed)
  196. mutate_st = datetime.datetime.now()
  197. model_mutation_generators = (
  198. root_dir + "/scripts/mutation/model_mutation_generators.py"
  199. )
  200. mutate_status = os.system(
  201. "{}/lemon/bin/python -u {} --model {} "
  202. "--mutate_op {} --save_path {} --mutate_ratio {}".format(
  203. python_prefix,
  204. model_mutation_generators,
  205. picked_seed_path,
  206. selected_op,
  207. new_seed_path,
  208. flags.mutate_ratio,
  209. )
  210. )
  211. # 使用变异算子进行变异(伪代码21行)
  212. mutate_et = datetime.datetime.now()
  213. mutate_dt = mutate_et - mutate_st
  214. h, m, s = utils.ToolUtils.get_HH_mm_ss(mutate_dt)
  215. mutate_logger.info(
  216. "INFO:Mutate Time Used on {} : {}h, {}m, {}s".format(
  217. selected_op, h, m, s
  218. )
  219. )
  220. # mutation status code is successful
  221. if mutate_status == 0: # 变异执行完成
  222. mutant.selected += 1
  223. mutator.total += 1
  224. # execute this model on all platforms
  225. predict_status, res_dict, accumulative_inconsistency, model_outputs = (
  226. get_model_prediction(
  227. res_dict, new_seed_path, new_seed_name, exp, test_size, backends
  228. )
  229. )
  230. # 计算ACC(m)
  231. if predict_status:
  232. mutant_history.append(new_seed_name)
  233. # 伪代码23-25行
  234. print("type:", type(model_outputs))
  235. print("model_outputs:", model_outputs)
  236. if utils.ModelUtils.is_valid_model(
  237. inputs_backends=model_outputs, backends_nums=len(backends)
  238. ):
  239. delta = (
  240. accumulative_inconsistency - last_inconsistency
  241. ) # 也就是ACC(m)-ACC(s)
  242. # 下面两个if好像没什么用,因为mutator字典里只有MCMC,mutant字典里只有ROULETTE
  243. if mutator_strategy == "MCMC":
  244. mutator.delta_bigger_than_zero = (
  245. mutator.delta_bigger_than_zero + 1
  246. if delta > 0
  247. else mutator.delta_bigger_than_zero
  248. )
  249. if mutant_strategy == "ROULETTE" and delta > 0:
  250. # when size >= capacity:
  251. # random_mutant & Roulette would drop one and add new one
  252. if mutant_selector.is_full():
  253. mutant_selector.pop_one_mutant()
  254. mutant_selector.add_mutant(
  255. new_seed_name
  256. ) # 如果放大了不一致程度,即ACC(m)>=ACC(s),就加入到种子模型集合里
  257. last_inconsistency = accumulative_inconsistency # 29行
  258. mutate_logger.info(
  259. "SUCCESS:{} pass testing!".format(new_seed_name)
  260. )
  261. mutant_counter += 1
  262. else:
  263. mutate_op_invalid_history[selected_op] += 1
  264. mutate_logger.error("Invalid model Found!")
  265. else:
  266. mutate_logger.error("Crashed or NaN model Found!")
  267. else:
  268. mutate_logger.error(
  269. "Exception raised when mutate {} with {}".format(
  270. picked_seed, selected_op
  271. )
  272. )
  273. mutate_logger.info("Mutated op used history:")
  274. mutate_logger.info(mutate_op_history)
  275. mutate_logger.info("Invalid mutant generated history:")
  276. mutate_logger.info(mutate_op_invalid_history)
  277. run_stat["cur_counters"] = mutant_counter
  278. save_mutate_history(mutator_selector, mutate_op_invalid_history, mutant_history)
  279. # calc_cov = CoverageCalculatornew(all_json_path, api_config_pool_path)
  280. # lines = 0
  281. # for file in os.listdir(folder_path):
  282. # if file == 'total.json': continue
  283. # file_path = os.path.join(folder_path, file)
  284. # calc_cov.load_json(file_path)
  285. # with open(file_path, 'r') as sub_json:
  286. # sub_info = json.load(sub_json)
  287. # outer_div = len(tar_set - set(sub_info['layer_type']))
  288. # input_cov, config_cov, api_cov, op_type_cov, op_num_cov, edge_cov = calc_cov.cal_coverage()
  289. # with open(output_path, 'a+', newline='') as fi:
  290. # writer = csv.writer(fi)
  291. # head = ['Layer Input Coverage', 'Layer Parameter Diversity', 'Layer Sequence Diversity',
  292. # 'Operator Type Coverage', 'Operator Num Coverage', 'Edge Coverage', 'Accumulative inconsistency']
  293. # if not lines:
  294. # writer.writerow(head)
  295. # lines += 1
  296. # printlist = [input_cov, config_cov, api_cov, op_type_cov, op_num_cov, edge_cov,
  297. # acc[lines]]
  298. # writer.writerow(printlist)
  299. return res_dict
  300. def generate_metrics_result(res_dict, predict_output, model_idntfr): # 计算ACC
  301. mutate_logger.info("Generating Metrics Result")
  302. accumulative_incons = 0
  303. backends_pairs_num = 0
  304. # Compare results pair by pair
  305. for pair in combinations(predict_output.items(), 2): # 每一对库
  306. backends_pairs_num += 1
  307. backend1, backend2 = pair
  308. bk_name1, prediction1 = backend1
  309. bk_name2, prediction2 = backend2
  310. bk_pair = "{}_{}".format(bk_name1, bk_name2)
  311. for metrics_name, metrics_result_dict in res_dict.items():
  312. metrics_func = utils.MetricsUtils.get_metrics_by_name(metrics_name) # 计算
  313. # metrics_results in list type
  314. metrics_results = metrics_func(
  315. prediction1, prediction2, y_test[: flags.test_size]
  316. )
  317. # 一共test_size个数据集,所以metrics_result是长度为test_size的预测结果列表
  318. # ACC -> float: The sum of all inputs under all backends
  319. accumulative_incons += sum(metrics_results) # ACC=∑
  320. for input_idx, delta in enumerate(metrics_results):
  321. delta_key = "{}_{}_{}_input{}".format(
  322. model_idntfr, bk_name1, bk_name2, input_idx
  323. )
  324. metrics_result_dict[delta_key] = delta
  325. mutate_logger.info(f"Accumulative Inconsistency: {accumulative_incons}")
  326. return res_dict, accumulative_incons
  327. def generate_gini_result(predict_output, backends):
  328. gini_res = {bk: 0 for bk in backends}
  329. for pair in predict_output.items():
  330. bk_name, prediction = pair
  331. gini_res[bk_name] = utils.MetricsUtils.get_gini_mean(prediction)
  332. return gini_res
  333. def generate_theta(predict_output, backends):
  334. theta_res = {bk: 0 for bk in backends}
  335. for pair in predict_output.items():
  336. bk_name, prediction = pair
  337. theta_res[bk_name] = utils.MetricsUtils.get_theta_mean(
  338. prediction, y_test[: flags.test_size]
  339. )
  340. return theta_res
  341. SHAPE_SPACE = 5
  342. model_num = 0
  343. def get_model_prediction(res_dict, model_path, model_name, exp, test_size, backends):
  344. # 计算ACC
  345. """
  346. Get model prediction on different backends and calculate distance by metrics
  347. """
  348. root_dir = model_path.split("origin_model")[0]
  349. npy_path = (
  350. root_dir + "res.npy"
  351. ) # 保存模型预测结果的路径,patch_prediction_extractor.py中的44行改成一样的路径
  352. predict_output = {b: [] for b in backends}
  353. model_idntfr = model_name[:-3]
  354. all_backends_predict_status = True
  355. for bk in backends:
  356. python_bin = f"{python_prefix}/{bk}/bin/python"
  357. predict_st = datetime.datetime.now()
  358. # 使用不同的库进行预测
  359. pre_status_bk = os.system(
  360. f"{python_bin} -u -m patch_prediction_extractor --backend {bk} "
  361. f"--exp {exp} --test_size {test_size} --model {model_path} "
  362. f"--config_name {flags.config_name}"
  363. )
  364. predict_et = datetime.datetime.now()
  365. predict_td = predict_et - predict_st
  366. h, m, s = utils.ToolUtils.get_HH_mm_ss(predict_td)
  367. mutate_logger.info(
  368. "Prediction Time Used on {} : {}h, {}m, {}s".format(bk, h, m, s)
  369. )
  370. # If no exception is thrown,save prediction result
  371. if pre_status_bk == 0: # 预测执行成功,保存结果
  372. # data = pickle.loads(redis_conn.hget("prediction_{}".format(model_name), bk))
  373. data = np.load(npy_path)
  374. predict_output[bk] = data
  375. # print(data)
  376. # record the crashed backend
  377. else:
  378. all_backends_predict_status = False
  379. mutate_logger.error(
  380. "{} crash on backend {} when predicting ".format(model_name, bk)
  381. )
  382. status = False
  383. accumulative_incons = None
  384. # run ok on all platforms
  385. if (
  386. all_backends_predict_status
  387. ): # 所有的库都执行成功且保存了结果,判断结果中是否有错误
  388. predictions = list(predict_output.values())
  389. res_dict, accumulative_incons = generate_metrics_result(
  390. res_dict=res_dict, predict_output=predict_output, model_idntfr=model_idntfr
  391. )
  392. # 计算ACC(用于衡量预测结果的不一致程度)
  393. # gini_res = generate_gini_result(predict_output=predict_output, backends=backends)
  394. # theta = generate_theta(predict_output=predict_output, backends=backends)
  395. # import csv
  396. # csvfile = open(r"D:\lemon_outputs\result\mobilenet.1.00.224-imagenet\tensorflow\5.csv", 'a+',newline='')
  397. # write=csv.writer(csvfile)
  398. # write.writerow([accumulative_incons, gini_res['tensorflow'], theta['tensorflow']])
  399. # csvfile.close()
  400. #
  401. # csvfile = open(r"D:\lemon_outputs\result\mobilenet.1.00.224-imagenet\mxnet\5.csv", 'a+',newline='')
  402. # write=csv.writer(csvfile)
  403. # write.writerow([accumulative_incons, gini_res['mxnet'], theta['mxnet']])
  404. # csvfile.close()
  405. # 计算gini
  406. # If all backends are working fine, check if there exists NAN or INF in the result
  407. # `accumulative_incons` is nan or inf --> NaN or INF in results
  408. if is_nan_or_inf(accumulative_incons):
  409. # has NaN on partial backends
  410. if partially_nan_or_inf(predictions, len(backends)):
  411. nan_model_path = os.path.join(nan_dir, f"{model_idntfr}_NaN_bug.h5")
  412. mutate_logger.error("Error: Found one NaN bug. move NAN model")
  413. # has NaN on all backends --> not a NaN bug
  414. else:
  415. nan_model_path = os.path.join(
  416. nan_dir, f"{model_idntfr}_NaN_on_all_backends.h5"
  417. )
  418. mutate_logger.error(
  419. "Error: Found one NaN Model on all libraries. move NAN model"
  420. )
  421. shutil.move(model_path, nan_model_path)
  422. else: # No NaN or INF on any backend
  423. print(model_path)
  424. for bk in backends:
  425. python_bin = f"{python_prefix}/{bk}/bin/python"
  426. os.system(
  427. f"{python_bin} -u -m model_to_txt --backend {bk} --model_path {model_path} --root_dir {root_dir}"
  428. )
  429. # if 'svhn' in model_name or 'fashion2' in model_name:
  430. # file_path = os.path.join(folder_path, model_path.split("\\")[-1][:-5] + '.json')
  431. # else:
  432. # file_path = os.path.join(folder_path, model_path.split("\\")[-1][:-3] + '.json')
  433. # union_json(file_path, all_json_path)
  434. # model_now = keras.models.load_model(model_path, custom_objects=custom_objects())
  435. # inner_div[model_num] = calc_inner_div(model_now)
  436. # with open(file_path, 'r') as sub_json:
  437. # sub_info = json.load(sub_json)
  438. # if len(set(sub_info['layer_type'])) > len(tar_set):
  439. # tar_set = set(sub_info['layer_type'])
  440. mutate_logger.info("Saving prediction")
  441. with open(
  442. "{}/prediction_{}.pkl".format(inner_output_dir, model_idntfr), "wb+"
  443. ) as f:
  444. pickle.dump(predict_output, file=f)
  445. status = True
  446. # save crashed model
  447. else:
  448. mutate_logger.error("Error: move crash model")
  449. crash_model_path = os.path.join(crash_dir, model_name)
  450. shutil.move(model_path, crash_model_path)
  451. return status, res_dict, accumulative_incons, predict_output
  452. if __name__ == "__main__":
  453. starttime = datetime.datetime.now()
  454. """
  455. Parser of command args.
  456. It could make mutate_lemon.py run independently without relying on mutation_executor.py
  457. """
  458. parse = argparse.ArgumentParser()
  459. parse.add_argument(
  460. "--is_mutate",
  461. type=ast.literal_eval,
  462. default=False,
  463. help="parameter to determine mutation option",
  464. )
  465. parse.add_argument(
  466. "--mutate_op",
  467. type=str,
  468. nargs="+",
  469. choices=[
  470. "WS",
  471. "GF",
  472. "NEB",
  473. "NAI",
  474. "NS",
  475. "ARem",
  476. "ARep",
  477. "LA",
  478. "LC",
  479. "LR",
  480. "LS",
  481. "MLA",
  482. ],
  483. help="parameter to determine mutation option",
  484. )
  485. parse.add_argument(
  486. "--model", type=str, help="relative path of model file(from root dir)"
  487. )
  488. parse.add_argument(
  489. "--output_dir", type=str, help="relative path of output dir(from root dir)"
  490. )
  491. parse.add_argument("--backends", type=str, nargs="+", help="list of backends")
  492. parse.add_argument(
  493. "--mutate_num",
  494. type=int,
  495. help="number of variant models generated by each mutation operator",
  496. )
  497. parse.add_argument("--mutate_ratio", type=float, help="ratio of mutation")
  498. parse.add_argument("--exp", type=str, help="experiments identifiers")
  499. parse.add_argument("--test_size", type=int, help="amount of testing image")
  500. parse.add_argument("--config_name", type=str, help="config name")
  501. flags, unparsed = parse.parse_known_args(sys.argv[1:])
  502. warnings.filterwarnings("ignore")
  503. lemon_cfg = configparser.ConfigParser()
  504. # lemon_cfg.read(f".\config\{flags.config_name}")
  505. cfg_path = os.path.join(os.path.dirname(os.getcwd()), "config", flags.config_name)
  506. lemon_cfg.read(cfg_path)
  507. # lemon_cfg.read(f"config/demo.conf")
  508. time_limit = lemon_cfg["parameters"].getint("time_limit")
  509. mutator_strategy = lemon_cfg["parameters"].get("mutator_strategy").upper()
  510. mutant_strategy = lemon_cfg["parameters"].get("mutant_strategy").upper()
  511. stop_mode = lemon_cfg["parameters"].get("stop_mode").upper()
  512. alpha = lemon_cfg["parameters"].getfloat("alpha")
  513. mutate_logger = Logger()
  514. # pool = redis.ConnectionPool(host=lemon_cfg['redis']['host'], port=lemon_cfg['redis']['port'],
  515. # db=lemon_cfg['redis'].getint('redis_db'))
  516. # redis_conn = redis.Redis(connection_pool=pool)
  517. # for k in redis_conn.keys():
  518. # if flags.exp in k.decode("utf-8"):
  519. # redis_conn.delete(k)
  520. # exp : like lenet5-mnist
  521. experiment_dir = os.path.join(flags.output_dir, flags.exp)
  522. mut_dir = os.path.join(experiment_dir, "mut_model")
  523. crash_dir = os.path.join(experiment_dir, "crash")
  524. nan_dir = os.path.join(experiment_dir, "nan")
  525. inner_output_dir = os.path.join(experiment_dir, "inner_output")
  526. metrics_result_dir = os.path.join(experiment_dir, "metrics_result")
  527. x, y = utils.DataUtils.get_data_by_exp(flags.exp) # 从conf文件中读取数据并转换形式
  528. x_test, y_test = x[: flags.test_size], y[: flags.test_size]
  529. pool_size = lemon_cfg["parameters"].getint("pool_size")
  530. python_prefix = lemon_cfg["parameters"]["python_prefix"].rstrip("\\")
  531. try: # 执行算法
  532. metrics_list = lemon_cfg["parameters"]["metrics"].split(" ") # D_MAD
  533. lemon_results = {k: dict() for k in metrics_list}
  534. lemon_results = _generate_and_predict(
  535. lemon_results,
  536. flags.model,
  537. flags.mutate_num,
  538. flags.mutate_op,
  539. flags.test_size,
  540. flags.exp,
  541. flags.backends,
  542. )
  543. with open(
  544. "{}/{}_lemon_results.pkl".format(experiment_dir, flags.exp), "wb+"
  545. ) as f:
  546. pickle.dump(lemon_results, file=f)
  547. utils.MetricsUtils.generate_result_by_metrics(
  548. metrics_list, lemon_results, metrics_result_dir, flags.exp
  549. )
  550. except Exception as e:
  551. mutate_logger.exception(sys.exc_info())
  552. from keras import backend as K
  553. K.clear_session()
  554. endtime = datetime.datetime.now()
  555. time_delta = endtime - starttime
  556. h, m, s = utils.ToolUtils.get_HH_mm_ss(time_delta)
  557. mutate_logger.info(
  558. "Mutation process is done: Time used: {} hour,{} min,{} sec".format(h, m, s)
  559. )