添加AI图像生成接口、修改相关界面

This commit is contained in:
xinzhu.yin
2026-04-29 15:25:58 +08:00
parent e243fc4e94
commit 377bba2a0b
4 changed files with 443 additions and 134 deletions

View File

@@ -1,6 +1,10 @@
"""AI 图片生成服务:后端请求 + 本地缓存管理。
API 端点待接入,当前通过 ``set_api_caller`` 注入具体实现。
后端接口(测试环境):
POST {API_BASE_URL}{API_PATH}
body: {"user_message": str, "session_id": str}
resp: {"code": 200, "message": "", "data": {"imageUrl": "..."}}
缓存目录:``settings/ai_image_cache/``,每张图片有同名的 ``.json`` 侧车记录。
"""
@@ -9,15 +13,24 @@ from __future__ import annotations
import datetime as _dt
import hashlib
import json
import logging
import mimetypes
import os
import shutil
import threading
import time
import uuid
from io import BytesIO
from dataclasses import dataclass, asdict
from typing import Callable, List, Optional
from urllib.parse import urlparse
from urllib.request import Request, urlopen
from PIL import Image
logger = logging.getLogger(__name__)
# ---------- 常量 ----------
@@ -25,41 +38,166 @@ _CACHE_DIRNAME = os.path.join("settings", "ai_image_cache")
_META_SUFFIX = ".json"
_SUPPORTED_IMG_EXT = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
# 测试环境后端
API_BASE_URL = "http://10.201.44.70:9018/ai-agent/"
API_PATH = "api/v1/pqtest/generate"
API_TIMEOUT = 90.0 # 后端最长 60s留余量
# 进程级会话 id多轮对话需保持一致可通过 ``reset_session`` 重置
_session_id: str = str(uuid.uuid4())
_session_lock = threading.Lock()
def get_session_id() -> str:
with _session_lock:
return _session_id
def set_session_id(session_id: str) -> str:
"""切换到指定会话。空值会抛错。"""
global _session_id
sid = (session_id or "").strip()
if not sid:
raise ValueError("session_id 不能为空")
with _session_lock:
old = _session_id
_session_id = sid
logger.info("[AIImage] 会话切换 %s -> %s", _mask_sid(old), _mask_sid(_session_id))
return _session_id
def reset_session() -> str:
"""开启新一轮会话,返回新的 session_id。"""
global _session_id
with _session_lock:
old = _session_id
_session_id = str(uuid.uuid4())
logger.info("[AIImage] 会话切换 %s -> %s", _mask_sid(old), _mask_sid(_session_id))
return _session_id
def _mask_sid(sid: str) -> str:
"""日志安全展示:仅保留前 8 位。"""
if not sid:
return "(none)"
return f"{sid[:8]}"
def _truncate(text: str, n: int = 80) -> str:
s = (text or "").replace("\n", " ").strip()
return s if len(s) <= n else s[:n] + ""
# ---------- 数据结构 ----------
@dataclass
class AIImageRecord:
"""一条缓存记录。"""
"""一条缓存记录。
字段说明:
- ``id``: 唯一 id等同于磁盘文件名不含扩展名格式 ``{时间戳}_{md5前8}``。
- ``prompt``: 用户原始输入(完整保留,用于回溯/调试,不应被改写)。
- ``title``: 用户自定义展示标题重命名时写入UI 优先使用,留空则回退 prompt 第一行截断。
- ``image_path``: 图片在缓存目录中的绝对路径。
- ``created_at``: ISO8601 时间字符串。
- ``extra``: 其它元数据,至少包含 ``source`` 与 ``session_id``(标识属于哪一轮对话)。
"""
id: str
prompt: str
image_path: str
created_at: str # ISO8601
extra: Optional[dict] = None
title: Optional[str] = None
def to_json(self) -> str:
return json.dumps(asdict(self), ensure_ascii=False, indent=2)
@property
def display_name(self) -> str:
"""UI 展示名title 优先,否则回退 prompt 第一行。"""
if self.title:
return self.title.strip()
first = (self.prompt or "").strip().splitlines()[0] if self.prompt else ""
return first or "(未命名)"
# ---------- API 注入 ----------
# 调用签名: ``fn(prompt: str) -> (image_bytes: bytes, image_ext: str, extra: dict|None)``
# ``image_ext`` 例如 ``".png"````extra`` 可为 None。
_ApiCaller = Callable[[str], tuple]
_api_caller: Optional[_ApiCaller] = None
@property
def session_id(self) -> str:
if isinstance(self.extra, dict):
return str(self.extra.get("session_id") or "")
return ""
def set_api_caller(fn: Optional[_ApiCaller]) -> None:
"""注入真实的后端 API 调用函数。在 API 就绪前可保持为 None。"""
global _api_caller
_api_caller = fn
# ---------- 后端 API ----------
def has_api() -> bool:
return _api_caller is not None
def _api_endpoint() -> str:
base = API_BASE_URL if API_BASE_URL.endswith("/") else API_BASE_URL + "/"
return base + API_PATH.lstrip("/")
def _call_pqtest_generate(user_message: str, session_id: str, timeout: float = API_TIMEOUT) -> str:
"""调用后端 ``api/v1/pqtest/generate``,返回 imageUrl。失败抛异常。"""
payload = json.dumps(
{"user_message": user_message,
"session_id": session_id},
ensure_ascii=False,
).encode("utf-8")
endpoint = _api_endpoint()
logger.info(
"[AIImage] 请求生成 sid=%s prompt_len=%d prompt=%r",
_mask_sid(session_id), len(user_message or ""), _truncate(user_message),
)
logger.debug("[AIImage] POST %s timeout=%.1fs", endpoint, timeout)
request = Request(
endpoint,
data=payload,
method="POST",
headers={
"Content-Type": "application/json; charset=utf-8",
"Accept": "application/json",
"User-Agent": "pqAutomationApp/1.0",
},
)
t0 = time.monotonic()
try:
with urlopen(request, timeout=timeout) as response:
raw = response.read()
http_status = response.status
except Exception as exc:
elapsed = time.monotonic() - t0
logger.error(
"[AIImage] 请求异常 sid=%s elapsed=%.2fs %s: %s",
_mask_sid(session_id), elapsed, type(exc).__name__, exc,
)
raise
elapsed = time.monotonic() - t0
logger.debug(
"[AIImage] HTTP %s 收到 %d bytes elapsed=%.2fs",
http_status, len(raw), elapsed,
)
try:
result = json.loads(raw.decode("utf-8"))
except Exception as exc:
logger.error("[AIImage] 响应解析失败 sid=%s raw=%r", _mask_sid(session_id), raw[:200])
raise RuntimeError(f"AI 接口返回非 JSON{raw[:200]!r}") from exc
code = result.get("code")
message = result.get("message") or ""
data = result.get("data") or {}
image_url = (data.get("imageUrl") or "").strip()
if code != 200 or not image_url:
logger.warning(
"[AIImage] 接口失败 sid=%s code=%s msg=%r",
_mask_sid(session_id), code, message,
)
raise RuntimeError(f"AI 接口失败 code={code} msg={message or '生成失败'}")
logger.info(
"[AIImage] 生成成功 sid=%s elapsed=%.2fs url=%s",
_mask_sid(session_id), elapsed, image_url,
)
return image_url
# ---------- 缓存路径工具 ----------
@@ -83,6 +221,28 @@ def _meta_path_for(image_path: str) -> str:
return os.path.splitext(image_path)[0] + _META_SUFFIX
def _sanitize_image_bytes(image_bytes: bytes, image_ext: str) -> bytes:
"""规范化图片字节,尽量去掉已知有问题的 PNG ICC profile。"""
ext = (image_ext or "").lower()
if ext not in {".png", ".jpg", ".jpeg", ".bmp", ".webp"}:
return image_bytes
try:
with Image.open(BytesIO(image_bytes)) as img:
img.load()
normalized = img.copy()
output = BytesIO()
save_kwargs = {}
if ext == ".png":
save_kwargs["icc_profile"] = None
normalized.save(output, format=normalized.format or ext.lstrip(".").upper(), **save_kwargs)
result = output.getvalue()
if result:
return result
except Exception as exc:
logger.warning("[AIImage] 图片规范化失败 ext=%s %s: %s", ext, type(exc).__name__, exc)
return image_bytes
# ---------- 读写 ----------
@@ -98,6 +258,7 @@ def list_records(base_dir: Optional[str] = None) -> List[AIImageRecord]:
prompt = ""
created_at = ""
extra = None
title = None
rec_id = os.path.splitext(name)[0]
if os.path.isfile(meta_path):
try:
@@ -106,6 +267,7 @@ def list_records(base_dir: Optional[str] = None) -> List[AIImageRecord]:
prompt = data.get("prompt", "")
created_at = data.get("created_at", "")
extra = data.get("extra")
title = data.get("title")
rec_id = data.get("id", rec_id)
except Exception:
pass
@@ -123,54 +285,27 @@ def list_records(base_dir: Optional[str] = None) -> List[AIImageRecord]:
image_path=full,
created_at=created_at,
extra=extra,
title=title,
)
)
if not records:
seeded = _seed_placeholder_record(cache_dir)
if seeded is not None:
records.append(seeded)
records.sort(key=lambda r: r.created_at, reverse=True)
return records
def _seed_placeholder_record(cache_dir: str) -> Optional[AIImageRecord]:
"""当缓存为空时,写入一张本地占位图,便于前端联调。"""
try:
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
src = os.path.join(repo_root, "assets", "entry_1.png")
if not os.path.isfile(src):
return None
rec_id = f"{_dt.datetime.now().strftime('%Y%m%d_%H%M%S')}_placeholder"
image_path = os.path.join(cache_dir, f"{rec_id}.png")
shutil.copyfile(src, image_path)
record = AIImageRecord(
id=rec_id,
prompt="本地测试占位图(后端未接入)",
image_path=image_path,
created_at=_dt.datetime.now().isoformat(timespec="seconds"),
extra={"source": "local-placeholder"},
)
with open(_meta_path_for(image_path), "w", encoding="utf-8") as f:
f.write(record.to_json())
return record
except Exception:
return None
def save_image_to_cache(
prompt: str,
image_bytes: bytes,
image_ext: str = ".png",
extra: Optional[dict] = None,
base_dir: Optional[str] = None,
title: Optional[str] = None,
) -> AIImageRecord:
"""把生成的图片字节写入缓存,返回记录。"""
if not image_ext.startswith("."):
image_ext = "." + image_ext
if image_ext.lower() not in _SUPPORTED_IMG_EXT:
image_ext = ".png"
image_bytes = _sanitize_image_bytes(image_bytes, image_ext)
cache_dir = get_cache_dir(base_dir)
rec_id = _make_id(prompt)
image_path = os.path.join(cache_dir, f"{rec_id}{image_ext}")
@@ -183,6 +318,7 @@ def save_image_to_cache(
image_path=image_path,
created_at=_dt.datetime.now().isoformat(timespec="seconds"),
extra=extra,
title=title,
)
try:
with open(_meta_path_for(image_path), "w", encoding="utf-8") as f:
@@ -256,6 +392,53 @@ def export_record(record: AIImageRecord, dest_path: str) -> None:
shutil.copyfile(record.image_path, dest_path)
def update_record_title(record: AIImageRecord, new_title: Optional[str]) -> bool:
"""更新记录的展示标题并写回侧车 JSON。空串/None 视为清除标题。"""
title = (new_title or "").strip() or None
meta_path = _meta_path_for(record.image_path)
try:
data: dict = {}
if os.path.isfile(meta_path):
with open(meta_path, "r", encoding="utf-8") as f:
data = json.load(f) or {}
if title is None:
data.pop("title", None)
else:
data["title"] = title
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except Exception:
return False
record.title = title
return True
def group_records_by_session(records: List[AIImageRecord]) -> List[dict]:
"""按 ``session_id`` 分组。
返回元素:``{"session_id", "records", "started_at", "latest_at"}``。
会话按"最近使用时间"倒序,组内记录按时间倒序。
没有 session_id 的记录归到空串 ``""`` 组。
"""
buckets: dict = {}
for rec in records:
buckets.setdefault(rec.session_id, []).append(rec)
sessions = []
for sid, recs in buckets.items():
recs.sort(key=lambda r: r.created_at, reverse=True)
started_at = min((r.created_at for r in recs if r.created_at), default="")
sessions.append(
{
"session_id": sid,
"records": recs,
"started_at": started_at,
"latest_at": recs[0].created_at if recs else "",
}
)
sessions.sort(key=lambda s: s["latest_at"], reverse=True)
return sessions
# ---------- 异步请求 ----------
@@ -264,27 +447,37 @@ def request_image_async(
on_success: Callable[[AIImageRecord], None],
on_error: Callable[[Exception], None],
base_dir: Optional[str] = None,
session_id: Optional[str] = None,
) -> threading.Thread:
"""在后台线程请求 API → 写入缓存 → 回调。
"""在后台线程调用后端 API → 下载图片 → 写入缓存 → 回调。
``on_success`` / ``on_error`` 会在 **工作线程** 中被调用UI 侧若需
切回主线程,请在回调内部自行用 ``root.after(0, ...)``。
``session_id`` 留空则使用进程级会话 id保证多轮对话上下文
"""
sid = session_id or get_session_id()
def _worker():
try:
if _api_caller is None:
raise RuntimeError("AI 图片 API 尚未接入,请调用 set_api_caller 注入")
image_bytes, image_ext, extra = _normalize_api_result(_api_caller(prompt))
record = save_image_to_cache(
image_url = _call_pqtest_generate(prompt, sid)
record = import_image_from_url(
image_url=image_url,
prompt=prompt,
image_bytes=image_bytes,
image_ext=image_ext,
extra=extra,
extra={"source": "ai-api", "session_id": sid},
base_dir=base_dir,
)
logger.info(
"[AIImage] 已写入缓存 sid=%s id=%s path=%s",
_mask_sid(sid), record.id, record.image_path,
)
on_success(record)
except Exception as exc:
logger.error(
"[AIImage] 生成流程失败 sid=%s %s: %s",
_mask_sid(sid), type(exc).__name__, exc,
)
on_error(exc)
t = threading.Thread(target=_worker, daemon=True)
@@ -321,18 +514,6 @@ def import_image_from_url_async(
return t
def _normalize_api_result(result):
"""允许 API 返回 ``bytes`` 或 ``(bytes, ext)`` 或 ``(bytes, ext, extra)``。"""
if isinstance(result, (bytes, bytearray)):
return bytes(result), ".png", None
if isinstance(result, tuple):
if len(result) == 2:
return bytes(result[0]), str(result[1]), None
if len(result) == 3:
return bytes(result[0]), str(result[1]), result[2]
raise ValueError("API 返回格式不支持,需为 bytes 或 (bytes, ext[, extra])")
def is_remote_image_url(value: str) -> bool:
"""判断输入是否为 http/https 图片地址。"""
url = (value or "").strip()