localize_lemon.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # -*-coding:UTF-8-*-
  2. """
  3. # Part of localization phase
  4. """
  5. import argparse
  6. import sys
  7. import os
  8. import pickle
  9. import configparser
  10. from scripts.tools.utils import ModelUtils
  11. import keras
  12. from keras.engine.input_layer import InputLayer
  13. import warnings
  14. import datetime
  15. from scripts.logger.lemon_logger import Logger
  16. import shutil
  17. from itertools import combinations
  18. import keras.backend as K
  19. warnings.filterwarnings("ignore")
  20. def is_lstm_not_exists(exp_id,output_id):
  21. if exp_id in ['lstm0-sinewave','lstm2-price'] and output_id in ['experiment4','experiment5']:
  22. return True
  23. else:
  24. return False
  25. def get_HH_mm_ss(td):
  26. days,seconds = td.days,td.seconds
  27. hours = days * 24 + seconds // 3600
  28. minutes = (seconds % 3600)//60
  29. secs = seconds % 60
  30. return hours,minutes,secs
  31. def generate_report(localize_res,savepath):
  32. with open(savepath,"w+") as fw:
  33. for localize_header, value in localize_res.items():
  34. fw.write("current_layer, delta,Rl,previous_layer\n".format(localize_header))
  35. for layer_res in value:
  36. fw.write("{},{},{},{}\n".format(layer_res[0],layer_res[1],layer_res[2],layer_res[3]))
  37. def localize(mut_model_dir,select_idntfr, exp_name,localize_tmp_dir,report_dir,backends):
  38. """
  39. select_idntfrs: lenet5-mnist_origin0_input17
  40. """
  41. # get layer_output for all models coming from specific exp on all backends
  42. identifier_split = select_idntfr.split("_")
  43. data_index = int(identifier_split[-1][5:])
  44. model_idntfr = "{}_{}".format(identifier_split[0], identifier_split[1])
  45. if 'svhn' in model_idntfr or 'fashion2' in model_idntfr:
  46. model_path = "{}/{}.hdf5".format(mut_model_dir, model_idntfr)
  47. else:
  48. model_path = "{}/{}.h5".format(mut_model_dir, model_idntfr)
  49. #
  50. # # check if indntfr hasn't been localized
  51. # for bk1, bk2 in combinations(backends, 2):
  52. # report_path = os.path.join(report_dir, "{}_{}_{}_input{}.csv".format(model_idntfr, bk1, bk2, data_index))
  53. # # not exists; continue fo localize
  54. # if not os.path.exists(report_path):
  55. # break
  56. # # all file exist; return
  57. # else:
  58. # mylogger.logger.info(f"{select_idntfr} has been localized")
  59. # return
  60. for bk in backends:
  61. python_bin = f"{python_prefix}\{bk}\python"
  62. return_stats = os.system(
  63. f"{python_bin} -u -m run.patch_hidden_output_extractor --backend {bk} --output_dir {output_dir} --exp {exp_name}"
  64. f" --model_idntfr {model_idntfr} --data_index {data_index} --config_name {config_name}")
  65. # assert return_stats==0,"Getting hidden output failed!"
  66. if return_stats != 0:
  67. mylogger.logger.info("Getting hidden output failed!")
  68. failed_list.append(select_idntfr)
  69. return
  70. mylogger.logger.info("Getting localization for {}".format(select_idntfr))
  71. model = keras.models.load_model(model_path, custom_objects=ModelUtils.custom_objects())
  72. for bk1, bk2 in combinations(backends, 2):
  73. local_res = {}
  74. local_res = get_outputs_divation_onbackends(model=model, backends=[bk1, bk2],
  75. model_idntfr=model_idntfr, local_res=local_res,
  76. data_index=data_index, localize_tmp_dir=localize_tmp_dir)
  77. mylogger.logger.info("Generating localization report for {} on {}-{}!".format(model_idntfr,bk1,bk2))
  78. report_path = os.path.join(report_dir, "{}_{}_{}_input{}.csv".format(model_idntfr,bk1,bk2, data_index))
  79. generate_report(local_res, report_path)
  80. del model
  81. K.clear_session()
  82. def get_outputs_divation_onbackends(model,backends,model_idntfr,local_res,data_index,localize_tmp_dir):
  83. backend1 = backends[0]
  84. backend2 = backends[1]
  85. with open(os.path.join(localize_tmp_dir, "{}_{}_{}".format(model_idntfr, backend1,data_index)), "rb") as fr:
  86. model_layers_outputs_1 = pickle.load(fr)
  87. with open(os.path.join(localize_tmp_dir, "{}_{}_{}".format(model_idntfr, backend2,data_index)), "rb") as fr:
  88. model_layers_outputs_2 = pickle.load(fr)
  89. divations = ModelUtils.layers_divation(model, model_layers_outputs_1, model_layers_outputs_2)
  90. compare_res = []
  91. for i, layer in enumerate(model.layers):
  92. if isinstance(layer, InputLayer):
  93. continue
  94. delta, divation, inputlayers = divations[i]
  95. layer_compare_res = [layer.name, delta[0], divation[0],",".join(inputlayers)] # batch accepted default
  96. compare_res.append(layer_compare_res)
  97. identifier = "{}_{}_{}_input_{}".format(model_idntfr,backend1,backend2,data_index)
  98. idntfr_localize = "{}_localize".format(identifier)
  99. local_res[idntfr_localize] = compare_res
  100. return local_res
  101. if __name__ == "__main__":
  102. starttime = datetime.datetime.now()
  103. # get id of experiments
  104. config_name = sys.argv[1]
  105. lemon_cfg = configparser.ConfigParser()
  106. lemon_cfg.read(f"./config/{config_name}")
  107. parameters = lemon_cfg['parameters']
  108. output_dir = parameters['output_dir']
  109. output_dir = output_dir[:-1] if output_dir.endswith("/") else output_dir
  110. current_container = os.path.split(output_dir)[-1]
  111. python_prefix = parameters['python_prefix'].rstrip("/")
  112. """Initialization"""
  113. mylogger = Logger()
  114. backend_choices = [1,2,3]
  115. exps = parameters['exps'].lstrip().rstrip().split(" ")
  116. exps.sort(key=lambda x: x)
  117. all_model_inputs = {e:set() for e in exps}
  118. items_lists = list()
  119. for backend_choice in backend_choices:
  120. if backend_choice == 1:
  121. pre_backends = ['tensorflow', 'theano', 'cntk']
  122. elif backend_choice == 2:
  123. pre_backends = ['tensorflow', 'theano', 'mxnet']
  124. else:
  125. pre_backends = ['tensorflow', 'cntk', 'mxnet']
  126. backends_str = "-".join(pre_backends)
  127. backend_pairs = [f"{pair[0]}_{pair[1]}" for pair in combinations(pre_backends, 2)]
  128. with open(os.path.join(output_dir, f"localize_model_inputs-{backends_str}.pkl"), "rb") as fr:
  129. localize_model_inputs = pickle.load(fr)
  130. for exp_id,model_set in localize_model_inputs.items():
  131. if exp_id in exps:
  132. for mi in model_set:
  133. all_model_inputs[exp_id].add(mi)
  134. for exp,mi_set in all_model_inputs.items():
  135. print(exp,len(mi_set))
  136. failed_list = []
  137. """Print result of inconsistency distribution"""
  138. for exp_idntfr,model_inputs_set in all_model_inputs.items():
  139. if len(model_inputs_set) > 0:
  140. if exp_idntfr == 'inception.v3-imagenet' or exp_idntfr == 'densenet121-imagenet' or is_lstm_not_exists(exp_idntfr,current_container):
  141. # inception and densenet can't run on mxnet.
  142. # lstm can't run on mxnet before mxnet version 1.3.x
  143. backends = ['tensorflow', 'theano', 'cntk']
  144. else:
  145. backends = ['tensorflow', 'theano', 'cntk','mxnet']
  146. print("Localize for {} : {} left.".format(exp_idntfr,len(model_inputs_set)))
  147. mut_dir = os.path.join(output_dir,exp_idntfr, "mut_model")
  148. localization_dir = os.path.join(output_dir,exp_idntfr, "localization_result")
  149. localize_output_dir = os.path.join(output_dir,exp_idntfr, "localize_tmp")
  150. """make dir for hidden_output and localization dir """
  151. if not os.path.exists(localize_output_dir):
  152. os.makedirs(localize_output_dir)
  153. if not os.path.exists(localization_dir):
  154. os.makedirs(localization_dir)
  155. """Localization"""
  156. for idx,select_identifier in enumerate(model_inputs_set):
  157. print("{} of {} {}".format(idx,len(model_inputs_set),select_identifier))
  158. localize(mut_model_dir=mut_dir,select_idntfr=select_identifier,exp_name=exp_idntfr,
  159. localize_tmp_dir=localize_output_dir,report_dir=localization_dir
  160. ,backends=backends)
  161. shutil.rmtree(localize_output_dir)
  162. with open(os.path.join(output_dir, f"failed_idntfrs.txt"), "w") as fw:
  163. if len(failed_list) > 0:
  164. mylogger.logger.warning(f"{len(failed_list)} idntfrs fail to localize")
  165. lists = [f"{line} \n" for line in failed_list]
  166. fw.writelines(lists)
  167. else:
  168. mylogger.logger.info("all idntfrs localize successfully")
  169. endtime = datetime.datetime.now()
  170. time_delta = endtime - starttime
  171. h,m,s = get_HH_mm_ss(time_delta)
  172. mylogger.logger.info("Localization precess is done: Time used: {} hour,{} min,{} sec".format(h,m,s))