nndeploy.diffusion.diffusers_info.extract_example_doc_string 源代码

#!/usr/bin/env python3
"""
Traverse all pipeline_*.py files under the specified src/diffusers/pipelines directory, extract the EXAMPLE_DOC_STRING from each file,
and write them into the specified diffusers_pipeline_md.py in a more readable format, preserving directory order and indentation.
"""

import os
import re
import argparse

[文档]def find_pipeline_files(directory: str): """Find all pipeline_*.py files, return their relative paths, preserving directory order.""" pipeline_files = [] for root, dirs, files in os.walk(directory): dirs.sort() files.sort() for file in files: if file.startswith('pipeline_') and file.endswith('.py'): rel_path = os.path.relpath(os.path.join(root, file), start=directory) pipeline_files.append(os.path.join(directory, rel_path)) return pipeline_files
[文档]def extract_example_doc_string(file_path: str): """Extract the content of EXAMPLE_DOC_STRING from the file.""" try: with open(file_path, 'r', encoding='utf-8') as f: content = f.read() # Match EXAMPLE_DOC_STRING = """ ... """ pattern = r'EXAMPLE_DOC_STRING\s*=\s*"""(.*?)"""' match = re.search(pattern, content, re.DOTALL) if match: return match.group(1).strip() except Exception as e: print(f"Error reading file {file_path}: {e}") return None
[文档]def write_dict_pretty(data: dict, output_file: str): """Write the dictionary to a file in a more readable format.""" with open(output_file, "w", encoding="utf-8") as f: f.write("# Automatically generated by extract_model_paths.py\n") f.write("DIFFUSERS_PIPELINE_MD = {\n") for idx, (filename, docstring) in enumerate(data.items()): f.write(f' "{filename}": (\n') # Preserve the original indentation of docstring, and indent the whole block by 8 spaces doc_lines = docstring.splitlines() if doc_lines: f.write(" '''\n") for line in doc_lines: f.write(f" {line.rstrip()}\n") f.write(" '''\n") else: f.write(" ''\n") f.write(" ),\n") f.write("}\n")
[文档]def main(): parser = argparse.ArgumentParser(description="Extract EXAMPLE_DOC_STRING from pipeline files.") parser.add_argument( "--pipelines_dir", type=str, default="src/diffusers/pipelines", help="Directory containing pipeline_*.py files." ) parser.add_argument( "--output_file", type=str, default="diffusers_pipeline_md.py", help="Output file to write the extracted docstrings." ) args = parser.parse_args() pipeline_files = find_pipeline_files(args.pipelines_dir) example_docstrings = {} for file_path in pipeline_files: docstring = extract_example_doc_string(file_path) if docstring: rel_path = os.path.relpath(file_path, start=".") example_docstrings[rel_path] = docstring try: write_dict_pretty(example_docstrings, args.output_file) print(f"Wrote {len(example_docstrings)} EXAMPLE_DOC_STRING entries to {args.output_file}") except Exception as e: print(f"Error writing file {args.output_file}: {e}")
if __name__ == "__main__": main()