添加后台线程下载远程图片并写入缓存

This commit is contained in:
xinzhu.yin
2026-04-23 10:07:41 +08:00
parent 4073a6e999
commit b26a3c398d
2 changed files with 130 additions and 8 deletions

View File

@@ -9,11 +9,14 @@ from __future__ import annotations
import datetime as _dt
import hashlib
import json
import mimetypes
import os
import shutil
import threading
from dataclasses import dataclass, asdict
from typing import Callable, List, Optional
from urllib.parse import urlparse
from urllib.request import Request, urlopen
# ---------- 常量 ----------
@@ -189,6 +192,53 @@ def save_image_to_cache(
return record
def import_image_from_url(
image_url: str,
prompt: Optional[str] = None,
extra: Optional[dict] = None,
base_dir: Optional[str] = None,
timeout: float = 20.0,
) -> AIImageRecord:
"""下载远程图片并写入缓存。"""
url = (image_url or "").strip()
if not url:
raise ValueError("图片地址不能为空")
request = Request(
url,
headers={
"User-Agent": "pqAutomationApp/1.0",
"Accept": "image/*,*/*;q=0.8",
},
)
with urlopen(request, timeout=timeout) as response:
image_bytes = response.read()
if not image_bytes:
raise ValueError("下载结果为空")
image_ext = _guess_image_ext(
image_url=url,
content_type=response.headers.get_content_type(),
)
merged_extra = dict(extra or {})
merged_extra.update(
{
"source": "remote-url",
"source_url": url,
"content_type": response.headers.get_content_type(),
}
)
record_prompt = (prompt or _default_prompt_from_url(url)).strip()
return save_image_to_cache(
prompt=record_prompt,
image_bytes=image_bytes,
image_ext=image_ext,
extra=merged_extra,
base_dir=base_dir,
)
def delete_record(record: AIImageRecord) -> bool:
"""删除一条缓存记录(图片 + 侧车)。返回是否成功。"""
ok = True
@@ -242,6 +292,35 @@ def request_image_async(
return t
def import_image_from_url_async(
image_url: str,
on_success: Callable[[AIImageRecord], None],
on_error: Callable[[Exception], None],
prompt: Optional[str] = None,
extra: Optional[dict] = None,
base_dir: Optional[str] = None,
timeout: float = 20.0,
) -> threading.Thread:
"""在后台线程下载远程图片并写入缓存"""
def _worker():
try:
record = import_image_from_url(
image_url=image_url,
prompt=prompt,
extra=extra,
base_dir=base_dir,
timeout=timeout,
)
on_success(record)
except Exception as exc:
on_error(exc)
t = threading.Thread(target=_worker, daemon=True)
t.start()
return t
def _normalize_api_result(result):
"""允许 API 返回 ``bytes`` 或 ``(bytes, ext)`` 或 ``(bytes, ext, extra)``。"""
if isinstance(result, (bytes, bytearray)):
@@ -252,3 +331,33 @@ def _normalize_api_result(result):
if len(result) == 3:
return bytes(result[0]), str(result[1]), result[2]
raise ValueError("API 返回格式不支持,需为 bytes 或 (bytes, ext[, extra])")
def is_remote_image_url(value: str) -> bool:
"""判断输入是否为 http/https 图片地址。"""
url = (value or "").strip()
if not url:
return False
parsed = urlparse(url)
return parsed.scheme in {"http", "https"} and bool(parsed.netloc)
def _guess_image_ext(image_url: str, content_type: Optional[str]) -> str:
if content_type:
guessed = mimetypes.guess_extension(content_type)
if guessed == ".jpe":
guessed = ".jpg"
if guessed and guessed.lower() in _SUPPORTED_IMG_EXT:
return guessed.lower()
url_path = urlparse(image_url).path
ext = os.path.splitext(url_path)[1].lower()
if ext in _SUPPORTED_IMG_EXT:
return ext
return ".png"
def _default_prompt_from_url(image_url: str) -> str:
path = urlparse(image_url).path
name = os.path.splitext(os.path.basename(path))[0].strip()
return name or "远程导入图片"