Files
pqAutomationApp/app/services/ai_image.py

545 lines
17 KiB
Python
Raw Normal View History

2026-04-21 14:06:48 +08:00
"""AI 图片生成服务:后端请求 + 本地缓存管理。
后端接口测试环境
POST {API_BASE_URL}{API_PATH}
body: {"user_message": str, "session_id": str}
resp: {"code": 200, "message": "", "data": {"imageUrl": "..."}}
2026-04-21 14:06:48 +08:00
缓存目录``settings/ai_image_cache/``每张图片有同名的 ``.json`` 侧车记录
"""
from __future__ import annotations
import datetime as _dt
import hashlib
import json
import logging
import mimetypes
2026-04-21 14:06:48 +08:00
import os
import shutil
import threading
import time
import uuid
from io import BytesIO
2026-04-21 14:06:48 +08:00
from dataclasses import dataclass, asdict
from typing import Callable, List, Optional
from urllib.parse import urlparse
from urllib.request import Request, urlopen
2026-04-21 14:06:48 +08:00
from PIL import Image
logger = logging.getLogger(__name__)
2026-04-21 14:06:48 +08:00
# ---------- 常量 ----------
_CACHE_DIRNAME = os.path.join("settings", "ai_image_cache")
_META_SUFFIX = ".json"
_SUPPORTED_IMG_EXT = (".png", ".jpg", ".jpeg", ".bmp", ".webp")
# 测试环境后端
API_BASE_URL = "http://10.201.44.70:9018/ai-agent/"
API_PATH = "api/v1/pqtest/generate"
API_TIMEOUT = 90.0 # 后端最长 60s留余量
# 进程级会话 id多轮对话需保持一致可通过 ``reset_session`` 重置
_session_id: str = str(uuid.uuid4())
_session_lock = threading.Lock()
def get_session_id() -> str:
with _session_lock:
return _session_id
def set_session_id(session_id: str) -> str:
"""切换到指定会话。空值会抛错。"""
global _session_id
sid = (session_id or "").strip()
if not sid:
raise ValueError("session_id 不能为空")
with _session_lock:
old = _session_id
_session_id = sid
logger.info("[AIImage] 会话切换 %s -> %s", _mask_sid(old), _mask_sid(_session_id))
return _session_id
def reset_session() -> str:
"""开启新一轮会话,返回新的 session_id。"""
global _session_id
with _session_lock:
old = _session_id
_session_id = str(uuid.uuid4())
logger.info("[AIImage] 会话切换 %s -> %s", _mask_sid(old), _mask_sid(_session_id))
return _session_id
def _mask_sid(sid: str) -> str:
"""日志安全展示:仅保留前 8 位。"""
if not sid:
return "(none)"
return f"{sid[:8]}"
def _truncate(text: str, n: int = 80) -> str:
s = (text or "").replace("\n", " ").strip()
return s if len(s) <= n else s[:n] + ""
2026-04-21 14:06:48 +08:00
# ---------- 数据结构 ----------
@dataclass
class AIImageRecord:
"""一条缓存记录。
字段说明
- ``id``: 唯一 id等同于磁盘文件名不含扩展名格式 ``{时间戳}_{md5前8}``
- ``prompt``: 用户原始输入完整保留用于回溯/调试不应被改写
- ``title``: 用户自定义展示标题重命名时写入UI 优先使用留空则回退 prompt 第一行截断
- ``image_path``: 图片在缓存目录中的绝对路径
- ``created_at``: ISO8601 时间字符串
- ``extra``: 其它元数据至少包含 ``source`` ``session_id``标识属于哪一轮对话
"""
2026-04-21 14:06:48 +08:00
id: str
prompt: str
image_path: str
created_at: str # ISO8601
extra: Optional[dict] = None
title: Optional[str] = None
2026-04-21 14:06:48 +08:00
def to_json(self) -> str:
return json.dumps(asdict(self), ensure_ascii=False, indent=2)
@property
def display_name(self) -> str:
"""UI 展示名title 优先,否则回退 prompt 第一行。"""
if self.title:
return self.title.strip()
first = (self.prompt or "").strip().splitlines()[0] if self.prompt else ""
return first or "(未命名)"
2026-04-21 14:06:48 +08:00
@property
def session_id(self) -> str:
if isinstance(self.extra, dict):
return str(self.extra.get("session_id") or "")
return ""
2026-04-21 14:06:48 +08:00
# ---------- 后端 API ----------
2026-04-21 14:06:48 +08:00
def _api_endpoint() -> str:
base = API_BASE_URL if API_BASE_URL.endswith("/") else API_BASE_URL + "/"
return base + API_PATH.lstrip("/")
2026-04-21 14:06:48 +08:00
def _call_pqtest_generate(user_message: str, session_id: str, timeout: float = API_TIMEOUT) -> str:
"""调用后端 ``api/v1/pqtest/generate``,返回 imageUrl。失败抛异常。"""
payload = json.dumps(
{"user_message": user_message,
"session_id": session_id},
ensure_ascii=False,
).encode("utf-8")
endpoint = _api_endpoint()
logger.info(
"[AIImage] 请求生成 sid=%s prompt_len=%d prompt=%r",
_mask_sid(session_id), len(user_message or ""), _truncate(user_message),
)
logger.debug("[AIImage] POST %s timeout=%.1fs", endpoint, timeout)
request = Request(
endpoint,
data=payload,
method="POST",
headers={
"Content-Type": "application/json; charset=utf-8",
"Accept": "application/json",
"User-Agent": "pqAutomationApp/1.0",
},
)
t0 = time.monotonic()
try:
with urlopen(request, timeout=timeout) as response:
raw = response.read()
http_status = response.status
except Exception as exc:
elapsed = time.monotonic() - t0
logger.error(
"[AIImage] 请求异常 sid=%s elapsed=%.2fs %s: %s",
_mask_sid(session_id), elapsed, type(exc).__name__, exc,
)
raise
elapsed = time.monotonic() - t0
logger.debug(
"[AIImage] HTTP %s 收到 %d bytes elapsed=%.2fs",
http_status, len(raw), elapsed,
)
try:
result = json.loads(raw.decode("utf-8"))
except Exception as exc:
logger.error("[AIImage] 响应解析失败 sid=%s raw=%r", _mask_sid(session_id), raw[:200])
raise RuntimeError(f"AI 接口返回非 JSON{raw[:200]!r}") from exc
code = result.get("code")
message = result.get("message") or ""
data = result.get("data") or {}
image_url = (data.get("imageUrl") or "").strip()
if code != 200 or not image_url:
logger.warning(
"[AIImage] 接口失败 sid=%s code=%s msg=%r",
_mask_sid(session_id), code, message,
)
raise RuntimeError(f"AI 接口失败 code={code} msg={message or '生成失败'}")
logger.info(
"[AIImage] 生成成功 sid=%s elapsed=%.2fs url=%s",
_mask_sid(session_id), elapsed, image_url,
)
return image_url
2026-04-21 14:06:48 +08:00
# ---------- 缓存路径工具 ----------
def get_cache_dir(base_dir: Optional[str] = None) -> str:
"""返回缓存目录,如不存在则创建。``base_dir`` 默认使用当前工作目录。"""
root = base_dir if base_dir else os.getcwd()
path = os.path.join(root, _CACHE_DIRNAME)
os.makedirs(path, exist_ok=True)
return path
def _make_id(prompt: str) -> str:
stamp = _dt.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
digest = hashlib.md5(prompt.encode("utf-8")).hexdigest()[:8]
return f"{stamp}_{digest}"
def _meta_path_for(image_path: str) -> str:
return os.path.splitext(image_path)[0] + _META_SUFFIX
def _sanitize_image_bytes(image_bytes: bytes, image_ext: str) -> bytes:
"""规范化图片字节,尽量去掉已知有问题的 PNG ICC profile。"""
ext = (image_ext or "").lower()
if ext not in {".png", ".jpg", ".jpeg", ".bmp", ".webp"}:
return image_bytes
try:
with Image.open(BytesIO(image_bytes)) as img:
img.load()
normalized = img.copy()
output = BytesIO()
save_kwargs = {}
if ext == ".png":
save_kwargs["icc_profile"] = None
normalized.save(output, format=normalized.format or ext.lstrip(".").upper(), **save_kwargs)
result = output.getvalue()
if result:
return result
except Exception as exc:
logger.warning("[AIImage] 图片规范化失败 ext=%s %s: %s", ext, type(exc).__name__, exc)
return image_bytes
2026-04-21 14:06:48 +08:00
# ---------- 读写 ----------
def list_records(base_dir: Optional[str] = None) -> List[AIImageRecord]:
"""列出缓存目录下的所有记录,按创建时间倒序(最新在前)。"""
cache_dir = get_cache_dir(base_dir)
records: List[AIImageRecord] = []
for name in os.listdir(cache_dir):
full = os.path.join(cache_dir, name)
if not (os.path.isfile(full) and name.lower().endswith(_SUPPORTED_IMG_EXT)):
continue
meta_path = _meta_path_for(full)
prompt = ""
created_at = ""
extra = None
title = None
2026-04-21 14:06:48 +08:00
rec_id = os.path.splitext(name)[0]
if os.path.isfile(meta_path):
try:
with open(meta_path, "r", encoding="utf-8") as f:
data = json.load(f)
prompt = data.get("prompt", "")
created_at = data.get("created_at", "")
extra = data.get("extra")
title = data.get("title")
2026-04-21 14:06:48 +08:00
rec_id = data.get("id", rec_id)
except Exception:
pass
if not created_at:
# fallback 到文件 mtime
try:
mtime = os.path.getmtime(full)
created_at = _dt.datetime.fromtimestamp(mtime).isoformat()
except Exception:
created_at = ""
records.append(
AIImageRecord(
id=rec_id,
prompt=prompt,
image_path=full,
created_at=created_at,
extra=extra,
title=title,
2026-04-21 14:06:48 +08:00
)
)
records.sort(key=lambda r: r.created_at, reverse=True)
return records
def save_image_to_cache(
prompt: str,
image_bytes: bytes,
image_ext: str = ".png",
extra: Optional[dict] = None,
base_dir: Optional[str] = None,
title: Optional[str] = None,
2026-04-21 14:06:48 +08:00
) -> AIImageRecord:
"""把生成的图片字节写入缓存,返回记录。"""
if not image_ext.startswith("."):
image_ext = "." + image_ext
if image_ext.lower() not in _SUPPORTED_IMG_EXT:
image_ext = ".png"
image_bytes = _sanitize_image_bytes(image_bytes, image_ext)
2026-04-21 14:06:48 +08:00
cache_dir = get_cache_dir(base_dir)
rec_id = _make_id(prompt)
image_path = os.path.join(cache_dir, f"{rec_id}{image_ext}")
with open(image_path, "wb") as f:
f.write(image_bytes)
record = AIImageRecord(
id=rec_id,
prompt=prompt,
image_path=image_path,
created_at=_dt.datetime.now().isoformat(timespec="seconds"),
extra=extra,
title=title,
2026-04-21 14:06:48 +08:00
)
try:
with open(_meta_path_for(image_path), "w", encoding="utf-8") as f:
f.write(record.to_json())
except Exception:
pass
return record
def import_image_from_url(
image_url: str,
prompt: Optional[str] = None,
extra: Optional[dict] = None,
base_dir: Optional[str] = None,
timeout: float = 20.0,
) -> AIImageRecord:
"""下载远程图片并写入缓存。"""
url = (image_url or "").strip()
if not url:
raise ValueError("图片地址不能为空")
request = Request(
url,
headers={
"User-Agent": "pqAutomationApp/1.0",
"Accept": "image/*,*/*;q=0.8",
},
)
with urlopen(request, timeout=timeout) as response:
image_bytes = response.read()
if not image_bytes:
raise ValueError("下载结果为空")
image_ext = _guess_image_ext(
image_url=url,
content_type=response.headers.get_content_type(),
)
merged_extra = dict(extra or {})
merged_extra.update(
{
"source": "remote-url",
"source_url": url,
"content_type": response.headers.get_content_type(),
}
)
record_prompt = (prompt or _default_prompt_from_url(url)).strip()
return save_image_to_cache(
prompt=record_prompt,
image_bytes=image_bytes,
image_ext=image_ext,
extra=merged_extra,
base_dir=base_dir,
)
2026-04-21 14:06:48 +08:00
def delete_record(record: AIImageRecord) -> bool:
"""删除一条缓存记录(图片 + 侧车)。返回是否成功。"""
ok = True
for p in (record.image_path, _meta_path_for(record.image_path)):
try:
if os.path.isfile(p):
os.remove(p)
except Exception:
ok = False
return ok
def export_record(record: AIImageRecord, dest_path: str) -> None:
"""把缓存中的图片另存到 ``dest_path``。"""
shutil.copyfile(record.image_path, dest_path)
def update_record_title(record: AIImageRecord, new_title: Optional[str]) -> bool:
"""更新记录的展示标题并写回侧车 JSON。空串/None 视为清除标题。"""
title = (new_title or "").strip() or None
meta_path = _meta_path_for(record.image_path)
try:
data: dict = {}
if os.path.isfile(meta_path):
with open(meta_path, "r", encoding="utf-8") as f:
data = json.load(f) or {}
if title is None:
data.pop("title", None)
else:
data["title"] = title
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except Exception:
return False
record.title = title
return True
def group_records_by_session(records: List[AIImageRecord]) -> List[dict]:
"""按 ``session_id`` 分组。
返回元素``{"session_id", "records", "started_at", "latest_at"}``
会话按"最近使用时间"倒序组内记录按时间倒序
没有 session_id 的记录归到空串 ``""``
"""
buckets: dict = {}
for rec in records:
buckets.setdefault(rec.session_id, []).append(rec)
sessions = []
for sid, recs in buckets.items():
recs.sort(key=lambda r: r.created_at, reverse=True)
started_at = min((r.created_at for r in recs if r.created_at), default="")
sessions.append(
{
"session_id": sid,
"records": recs,
"started_at": started_at,
"latest_at": recs[0].created_at if recs else "",
}
)
sessions.sort(key=lambda s: s["latest_at"], reverse=True)
return sessions
2026-04-21 14:06:48 +08:00
# ---------- 异步请求 ----------
def request_image_async(
prompt: str,
on_success: Callable[[AIImageRecord], None],
on_error: Callable[[Exception], None],
base_dir: Optional[str] = None,
session_id: Optional[str] = None,
2026-04-21 14:06:48 +08:00
) -> threading.Thread:
"""在后台线程调用后端 API → 下载图片 → 写入缓存 → 回调。
2026-04-21 14:06:48 +08:00
``on_success`` / ``on_error`` 会在 **工作线程** 中被调用UI 侧若需
切回主线程请在回调内部自行用 ``root.after(0, ...)``
``session_id`` 留空则使用进程级会话 id保证多轮对话上下文
2026-04-21 14:06:48 +08:00
"""
sid = session_id or get_session_id()
2026-04-21 14:06:48 +08:00
def _worker():
try:
image_url = _call_pqtest_generate(prompt, sid)
record = import_image_from_url(
image_url=image_url,
2026-04-21 14:06:48 +08:00
prompt=prompt,
extra={"source": "ai-api", "session_id": sid},
2026-04-21 14:06:48 +08:00
base_dir=base_dir,
)
logger.info(
"[AIImage] 已写入缓存 sid=%s id=%s path=%s",
_mask_sid(sid), record.id, record.image_path,
)
2026-04-21 14:06:48 +08:00
on_success(record)
except Exception as exc:
logger.error(
"[AIImage] 生成流程失败 sid=%s %s: %s",
_mask_sid(sid), type(exc).__name__, exc,
)
2026-04-21 14:06:48 +08:00
on_error(exc)
t = threading.Thread(target=_worker, daemon=True)
t.start()
return t
def import_image_from_url_async(
image_url: str,
on_success: Callable[[AIImageRecord], None],
on_error: Callable[[Exception], None],
prompt: Optional[str] = None,
extra: Optional[dict] = None,
base_dir: Optional[str] = None,
timeout: float = 20.0,
) -> threading.Thread:
"""在后台线程下载远程图片并写入缓存"""
def _worker():
try:
record = import_image_from_url(
image_url=image_url,
prompt=prompt,
extra=extra,
base_dir=base_dir,
timeout=timeout,
)
on_success(record)
except Exception as exc:
on_error(exc)
t = threading.Thread(target=_worker, daemon=True)
t.start()
return t
def is_remote_image_url(value: str) -> bool:
"""判断输入是否为 http/https 图片地址。"""
url = (value or "").strip()
if not url:
return False
parsed = urlparse(url)
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
def _guess_image_ext(image_url: str, content_type: Optional[str]) -> str:
if content_type:
guessed = mimetypes.guess_extension(content_type)
if guessed == ".jpe":
guessed = ".jpg"
if guessed and guessed.lower() in _SUPPORTED_IMG_EXT:
return guessed.lower()
url_path = urlparse(image_url).path
ext = os.path.splitext(url_path)[1].lower()
if ext in _SUPPORTED_IMG_EXT:
return ext
return ".png"
def _default_prompt_from_url(image_url: str) -> str:
path = urlparse(image_url).path
name = os.path.splitext(os.path.basename(path))[0].strip()
return name or "远程导入图片"