"""AI 图片生成服务:后端请求 + 本地缓存管理。 API 端点待接入,当前通过 ``set_api_caller`` 注入具体实现。 缓存目录:``settings/ai_image_cache/``,每张图片有同名的 ``.json`` 侧车记录。 """ from __future__ import annotations import datetime as _dt import hashlib import json import os import shutil import threading from dataclasses import dataclass, asdict from typing import Callable, List, Optional # ---------- 常量 ---------- _CACHE_DIRNAME = os.path.join("settings", "ai_image_cache") _META_SUFFIX = ".json" _SUPPORTED_IMG_EXT = (".png", ".jpg", ".jpeg", ".bmp", ".webp") # ---------- 数据结构 ---------- @dataclass class AIImageRecord: """一条缓存记录。""" id: str prompt: str image_path: str created_at: str # ISO8601 extra: Optional[dict] = None def to_json(self) -> str: return json.dumps(asdict(self), ensure_ascii=False, indent=2) # ---------- API 注入 ---------- # 调用签名: ``fn(prompt: str) -> (image_bytes: bytes, image_ext: str, extra: dict|None)`` # ``image_ext`` 例如 ``".png"``;``extra`` 可为 None。 _ApiCaller = Callable[[str], tuple] _api_caller: Optional[_ApiCaller] = None def set_api_caller(fn: Optional[_ApiCaller]) -> None: """注入真实的后端 API 调用函数。在 API 就绪前可保持为 None。""" global _api_caller _api_caller = fn def has_api() -> bool: return _api_caller is not None # ---------- 缓存路径工具 ---------- def get_cache_dir(base_dir: Optional[str] = None) -> str: """返回缓存目录,如不存在则创建。``base_dir`` 默认使用当前工作目录。""" root = base_dir if base_dir else os.getcwd() path = os.path.join(root, _CACHE_DIRNAME) os.makedirs(path, exist_ok=True) return path def _make_id(prompt: str) -> str: stamp = _dt.datetime.now().strftime("%Y%m%d_%H%M%S_%f") digest = hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8] return f"{stamp}_{digest}" def _meta_path_for(image_path: str) -> str: return os.path.splitext(image_path)[0] + _META_SUFFIX # ---------- 读写 ---------- def list_records(base_dir: Optional[str] = None) -> List[AIImageRecord]: """列出缓存目录下的所有记录,按创建时间倒序(最新在前)。""" cache_dir = get_cache_dir(base_dir) records: List[AIImageRecord] = [] for name in os.listdir(cache_dir): full = os.path.join(cache_dir, name) if not (os.path.isfile(full) and name.lower().endswith(_SUPPORTED_IMG_EXT)): continue meta_path = _meta_path_for(full) prompt = "" created_at = "" extra = None rec_id = os.path.splitext(name)[0] if os.path.isfile(meta_path): try: with open(meta_path, "r", encoding="utf-8") as f: data = json.load(f) prompt = data.get("prompt", "") created_at = data.get("created_at", "") extra = data.get("extra") rec_id = data.get("id", rec_id) except Exception: pass if not created_at: # fallback 到文件 mtime try: mtime = os.path.getmtime(full) created_at = _dt.datetime.fromtimestamp(mtime).isoformat() except Exception: created_at = "" records.append( AIImageRecord( id=rec_id, prompt=prompt, image_path=full, created_at=created_at, extra=extra, ) ) records.sort(key=lambda r: r.created_at, reverse=True) return records def save_image_to_cache( prompt: str, image_bytes: bytes, image_ext: str = ".png", extra: Optional[dict] = None, base_dir: Optional[str] = None, ) -> AIImageRecord: """把生成的图片字节写入缓存,返回记录。""" if not image_ext.startswith("."): image_ext = "." + image_ext if image_ext.lower() not in _SUPPORTED_IMG_EXT: image_ext = ".png" cache_dir = get_cache_dir(base_dir) rec_id = _make_id(prompt) image_path = os.path.join(cache_dir, f"{rec_id}{image_ext}") with open(image_path, "wb") as f: f.write(image_bytes) record = AIImageRecord( id=rec_id, prompt=prompt, image_path=image_path, created_at=_dt.datetime.now().isoformat(timespec="seconds"), extra=extra, ) try: with open(_meta_path_for(image_path), "w", encoding="utf-8") as f: f.write(record.to_json()) except Exception: pass return record def delete_record(record: AIImageRecord) -> bool: """删除一条缓存记录(图片 + 侧车)。返回是否成功。""" ok = True for p in (record.image_path, _meta_path_for(record.image_path)): try: if os.path.isfile(p): os.remove(p) except Exception: ok = False return ok def export_record(record: AIImageRecord, dest_path: str) -> None: """把缓存中的图片另存到 ``dest_path``。""" shutil.copyfile(record.image_path, dest_path) # ---------- 异步请求 ---------- def request_image_async( prompt: str, on_success: Callable[[AIImageRecord], None], on_error: Callable[[Exception], None], base_dir: Optional[str] = None, ) -> threading.Thread: """在后台线程请求 API → 写入缓存 → 回调。 ``on_success`` / ``on_error`` 会在 **工作线程** 中被调用;UI 侧若需 切回主线程,请在回调内部自行用 ``root.after(0, ...)``。 """ def _worker(): try: if _api_caller is None: raise RuntimeError("AI 图片 API 尚未接入,请调用 set_api_caller 注入") image_bytes, image_ext, extra = _normalize_api_result(_api_caller(prompt)) record = save_image_to_cache( prompt=prompt, image_bytes=image_bytes, image_ext=image_ext, extra=extra, base_dir=base_dir, ) on_success(record) except Exception as exc: on_error(exc) t = threading.Thread(target=_worker, daemon=True) t.start() return t def _normalize_api_result(result): """允许 API 返回 ``bytes`` 或 ``(bytes, ext)`` 或 ``(bytes, ext, extra)``。""" if isinstance(result, (bytes, bytearray)): return bytes(result), ".png", None if isinstance(result, tuple): if len(result) == 2: return bytes(result[0]), str(result[1]), None if len(result) == 3: return bytes(result[0]), str(result[1]), result[2] raise ValueError("API 返回格式不支持,需为 bytes 或 (bytes, ext[, extra])")