nndeploy.dag.run_json 源代码

# graph_runner.py

from __future__ import annotations
import json
import time
import logging
import traceback
from typing import Dict, Any, Tuple, List
import argparse
import shutil
from pathlib import Path

import nndeploy.base
import nndeploy.device
from .node import add_global_import_lib, import_global_import_lib
from .graph_runner import GraphRunner
     
       
[文档]def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--json_file", type=str, required=True) parser.add_argument("--name", type=str, default="", required=False) parser.add_argument("--task_id", type=str, default="", required=False) # node_name:param_key:param_value parser.add_argument("--node_param", "-np", type=str, nargs='*', default=[], required=False) # kParallelTypeSequential parser.add_argument("--parallel_type", type=str, default="", required=False) # dump parser.add_argument("--dump", action="store_true", default=False, required=False) # debug parser.add_argument("--debug", action="store_true", default=False, required=False) # plugin parser.add_argument("--plugin", type=str, nargs='*', default=[], required=False) # resources parser.add_argument("--resources", type=str, default="", required=False) return parser.parse_args()
[文档]class GraphRunnerArgs:
[文档] def __init__(self): self.json_file = "" self.name = "" self.task_id = "" self.input_path = {} self.output_path = {} self.node_param = [] self.parallel_type = "" self.dump = False self.debug = False self.plugin = [] self.resources = ""
[文档] def parse_args(self): args = parse_args() self.json_file = args.json_file self.name = args.name self.task_id = args.task_id for item in args.node_param: self.node_param.append(item) for item in args.plugin: self.plugin.append(item) self.parallel_type = args.parallel_type self.dump = args.dump self.debug = args.debug self.resources = args.resources
# def main(): # args = GraphRunnerArgs() # args.parse_args() # if args.plugin != []: # for plugin_path in args.plugin: # add_global_import_lib(plugin_path) # import_global_import_lib() # graph_json_str = "" # with open(args.json_file, "r") as f: # graph_json_str = f.read() # gr = GraphRunner() # time_profiler_map, results, _, _ = gr.run(graph_json_str, args.name, args.task_id, args) # # print(time_profiler_map) # # print(results)
[文档]def main(): args = GraphRunnerArgs() args.parse_args() logger = logging.getLogger("run_json") logging.basicConfig( level=logging.INFO, format="[%(asctime)s] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) # Step 1: copy resources directory if args.resources: src = Path(args.resources).resolve() dst = Path.cwd() / src.name logger.info(f"Preparing to copy resources from {src} -> {dst}") if not src.exists() or not src.is_dir(): logger.error(f"Resource directory not found or not a directory: {src}") return if dst.exists(): logger.warning( f"Target directory already exists: {dst}. " "Skipping copy. Please remove or rename it manually if you want to overwrite." ) return else: try: shutil.copytree(src, dst) logger.info(f"Resources copied successfully to {dst}") except Exception as e: logger.error(f"Failed to copy resources from {src} to {dst}: {e}") return # Step 2: load plugins if args.plugin: for plugin_path in args.plugin: add_global_import_lib(plugin_path) import_global_import_lib() logger.info(f"Loaded plugins: {args.plugin}") # Step 3: load JSON file and run try: with open(args.json_file, "r") as f: graph_json_str = f.read() except Exception as e: logger.error(f"Failed to read JSON file {args.json_file}: {e}") return gr = GraphRunner() try: time_profiler_map, results, _, _ = gr.run(graph_json_str, args.name, args.task_id, args) logger.info("GraphRunner execution completed successfully.") except Exception as e: logger.error(f"GraphRunner execution failed: {e}\n{traceback.format_exc()}")
if __name__ == "__main__": main()