123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- import ast
- import json
- import logging
- import os
- import random
- logger = logging.getLogger(__name__)
- def walk_files(path, endpoint='.py'):
- file_list = []
- for root, dirs, files in os.walk(path):
- for file in files:
- file_path = os.path.join(root, file)
- if file_path.endswith(endpoint):
- file_list.append(file_path)
- return file_list
- def write_to_json(file_path, output_path):
- types = os.listdir(file_path)
- json_input = list()
- for directory in types:
- for file_name in walk_files(file_path + "/" + directory):
- try:
- with open(file_name, 'r', encoding='utf-8') as file:
- content = file.read()
- except Exception as e:
- logger.error("failed write: " + file_name)
- pass
- json_input.append({
- "func": content,
- "type": directory,
- "type_index": types.index(directory),
- "path": file_name
- })
- random.shuffle(json_input)
- length = len(json_input)
- with open(output_path + "train.jsonl", 'w+') as train_out:
- train_json = json.dumps(json_input[:int(7 * length / 10)])
- train_out.write(train_json)
- with open(output_path + "test.jsonl", 'w+') as test_out:
- test_json = json.dumps(json_input[int(7 * length / 10): int(9 * length / 10)])
- test_out.write(test_json)
- with open(output_path + "eval.jsonl", 'w+') as eval_out:
- eval_json = json.dumps(json_input[int(9 * length / 10):])
- eval_out.write(eval_json)
- def write_content_to_file(content, file_path):
- if not os.path.exists(file_path[:file_path.rindex("/")]):
- os.makedirs(file_path[:file_path.rindex("/")])
- try:
- with open(file_path, 'w+') as file_out:
- file_out.write(content)
- except:
- print(file_path)
- pass
- def split_file(file_dir, output_dir, endpoint=".py"):
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
- for file_path in walk_files(file_dir, endpoint):
- with open(file_path, 'r',encoding='utf8') as file:
- content = file.read()
- try:
- root = ast.parse(content)
- except Exception as e:
- print("错误明细:", e.__class__.__name__, e, file_path)
- continue
- file_id = 1
- for node in root.body:
- if isinstance(node, ast.FunctionDef):
- new_file_name = file_path.replace(file_dir, output_dir).replace(".py", "_" + str(file_id) + ".py")
- write_content_to_file(ast.get_source_segment(content, node), new_file_name)
- file_id += 1
- elif isinstance(node, ast.ClassDef):
- for son in node.body:
- if isinstance(son, ast.FunctionDef):
- new_file_name = file_path.replace(file_dir, output_dir).replace(".py",
- "_" + str(file_id) + ".py")
- write_content_to_file(ast.get_source_segment(content, son), new_file_name)
- file_id += 1
- def split_file_by_func(file_path):
- """
- :param file_path:
- :return: [func_name: (func_content, func_node of ast)]
- """
- with open(file_path, 'r',encoding='utf8') as file:
- content = file.read()
- root = None
- func_content = {}
- try:
- root = ast.parse(content)
- for node in root.body:
- if isinstance(node, ast.FunctionDef):
- func_content[node.name] = (ast.get_source_segment(content, node), node)
- elif isinstance(node, ast.ClassDef):
- class_name = node.name
- for son in node.body:
- if isinstance(son, ast.FunctionDef):
- func_name = class_name + '.' + node.name
- func_content[func_name] = (ast.get_source_segment(content, son), son)
- except Exception as e:
- logger.error(e.__class__.__name__ + " " + file_path)
- return func_content
- def load_json(json_file):
- with open(json_file, 'r') as load_f:
- load_dict = json.load(load_f)
- return load_dict
- def write_json(json_file, data):
- with open(json_file, 'w') as file:
- file.write(json.dumps(data))
- if __name__ == '__main__':
- write_to_json("dataset/origin", "dataset/")
- # for directory in os.listdir("dataset/origin"):
- # print(directory, len(walk_files("dataset/origin/" + directory)))
|