from flask import Flask, request from tool import runtool import yaml import os import json import threading app = Flask(__name__) with open('config.yml', 'r') as file: config = yaml.load(file, Loader=yaml.FullLoader) # env = os.getenv('FLASK_ENV', 'development') app.config.update(config['default']) # if env in config: # app.config.update(config[env]) base_dir = os.path.dirname(os.path.abspath(__file__)) task_output_dir = os.path.join(base_dir, "output") def get_task_list(): ''' 获取任务列表, 任务是task_output_dir下的第一层文件夹 同时检查是否超过了最大任务数,如果超过了,删除最早的任务直到任务数小于最大任务数 ''' task_list = [] task_list = os.listdir(task_output_dir) # 按照时间排序 task_list.sort() # 检查是否超过了最大任务数 max_task_num = app.config['max_task_num'] if len(task_list) > max_task_num: # 删除最早的任务 task_list.sort() for i in range(len(task_list) - max_task_num): task_dir = os.path.join(task_output_dir, task_list[i]) os.system("rm -rf " + task_dir) task_list = os.listdir(task_output_dir) task_list.sort() return task_list def get_result(task_id:str): result_dir = os.path.join(task_output_dir, task_id) # 查看result_dir的下一级文件夹 result_model_dir = os.path.join(result_dir,os.listdir(result_dir)[0]) print(result_model_dir) mxnet_json = os.path.join(result_model_dir, "mxnet.json") with open( mxnet_json, "r" ) as file1: # tensorflow.json data1 = json.load(file1) tensorflow_json = os.path.join(result_model_dir, "tensorflow.json") with open(tensorflow_json, "r") as file2: data2 = json.load(file2) combined_data = { "mxnet": data1, "tensorflow": data2 } return combined_data @app.route('/') def hello_world(): # put application's code here return 'Hello World!' @app.route('/models', methods=['GET']) def getModels(): models = ['lexnet-cifar10', 'lenet5-fashion-mnist', 'fashion2', 'svhn', 'lenet5-mnist', 'alexnet-cifar10', 'mobilenet.1.00.224-imagenet', 'vgg16-imagenet'] return models import time @app.route('/run', methods=['POST']) def run(): data = request.get_json() exp = data['exp'] mutate_num = data['mutate_num'] #创建task_id: task_2021-09-07T14:57:00_1 task_id = 'task_' + time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()) + '_' + str(len(get_task_list())+1) # 开启一个线程执行,直接返回请求 # data = runtool(exp, mutate_num,task_id) t = threading.Thread(target=runtool, args=(exp, mutate_num,task_id)) t.start() return {'task_id': task_id} @app.route('/get_task_list', methods=['GET']) def get_task_list_api(): return {'task_list': get_task_list()} @app.route('/get_task_result', methods=['POST']) def get_task_result(): data = request.get_json() task_id = data['task_id'] try: res = get_result(task_id) return res except Exception as e: return {'error:':'错误的task_id或任务未完成'} if __name__ == '__main__': # app.run() from waitress import serve serve(app, host="0.0.0.0", port=5000)