nndeploy.api_aigc.openai_plugin 源代码

import json
import logging
import base64
import requests
import numpy as np
from io import BytesIO
from PIL import Image
import os

import nndeploy.dag
import nndeploy.base


[文档]class OpenAIImageNode(nndeploy.dag.Node): """OpenAI图像生成节点(仅保留此节点)"""
[文档] def __init__(self, name, inputs: list[nndeploy.dag.Edge] = None, outputs: list[nndeploy.dag.Edge] = None): super().__init__(name, inputs, outputs) super().set_key("nndeploy.openai.OpenAIImageNode") super().set_desc("OpenAI图像生成节点 - 支持DALL-E图像生成") self._logger = logging.getLogger(__name__) # 设置输入输出类型 self.set_input_type(str) # 输入图像描述 self.set_output_type(np.ndarray) # 输出生成的图像 # 前端可配置的参数 self.api_key = "" # API密钥 self.base_url = "https://api.openai.com/v1" # 基础URL(可改为其他兼容服务) self.model = "dall-e-3" # 模型名称 self.size = "1024x1024" # 图像尺寸 self.quality = "standard" # 图像质量 self.style = "vivid" # 图像风格
[文档] def run(self): try: # 获取输入 input_edge = self.get_input(0) prompt = input_edge.get(self) self._logger.info("[OpenAIImageNode] Received prompt: %s", (prompt[:80] + '...') if isinstance(prompt, str) and len(prompt) > 80 else prompt) if not self.api_key: self._logger.error("[OpenAIImageNode] API key not set") return nndeploy.base.Status.error("OpenAI API密钥未设置") # 构建请求 self._logger.debug("[OpenAIImageNode] Building request | base_url=%s model=%s size=%s quality=%s style=%s", self.base_url, self.model, self.size, self.quality, self.style) headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } data = { "model": self.model, "prompt": prompt, "size": self.size, "quality": self.quality, "style": self.style, "n": 1 } # 发送请求(使用可配置base_url) url = f"{self.base_url.rstrip('/')}/images/generations" self._logger.info("[OpenAIImageNode] POST %s", url) response = requests.post( url, headers=headers, json=data, timeout=60 ) self._logger.info("[OpenAIImageNode] Response status: %s", response.status_code) if response.status_code == 200: result = response.json() self._logger.debug("[OpenAIImageNode] Response JSON keys: %s", list(result.keys())) # 优先读取b64_json,其次读取url(兼容不同服务实现) image_b64 = None image_url = None if isinstance(result.get("data"), list) and result["data"]: item = result["data"][0] image_b64 = item.get("b64_json") image_url = item.get("url") self._logger.debug("[OpenAIImageNode] Have b64:%s url:%s", bool(image_b64), bool(image_url)) if image_b64: img = Image.open(BytesIO(base64.b64decode(image_b64))) img_array = np.array(img) elif image_url: self._logger.info("[OpenAIImageNode] Downloading image: %s", image_url) img_response = requests.get(image_url, timeout=30) if img_response.status_code != 200: self._logger.error("[OpenAIImageNode] Image download failed: %s", img_response.status_code) return nndeploy.base.Status.error("图像下载失败") img = Image.open(BytesIO(img_response.content)) img_array = np.array(img) else: self._logger.error("[OpenAIImageNode] No image data in response") return nndeploy.base.Status.error("返回结果中未包含图像数据") # 保存到文件 try: save_dir = os.path.join("resources", "images") os.makedirs(save_dir, exist_ok=True) save_path = os.path.join(save_dir, "result.openai.jpg") pil_to_save = Image.fromarray(img_array) if pil_to_save.mode in ("RGBA", "P"): pil_to_save = pil_to_save.convert("RGB") pil_to_save.save(save_path, format="JPEG") self._logger.info("[OpenAIImageNode] Image saved to %s", save_path) except Exception: self._logger.exception("[OpenAIImageNode] Failed to save image to file") # 设置输出 output_edge = self.get_output(0) output_edge.set(img_array) self._logger.info("[OpenAIImageNode] Output image set: shape=%s dtype=%s", getattr(img_array, 'shape', None), getattr(img_array, 'dtype', None)) return nndeploy.base.Status.ok() else: error_msg = f"图像生成API请求失败: {response.status_code} - {response.text}" self._logger.error("[OpenAIImageNode] %s", error_msg) return nndeploy.base.Status.error(error_msg) except Exception as e: self._logger.exception("[OpenAIImageNode] Exception during run") return nndeploy.base.Status.error(f"图像生成错误: {str(e)}")
[文档] def serialize(self): json_str = super().serialize() json_obj = json.loads(json_str) json_obj.update({ "api_key": self.api_key, "base_url": self.base_url, "model": self.model, "size": self.size, "quality": self.quality, "style": self.style }) return json.dumps(json_obj)
[文档] def deserialize(self, target: str): json_obj = json.loads(target) self.api_key = json_obj.get("api_key", "") self.base_url = json_obj.get("base_url", "https://api.openai.com/v1") self.model = json_obj.get("model", "dall-e-3") self.size = json_obj.get("size", "1024x1024") self.quality = json_obj.get("quality", "standard") self.style = json_obj.get("style", "vivid") return super().deserialize(target)
[文档]class OpenAIImageNodeCreator(nndeploy.dag.NodeCreator):
[文档] def __init__(self): super().__init__()
[文档] def create_node(self, name: str, inputs: list[nndeploy.dag.Edge], outputs: list[nndeploy.dag.Edge]): self.node = OpenAIImageNode(name, inputs, outputs) return self.node
# 注册节点 openai_image_node_creator = OpenAIImageNodeCreator() nndeploy.dag.register_node("nndeploy.api_aigc.OpenAIImageNode", openai_image_node_creator)