app.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from flask import Flask, request
  2. from tool import runtool
  3. import yaml
  4. import os
  5. import json
  6. import threading
  7. app = Flask(__name__)
  8. with open('config.yml', 'r') as file:
  9. config = yaml.load(file, Loader=yaml.FullLoader)
  10. # env = os.getenv('FLASK_ENV', 'development')
  11. app.config.update(config['default'])
  12. # if env in config:
  13. # app.config.update(config[env])
  14. base_dir = os.path.dirname(os.path.abspath(__file__))
  15. task_output_dir = os.path.join(base_dir, "output")
  16. def get_task_list():
  17. '''
  18. 获取任务列表,
  19. 任务是task_output_dir下的第一层文件夹
  20. 同时检查是否超过了最大任务数,如果超过了,删除最早的任务直到任务数小于最大任务数
  21. '''
  22. task_list = []
  23. task_list = os.listdir(task_output_dir)
  24. # 按照时间排序
  25. task_list.sort()
  26. # 检查是否超过了最大任务数
  27. max_task_num = app.config['max_task_num']
  28. if len(task_list) > max_task_num:
  29. # 删除最早的任务
  30. task_list.sort()
  31. for i in range(len(task_list) - max_task_num):
  32. task_dir = os.path.join(task_output_dir, task_list[i])
  33. os.system("rm -rf " + task_dir)
  34. task_list = os.listdir(task_output_dir)
  35. task_list.sort()
  36. return task_list
  37. def get_result(task_id:str):
  38. result_dir = os.path.join(task_output_dir, task_id)
  39. # 查看result_dir的下一级文件夹
  40. result_model_dir = os.path.join(result_dir,os.listdir(result_dir)[0])
  41. print(result_model_dir)
  42. mxnet_json = os.path.join(result_model_dir, "mxnet.json")
  43. with open(
  44. mxnet_json, "r"
  45. ) as file1: # tensorflow.json
  46. data1 = json.load(file1)
  47. tensorflow_json = os.path.join(result_model_dir, "tensorflow.json")
  48. with open(tensorflow_json, "r") as file2:
  49. data2 = json.load(file2)
  50. combined_data = {
  51. "mxnet": data1,
  52. "tensorflow": data2
  53. }
  54. return combined_data
  55. @app.route('/')
  56. def hello_world(): # put application's code here
  57. return 'Hello World!'
  58. @app.route('/models', methods=['GET'])
  59. def getModels():
  60. models = ['lexnet-cifar10', 'lenet5-fashion-mnist', 'fashion2', 'svhn', 'lenet5-mnist',
  61. 'alexnet-cifar10', 'mobilenet.1.00.224-imagenet', 'vgg16-imagenet']
  62. return models
  63. import time
  64. @app.route('/run', methods=['POST'])
  65. def run():
  66. data = request.get_json()
  67. exp = data['exp']
  68. mutate_num = data['mutate_num']
  69. #创建task_id: task_2021-09-07T14:57:00_1
  70. task_id = 'task_' + time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()) + '_' + str(len(get_task_list())+1)
  71. # 开启一个线程执行,直接返回请求
  72. # data = runtool(exp, mutate_num,task_id)
  73. t = threading.Thread(target=runtool, args=(exp, mutate_num,task_id))
  74. t.start()
  75. return {'task_id': task_id}
  76. @app.route('/get_task_list', methods=['GET'])
  77. def get_task_list_api():
  78. return {'task_list': get_task_list()}
  79. @app.route('/get_task_result', methods=['POST'])
  80. def get_task_result():
  81. data = request.get_json()
  82. task_id = data['task_id']
  83. try:
  84. res = get_result(task_id)
  85. return res
  86. except Exception as e:
  87. return {'error:':'错误的task_id或任务未完成'}
  88. if __name__ == '__main__':
  89. # app.run()
  90. from waitress import serve
  91. serve(app, host="0.0.0.0", port=5000)