123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- from flask import Flask, request
- from tool import runtool
- import yaml
- import os
- import json
- import threading
- app = Flask(__name__)
- with open('config.yml', 'r') as file:
- config = yaml.load(file, Loader=yaml.FullLoader)
- # env = os.getenv('FLASK_ENV', 'development')
- app.config.update(config['default'])
- # if env in config:
- # app.config.update(config[env])
- base_dir = os.path.dirname(os.path.abspath(__file__))
- task_output_dir = os.path.join(base_dir, "output")
- def get_task_list():
- '''
- 获取任务列表,
- 任务是task_output_dir下的第一层文件夹
-
- 同时检查是否超过了最大任务数,如果超过了,删除最早的任务直到任务数小于最大任务数
- '''
-
- task_list = []
- task_list = os.listdir(task_output_dir)
- # 按照时间排序
- task_list.sort()
-
- # 检查是否超过了最大任务数
- max_task_num = app.config['max_task_num']
- if len(task_list) > max_task_num:
- # 删除最早的任务
- task_list.sort()
- for i in range(len(task_list) - max_task_num):
- task_dir = os.path.join(task_output_dir, task_list[i])
- os.system("rm -rf " + task_dir)
-
- task_list = os.listdir(task_output_dir)
- task_list.sort()
- return task_list
- def get_result(task_id:str):
- result_dir = os.path.join(task_output_dir, task_id)
- # 查看result_dir的下一级文件夹
-
- result_model_dir = os.path.join(result_dir,os.listdir(result_dir)[0])
- print(result_model_dir)
-
- mxnet_json = os.path.join(result_model_dir, "mxnet.json")
- with open(
- mxnet_json, "r"
- ) as file1: # tensorflow.json
- data1 = json.load(file1)
-
- tensorflow_json = os.path.join(result_model_dir, "tensorflow.json")
- with open(tensorflow_json, "r") as file2:
- data2 = json.load(file2)
-
- combined_data = {
- "mxnet": data1,
- "tensorflow": data2
- }
- return combined_data
-
-
- @app.route('/')
- def hello_world(): # put application's code here
- return 'Hello World!'
- @app.route('/models', methods=['GET'])
- def getModels():
- models = ['lexnet-cifar10', 'lenet5-fashion-mnist', 'fashion2', 'svhn', 'lenet5-mnist',
- 'alexnet-cifar10', 'mobilenet.1.00.224-imagenet', 'vgg16-imagenet']
- return models
- import time
- @app.route('/run', methods=['POST'])
- def run():
- data = request.get_json()
- exp = data['exp']
- mutate_num = data['mutate_num']
-
- #创建task_id: task_2021-09-07T14:57:00_1
- task_id = 'task_' + time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()) + '_' + str(len(get_task_list())+1)
-
- # 开启一个线程执行,直接返回请求
- # data = runtool(exp, mutate_num,task_id)
- t = threading.Thread(target=runtool, args=(exp, mutate_num,task_id))
- t.start()
-
- return {'task_id': task_id}
- @app.route('/get_task_list', methods=['GET'])
- def get_task_list_api():
- return {'task_list': get_task_list()}
- @app.route('/get_task_result', methods=['POST'])
- def get_task_result():
- data = request.get_json()
- task_id = data['task_id']
-
- try:
- res = get_result(task_id)
- return res
- except Exception as e:
- return {'error:':'错误的task_id或任务未完成'}
-
-
- if __name__ == '__main__':
- # app.run()
- from waitress import serve
- serve(app, host="0.0.0.0", port=5000)
|