Files
pqAutomationApp/app/services/ai_image.py
2026-04-21 14:06:48 +08:00

225 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI 图片生成服务:后端请求 + 本地缓存管理。
API 端点待接入,当前通过 ``set_api_caller`` 注入具体实现。
缓存目录:``settings/ai_image_cache/``,每张图片有同名的 ``.json`` 侧车记录。
"""
from __future__ import annotations
import datetime as _dt
import hashlib
import json
import os
import shutil
import threading
from dataclasses import dataclass, asdict
from typing import Callable, List, Optional
# ---------- 常量 ----------
_CACHE_DIRNAME = os.path.join("settings", "ai_image_cache")
_META_SUFFIX = ".json"
_SUPPORTED_IMG_EXT = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
# ---------- 数据结构 ----------
@dataclass
class AIImageRecord:
"""一条缓存记录。"""
id: str
prompt: str
image_path: str
created_at: str # ISO8601
extra: Optional[dict] = None
def to_json(self) -> str:
return json.dumps(asdict(self), ensure_ascii=False, indent=2)
# ---------- API 注入 ----------
# 调用签名: ``fn(prompt: str) -> (image_bytes: bytes, image_ext: str, extra: dict|None)``
# ``image_ext`` 例如 ``".png"````extra`` 可为 None。
_ApiCaller = Callable[[str], tuple]
_api_caller: Optional[_ApiCaller] = None
def set_api_caller(fn: Optional[_ApiCaller]) -> None:
"""注入真实的后端 API 调用函数。在 API 就绪前可保持为 None。"""
global _api_caller
_api_caller = fn
def has_api() -> bool:
return _api_caller is not None
# ---------- 缓存路径工具 ----------
def get_cache_dir(base_dir: Optional[str] = None) -> str:
"""返回缓存目录,如不存在则创建。``base_dir`` 默认使用当前工作目录。"""
root = base_dir if base_dir else os.getcwd()
path = os.path.join(root, _CACHE_DIRNAME)
os.makedirs(path, exist_ok=True)
return path
def _make_id(prompt: str) -> str:
stamp = _dt.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
digest = hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8]
return f"{stamp}_{digest}"
def _meta_path_for(image_path: str) -> str:
return os.path.splitext(image_path)[0] + _META_SUFFIX
# ---------- 读写 ----------
def list_records(base_dir: Optional[str] = None) -> List[AIImageRecord]:
"""列出缓存目录下的所有记录,按创建时间倒序(最新在前)。"""
cache_dir = get_cache_dir(base_dir)
records: List[AIImageRecord] = []
for name in os.listdir(cache_dir):
full = os.path.join(cache_dir, name)
if not (os.path.isfile(full) and name.lower().endswith(_SUPPORTED_IMG_EXT)):
continue
meta_path = _meta_path_for(full)
prompt = ""
created_at = ""
extra = None
rec_id = os.path.splitext(name)[0]
if os.path.isfile(meta_path):
try:
with open(meta_path, "r", encoding="utf-8") as f:
data = json.load(f)
prompt = data.get("prompt", "")
created_at = data.get("created_at", "")
extra = data.get("extra")
rec_id = data.get("id", rec_id)
except Exception:
pass
if not created_at:
# fallback 到文件 mtime
try:
mtime = os.path.getmtime(full)
created_at = _dt.datetime.fromtimestamp(mtime).isoformat()
except Exception:
created_at = ""
records.append(
AIImageRecord(
id=rec_id,
prompt=prompt,
image_path=full,
created_at=created_at,
extra=extra,
)
)
records.sort(key=lambda r: r.created_at, reverse=True)
return records
def save_image_to_cache(
prompt: str,
image_bytes: bytes,
image_ext: str = ".png",
extra: Optional[dict] = None,
base_dir: Optional[str] = None,
) -> AIImageRecord:
"""把生成的图片字节写入缓存,返回记录。"""
if not image_ext.startswith("."):
image_ext = "." + image_ext
if image_ext.lower() not in _SUPPORTED_IMG_EXT:
image_ext = ".png"
cache_dir = get_cache_dir(base_dir)
rec_id = _make_id(prompt)
image_path = os.path.join(cache_dir, f"{rec_id}{image_ext}")
with open(image_path, "wb") as f:
f.write(image_bytes)
record = AIImageRecord(
id=rec_id,
prompt=prompt,
image_path=image_path,
created_at=_dt.datetime.now().isoformat(timespec="seconds"),
extra=extra,
)
try:
with open(_meta_path_for(image_path), "w", encoding="utf-8") as f:
f.write(record.to_json())
except Exception:
pass
return record
def delete_record(record: AIImageRecord) -> bool:
"""删除一条缓存记录(图片 + 侧车)。返回是否成功。"""
ok = True
for p in (record.image_path, _meta_path_for(record.image_path)):
try:
if os.path.isfile(p):
os.remove(p)
except Exception:
ok = False
return ok
def export_record(record: AIImageRecord, dest_path: str) -> None:
"""把缓存中的图片另存到 ``dest_path``。"""
shutil.copyfile(record.image_path, dest_path)
# ---------- 异步请求 ----------
def request_image_async(
prompt: str,
on_success: Callable[[AIImageRecord], None],
on_error: Callable[[Exception], None],
base_dir: Optional[str] = None,
) -> threading.Thread:
"""在后台线程请求 API → 写入缓存 → 回调。
``on_success`` / ``on_error`` 会在 **工作线程** 中被调用UI 侧若需
切回主线程,请在回调内部自行用 ``root.after(0, ...)``。
"""
def _worker():
try:
if _api_caller is None:
raise RuntimeError("AI 图片 API 尚未接入,请调用 set_api_caller 注入")
image_bytes, image_ext, extra = _normalize_api_result(_api_caller(prompt))
record = save_image_to_cache(
prompt=prompt,
image_bytes=image_bytes,
image_ext=image_ext,
extra=extra,
base_dir=base_dir,
)
on_success(record)
except Exception as exc:
on_error(exc)
t = threading.Thread(target=_worker, daemon=True)
t.start()
return t
def _normalize_api_result(result):
"""允许 API 返回 ``bytes`` 或 ``(bytes, ext)`` 或 ``(bytes, ext, extra)``。"""
if isinstance(result, (bytes, bytearray)):
return bytes(result), ".png", None
if isinstance(result, tuple):
if len(result) == 2:
return bytes(result[0]), str(result[1]), None
if len(result) == 3:
return bytes(result[0]), str(result[1]), result[2]
raise ValueError("API 返回格式不支持,需为 bytes 或 (bytes, ext[, extra])")