patch_hidden_output_extractor.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # -*-coding:UTF-8-*-
  2. """
  3. # Part of localization phase
  4. # get prediction for each backend
  5. """
  6. import sys
  7. sys.path.append("../")
  8. import os
  9. import pickle
  10. import argparse
  11. from scripts.tools.utils import DataUtils,ModelUtils
  12. from scripts.logger.lemon_logger import Logger
  13. import configparser
  14. import warnings
  15. import traceback
  16. import numpy as np
  17. #np.random.seed(20200501)
  18. warnings.filterwarnings("ignore")
  19. def _get_hidden_output(test_data,backend,select_model,model_dir,data_index):
  20. """
  21. layers_output: list of ndarray which store outputs in each layer
  22. The result stored in redis like:
  23. (lenet5-mnist_origin0_theano,layers_output)
  24. """
  25. if 'svhn' in select_model or 'fashion2' in select_model:
  26. model_pathname = os.path.join(model_dir, "{}.hdf5".format(select_model))
  27. else:
  28. model_pathname = os.path.join(model_dir, "{}.h5".format(select_model))
  29. model = keras.models.load_model(model_pathname,custom_objects=ModelUtils.custom_objects())
  30. model_idntfr_backend = "{}_{}_{}".format(select_model, backend, data_index)
  31. select_data = np.expand_dims(test_data[data_index], axis=0)
  32. layers_output = ModelUtils.layers_output(model, select_data)
  33. with open(os.path.join(localize_output_dir,model_idntfr_backend),"wb") as fw:
  34. pickle.dump(layers_output,fw)
  35. if __name__ == "__main__":
  36. """Parser of command args"""
  37. parse = argparse.ArgumentParser()
  38. parse.add_argument("--backend", type=str, help="name of backends")
  39. parse.add_argument("--exp", type=str, help="experiments identifiers")
  40. parse.add_argument("--output_dir", type=str, help="relative path of output dir(from root dir)")
  41. parse.add_argument("--data_index", type=int, help="redis db port")
  42. parse.add_argument("--config_name", type=str, help="config name")
  43. parse.add_argument("--model_idntfr", type=str, help="redis db port")
  44. flags, unparsed = parse.parse_known_args(sys.argv[1:])
  45. mylogger = Logger()
  46. """Load Configuration"""
  47. warnings.filterwarnings("ignore")
  48. lemon_cfg = configparser.ConfigParser()
  49. # lemon_cfg.read(f"./config/{flags.config_name}")
  50. grandparent_directory = os.path.dirname(os.getcwd())
  51. conf_path = grandparent_directory + "/config/demo.conf"
  52. lemon_cfg.read(conf_path)
  53. parameters = lemon_cfg['parameters']
  54. gpu_ids = parameters['gpu_ids']
  55. gpu_list = parameters['gpu_ids'].split(",")
  56. """Init cuda"""
  57. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  58. os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids
  59. warnings.filterwarnings("ignore")
  60. batch_size = 64
  61. """Switch backend"""
  62. bk_list = ['tensorflow', 'theano', 'cntk', 'mxnet']
  63. bk = flags.backend
  64. print('.........................',type(bk))
  65. os.environ['KERAS_BACKEND'] = bk
  66. os.environ['PYTHONHASHSEED'] = '0'
  67. if bk == 'tensorflow':
  68. os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2' # 只显示 warning 和 Error
  69. import tensorflow as tf
  70. mylogger.info(tf.__version__)
  71. if bk == 'theano':
  72. if len(gpu_list) == 2:
  73. os.environ[
  74. 'THEANO_FLAGS'] = f"device=cuda,contexts=dev{gpu_list[0]}->cuda{gpu_list[0]};dev{gpu_list[1]}->cuda{gpu_list[1]}," \
  75. f"force_device=True,floatX=float32,lib.cnmem=1"
  76. else:
  77. os.environ['THEANO_FLAGS'] = f"device=cuda,contexts=dev{gpu_list[0]}->cuda{gpu_list[0]}," \
  78. f"force_device=True,floatX=float32,lib.cnmem=1"
  79. batch_size = 32
  80. import theano as th
  81. mylogger.info(th.__version__)
  82. if bk == "cntk":
  83. batch_size = 32
  84. from cntk.device import try_set_default_device, gpu
  85. try_set_default_device(gpu(int(gpu_list[0])))
  86. import cntk as ck
  87. mylogger.info(ck.__version__)
  88. if bk == "mxnet":
  89. batch_size = 32
  90. import mxnet as mxnet
  91. mylogger.info(mxnet.__version__)
  92. from keras import backend as K
  93. import keras
  94. mylogger.logger.info("Using {} as backend for states extraction| {} is wanted".format(K.backend(),bk))
  95. """Get model hidden output on selected_index data on specific backend"""
  96. try:
  97. backend_input_dict = {}
  98. localize_output_dir = os.path.join(flags.output_dir,flags.exp,"localize_tmp")
  99. x, y = DataUtils.get_data_by_exp(flags.exp)
  100. mut_dir = os.path.join(flags.output_dir,flags.exp,"mut_model")
  101. _get_hidden_output(test_data=x, backend=bk,select_model=flags.model_idntfr,model_dir=mut_dir,data_index=flags.data_index)
  102. mylogger.logger.info("Hidden output extracting done!")
  103. except:
  104. traceback.print_exc()
  105. sys.exit(-1)