添加AI图像生成接口、修改相关界面
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
"""AI 图片生成服务:后端请求 + 本地缓存管理。
|
||||
|
||||
API 端点待接入,当前通过 ``set_api_caller`` 注入具体实现。
|
||||
后端接口(测试环境):
|
||||
POST {API_BASE_URL}{API_PATH}
|
||||
body: {"user_message": str, "session_id": str}
|
||||
resp: {"code": 200, "message": "", "data": {"imageUrl": "..."}}
|
||||
|
||||
缓存目录:``settings/ai_image_cache/``,每张图片有同名的 ``.json`` 侧车记录。
|
||||
"""
|
||||
|
||||
@@ -9,15 +13,24 @@ from __future__ import annotations
|
||||
import datetime as _dt
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Callable, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------- 常量 ----------
|
||||
|
||||
@@ -25,41 +38,166 @@ _CACHE_DIRNAME = os.path.join("settings", "ai_image_cache")
|
||||
_META_SUFFIX = ".json"
|
||||
_SUPPORTED_IMG_EXT = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
|
||||
|
||||
# 测试环境后端
|
||||
API_BASE_URL = "http://10.201.44.70:9018/ai-agent/"
|
||||
API_PATH = "api/v1/pqtest/generate"
|
||||
API_TIMEOUT = 90.0 # 后端最长 60s,留余量
|
||||
|
||||
# 进程级会话 id(多轮对话需保持一致),可通过 ``reset_session`` 重置
|
||||
_session_id: str = str(uuid.uuid4())
|
||||
_session_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_session_id() -> str:
|
||||
with _session_lock:
|
||||
return _session_id
|
||||
|
||||
|
||||
def set_session_id(session_id: str) -> str:
|
||||
"""切换到指定会话。空值会抛错。"""
|
||||
global _session_id
|
||||
sid = (session_id or "").strip()
|
||||
if not sid:
|
||||
raise ValueError("session_id 不能为空")
|
||||
with _session_lock:
|
||||
old = _session_id
|
||||
_session_id = sid
|
||||
logger.info("[AIImage] 会话切换 %s -> %s", _mask_sid(old), _mask_sid(_session_id))
|
||||
return _session_id
|
||||
|
||||
|
||||
def reset_session() -> str:
|
||||
"""开启新一轮会话,返回新的 session_id。"""
|
||||
global _session_id
|
||||
with _session_lock:
|
||||
old = _session_id
|
||||
_session_id = str(uuid.uuid4())
|
||||
logger.info("[AIImage] 会话切换 %s -> %s", _mask_sid(old), _mask_sid(_session_id))
|
||||
return _session_id
|
||||
|
||||
|
||||
def _mask_sid(sid: str) -> str:
|
||||
"""日志安全展示:仅保留前 8 位。"""
|
||||
if not sid:
|
||||
return "(none)"
|
||||
return f"{sid[:8]}…"
|
||||
|
||||
|
||||
def _truncate(text: str, n: int = 80) -> str:
|
||||
s = (text or "").replace("\n", " ").strip()
|
||||
return s if len(s) <= n else s[:n] + "…"
|
||||
|
||||
|
||||
# ---------- 数据结构 ----------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIImageRecord:
|
||||
"""一条缓存记录。"""
|
||||
"""一条缓存记录。
|
||||
|
||||
字段说明:
|
||||
- ``id``: 唯一 id,等同于磁盘文件名(不含扩展名),格式 ``{时间戳}_{md5前8}``。
|
||||
- ``prompt``: 用户原始输入(完整保留,用于回溯/调试,不应被改写)。
|
||||
- ``title``: 用户自定义展示标题(重命名时写入),UI 优先使用,留空则回退 prompt 第一行截断。
|
||||
- ``image_path``: 图片在缓存目录中的绝对路径。
|
||||
- ``created_at``: ISO8601 时间字符串。
|
||||
- ``extra``: 其它元数据,至少包含 ``source`` 与 ``session_id``(标识属于哪一轮对话)。
|
||||
"""
|
||||
|
||||
id: str
|
||||
prompt: str
|
||||
image_path: str
|
||||
created_at: str # ISO8601
|
||||
extra: Optional[dict] = None
|
||||
title: Optional[str] = None
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(asdict(self), ensure_ascii=False, indent=2)
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""UI 展示名:title 优先,否则回退 prompt 第一行。"""
|
||||
if self.title:
|
||||
return self.title.strip()
|
||||
first = (self.prompt or "").strip().splitlines()[0] if self.prompt else ""
|
||||
return first or "(未命名)"
|
||||
|
||||
# ---------- API 注入 ----------
|
||||
|
||||
# 调用签名: ``fn(prompt: str) -> (image_bytes: bytes, image_ext: str, extra: dict|None)``
|
||||
# ``image_ext`` 例如 ``".png"``;``extra`` 可为 None。
|
||||
_ApiCaller = Callable[[str], tuple]
|
||||
|
||||
_api_caller: Optional[_ApiCaller] = None
|
||||
@property
|
||||
def session_id(self) -> str:
|
||||
if isinstance(self.extra, dict):
|
||||
return str(self.extra.get("session_id") or "")
|
||||
return ""
|
||||
|
||||
|
||||
def set_api_caller(fn: Optional[_ApiCaller]) -> None:
|
||||
"""注入真实的后端 API 调用函数。在 API 就绪前可保持为 None。"""
|
||||
global _api_caller
|
||||
_api_caller = fn
|
||||
# ---------- 后端 API ----------
|
||||
|
||||
|
||||
def has_api() -> bool:
|
||||
return _api_caller is not None
|
||||
def _api_endpoint() -> str:
|
||||
base = API_BASE_URL if API_BASE_URL.endswith("/") else API_BASE_URL + "/"
|
||||
return base + API_PATH.lstrip("/")
|
||||
|
||||
|
||||
def _call_pqtest_generate(user_message: str, session_id: str, timeout: float = API_TIMEOUT) -> str:
|
||||
"""调用后端 ``api/v1/pqtest/generate``,返回 imageUrl。失败抛异常。"""
|
||||
payload = json.dumps(
|
||||
{"user_message": user_message,
|
||||
"session_id": session_id},
|
||||
ensure_ascii=False,
|
||||
).encode("utf-8")
|
||||
endpoint = _api_endpoint()
|
||||
logger.info(
|
||||
"[AIImage] 请求生成 sid=%s prompt_len=%d prompt=%r",
|
||||
_mask_sid(session_id), len(user_message or ""), _truncate(user_message),
|
||||
)
|
||||
logger.debug("[AIImage] POST %s timeout=%.1fs", endpoint, timeout)
|
||||
request = Request(
|
||||
endpoint,
|
||||
data=payload,
|
||||
method="POST",
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "pqAutomationApp/1.0",
|
||||
},
|
||||
)
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
with urlopen(request, timeout=timeout) as response:
|
||||
raw = response.read()
|
||||
http_status = response.status
|
||||
except Exception as exc:
|
||||
elapsed = time.monotonic() - t0
|
||||
logger.error(
|
||||
"[AIImage] 请求异常 sid=%s elapsed=%.2fs %s: %s",
|
||||
_mask_sid(session_id), elapsed, type(exc).__name__, exc,
|
||||
)
|
||||
raise
|
||||
elapsed = time.monotonic() - t0
|
||||
logger.debug(
|
||||
"[AIImage] HTTP %s 收到 %d bytes elapsed=%.2fs",
|
||||
http_status, len(raw), elapsed,
|
||||
)
|
||||
try:
|
||||
result = json.loads(raw.decode("utf-8"))
|
||||
except Exception as exc:
|
||||
logger.error("[AIImage] 响应解析失败 sid=%s raw=%r", _mask_sid(session_id), raw[:200])
|
||||
raise RuntimeError(f"AI 接口返回非 JSON:{raw[:200]!r}") from exc
|
||||
|
||||
code = result.get("code")
|
||||
message = result.get("message") or ""
|
||||
data = result.get("data") or {}
|
||||
image_url = (data.get("imageUrl") or "").strip()
|
||||
if code != 200 or not image_url:
|
||||
logger.warning(
|
||||
"[AIImage] 接口失败 sid=%s code=%s msg=%r",
|
||||
_mask_sid(session_id), code, message,
|
||||
)
|
||||
raise RuntimeError(f"AI 接口失败 code={code} msg={message or '生成失败'}")
|
||||
logger.info(
|
||||
"[AIImage] 生成成功 sid=%s elapsed=%.2fs url=%s",
|
||||
_mask_sid(session_id), elapsed, image_url,
|
||||
)
|
||||
return image_url
|
||||
|
||||
|
||||
# ---------- 缓存路径工具 ----------
|
||||
@@ -83,6 +221,28 @@ def _meta_path_for(image_path: str) -> str:
|
||||
return os.path.splitext(image_path)[0] + _META_SUFFIX
|
||||
|
||||
|
||||
def _sanitize_image_bytes(image_bytes: bytes, image_ext: str) -> bytes:
|
||||
"""规范化图片字节,尽量去掉已知有问题的 PNG ICC profile。"""
|
||||
ext = (image_ext or "").lower()
|
||||
if ext not in {".png", ".jpg", ".jpeg", ".bmp", ".webp"}:
|
||||
return image_bytes
|
||||
try:
|
||||
with Image.open(BytesIO(image_bytes)) as img:
|
||||
img.load()
|
||||
normalized = img.copy()
|
||||
output = BytesIO()
|
||||
save_kwargs = {}
|
||||
if ext == ".png":
|
||||
save_kwargs["icc_profile"] = None
|
||||
normalized.save(output, format=normalized.format or ext.lstrip(".").upper(), **save_kwargs)
|
||||
result = output.getvalue()
|
||||
if result:
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("[AIImage] 图片规范化失败 ext=%s %s: %s", ext, type(exc).__name__, exc)
|
||||
return image_bytes
|
||||
|
||||
|
||||
# ---------- 读写 ----------
|
||||
|
||||
|
||||
@@ -98,6 +258,7 @@ def list_records(base_dir: Optional[str] = None) -> List[AIImageRecord]:
|
||||
prompt = ""
|
||||
created_at = ""
|
||||
extra = None
|
||||
title = None
|
||||
rec_id = os.path.splitext(name)[0]
|
||||
if os.path.isfile(meta_path):
|
||||
try:
|
||||
@@ -106,6 +267,7 @@ def list_records(base_dir: Optional[str] = None) -> List[AIImageRecord]:
|
||||
prompt = data.get("prompt", "")
|
||||
created_at = data.get("created_at", "")
|
||||
extra = data.get("extra")
|
||||
title = data.get("title")
|
||||
rec_id = data.get("id", rec_id)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -123,54 +285,27 @@ def list_records(base_dir: Optional[str] = None) -> List[AIImageRecord]:
|
||||
image_path=full,
|
||||
created_at=created_at,
|
||||
extra=extra,
|
||||
title=title,
|
||||
)
|
||||
)
|
||||
if not records:
|
||||
seeded = _seed_placeholder_record(cache_dir)
|
||||
if seeded is not None:
|
||||
records.append(seeded)
|
||||
records.sort(key=lambda r: r.created_at, reverse=True)
|
||||
return records
|
||||
|
||||
|
||||
def _seed_placeholder_record(cache_dir: str) -> Optional[AIImageRecord]:
|
||||
"""当缓存为空时,写入一张本地占位图,便于前端联调。"""
|
||||
try:
|
||||
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
src = os.path.join(repo_root, "assets", "entry_1.png")
|
||||
if not os.path.isfile(src):
|
||||
return None
|
||||
|
||||
rec_id = f"{_dt.datetime.now().strftime('%Y%m%d_%H%M%S')}_placeholder"
|
||||
image_path = os.path.join(cache_dir, f"{rec_id}.png")
|
||||
shutil.copyfile(src, image_path)
|
||||
|
||||
record = AIImageRecord(
|
||||
id=rec_id,
|
||||
prompt="本地测试占位图(后端未接入)",
|
||||
image_path=image_path,
|
||||
created_at=_dt.datetime.now().isoformat(timespec="seconds"),
|
||||
extra={"source": "local-placeholder"},
|
||||
)
|
||||
with open(_meta_path_for(image_path), "w", encoding="utf-8") as f:
|
||||
f.write(record.to_json())
|
||||
return record
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def save_image_to_cache(
|
||||
prompt: str,
|
||||
image_bytes: bytes,
|
||||
image_ext: str = ".png",
|
||||
extra: Optional[dict] = None,
|
||||
base_dir: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
) -> AIImageRecord:
|
||||
"""把生成的图片字节写入缓存,返回记录。"""
|
||||
if not image_ext.startswith("."):
|
||||
image_ext = "." + image_ext
|
||||
if image_ext.lower() not in _SUPPORTED_IMG_EXT:
|
||||
image_ext = ".png"
|
||||
image_bytes = _sanitize_image_bytes(image_bytes, image_ext)
|
||||
cache_dir = get_cache_dir(base_dir)
|
||||
rec_id = _make_id(prompt)
|
||||
image_path = os.path.join(cache_dir, f"{rec_id}{image_ext}")
|
||||
@@ -183,6 +318,7 @@ def save_image_to_cache(
|
||||
image_path=image_path,
|
||||
created_at=_dt.datetime.now().isoformat(timespec="seconds"),
|
||||
extra=extra,
|
||||
title=title,
|
||||
)
|
||||
try:
|
||||
with open(_meta_path_for(image_path), "w", encoding="utf-8") as f:
|
||||
@@ -256,6 +392,53 @@ def export_record(record: AIImageRecord, dest_path: str) -> None:
|
||||
shutil.copyfile(record.image_path, dest_path)
|
||||
|
||||
|
||||
def update_record_title(record: AIImageRecord, new_title: Optional[str]) -> bool:
|
||||
"""更新记录的展示标题并写回侧车 JSON。空串/None 视为清除标题。"""
|
||||
title = (new_title or "").strip() or None
|
||||
meta_path = _meta_path_for(record.image_path)
|
||||
try:
|
||||
data: dict = {}
|
||||
if os.path.isfile(meta_path):
|
||||
with open(meta_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f) or {}
|
||||
if title is None:
|
||||
data.pop("title", None)
|
||||
else:
|
||||
data["title"] = title
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
return False
|
||||
record.title = title
|
||||
return True
|
||||
|
||||
|
||||
def group_records_by_session(records: List[AIImageRecord]) -> List[dict]:
|
||||
"""按 ``session_id`` 分组。
|
||||
|
||||
返回元素:``{"session_id", "records", "started_at", "latest_at"}``。
|
||||
会话按"最近使用时间"倒序,组内记录按时间倒序。
|
||||
没有 session_id 的记录归到空串 ``""`` 组。
|
||||
"""
|
||||
buckets: dict = {}
|
||||
for rec in records:
|
||||
buckets.setdefault(rec.session_id, []).append(rec)
|
||||
sessions = []
|
||||
for sid, recs in buckets.items():
|
||||
recs.sort(key=lambda r: r.created_at, reverse=True)
|
||||
started_at = min((r.created_at for r in recs if r.created_at), default="")
|
||||
sessions.append(
|
||||
{
|
||||
"session_id": sid,
|
||||
"records": recs,
|
||||
"started_at": started_at,
|
||||
"latest_at": recs[0].created_at if recs else "",
|
||||
}
|
||||
)
|
||||
sessions.sort(key=lambda s: s["latest_at"], reverse=True)
|
||||
return sessions
|
||||
|
||||
|
||||
# ---------- 异步请求 ----------
|
||||
|
||||
|
||||
@@ -264,27 +447,37 @@ def request_image_async(
|
||||
on_success: Callable[[AIImageRecord], None],
|
||||
on_error: Callable[[Exception], None],
|
||||
base_dir: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> threading.Thread:
|
||||
"""在后台线程请求 API → 写入缓存 → 回调。
|
||||
"""在后台线程调用后端 API → 下载图片 → 写入缓存 → 回调。
|
||||
|
||||
``on_success`` / ``on_error`` 会在 **工作线程** 中被调用;UI 侧若需
|
||||
切回主线程,请在回调内部自行用 ``root.after(0, ...)``。
|
||||
|
||||
``session_id`` 留空则使用进程级会话 id(保证多轮对话上下文)。
|
||||
"""
|
||||
|
||||
sid = session_id or get_session_id()
|
||||
|
||||
def _worker():
|
||||
try:
|
||||
if _api_caller is None:
|
||||
raise RuntimeError("AI 图片 API 尚未接入,请调用 set_api_caller 注入")
|
||||
image_bytes, image_ext, extra = _normalize_api_result(_api_caller(prompt))
|
||||
record = save_image_to_cache(
|
||||
image_url = _call_pqtest_generate(prompt, sid)
|
||||
record = import_image_from_url(
|
||||
image_url=image_url,
|
||||
prompt=prompt,
|
||||
image_bytes=image_bytes,
|
||||
image_ext=image_ext,
|
||||
extra=extra,
|
||||
extra={"source": "ai-api", "session_id": sid},
|
||||
base_dir=base_dir,
|
||||
)
|
||||
logger.info(
|
||||
"[AIImage] 已写入缓存 sid=%s id=%s path=%s",
|
||||
_mask_sid(sid), record.id, record.image_path,
|
||||
)
|
||||
on_success(record)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"[AIImage] 生成流程失败 sid=%s %s: %s",
|
||||
_mask_sid(sid), type(exc).__name__, exc,
|
||||
)
|
||||
on_error(exc)
|
||||
|
||||
t = threading.Thread(target=_worker, daemon=True)
|
||||
@@ -321,18 +514,6 @@ def import_image_from_url_async(
|
||||
return t
|
||||
|
||||
|
||||
def _normalize_api_result(result):
|
||||
"""允许 API 返回 ``bytes`` 或 ``(bytes, ext)`` 或 ``(bytes, ext, extra)``。"""
|
||||
if isinstance(result, (bytes, bytearray)):
|
||||
return bytes(result), ".png", None
|
||||
if isinstance(result, tuple):
|
||||
if len(result) == 2:
|
||||
return bytes(result[0]), str(result[1]), None
|
||||
if len(result) == 3:
|
||||
return bytes(result[0]), str(result[1]), result[2]
|
||||
raise ValueError("API 返回格式不支持,需为 bytes 或 (bytes, ext[, extra])")
|
||||
|
||||
|
||||
def is_remote_image_url(value: str) -> bool:
|
||||
"""判断输入是否为 http/https 图片地址。"""
|
||||
url = (value or "").strip()
|
||||
|
||||
Reference in New Issue
Block a user