app.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from flask import Flask, request,send_file
  2. from tool import runtool
  3. import yaml
  4. import os
  5. import json
  6. import threading
  7. app = Flask(__name__)
  8. with open('config.yml', 'r') as file:
  9. config = yaml.load(file, Loader=yaml.FullLoader)
  10. # env = os.getenv('FLASK_ENV', 'development')
  11. app.config.update(config['default'])
  12. # if env in config:
  13. # app.config.update(config[env])
  14. base_dir = os.path.dirname(os.path.abspath(__file__))
  15. task_output_dir = os.path.join(base_dir, "output")
  16. def get_task_list():
  17. '''
  18. 获取任务列表,
  19. 任务是task_output_dir下的第一层文件夹
  20. 同时检查是否超过了最大任务数,如果超过了,删除最早的任务直到任务数小于最大任务数
  21. '''
  22. task_list = []
  23. task_list = os.listdir(task_output_dir)
  24. # 文件夹的名称是按照task_<task_num>_...来的, 按照task_num排序
  25. task_list_dict = {}
  26. for task_dir in task_list:
  27. task_num = int(task_dir.split('_')[1])
  28. task_list_dict[task_num] = task_dir
  29. task_num_list = list(task_list_dict.keys())
  30. task_num_list.sort()
  31. return_list = []
  32. for task_num in task_num_list:
  33. return_list.append(task_list_dict[task_num])
  34. return return_list
  35. def get_max_task_id():
  36. task_list = []
  37. task_list = os.listdir(task_output_dir)
  38. if len(task_list) == 0:
  39. return 0
  40. # 文件夹的名称是按照task_<task_id>_...来的,取出最大的task_id
  41. task_id_to_task_dir = {}
  42. task_id_list = []
  43. for task_dir in task_list:
  44. task_id = int(task_dir.split('_')[1])
  45. task_id_list.append(task_id)
  46. task_id_to_task_dir[task_id] = task_dir
  47. max_task_id = max(task_id_list)
  48. # 检查是否超过了最大任务数,删除task_num最小的任务直到任务数小于最大任务数
  49. print(len(task_list))
  50. print(app.config['max_task_num'])
  51. if len(task_list) > app.config['max_task_num']:
  52. print("超过最大任务数,删除最早的任务直到任务数小于最大任务数")
  53. task_id_list.sort()
  54. over_num = len(task_list) - app.config['max_task_num']
  55. for i in range(over_num):
  56. task_id = task_id_list[i]
  57. task_dir = task_id_to_task_dir[task_id]
  58. task_dir_path = os.path.join(task_output_dir, task_dir)
  59. print("删除任务:", task_dir_path)
  60. os.system("rm -rf " + task_dir_path)
  61. return max_task_id
  62. def get_result(task_id:str):
  63. result_dir = os.path.join(task_output_dir, task_id)
  64. # 查看result_dir的下一级文件夹
  65. result_model_name = os.listdir(result_dir)[0]
  66. result_model_dir = os.path.join(result_dir,result_model_name)
  67. print(result_model_dir)
  68. mxnet_json = os.path.join(result_model_dir, "mxnet.json")
  69. with open(
  70. mxnet_json, "r"
  71. ) as file1: # tensorflow.json
  72. data1 = json.load(file1)
  73. tensorflow_json = os.path.join(result_model_dir, "tensorflow.json")
  74. with open(tensorflow_json, "r") as file2:
  75. data2 = json.load(file2)
  76. img_root_relative_path = "/model_accuracy_api/image/" + task_id + '/' + result_model_name + '/'
  77. combined_data = {
  78. "mxnet": data1,
  79. "tensorflow": data2,
  80. "img_path":{
  81. "mxnet_train":img_root_relative_path + 'mxnet_train.jpg',
  82. "tensorflow_train": img_root_relative_path + 'tensorflow_train.jpg',
  83. "accuracy": img_root_relative_path + 'accuracy.jpg',
  84. "losses": img_root_relative_path + 'losses.jpg',
  85. "memory": img_root_relative_path + 'memory.jpg'
  86. }
  87. }
  88. return combined_data
  89. @app.route('/')
  90. def hello_world(): # put application's code here
  91. return 'Hello World!'
  92. @app.route('/models', methods=['GET'])
  93. def getModels():
  94. models = ['lexnet-cifar10', 'lenet5-fashion-mnist', 'fashion2', 'svhn', 'lenet5-mnist',
  95. 'alexnet-cifar10', 'mobilenet.1.00.224-imagenet', 'vgg16-imagenet']
  96. return models
  97. import time
  98. @app.route('/run', methods=['POST'])
  99. def run():
  100. data = request.get_json()
  101. exp = data['exp']
  102. mutate_num = data['mutate_num']
  103. #创建task_id: task_ task_num exp mutate_num
  104. # task_id = 'task_' + time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()) + '_' + str(len(get_task_list())+1)
  105. task_id = 'task_' + str(get_max_task_id()+1) + '_' + exp + '_' + str(mutate_num)
  106. # 开启一个线程执行,直接返回请求
  107. # data = runtool(exp, mutate_num,task_id)
  108. t = threading.Thread(target=runtool, args=(exp, mutate_num,task_id))
  109. t.start()
  110. return {'task_id': task_id}
  111. @app.route('/get_task_list', methods=['GET'])
  112. def get_task_list_api():
  113. return {'task_list': get_task_list()}
  114. @app.route('/get_task_result', methods=['POST'])
  115. def get_task_result():
  116. data = request.get_json()
  117. task_id = data['task_id']
  118. try:
  119. res = get_result(task_id)
  120. return res
  121. except Exception as e:
  122. return {'error:':'错误的task_id或任务未完成'}
  123. @app.route('/image/<path:filename>')
  124. def serve_image(filename):
  125. # 指定图片所在的根目录
  126. # 构建完整的图片路径
  127. image_path = os.path.join(task_output_dir, filename)
  128. # 使用send_file()函数返回图片
  129. return send_file(image_path, mimetype='image/jpeg')
  130. if __name__ == '__main__':
  131. # app.run()
  132. from waitress import serve
  133. serve(app, host="0.0.0.0", port=5000)