app.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. from flask import Flask, request,send_file
  2. from tool import runtool
  3. import yaml
  4. import os
  5. import json
  6. import threading
  7. app = Flask(__name__)
  8. with open('config.yml', 'r') as file:
  9. config = yaml.load(file, Loader=yaml.FullLoader)
  10. # env = os.getenv('FLASK_ENV', 'development')
  11. app.config.update(config['default'])
  12. # if env in config:
  13. # app.config.update(config[env])
  14. base_dir = os.path.dirname(os.path.abspath(__file__))
  15. task_output_dir = os.path.join(base_dir, "output")
  16. def get_task_list():
  17. '''
  18. 获取任务列表,
  19. 任务是task_output_dir下的第一层文件夹
  20. 同时检查是否超过了最大任务数,如果超过了,删除最早的任务直到任务数小于最大任务数
  21. '''
  22. task_list = []
  23. task_list = os.listdir(task_output_dir)
  24. # 文件夹的名称是按照task_<task_num>_...来的, 按照task_num排序
  25. task_list_dict = {}
  26. for task_dir in task_list:
  27. task_num = int(task_dir.split('_')[1])
  28. task_list_dict[task_num] = task_dir
  29. task_num_list = list(task_list_dict.keys())
  30. task_num_list.sort()
  31. return_list = []
  32. for task_num in task_num_list:
  33. return_list.append(task_list_dict[task_num])
  34. return return_list
  35. def get_max_task_id():
  36. task_list = []
  37. task_list = os.listdir(task_output_dir)
  38. # 文件夹的名称是按照task_<task_id>_...来的,取出最大的task_id
  39. task_id_to_task_dir = {}
  40. task_id_list = []
  41. for task_dir in task_list:
  42. task_id = int(task_dir.split('_')[1])
  43. task_id_list.append(task_id)
  44. task_id_to_task_dir[task_id] = task_dir
  45. max_task_id = max(task_id_list)
  46. # 检查是否超过了最大任务数,删除task_num最小的任务直到任务数小于最大任务数
  47. print(len(task_list))
  48. print(app.config['max_task_num'])
  49. if len(task_list) > app.config['max_task_num']:
  50. print("超过最大任务数,删除最早的任务直到任务数小于最大任务数")
  51. task_id_list.sort()
  52. over_num = len(task_list) - app.config['max_task_num']
  53. for i in range(over_num):
  54. task_id = task_id_list[i]
  55. task_dir = task_id_to_task_dir[task_id]
  56. task_dir_path = os.path.join(task_output_dir, task_dir)
  57. print("删除任务:", task_dir_path)
  58. os.system("rm -rf " + task_dir_path)
  59. return max_task_id
  60. def get_result(task_id:str):
  61. result_dir = os.path.join(task_output_dir, task_id)
  62. # 查看result_dir的下一级文件夹
  63. result_model_name = os.listdir(result_dir)[0]
  64. result_model_dir = os.path.join(result_dir,result_model_name)
  65. print(result_model_dir)
  66. mxnet_json = os.path.join(result_model_dir, "mxnet.json")
  67. with open(
  68. mxnet_json, "r"
  69. ) as file1: # tensorflow.json
  70. data1 = json.load(file1)
  71. tensorflow_json = os.path.join(result_model_dir, "tensorflow.json")
  72. with open(tensorflow_json, "r") as file2:
  73. data2 = json.load(file2)
  74. img_root_relative_path = task_id + '/' + result_model_name + '/'
  75. combined_data = {
  76. "mxnet": data1,
  77. "tensorflow": data2,
  78. "img_path":{
  79. "mxnet_train":img_root_relative_path + 'mxnet_train.jpg',
  80. "tensorflow_train": img_root_relative_path + 'tensorflow_train.jpg',
  81. "accuracy": img_root_relative_path + 'accuracy.jpg',
  82. "losses": img_root_relative_path + 'losses.jpg',
  83. "memory": img_root_relative_path + 'memory.jpg'
  84. }
  85. }
  86. return combined_data
  87. @app.route('/')
  88. def hello_world(): # put application's code here
  89. return 'Hello World!'
  90. @app.route('/models', methods=['GET'])
  91. def getModels():
  92. models = ['lexnet-cifar10', 'lenet5-fashion-mnist', 'fashion2', 'svhn', 'lenet5-mnist',
  93. 'alexnet-cifar10', 'mobilenet.1.00.224-imagenet', 'vgg16-imagenet']
  94. return models
  95. import time
  96. @app.route('/run', methods=['POST'])
  97. def run():
  98. data = request.get_json()
  99. exp = data['exp']
  100. mutate_num = data['mutate_num']
  101. #创建task_id: task_ task_num exp mutate_num
  102. # task_id = 'task_' + time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()) + '_' + str(len(get_task_list())+1)
  103. task_id = 'task_' + str(get_max_task_id()+1) + '_' + exp + '_' + str(mutate_num)
  104. # 开启一个线程执行,直接返回请求
  105. # data = runtool(exp, mutate_num,task_id)
  106. t = threading.Thread(target=runtool, args=(exp, mutate_num,task_id))
  107. t.start()
  108. return {'task_id': task_id}
  109. @app.route('/get_task_list', methods=['GET'])
  110. def get_task_list_api():
  111. return {'task_list': get_task_list()}
  112. @app.route('/get_task_result', methods=['POST'])
  113. def get_task_result():
  114. data = request.get_json()
  115. task_id = data['task_id']
  116. try:
  117. res = get_result(task_id)
  118. return res
  119. except Exception as e:
  120. return {'error:':'错误的task_id或任务未完成'}
  121. @app.route('/image/<path:filename>')
  122. def serve_image(filename):
  123. # 指定图片所在的根目录
  124. # 构建完整的图片路径
  125. image_path = os.path.join(task_output_dir, filename)
  126. # 使用send_file()函数返回图片
  127. return send_file(image_path, mimetype='image/jpeg')
  128. if __name__ == '__main__':
  129. # app.run()
  130. from waitress import serve
  131. serve(app, host="0.0.0.0", port=5000)