Files
pqAutomationApp/app/services/ai_image.py
2026-04-23 10:07:41 +08:00

364 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI 图片生成服务:后端请求 + 本地缓存管理。
API 端点待接入,当前通过 ``set_api_caller`` 注入具体实现。
缓存目录:``settings/ai_image_cache/``,每张图片有同名的 ``.json`` 侧车记录。
"""
from __future__ import annotations
import datetime as _dt
import hashlib
import json
import mimetypes
import os
import shutil
import threading
from dataclasses import dataclass, asdict
from typing import Callable, List, Optional
from urllib.parse import urlparse
from urllib.request import Request, urlopen
# ---------- 常量 ----------
_CACHE_DIRNAME = os.path.join("settings", "ai_image_cache")
_META_SUFFIX = ".json"
_SUPPORTED_IMG_EXT = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
# ---------- 数据结构 ----------
@dataclass
class AIImageRecord:
"""一条缓存记录。"""
id: str
prompt: str
image_path: str
created_at: str # ISO8601
extra: Optional[dict] = None
def to_json(self) -> str:
return json.dumps(asdict(self), ensure_ascii=False, indent=2)
# ---------- API 注入 ----------
# 调用签名: ``fn(prompt: str) -> (image_bytes: bytes, image_ext: str, extra: dict|None)``
# ``image_ext`` 例如 ``".png"````extra`` 可为 None。
_ApiCaller = Callable[[str], tuple]
_api_caller: Optional[_ApiCaller] = None
def set_api_caller(fn: Optional[_ApiCaller]) -> None:
"""注入真实的后端 API 调用函数。在 API 就绪前可保持为 None。"""
global _api_caller
_api_caller = fn
def has_api() -> bool:
return _api_caller is not None
# ---------- 缓存路径工具 ----------
def get_cache_dir(base_dir: Optional[str] = None) -> str:
"""返回缓存目录,如不存在则创建。``base_dir`` 默认使用当前工作目录。"""
root = base_dir if base_dir else os.getcwd()
path = os.path.join(root, _CACHE_DIRNAME)
os.makedirs(path, exist_ok=True)
return path
def _make_id(prompt: str) -> str:
stamp = _dt.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
digest = hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8]
return f"{stamp}_{digest}"
def _meta_path_for(image_path: str) -> str:
return os.path.splitext(image_path)[0] + _META_SUFFIX
# ---------- 读写 ----------
def list_records(base_dir: Optional[str] = None) -> List[AIImageRecord]:
"""列出缓存目录下的所有记录,按创建时间倒序(最新在前)。"""
cache_dir = get_cache_dir(base_dir)
records: List[AIImageRecord] = []
for name in os.listdir(cache_dir):
full = os.path.join(cache_dir, name)
if not (os.path.isfile(full) and name.lower().endswith(_SUPPORTED_IMG_EXT)):
continue
meta_path = _meta_path_for(full)
prompt = ""
created_at = ""
extra = None
rec_id = os.path.splitext(name)[0]
if os.path.isfile(meta_path):
try:
with open(meta_path, "r", encoding="utf-8") as f:
data = json.load(f)
prompt = data.get("prompt", "")
created_at = data.get("created_at", "")
extra = data.get("extra")
rec_id = data.get("id", rec_id)
except Exception:
pass
if not created_at:
# fallback 到文件 mtime
try:
mtime = os.path.getmtime(full)
created_at = _dt.datetime.fromtimestamp(mtime).isoformat()
except Exception:
created_at = ""
records.append(
AIImageRecord(
id=rec_id,
prompt=prompt,
image_path=full,
created_at=created_at,
extra=extra,
)
)
if not records:
seeded = _seed_placeholder_record(cache_dir)
if seeded is not None:
records.append(seeded)
records.sort(key=lambda r: r.created_at, reverse=True)
return records
def _seed_placeholder_record(cache_dir: str) -> Optional[AIImageRecord]:
"""当缓存为空时,写入一张本地占位图,便于前端联调。"""
try:
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
src = os.path.join(repo_root, "assets", "entry_1.png")
if not os.path.isfile(src):
return None
rec_id = f"{_dt.datetime.now().strftime('%Y%m%d_%H%M%S')}_placeholder"
image_path = os.path.join(cache_dir, f"{rec_id}.png")
shutil.copyfile(src, image_path)
record = AIImageRecord(
id=rec_id,
prompt="本地测试占位图(后端未接入)",
image_path=image_path,
created_at=_dt.datetime.now().isoformat(timespec="seconds"),
extra={"source": "local-placeholder"},
)
with open(_meta_path_for(image_path), "w", encoding="utf-8") as f:
f.write(record.to_json())
return record
except Exception:
return None
def save_image_to_cache(
prompt: str,
image_bytes: bytes,
image_ext: str = ".png",
extra: Optional[dict] = None,
base_dir: Optional[str] = None,
) -> AIImageRecord:
"""把生成的图片字节写入缓存,返回记录。"""
if not image_ext.startswith("."):
image_ext = "." + image_ext
if image_ext.lower() not in _SUPPORTED_IMG_EXT:
image_ext = ".png"
cache_dir = get_cache_dir(base_dir)
rec_id = _make_id(prompt)
image_path = os.path.join(cache_dir, f"{rec_id}{image_ext}")
with open(image_path, "wb") as f:
f.write(image_bytes)
record = AIImageRecord(
id=rec_id,
prompt=prompt,
image_path=image_path,
created_at=_dt.datetime.now().isoformat(timespec="seconds"),
extra=extra,
)
try:
with open(_meta_path_for(image_path), "w", encoding="utf-8") as f:
f.write(record.to_json())
except Exception:
pass
return record
def import_image_from_url(
image_url: str,
prompt: Optional[str] = None,
extra: Optional[dict] = None,
base_dir: Optional[str] = None,
timeout: float = 20.0,
) -> AIImageRecord:
"""下载远程图片并写入缓存。"""
url = (image_url or "").strip()
if not url:
raise ValueError("图片地址不能为空")
request = Request(
url,
headers={
"User-Agent": "pqAutomationApp/1.0",
"Accept": "image/*,*/*;q=0.8",
},
)
with urlopen(request, timeout=timeout) as response:
image_bytes = response.read()
if not image_bytes:
raise ValueError("下载结果为空")
image_ext = _guess_image_ext(
image_url=url,
content_type=response.headers.get_content_type(),
)
merged_extra = dict(extra or {})
merged_extra.update(
{
"source": "remote-url",
"source_url": url,
"content_type": response.headers.get_content_type(),
}
)
record_prompt = (prompt or _default_prompt_from_url(url)).strip()
return save_image_to_cache(
prompt=record_prompt,
image_bytes=image_bytes,
image_ext=image_ext,
extra=merged_extra,
base_dir=base_dir,
)
def delete_record(record: AIImageRecord) -> bool:
"""删除一条缓存记录(图片 + 侧车)。返回是否成功。"""
ok = True
for p in (record.image_path, _meta_path_for(record.image_path)):
try:
if os.path.isfile(p):
os.remove(p)
except Exception:
ok = False
return ok
def export_record(record: AIImageRecord, dest_path: str) -> None:
"""把缓存中的图片另存到 ``dest_path``。"""
shutil.copyfile(record.image_path, dest_path)
# ---------- 异步请求 ----------
def request_image_async(
prompt: str,
on_success: Callable[[AIImageRecord], None],
on_error: Callable[[Exception], None],
base_dir: Optional[str] = None,
) -> threading.Thread:
"""在后台线程请求 API → 写入缓存 → 回调。
``on_success`` / ``on_error`` 会在 **工作线程** 中被调用UI 侧若需
切回主线程,请在回调内部自行用 ``root.after(0, ...)``。
"""
def _worker():
try:
if _api_caller is None:
raise RuntimeError("AI 图片 API 尚未接入,请调用 set_api_caller 注入")
image_bytes, image_ext, extra = _normalize_api_result(_api_caller(prompt))
record = save_image_to_cache(
prompt=prompt,
image_bytes=image_bytes,
image_ext=image_ext,
extra=extra,
base_dir=base_dir,
)
on_success(record)
except Exception as exc:
on_error(exc)
t = threading.Thread(target=_worker, daemon=True)
t.start()
return t
def import_image_from_url_async(
image_url: str,
on_success: Callable[[AIImageRecord], None],
on_error: Callable[[Exception], None],
prompt: Optional[str] = None,
extra: Optional[dict] = None,
base_dir: Optional[str] = None,
timeout: float = 20.0,
) -> threading.Thread:
"""在后台线程下载远程图片并写入缓存"""
def _worker():
try:
record = import_image_from_url(
image_url=image_url,
prompt=prompt,
extra=extra,
base_dir=base_dir,
timeout=timeout,
)
on_success(record)
except Exception as exc:
on_error(exc)
t = threading.Thread(target=_worker, daemon=True)
t.start()
return t
def _normalize_api_result(result):
"""允许 API 返回 ``bytes`` 或 ``(bytes, ext)`` 或 ``(bytes, ext, extra)``。"""
if isinstance(result, (bytes, bytearray)):
return bytes(result), ".png", None
if isinstance(result, tuple):
if len(result) == 2:
return bytes(result[0]), str(result[1]), None
if len(result) == 3:
return bytes(result[0]), str(result[1]), result[2]
raise ValueError("API 返回格式不支持,需为 bytes 或 (bytes, ext[, extra])")
def is_remote_image_url(value: str) -> bool:
"""判断输入是否为 http/https 图片地址。"""
url = (value or "").strip()
if not url:
return False
parsed = urlparse(url)
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
def _guess_image_ext(image_url: str, content_type: Optional[str]) -> str:
if content_type:
guessed = mimetypes.guess_extension(content_type)
if guessed == ".jpe":
guessed = ".jpg"
if guessed and guessed.lower() in _SUPPORTED_IMG_EXT:
return guessed.lower()
url_path = urlparse(image_url).path
ext = os.path.splitext(url_path)[1].lower()
if ext in _SUPPORTED_IMG_EXT:
return ext
return ".png"
def _default_prompt_from_url(image_url: str) -> str:
path = urlparse(image_url).path
name = os.path.splitext(os.path.basename(path))[0].strip()
return name or "远程导入图片"