"""AI 图片生成服务:后端请求 + 本地缓存管理。 API 端点待接入,当前通过 ``set_api_caller`` 注入具体实现。 缓存目录:``settings/ai_image_cache/``,每张图片有同名的 ``.json`` 侧车记录。 """ from __future__ import annotations import datetime as _dt import hashlib import json import mimetypes import os import shutil import threading from dataclasses import dataclass, asdict from typing import Callable, List, Optional from urllib.parse import urlparse from urllib.request import Request, urlopen # ---------- 常量 ---------- _CACHE_DIRNAME = os.path.join("settings", "ai_image_cache") _META_SUFFIX = ".json" _SUPPORTED_IMG_EXT = (".png", ".jpg", ".jpeg", ".bmp", ".webp") # ---------- 数据结构 ---------- @dataclass class AIImageRecord: """一条缓存记录。""" id: str prompt: str image_path: str created_at: str # ISO8601 extra: Optional[dict] = None def to_json(self) -> str: return json.dumps(asdict(self), ensure_ascii=False, indent=2) # ---------- API 注入 ---------- # 调用签名: ``fn(prompt: str) -> (image_bytes: bytes, image_ext: str, extra: dict|None)`` # ``image_ext`` 例如 ``".png"``;``extra`` 可为 None。 _ApiCaller = Callable[[str], tuple] _api_caller: Optional[_ApiCaller] = None def set_api_caller(fn: Optional[_ApiCaller]) -> None: """注入真实的后端 API 调用函数。在 API 就绪前可保持为 None。""" global _api_caller _api_caller = fn def has_api() -> bool: return _api_caller is not None # ---------- 缓存路径工具 ---------- def get_cache_dir(base_dir: Optional[str] = None) -> str: """返回缓存目录,如不存在则创建。``base_dir`` 默认使用当前工作目录。""" root = base_dir if base_dir else os.getcwd() path = os.path.join(root, _CACHE_DIRNAME) os.makedirs(path, exist_ok=True) return path def _make_id(prompt: str) -> str: stamp = _dt.datetime.now().strftime("%Y%m%d_%H%M%S_%f") digest = hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8] return f"{stamp}_{digest}" def _meta_path_for(image_path: str) -> str: return os.path.splitext(image_path)[0] + _META_SUFFIX # ---------- 读写 ---------- def list_records(base_dir: Optional[str] = None) -> List[AIImageRecord]: """列出缓存目录下的所有记录,按创建时间倒序(最新在前)。""" cache_dir = get_cache_dir(base_dir) records: List[AIImageRecord] = [] for name in os.listdir(cache_dir): full = os.path.join(cache_dir, name) if not (os.path.isfile(full) and name.lower().endswith(_SUPPORTED_IMG_EXT)): continue meta_path = _meta_path_for(full) prompt = "" created_at = "" extra = None rec_id = os.path.splitext(name)[0] if os.path.isfile(meta_path): try: with open(meta_path, "r", encoding="utf-8") as f: data = json.load(f) prompt = data.get("prompt", "") created_at = data.get("created_at", "") extra = data.get("extra") rec_id = data.get("id", rec_id) except Exception: pass if not created_at: # fallback 到文件 mtime try: mtime = os.path.getmtime(full) created_at = _dt.datetime.fromtimestamp(mtime).isoformat() except Exception: created_at = "" records.append( AIImageRecord( id=rec_id, prompt=prompt, image_path=full, created_at=created_at, extra=extra, ) ) if not records: seeded = _seed_placeholder_record(cache_dir) if seeded is not None: records.append(seeded) records.sort(key=lambda r: r.created_at, reverse=True) return records def _seed_placeholder_record(cache_dir: str) -> Optional[AIImageRecord]: """当缓存为空时,写入一张本地占位图,便于前端联调。""" try: repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) src = os.path.join(repo_root, "assets", "entry_1.png") if not os.path.isfile(src): return None rec_id = f"{_dt.datetime.now().strftime('%Y%m%d_%H%M%S')}_placeholder" image_path = os.path.join(cache_dir, f"{rec_id}.png") shutil.copyfile(src, image_path) record = AIImageRecord( id=rec_id, prompt="本地测试占位图(后端未接入)", image_path=image_path, created_at=_dt.datetime.now().isoformat(timespec="seconds"), extra={"source": "local-placeholder"}, ) with open(_meta_path_for(image_path), "w", encoding="utf-8") as f: f.write(record.to_json()) return record except Exception: return None def save_image_to_cache( prompt: str, image_bytes: bytes, image_ext: str = ".png", extra: Optional[dict] = None, base_dir: Optional[str] = None, ) -> AIImageRecord: """把生成的图片字节写入缓存,返回记录。""" if not image_ext.startswith("."): image_ext = "." + image_ext if image_ext.lower() not in _SUPPORTED_IMG_EXT: image_ext = ".png" cache_dir = get_cache_dir(base_dir) rec_id = _make_id(prompt) image_path = os.path.join(cache_dir, f"{rec_id}{image_ext}") with open(image_path, "wb") as f: f.write(image_bytes) record = AIImageRecord( id=rec_id, prompt=prompt, image_path=image_path, created_at=_dt.datetime.now().isoformat(timespec="seconds"), extra=extra, ) try: with open(_meta_path_for(image_path), "w", encoding="utf-8") as f: f.write(record.to_json()) except Exception: pass return record def import_image_from_url( image_url: str, prompt: Optional[str] = None, extra: Optional[dict] = None, base_dir: Optional[str] = None, timeout: float = 20.0, ) -> AIImageRecord: """下载远程图片并写入缓存。""" url = (image_url or "").strip() if not url: raise ValueError("图片地址不能为空") request = Request( url, headers={ "User-Agent": "pqAutomationApp/1.0", "Accept": "image/*,*/*;q=0.8", }, ) with urlopen(request, timeout=timeout) as response: image_bytes = response.read() if not image_bytes: raise ValueError("下载结果为空") image_ext = _guess_image_ext( image_url=url, content_type=response.headers.get_content_type(), ) merged_extra = dict(extra or {}) merged_extra.update( { "source": "remote-url", "source_url": url, "content_type": response.headers.get_content_type(), } ) record_prompt = (prompt or _default_prompt_from_url(url)).strip() return save_image_to_cache( prompt=record_prompt, image_bytes=image_bytes, image_ext=image_ext, extra=merged_extra, base_dir=base_dir, ) def delete_record(record: AIImageRecord) -> bool: """删除一条缓存记录(图片 + 侧车)。返回是否成功。""" ok = True for p in (record.image_path, _meta_path_for(record.image_path)): try: if os.path.isfile(p): os.remove(p) except Exception: ok = False return ok def export_record(record: AIImageRecord, dest_path: str) -> None: """把缓存中的图片另存到 ``dest_path``。""" shutil.copyfile(record.image_path, dest_path) # ---------- 异步请求 ---------- def request_image_async( prompt: str, on_success: Callable[[AIImageRecord], None], on_error: Callable[[Exception], None], base_dir: Optional[str] = None, ) -> threading.Thread: """在后台线程请求 API → 写入缓存 → 回调。 ``on_success`` / ``on_error`` 会在 **工作线程** 中被调用;UI 侧若需 切回主线程,请在回调内部自行用 ``root.after(0, ...)``。 """ def _worker(): try: if _api_caller is None: raise RuntimeError("AI 图片 API 尚未接入,请调用 set_api_caller 注入") image_bytes, image_ext, extra = _normalize_api_result(_api_caller(prompt)) record = save_image_to_cache( prompt=prompt, image_bytes=image_bytes, image_ext=image_ext, extra=extra, base_dir=base_dir, ) on_success(record) except Exception as exc: on_error(exc) t = threading.Thread(target=_worker, daemon=True) t.start() return t def import_image_from_url_async( image_url: str, on_success: Callable[[AIImageRecord], None], on_error: Callable[[Exception], None], prompt: Optional[str] = None, extra: Optional[dict] = None, base_dir: Optional[str] = None, timeout: float = 20.0, ) -> threading.Thread: """在后台线程下载远程图片并写入缓存""" def _worker(): try: record = import_image_from_url( image_url=image_url, prompt=prompt, extra=extra, base_dir=base_dir, timeout=timeout, ) on_success(record) except Exception as exc: on_error(exc) t = threading.Thread(target=_worker, daemon=True) t.start() return t def _normalize_api_result(result): """允许 API 返回 ``bytes`` 或 ``(bytes, ext)`` 或 ``(bytes, ext, extra)``。""" if isinstance(result, (bytes, bytearray)): return bytes(result), ".png", None if isinstance(result, tuple): if len(result) == 2: return bytes(result[0]), str(result[1]), None if len(result) == 3: return bytes(result[0]), str(result[1]), result[2] raise ValueError("API 返回格式不支持,需为 bytes 或 (bytes, ext[, extra])") def is_remote_image_url(value: str) -> bool: """判断输入是否为 http/https 图片地址。""" url = (value or "").strip() if not url: return False parsed = urlparse(url) return parsed.scheme in {"http", "https"} and bool(parsed.netloc) def _guess_image_ext(image_url: str, content_type: Optional[str]) -> str: if content_type: guessed = mimetypes.guess_extension(content_type) if guessed == ".jpe": guessed = ".jpg" if guessed and guessed.lower() in _SUPPORTED_IMG_EXT: return guessed.lower() url_path = urlparse(image_url).path ext = os.path.splitext(url_path)[1].lower() if ext in _SUPPORTED_IMG_EXT: return ext return ".png" def _default_prompt_from_url(image_url: str) -> str: path = urlparse(image_url).path name = os.path.splitext(os.path.basename(path))[0].strip() return name or "远程导入图片"