nndeploy.server.template 源代码

import logging
import sys
import tempfile
import zipfile
import requests
from pathlib import Path
from typing import Optional, Tuple
import re
import time
from functools import cached_property
from dataclasses import dataclass
from urllib.parse import urlsplit, urlunsplit

# ---------------------------- Constants ----------------------------

REQUEST_TIMEOUT = 60  # seconds
TEMPLATE_ROOT = Path.cwd() / "resources" / "template"

# Fallback provider (owner, repo, tag) when config is missing
DEFAULT_PROVIDER = ("nndeploy", "nndeploy-workflow", "v1.0.0")

# "!" means: resolve via version.py::CONFIG + nndeploy.__version__
DEFAULT_VERSION_STRING = "!"

GITHUB_HOST = "github.com"
GITEE_HOST = "gitee.com"
GITHUB_API_BASE = "https://api.github.com"


# ---------------------------- Provider ----------------------------

[文档]@dataclass class TemplateProvider: owner: str repo: str @property def _base(self) -> str: return f"{GITHUB_API_BASE}/repos/{self.owner}/{self.repo}/releases" @cached_property def latest(self) -> dict: r = requests.get(f"{self._base}/latest", timeout=REQUEST_TIMEOUT) if r.status_code == 404: r = requests.get(self._base, timeout=REQUEST_TIMEOUT) r.raise_for_status() releases = r.json() if not releases: raise RuntimeError("No releases found on GitHub") return releases[0] r.raise_for_status() return r.json()
[文档] def by_tag(self, ver: str) -> dict: r = requests.get(f"{self._base}/tags/{ver}", timeout=REQUEST_TIMEOUT) if r.status_code == 404: r = requests.get(self._base, timeout=REQUEST_TIMEOUT) r.raise_for_status() for rel in r.json(): if rel["tag_name"] in (ver, f"v{ver}"): return rel raise RuntimeError(f"Release {ver} not found on GitHub") r.raise_for_status() return r.json()
# ---------------------------- Download helpers ---------------------------- def _swap_host(url: str, new_host: str) -> str: parts = urlsplit(url) return urlunsplit((parts.scheme, new_host, parts.path, parts.query, parts.fragment)) def _stream_download_with_retry(url: str, dest_dir: Path, retries: int = 3, delay: float = 2.0) -> None: for attempt in range(1, retries + 1): try: logging.info(f"[Attempt {attempt}] Downloading from {url}") with tempfile.TemporaryFile() as tmp: r = requests.get(url, stream=True, timeout=REQUEST_TIMEOUT) r.raise_for_status() for chunk in r.iter_content(8192): if chunk: tmp.write(chunk) tmp.seek(0) with zipfile.ZipFile(tmp) as zf: zf.extractall(dest_dir) logging.info(f"Template extracted successfully to {dest_dir}") return except Exception as e: logging.warning(f"Download failed on attempt {attempt}: {e}") if attempt < retries: time.sleep(delay) else: logging.error(f"All {retries} download attempts failed for {url}") raise def _download_with_fallbacks(urls: list[str], dest: Path, retries_each: int = 3, delay: float = 2.0) -> None: last_err: Optional[Exception] = None for u in urls: try: _stream_download_with_retry(u, dest, retries=retries_each, delay=delay) return except Exception as e: last_err = e logging.warning(f"Download attempt failed for {u}: {e}") raise RuntimeError(f"All download attempts failed. Last error: {last_err}") def _download_via_release(rel: dict, dest: Path, asset_name: str = "nndeploy-workflow.zip") -> None: """ Download via GitHub API release object. Priority: exact 'asset_name' → single .zip asset → error. """ assets = rel.get("assets", []) or [] asset = next((a for a in assets if a.get("name") == asset_name), None) if asset is None: zips = [a for a in assets if str(a.get("name", "")).lower().endswith(".zip")] if len(zips) == 1: asset = zips[0] else: names = [a.get("name") for a in assets] raise RuntimeError(f"Asset '{asset_name}' not found; available assets: {names}") browser_url = asset["browser_download_url"] try_urls = [] try: try_urls.append(_swap_host(browser_url, GITEE_HOST)) except Exception: pass try_urls.append(browser_url) _download_with_fallbacks(try_urls, dest) # ---------------------------- Config + version utils ---------------------------- def _load_versions_config() -> dict: """ Import CONFIG from sibling 'version.py' (script style), or from '.version' when running inside a package. """ CONFIG = None try: from version import CONFIG as _CONF # type: ignore CONFIG = _CONF except Exception: try: from .version import CONFIG as _CONF # type: ignore CONFIG = _CONF except Exception as e: raise RuntimeError("version.py not found or import failed") from e if not isinstance(CONFIG, dict): raise RuntimeError("CONFIG not found or not a dict in version.py") if CONFIG.get("schema") != 1: raise RuntimeError("version CONFIG schema unsupported or missing (expect 1)") return CONFIG def _normalize_version(ver: Optional[str]) -> Optional[str]: """ Normalize nndeploy.__version__ into 'X.Y.Z'. Handles cases like 'nndeploy 2.6.1' → '2.6.1'. """ if not ver: return None parts = ver.strip().split() if len(parts) > 1 and re.match(r"\d+\.\d+\.\d+", parts[-1]): return parts[-1] return ver.strip() def _semver_tuple(v: str) -> Tuple[int, int, int]: """ Parse 'vX.Y.Z[-suffix]' into (X, Y, Z); non-parsable parts become 0. """ v = v.strip().lstrip("v") core = re.split(r"[-+]", v, 1)[0] parts = core.split(".") out = [] for i in range(3): try: out.append(int(parts[i])) except Exception: out.append(0) return tuple(out) # type: ignore[return-value] def _cmp_ver(a: str, b: str) -> int: ta, tb = _semver_tuple(a), _semver_tuple(b) return (ta > tb) - (ta < tb) def _match_constraint(nndeploy_ver: str, expr: str) -> bool: """ Support constraint expressions like '>=0.2.12,<0.3.0' or '>=0.3.0'. All comma-separated conditions must hold. """ nv = nndeploy_ver for cond in [c.strip() for c in expr.split(",") if c.strip()]: m = re.match(r"(>=|<=|>|<|==)?\s*v?(\d+\.\d+\.\d+)", cond) if not m: return False op, ver = m.groups() op = op or "==" cmpres = _cmp_ver(nv, ver) ok = { "==": cmpres == 0, ">=": cmpres >= 0, "<=": cmpres <= 0, ">": cmpres > 0, "<": cmpres < 0, }[op] if not ok: return False return True def _normalize_templates_entry(x: object) -> dict: """ Normalize templates config shapes into {'tag': str, 'asset': Optional[str]}. Supported inputs: - "v1.0.0" - {"templates": "v1.0.0", "asset": "..."} - {"templates": {"tag": "v1.0.0", "asset": "..."}} - {"tag": "v1.0.0", "asset": "..."} """ if x is None: return {} if isinstance(x, str): return {"tag": x} if isinstance(x, dict): if isinstance(x.get("templates"), str): return {"tag": x["templates"], "asset": x.get("asset")} if isinstance(x.get("templates"), dict): te = x["templates"] return {"tag": te.get("tag") or te.get("templates"), "asset": te.get("asset") or x.get("asset")} if "tag" in x or "asset" in x: return {"tag": x.get("tag"), "asset": x.get("asset")} return {} def _resolve_templates_from_config() -> tuple[str, str, str, str]: """ Resolve (owner, repo, tag, asset) for the workflow templates bundle. Resolution order: exact versions → range rules → fallback. """ cfg = _load_versions_config() # Allow either flat default_provider or sectioned default_provider["templates"]. defprov = (cfg.get("default_provider") or {}).get("templates") or cfg.get("default_provider") or {} owner = defprov.get("owner") or DEFAULT_PROVIDER[0] repo = defprov.get("repo") or DEFAULT_PROVIDER[1] nndeploy_ver: Optional[str] = None try: import nndeploy # type: ignore raw_ver = getattr(nndeploy, "__version__", None) nndeploy_ver = _normalize_version(raw_ver) except Exception: pass chosen: dict = {} if nndeploy_ver: # Exact table vermap = cfg.get("versions", {}) or {} hit = vermap.get(nndeploy_ver) if hit is not None: te = hit.get("templates", hit) if isinstance(hit, dict) else hit chosen = _normalize_templates_entry(te) # Range rules if not chosen: for r in cfg.get("ranges", []) or []: expr = r.get("nndeploy") if expr and _match_constraint(nndeploy_ver, expr): te = r.get("templates", r) chosen = _normalize_templates_entry(te) if chosen: break # Fallback if not chosen: fb = cfg.get("fallback") or {} te = fb.get("templates", fb) chosen = _normalize_templates_entry(te) tag = chosen.get("tag") or chosen.get("templates") or DEFAULT_PROVIDER[2] asset = chosen.get("asset") or "nndeploy-workflow.zip" return owner, repo, str(tag), str(asset) # ---------------------------- Public manager ----------------------------
[文档]class WorkflowTemplateManager: VERSION_RE = re.compile(r"^([\w-]+)/([\w_.-]+)@(v?\d+\.\d+\.\d+|latest)$")
[文档] @classmethod def init_templates(cls, version_string: str = DEFAULT_VERSION_STRING) -> Optional[str]: TEMPLATE_ROOT.mkdir(parents=True, exist_ok=True) try: return cls._impl(version_string) except Exception as exc: logging.error("Failed to initialize workflow templates: %s", exc, exc_info=True) return None
@classmethod def _impl(cls, ver_str: str) -> str: # Resolve owner/repo/tag/asset from config ("!") or parse explicit string if ver_str == DEFAULT_VERSION_STRING: owner, repo, tag, asset = _resolve_templates_from_config() else: m = cls.VERSION_RE.match(ver_str) if m is None: raise ValueError(f"Invalid version string format: {ver_str}") owner, repo, tag = m.groups() asset = "nndeploy-workflow.zip" # Direct download for non-latest tags if tag != "latest": dest = TEMPLATE_ROOT dest_inner = dest / "nndeploy-workflow" # keep your existing cache marker if dest_inner.exists(): logging.info(f"use cached templates at {dest_inner}") return str(dest) gitee_url = f"https://{GITEE_HOST}/{owner}/{repo}/releases/download/{tag}/{asset}" github_url = f"https://{GITHUB_HOST}/{owner}/{repo}/releases/download/{tag}/{asset}" try: logging.info(f"Attempting direct download (Gitee→GitHub): {gitee_url} / {github_url}{dest}") dest.mkdir(parents=True, exist_ok=True) _download_with_fallbacks([gitee_url, github_url], dest) return str(dest) except Exception as err: try: dest.rmdir() except Exception as e: logging.debug(f"Failed to clean up failed directory: {e}") logging.warning(f"Direct download failed. Falling back to GitHub API: {err}") # API fallback (latest or when direct download failed) provider = TemplateProvider(owner, repo) rel = provider.latest if tag == "latest" else provider.by_tag(tag) semver = rel["tag_name"].lstrip("v") # dest = TEMPLATE_ROOT / repo dest = TEMPLATE_ROOT dest_inner = dest / "nndeploy-workflow" if not dest_inner.exists(): logging.info(f"Downloading templates via API: {owner}/{repo}@{semver}{dest}") dest.mkdir(parents=True, exist_ok=True) _download_via_release(rel, dest, asset_name=asset) return str(dest)