model_mutation_generators.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import sys
  2. sys.path.append("../")
  3. from scripts.mutation.model_mutation_operators import *
  4. import scripts.mutation.utils as utils
  5. import scripts.tools.utils as utils_tools
  6. import argparse
  7. warnings.filterwarnings("ignore")
  8. os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # 只显示 warning 和 Error
  9. # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  10. # os.environ["CUDA_VISIBLE_DEVICES"] = ""
  11. mylogger = Logger()
  12. import tensorflow as tf
  13. import psutil
  14. def generate_model_by_model_mutation(model, operator, mutate_ratio=0.3):
  15. """
  16. Generate models using specific mutate operator
  17. :param model: model loaded by keras (tensorflow backend default)
  18. :param operator: mutation operator
  19. :param mutate_ratio: ratio of selected neurons
  20. :return: mutation model object
  21. """
  22. if operator == "WS":
  23. mutate_indices = utils.ModelUtils.weighted_layer_indices(model)
  24. mylogger.info("Generating model using {}".format(operator))
  25. return WS_mut(
  26. model=model,
  27. mutation_ratio=mutate_ratio,
  28. mutated_layer_indices=mutate_indices,
  29. )
  30. elif operator == "GF":
  31. mylogger.info("Generating model using {}".format(operator))
  32. return GF_mut(model=model, mutation_ratio=mutate_ratio)
  33. elif operator == "NEB":
  34. mylogger.info("Generating model using {}".format(operator))
  35. return NEB_mut(model=model, mutation_ratio=mutate_ratio)
  36. elif operator == "NAI":
  37. mylogger.info("Generating model using {}".format(operator))
  38. return NAI_mut(model=model, mutation_ratio=mutate_ratio)
  39. elif operator == "NS":
  40. mylogger.info("Generating model using {}".format(operator))
  41. return NS_mut(model=model)
  42. elif operator == "ARem":
  43. mylogger.info("Generating model using {}".format(operator))
  44. return ARem_mut(model=model)
  45. elif operator == "ARep":
  46. mylogger.info("Generating model using {}".format(operator))
  47. return ARep_mut(model=model)
  48. elif operator == "LA":
  49. mylogger.info("Generating model using {}".format(operator))
  50. return LA_mut(model=model)
  51. elif operator == "LC":
  52. mylogger.info("Generating model using {}".format(operator))
  53. return LC_mut(model=model)
  54. elif operator == "LR":
  55. mylogger.info("Generating model using {}".format(operator))
  56. return LR_mut(model=model)
  57. elif operator == "LS":
  58. mylogger.info("Generating model using {}".format(operator))
  59. return LS_mut(model=model)
  60. elif operator == "MLA":
  61. mylogger.info("Generating model using {}".format(operator))
  62. return MLA_mut(model=model)
  63. else:
  64. mylogger.info("No such Mutation operator {}".format(operator))
  65. return None
  66. def all_mutate_ops():
  67. return [
  68. "WS",
  69. "GF",
  70. "NEB",
  71. "NAI",
  72. "NS",
  73. "ARem",
  74. "ARep",
  75. "LA",
  76. "LC",
  77. "LR",
  78. "LS",
  79. "MLA",
  80. ]
  81. if __name__ == "__main__":
  82. """Parser of command args"""
  83. parse = argparse.ArgumentParser() # 创建解析器
  84. parse.add_argument("--model", type=str, help="model path")
  85. parse.add_argument("--mutate_op", type=str, help="model mutation operator")
  86. parse.add_argument("--save_path", type=str, help="model save path")
  87. parse.add_argument("--mutate_ratio", type=float, help="mutate ratio")
  88. # 添加参数
  89. flags, unparsed = parse.parse_known_args(
  90. sys.argv[1:]
  91. ) # 解析参数,flags接收conf文件里的参数
  92. # model = r"D:\LEMON\LEMON-master\origin_model\lenet5-mnist_origin.h5"
  93. # mutate_op = r'NS' #变异算子
  94. # save_path = r"/lemon_outputs"
  95. # mutate_ratio = 0.3
  96. import keras
  97. model_path = flags.model
  98. mutate_ratio = flags.mutate_ratio
  99. print("Current {}; Mutate ratio {}".format(flags.mutate_op, mutate_ratio))
  100. origin_model = keras.models.load_model(
  101. model_path, custom_objects=utils.ModelUtils.custom_objects()
  102. )
  103. mutated_model = generate_model_by_model_mutation(
  104. model=origin_model, operator=flags.mutate_op, mutate_ratio=mutate_ratio
  105. )
  106. if mutated_model is None:
  107. raise Exception("Error: Model mutation using {} failed".format(flags.mutate_op))
  108. else:
  109. mutated_model.save(flags.save_path)