#!/usr/bin/env python3
"""
Traverse all pipeline_*.py files under the specified src/diffusers/pipelines directory, extract the EXAMPLE_DOC_STRING from each file,
and write them into the specified diffusers_pipeline_md.py in a more readable format, preserving directory order and indentation.
"""
import os
import re
import argparse
[文档]def find_pipeline_files(directory: str):
"""Find all pipeline_*.py files, return their relative paths, preserving directory order."""
pipeline_files = []
for root, dirs, files in os.walk(directory):
dirs.sort()
files.sort()
for file in files:
if file.startswith('pipeline_') and file.endswith('.py'):
rel_path = os.path.relpath(os.path.join(root, file), start=directory)
pipeline_files.append(os.path.join(directory, rel_path))
return pipeline_files
[文档]def write_dict_pretty(data: dict, output_file: str):
"""Write the dictionary to a file in a more readable format."""
with open(output_file, "w", encoding="utf-8") as f:
f.write("# Automatically generated by extract_model_paths.py\n")
f.write("DIFFUSERS_PIPELINE_MD = {\n")
for idx, (filename, docstring) in enumerate(data.items()):
f.write(f' "{filename}": (\n')
# Preserve the original indentation of docstring, and indent the whole block by 8 spaces
doc_lines = docstring.splitlines()
if doc_lines:
f.write(" '''\n")
for line in doc_lines:
f.write(f" {line.rstrip()}\n")
f.write(" '''\n")
else:
f.write(" ''\n")
f.write(" ),\n")
f.write("}\n")
[文档]def main():
parser = argparse.ArgumentParser(description="Extract EXAMPLE_DOC_STRING from pipeline files.")
parser.add_argument(
"--pipelines_dir",
type=str,
default="src/diffusers/pipelines",
help="Directory containing pipeline_*.py files."
)
parser.add_argument(
"--output_file",
type=str,
default="diffusers_pipeline_md.py",
help="Output file to write the extracted docstrings."
)
args = parser.parse_args()
pipeline_files = find_pipeline_files(args.pipelines_dir)
example_docstrings = {}
for file_path in pipeline_files:
docstring = extract_example_doc_string(file_path)
if docstring:
rel_path = os.path.relpath(file_path, start=".")
example_docstrings[rel_path] = docstring
try:
write_dict_pretty(example_docstrings, args.output_file)
print(f"Wrote {len(example_docstrings)} EXAMPLE_DOC_STRING entries to {args.output_file}")
except Exception as e:
print(f"Error writing file {args.output_file}: {e}")
if __name__ == "__main__":
main()