xhs_factory/sd_service.py
zhoujie 883082411a feat(sd): 新增图片生成质量预设模式并优化换脸功能
- 新增三档生成质量预设【快速/标准/精细】,针对 SDXL 模型优化参数
- 新增 `SD_PRESETS` 配置字典和 `get_sd_preset` 工具函数
- 为 `generate_images` 函数和 `txt2img` 方法添加 `quality_mode` 参数支持
- 在 Gradio UI 中添加生成模式选择器,并实现参数联动预览
- 优化换脸头像处理逻辑,支持多种输入格式并增强日志记录
- 调整默认绘图参数以匹配预设,并更新相关函数调用

♻️ refactor(sd): 重构 ReActor 换脸 API 调用参数

- 更新 `_build_reactor_payload` 方法参数列表以匹配最新 API
- 将部分字符串参数(如日志级别、性别检测)调整为整数类型
- 优化参数默认值,如提高 CodeFormer 权重至 0.8
2026-02-09 23:46:50 +08:00

314 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Stable Diffusion 服务模块
封装对 SD WebUI API 的调用,支持 txt2img 和 img2img支持 ReActor 换脸
"""
import requests
import base64
import io
import logging
import os
from PIL import Image
logger = logging.getLogger(__name__)
SD_TIMEOUT = 1800 # 图片生成可能需要较长时间
# 头像文件默认保存路径
FACE_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "my_face.png")
# ==================== 生成质量预设 ====================
# 针对 JuggernautXL (SDXL) 优化的三档参数
SD_PRESETS = {
"快速 (约30秒)": {
"steps": 12,
"cfg_scale": 5.0,
"width": 768,
"height": 1024,
"sampler_name": "Euler a",
"scheduler": "Normal",
"batch_size": 2,
},
"标准 (约1分钟)": {
"steps": 20,
"cfg_scale": 5.5,
"width": 832,
"height": 1216,
"sampler_name": "DPM++ 2M",
"scheduler": "Karras",
"batch_size": 2,
},
"精细 (约2-3分钟)": {
"steps": 35,
"cfg_scale": 6.0,
"width": 832,
"height": 1216,
"sampler_name": "DPM++ 2M SDE",
"scheduler": "Karras",
"batch_size": 2,
},
}
SD_PRESET_NAMES = list(SD_PRESETS.keys())
def get_sd_preset(name: str) -> dict:
"""获取生成预设参数,默认返回'标准'"""
return SD_PRESETS.get(name, SD_PRESETS["标准 (约1分钟)"])
# 默认反向提示词(针对 JuggernautXL / SDXL 优化,偏向东方审美)
DEFAULT_NEGATIVE = (
"nsfw, nudity, lowres, bad anatomy, bad hands, text, error, missing fingers, "
"extra digit, fewer digits, cropped, worst quality, low quality, normal quality, "
"jpeg artifacts, signature, watermark, blurry, deformed, mutated, disfigured, "
"ugly, duplicate, morbid, mutilated, poorly drawn face, poorly drawn hands, "
"extra limbs, fused fingers, too many fingers, long neck, username, "
"out of frame, distorted, oversaturated, underexposed, overexposed, "
"western face, european face, caucasian, deep-set eyes, high nose bridge, "
"blonde hair, red hair, blue eyes, green eyes, freckles, thick body hair"
)
class SDService:
"""Stable Diffusion WebUI API 封装"""
def __init__(self, sd_url: str = "http://127.0.0.1:7860"):
self.sd_url = sd_url.rstrip("/")
# ---------- 工具方法 ----------
@staticmethod
def _image_to_base64(img: Image.Image) -> str:
"""PIL Image → base64 字符串"""
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
@staticmethod
def load_face_image(path: str = None) -> Image.Image | None:
"""加载头像图片,不存在则返回 None"""
path = path or FACE_IMAGE_PATH
if path and os.path.isfile(path):
try:
return Image.open(path).convert("RGB")
except Exception as e:
logger.warning("头像加载失败: %s", e)
return None
@staticmethod
def save_face_image(img: Image.Image, path: str = None) -> str:
"""保存头像图片,返回保存路径"""
path = path or FACE_IMAGE_PATH
img = img.convert("RGB")
img.save(path, format="PNG")
logger.info("头像已保存: %s", path)
return path
def _build_reactor_args(self, face_image: Image.Image) -> dict:
"""构建 ReActor 换脸参数alwayson_scripts 格式)
参数索引对照 (reactor script-info):
0: source_image (base64) 1: enable 2: source_faces
3: target_faces 4: model 5: restore_face
6: restore_visibility 7: restore_first 8: upscaler
9: scale 10: upscaler_vis 11: swap_in_source
12: swap_in_generated 13: log_level 14: gender_source
15: gender_target 16: save_original 17: codeformer_weight
18: source_hash_check 19: target_hash_check 20: exec_provider
21: face_mask_correction 22: select_source 23: face_model
24: source_folder 25: multiple_sources 26: random_image
27: force_upscale 28: threshold 29: max_faces
30: tab_single
"""
face_b64 = self._image_to_base64(face_image)
return {
"reactor": {
"args": [
face_b64, # 0: source image (base64)
True, # 1: enable ReActor
"0", # 2: source face index
"0", # 3: target face index
"inswapper_128.onnx", # 4: swap model
"CodeFormer", # 5: restore face method
1, # 6: restore face visibility
True, # 7: restore face first, then upscale
"None", # 8: upscaler
1, # 9: scale
1, # 10: upscaler visibility
False, # 11: swap in source
True, # 12: swap in generated
1, # 13: console log level (0=min, 1=med, 2=max)
0, # 14: gender detection source (0=No)
0, # 15: gender detection target (0=No)
False, # 16: save original
0.8, # 17: CodeFormer weight (0=max effect, 1=min)
False, # 18: source hash check
False, # 19: target hash check
"CUDA", # 20: execution provider
True, # 21: face mask correction
0, # 22: select source (0=Image, 1=FaceModel, 2=Folder)
"", # 23: face model filename (when #22=1)
"", # 24: source folder path (when #22=2)
None, # 25: skip for API
False, # 26: random image
False, # 27: force upscale
0.6, # 28: face detection threshold
2, # 29: max faces to detect (0=unlimited)
],
}
}
def has_reactor(self) -> bool:
"""检查 SD WebUI 是否安装了 ReActor 扩展"""
try:
resp = requests.get(f"{self.sd_url}/sdapi/v1/scripts", timeout=5)
scripts = resp.json()
all_scripts = scripts.get("txt2img", []) + scripts.get("img2img", [])
return any("reactor" in s.lower() for s in all_scripts)
except Exception:
return False
def check_connection(self) -> tuple[bool, str]:
"""检查 SD 服务是否可用"""
try:
resp = requests.get(f"{self.sd_url}/sdapi/v1/sd-models", timeout=5)
if resp.status_code == 200:
count = len(resp.json())
return True, f"SD 已连接,{count} 个模型可用"
return False, f"SD 返回异常状态: {resp.status_code}"
except requests.exceptions.ConnectionError:
return False, "SD WebUI 未启动或端口错误"
except Exception as e:
return False, f"SD 连接失败: {e}"
def get_models(self) -> list[str]:
"""获取 SD 模型列表"""
resp = requests.get(f"{self.sd_url}/sdapi/v1/sd-models", timeout=5)
resp.raise_for_status()
return [m["title"] for m in resp.json()]
def switch_model(self, model_name: str):
"""切换 SD 模型"""
try:
requests.post(
f"{self.sd_url}/sdapi/v1/options",
json={"sd_model_checkpoint": model_name},
timeout=60,
)
except Exception as e:
logger.warning("模型切换失败: %s", e)
def txt2img(
self,
prompt: str,
negative_prompt: str = DEFAULT_NEGATIVE,
model: str = None,
steps: int = None,
cfg_scale: float = None,
width: int = None,
height: int = None,
batch_size: int = None,
seed: int = -1,
sampler_name: str = None,
scheduler: str = None,
face_image: Image.Image = None,
quality_mode: str = None,
) -> list[Image.Image]:
"""文生图(参数针对 JuggernautXL 优化)
Args:
face_image: 头像 PIL Image传入后自动启用 ReActor 换脸
quality_mode: 预设模式名,如 '快速 (约30秒)' / '标准 (约1分钟)' / '精细 (约2-3分钟)'
传入后自动应用预设参数,其余参数可覆盖
"""
if model:
self.switch_model(model)
# 加载预设作为基底,再用显式参数覆盖
preset = get_sd_preset(quality_mode) if quality_mode else get_sd_preset("标准 (约1分钟)")
payload = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"steps": steps if steps is not None else preset["steps"],
"cfg_scale": cfg_scale if cfg_scale is not None else preset["cfg_scale"],
"width": width if width is not None else preset["width"],
"height": height if height is not None else preset["height"],
"batch_size": batch_size if batch_size is not None else preset["batch_size"],
"seed": seed,
"sampler_name": sampler_name if sampler_name is not None else preset["sampler_name"],
"scheduler": scheduler if scheduler is not None else preset["scheduler"],
}
logger.info("SD 生成参数: steps=%s, cfg=%.1f, %dx%d, sampler=%s",
payload['steps'], payload['cfg_scale'],
payload['width'], payload['height'], payload['sampler_name'])
# 如果提供了头像,通过 ReActor 换脸
if face_image is not None:
payload["alwayson_scripts"] = self._build_reactor_args(face_image)
logger.info("🎭 ReActor 换脸已启用")
resp = requests.post(
f"{self.sd_url}/sdapi/v1/txt2img",
json=payload,
timeout=SD_TIMEOUT,
)
resp.raise_for_status()
images = []
for img_b64 in resp.json().get("images", []):
img = Image.open(io.BytesIO(base64.b64decode(img_b64)))
images.append(img)
return images
def img2img(
self,
init_image: Image.Image,
prompt: str,
negative_prompt: str = DEFAULT_NEGATIVE,
denoising_strength: float = 0.5,
steps: int = 30,
cfg_scale: float = 5.0,
sampler_name: str = "DPM++ 2M",
scheduler: str = "Karras",
) -> list[Image.Image]:
"""图生图(参数针对 JuggernautXL 优化)"""
# 将 PIL Image 转为 base64
buf = io.BytesIO()
init_image.save(buf, format="PNG")
init_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
payload = {
"init_images": [init_b64],
"prompt": prompt,
"negative_prompt": negative_prompt,
"denoising_strength": denoising_strength,
"steps": steps,
"cfg_scale": cfg_scale,
"width": init_image.width,
"height": init_image.height,
"sampler_name": sampler_name,
"scheduler": scheduler,
}
resp = requests.post(
f"{self.sd_url}/sdapi/v1/img2img",
json=payload,
timeout=SD_TIMEOUT,
)
resp.raise_for_status()
images = []
for img_b64 in resp.json().get("images", []):
img = Image.open(io.BytesIO(base64.b64decode(img_b64)))
images.append(img)
return images
def get_lora_models(self) -> list[str]:
"""获取可用的 LoRA 模型列表"""
try:
resp = requests.get(f"{self.sd_url}/sdapi/v1/loras", timeout=5)
resp.raise_for_status()
return [lora["name"] for lora in resp.json()]
except Exception:
return []