nndeploy.api_llm.llm_chat 源代码

# llm_chat_node.py
# -*- coding: utf-8 -*-

import os
import json
import time
import logging
import random
import typing as T
from dataclasses import dataclass

import requests  # Used for general HTTP requests (OpenAI-compatible /v1/chat/completions)

import nndeploy.dag
from nndeploy.base import Status


# ──────────────────────────────────────────────────────────────────────────────
# Parameter structure
# ──────────────────────────────────────────────────────────────────────────────

[文档]@dataclass class LLMParams: api_base: str = "https://api.openai.com" api_key_env: str = "OPENAI_API_KEY" model: str = "gpt-4o-mini" temperature: float = 0.2 max_tokens: int = 512 system_prompt: str = "You are a helpful assistant." timeout_s: float = 60.0 retry: int = 2 retry_backoff_s: float = 1.2 stream: bool = False response_json_mode: bool = False json_schema: T.Optional[dict] = None
# ────────────────────────────────────────────────────────────────────────────── # Utility functions # ────────────────────────────────────────────────────────────────────────────── def _read_api_key(env_name: str) -> str: key = os.getenv(env_name, "").strip() if not key: raise RuntimeError( f"[LLMChatNode] Missing API key. Please set environment variable: {env_name}" ) return key def _normalize_messages( user_input: T.Union[str, dict, list], system_prompt: str ) -> list: """ Normalize user input into an OpenAI-compatible messages array: - str: treated as one user message - list[dict]: if already in messages format (with role/content), use directly; otherwise convert to string - dict: if contains "messages" list, use directly; otherwise convert whole dict to string Ensure that there is always one system message (add if missing). """ sys_msg = {"role": "system", "content": system_prompt} if isinstance(user_input, str): return [sys_msg, {"role": "user", "content": user_input}] if isinstance(user_input, list): ok = all(isinstance(m, dict) and "role" in m and "content" in m for m in user_input) if ok: has_system = any(m.get("role") == "system" for m in user_input) return ([sys_msg] + user_input) if not has_system else user_input return [sys_msg, {"role": "user", "content": json.dumps(user_input, ensure_ascii=False)}] if isinstance(user_input, dict): if "messages" in user_input and isinstance(user_input["messages"], list): msgs = user_input["messages"] has_system = any(isinstance(m, dict) and m.get("role") == "system" for m in msgs) return ([sys_msg] + msgs) if not has_system else msgs return [sys_msg, {"role": "user", "content": json.dumps(user_input, ensure_ascii=False)}] # Fallback for other types return [sys_msg, {"role": "user", "content": str(user_input)}] def _calc_backoff(attempt: int, base: float = 1.2, jitter: float = 0.3) -> float: """Exponential backoff with jitter.""" return (base ** attempt) * (1.0 + random.uniform(-jitter, jitter)) def _should_stop(node) -> bool: """Respect NNDeploy interruption if available.""" try: return hasattr(node, "checkInterruptStatus") and node.checkInterruptStatus() except Exception: return False def _mask_key(key: str) -> str: """Return a masked preview like 'sk-****abcd' (no real key leakage).""" if not key: return "" tail = key[-4:] prefix = "sk-" if key.startswith("sk-") else "" return f"{prefix}****{tail}" # ────────────────────────────────────────────────────────────────────────────── # Core HTTP caller # ──────────────────────────────────────────────────────────────────────────────
[文档]def call_llm( params: LLMParams, messages: list, *, api_key_override: str = None, node_for_interrupt: T.Optional[nndeploy.dag.Node] = None, emit_usage_to_log: bool = True, stream_print: bool = False, # print streaming pieces to stdout proxies: T.Optional[dict] = None, # e.g. {"http": "...", "https": "..."} verify: T.Union[bool, str, None] = None # True/False or CA bundle path ) -> str: """ Call an OpenAI-compatible Chat Completions endpoint. - Supports non-stream and stream modes. - Respects 429 Retry-After; uses exponential backoff with jitter. - Logs usage if available. Returns the final concatenated content as a string. """ api_key = (api_key_override or _read_api_key(params.api_key_env)).strip() url = params.api_base.rstrip("/") + "/v1/chat/completions" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } payload = { "model": params.model, "messages": messages, "temperature": params.temperature, "max_tokens": params.max_tokens, "stream": bool(params.stream), } # JSON strict mode (if provider supports OpenAI response_format) if params.response_json_mode: payload["response_format"] = {"type": "json_object"} if isinstance(params.json_schema, dict): # Some providers may ignore 'json_schema'; keep best-effort payload["response_format"]["json_schema"] = params.json_schema # Proxies & TLS verification (when not provided by caller) if proxies is None: proxies = { "http": os.getenv("HTTP_PROXY") or os.getenv("http_proxy"), "https": os.getenv("HTTPS_PROXY") or os.getenv("https_proxy"), } if verify is None: verify = os.getenv("REQUESTS_CA_BUNDLE", True) # custom CA bundle path or bool # Helper to check interrupt between retries/stream chunks def _interrupted() -> bool: return _should_stop(node_for_interrupt) last_err = None for attempt in range(params.retry + 1): if _interrupted(): return "" # Return empty on interrupt try: if not payload["stream"]: # Non-streaming resp = requests.post( url, headers=headers, json=payload, timeout=params.timeout_s, proxies=proxies, verify=verify ) if resp.status_code == 429: retry_after = resp.headers.get("Retry-After") sleep_s = float(retry_after) if retry_after else _calc_backoff(attempt + 1, base=params.retry_backoff_s) logging.warning(f"[LLMChatNode] 429 Too Many Requests; sleeping {sleep_s:.2f}s") time.sleep(sleep_s) continue if resp.status_code >= 400: raise RuntimeError(f"[LLMChatNode] HTTP {resp.status_code}: {resp.text[:800]}") data = resp.json() # optional usage logging usage = data.get("usage", {}) if emit_usage_to_log and usage: logging.info(f"[LLMChatNode] usage={usage}") choices = data.get("choices", []) if not choices: raise RuntimeError(f"[LLMChatNode] Empty choices: {data}") msg = choices[0].get("message") or {} content = msg.get("content") if not isinstance(content, str): content = json.dumps(content, ensure_ascii=False) return content else: # Streaming chunks: list[str] = [] with requests.post( url, headers=headers, json=payload, timeout=params.timeout_s, stream=True, proxies=proxies, verify=verify ) as r: if r.status_code == 429: retry_after = r.headers.get("Retry-After") sleep_s = float(retry_after) if retry_after else _calc_backoff(attempt + 1, base=params.retry_backoff_s) logging.warning(f"[LLMChatNode] 429 (stream); sleeping {sleep_s:.2f}s") time.sleep(sleep_s) continue if r.status_code >= 400: raise RuntimeError(f"[LLMChatNode] HTTP {r.status_code}: {r.text[:800]}") for raw in r.iter_lines(decode_unicode=True): if _interrupted(): return "".join(chunks) if not raw: continue # SSE line format: "data: {...}" or "data: [DONE]" if not raw.startswith("data:"): continue data_line = raw[5:].strip() if data_line == "[DONE]": break try: delta = json.loads(data_line) choice0 = (delta.get("choices") or [{}])[0] piece = (choice0.get("delta") or {}).get("content") # Some providers may stream full 'message' frames (rare) if piece is None: piece = (choice0.get("message") or {}).get("content") if piece: chunks.append(piece) if stream_print: print(piece, end="", flush=True) except Exception: # Ignore malformed lines continue # No usage is guaranteed in streaming mode return "".join(chunks) except (requests.Timeout, requests.ConnectionError) as e: last_err = e if attempt < params.retry: sleep_s = _calc_backoff(attempt + 1, base=params.retry_backoff_s) logging.warning(f"[LLMChatNode] network issue; retry in {sleep_s:.2f}s: {e}") time.sleep(sleep_s) continue break except Exception as e: last_err = e if attempt < params.retry: sleep_s = _calc_backoff(attempt + 1, base=params.retry_backoff_s) logging.warning(f"[LLMChatNode] call failed; retry in {sleep_s:.2f}s: {e}") time.sleep(sleep_s) continue break raise RuntimeError(f"[LLMChatNode] call failed after retries: {last_err}")
# ────────────────────────────────────────────────────────────────────────────── # Node implementation # ──────────────────────────────────────────────────────────────────────────────
[文档]class LLMChatNode(nndeploy.dag.Node): """ LLM chat node (OpenAI-compatible) Input: str / list[messages] / dict({"messages":[...]}) Output: str (model generated text) """
[文档] def __init__(self, name, inputs: list[nndeploy.dag.Edge] = None, outputs: list[nndeploy.dag.Edge] = None): super().__init__(name, inputs, outputs) super().set_key("nndeploy.api_llm.LLMChatNode") super().set_desc("LLM chat node (OpenAI-compatible)") self.set_input_type(str) self.set_output_type(str) self.set_node_type(nndeploy.dag.NodeType.Intermediate) # Frontend editable parameters (exposed via serialize/deserialize) self.api_base = "https://api.openai.com" self.api_key_env = "OPENAI_API_KEY" # environment variable to read Key (optional) self.api_key = "" # allow frontend to provide Key inline (priority if allowed) self.allow_inline_api_key = True # security switch for inline key usage self.model = "gpt-4o-mini" self.temperature = 0.2 self.max_tokens = 512 self.system_prompt = "You are a helpful assistant." self.timeout_s = 60.0 self.retry = 2 self.retry_backoff_s = 1.2 # New features self.stream = False # streaming mode self.stream_print = False # print streaming pieces to stdout self.response_json_mode = False # enforce JSON object responses (if supported) self.json_schema: T.Optional[dict] = None self.emit_usage_to_log = True # log token usage if present self.verify_tls: T.Union[bool, str] = True # True/False or CA bundle path # NEW: Proxies as frontend parameters self.http_proxy: str = "" # e.g. "http://user:pass@host:port" self.https_proxy: str = "" # e.g. "https://user:pass@host:port" # Aggregate into a dict for frontend rendering/editing self.frontend_params = { "api_base": self.api_base, "api_key_env": self.api_key_env, "api_key": self.api_key, # UI should render this field; value won't be persisted "allow_inline_api_key": self.allow_inline_api_key, "model": self.model, "temperature": self.temperature, "max_tokens": self.max_tokens, "system_prompt": self.system_prompt, "timeout_s": self.timeout_s, "retry": self.retry, "retry_backoff_s": self.retry_backoff_s, "stream": self.stream, "stream_print": self.stream_print, "response_json_mode": self.response_json_mode, "json_schema": self.json_schema, "emit_usage_to_log": self.emit_usage_to_log, "verify_tls": self.verify_tls, # NEW: proxies exposed to UI "http_proxy": self.http_proxy, "https_proxy": self.https_proxy, }
# Sync frontend parameters back to instance variables def _sync_params_from_frontend(self): fp = self.frontend_params or {} self.api_base = str(fp.get("api_base", self.api_base)) self.api_key_env = str(fp.get("api_key_env", self.api_key_env)) self.api_key = str(fp.get("api_key", self.api_key)) # UI can set it each time self.allow_inline_api_key = bool(fp.get("allow_inline_api_key", self.allow_inline_api_key)) self.model = str(fp.get("model", self.model)) self.temperature = float(fp.get("temperature", self.temperature)) self.max_tokens = int(fp.get("max_tokens", self.max_tokens)) self.system_prompt = str(fp.get("system_prompt", self.system_prompt)) self.timeout_s = float(fp.get("timeout_s", self.timeout_s)) self.retry = int(fp.get("retry", self.retry)) self.retry_backoff_s = float(fp.get("retry_backoff_s", self.retry_backoff_s)) self.stream = bool(fp.get("stream", self.stream)) self.stream_print = bool(fp.get("stream_print", self.stream_print)) self.response_json_mode = bool(fp.get("response_json_mode", self.response_json_mode)) self.json_schema = fp.get("json_schema", self.json_schema) self.emit_usage_to_log = bool(fp.get("emit_usage_to_log", self.emit_usage_to_log)) self.verify_tls = fp.get("verify_tls", self.verify_tls) # NEW: proxies self.http_proxy = str(fp.get("http_proxy", self.http_proxy)) self.https_proxy = str(fp.get("https_proxy", self.https_proxy))
[文档] def run(self): # 0) Interrupt guard if _should_stop(self): return Status.ok() # 1) Get input input_edge = self.get_input(0) user_input = input_edge.get(self) # May be str / list / dict # 2) Sync parameters self._sync_params_from_frontend() # 3) API key selection (priority: inline key if allowed > environment variable) api_key = "" if self.allow_inline_api_key: api_key = (self.api_key or "").strip() if not api_key: api_key = os.getenv(self.api_key_env, "").strip() if not api_key: err = (f"[LLMChatNode] Missing API key. " f"Please set 'frontend_params.api_key' (if allowed) or environment variable '{self.api_key_env}'.") logging.error(err) output_edge = self.get_output(0) output_edge.set(err) return Status.ok() # 4) Build params object and messages params = LLMParams( api_base=self.api_base, api_key_env=self.api_key_env, model=self.model, temperature=self.temperature, max_tokens=self.max_tokens, system_prompt=self.system_prompt, timeout_s=self.timeout_s, retry=self.retry, retry_backoff_s=self.retry_backoff_s, stream=self.stream, response_json_mode=self.response_json_mode, json_schema=self.json_schema if isinstance(self.json_schema, dict) else None, ) messages = _normalize_messages(user_input, params.system_prompt) # 5) Log target (mask scheme) safe_base = (self.api_base or "").replace("https://", "").replace("http://", "") logging.info(f"[LLMChatNode] base={safe_base} model={self.model} stream={self.stream}") # 6) Build proxies from frontend or environment proxies = { "http": self.http_proxy or os.getenv("HTTP_PROXY") or os.getenv("http_proxy"), "https": self.https_proxy or os.getenv("HTTPS_PROXY") or os.getenv("https_proxy"), } # 7) Call LLM try: text = call_llm( params, messages, api_key_override=api_key, node_for_interrupt=self, emit_usage_to_log=self.emit_usage_to_log, stream_print=self.stream_print, proxies=proxies, verify=self.verify_tls ) except Exception as e: logging.exception("[LLMChatNode] LLM call failed") text = f"[LLMChatNode ERROR] {e}" # 8) Write output to edge 0 output_edge = self.get_output(0) output_edge.set(text) return Status.ok()
# Serialization / Deserialization
[文档] def serialize(self): base = json.loads(super().serialize()) fp = dict(self.frontend_params) # Always include the api_key field for UI rendering, but NEVER return the real value fp["api_key"] = "" # Provide status & masked preview so UI can show "configured" state has_key = bool(self.api_key) fp["has_inline_api_key"] = has_key if has_key: fp["api_key_preview"] = _mask_key(self.api_key) # Proxies are returned as-is so UI can render & persist them # (If you don't want to persist proxies with credentials, mask or drop them here.) base["frontend_params"] = fp return json.dumps(base, ensure_ascii=False)
[文档] def deserialize(self, target: str): obj = json.loads(target) if "frontend_params" in obj: self.frontend_params = obj["frontend_params"] or self.frontend_params self._sync_params_from_frontend() return super().deserialize(target)
# ────────────────────────────────────────────────────────────────────────────── # Node creator & registration # ──────────────────────────────────────────────────────────────────────────────
[文档]class LLMChatNodeCreator(nndeploy.dag.NodeCreator):
[文档] def __init__(self): super().__init__() self.node: T.Optional[LLMChatNode] = None
[文档] def create_node(self, name: str, inputs: list[nndeploy.dag.Edge], outputs: list[nndeploy.dag.Edge]): self.node = LLMChatNode(name, inputs, outputs) return self.node
# Register into NNDeploy llm_chat_node_creator = LLMChatNodeCreator() nndeploy.dag.register_node("nndeploy.api_llm.LLMChatNode", llm_chat_node_creator)