""" Stable Diffusion 服务模块 封装对 SD WebUI API 的调用,支持 txt2img 和 img2img,支持 ReActor 换脸 """ import requests import base64 import io import logging import os from PIL import Image logger = logging.getLogger(__name__) SD_TIMEOUT = 1800 # 图片生成可能需要较长时间 # 头像文件默认保存路径 FACE_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "my_face.png") # ==================== 生成质量预设 ==================== # 针对 JuggernautXL (SDXL) 优化的三档参数 SD_PRESETS = { "快速 (约30秒)": { "steps": 12, "cfg_scale": 5.0, "width": 768, "height": 1024, "sampler_name": "Euler a", "scheduler": "Normal", "batch_size": 2, }, "标准 (约1分钟)": { "steps": 20, "cfg_scale": 5.5, "width": 832, "height": 1216, "sampler_name": "DPM++ 2M", "scheduler": "Karras", "batch_size": 2, }, "精细 (约2-3分钟)": { "steps": 35, "cfg_scale": 6.0, "width": 832, "height": 1216, "sampler_name": "DPM++ 2M SDE", "scheduler": "Karras", "batch_size": 2, }, } SD_PRESET_NAMES = list(SD_PRESETS.keys()) def get_sd_preset(name: str) -> dict: """获取生成预设参数,默认返回'标准'""" return SD_PRESETS.get(name, SD_PRESETS["标准 (约1分钟)"]) # 默认反向提示词(针对 JuggernautXL / SDXL 优化,偏向东方审美) DEFAULT_NEGATIVE = ( "nsfw, nudity, lowres, bad anatomy, bad hands, text, error, missing fingers, " "extra digit, fewer digits, cropped, worst quality, low quality, normal quality, " "jpeg artifacts, signature, watermark, blurry, deformed, mutated, disfigured, " "ugly, duplicate, morbid, mutilated, poorly drawn face, poorly drawn hands, " "extra limbs, fused fingers, too many fingers, long neck, username, " "out of frame, distorted, oversaturated, underexposed, overexposed, " "western face, european face, caucasian, deep-set eyes, high nose bridge, " "blonde hair, red hair, blue eyes, green eyes, freckles, thick body hair" ) class SDService: """Stable Diffusion WebUI API 封装""" def __init__(self, sd_url: str = "http://127.0.0.1:7860"): self.sd_url = sd_url.rstrip("/") # ---------- 工具方法 ---------- @staticmethod def _image_to_base64(img: Image.Image) -> str: """PIL Image → base64 字符串""" buf = io.BytesIO() img.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("utf-8") @staticmethod def load_face_image(path: str = None) -> Image.Image | None: """加载头像图片,不存在则返回 None""" path = path or FACE_IMAGE_PATH if path and os.path.isfile(path): try: return Image.open(path).convert("RGB") except Exception as e: logger.warning("头像加载失败: %s", e) return None @staticmethod def save_face_image(img: Image.Image, path: str = None) -> str: """保存头像图片,返回保存路径""" path = path or FACE_IMAGE_PATH img = img.convert("RGB") img.save(path, format="PNG") logger.info("头像已保存: %s", path) return path def _build_reactor_args(self, face_image: Image.Image) -> dict: """构建 ReActor 换脸参数(alwayson_scripts 格式) 参数索引对照 (reactor script-info): 0: source_image (base64) 1: enable 2: source_faces 3: target_faces 4: model 5: restore_face 6: restore_visibility 7: restore_first 8: upscaler 9: scale 10: upscaler_vis 11: swap_in_source 12: swap_in_generated 13: log_level 14: gender_source 15: gender_target 16: save_original 17: codeformer_weight 18: source_hash_check 19: target_hash_check 20: exec_provider 21: face_mask_correction 22: select_source 23: face_model 24: source_folder 25: multiple_sources 26: random_image 27: force_upscale 28: threshold 29: max_faces 30: tab_single """ face_b64 = self._image_to_base64(face_image) return { "reactor": { "args": [ face_b64, # 0: source image (base64) True, # 1: enable ReActor "0", # 2: source face index "0", # 3: target face index "inswapper_128.onnx", # 4: swap model "CodeFormer", # 5: restore face method 1, # 6: restore face visibility True, # 7: restore face first, then upscale "None", # 8: upscaler 1, # 9: scale 1, # 10: upscaler visibility False, # 11: swap in source True, # 12: swap in generated 1, # 13: console log level (0=min, 1=med, 2=max) 0, # 14: gender detection source (0=No) 0, # 15: gender detection target (0=No) False, # 16: save original 0.8, # 17: CodeFormer weight (0=max effect, 1=min) False, # 18: source hash check False, # 19: target hash check "CUDA", # 20: execution provider True, # 21: face mask correction 0, # 22: select source (0=Image, 1=FaceModel, 2=Folder) "", # 23: face model filename (when #22=1) "", # 24: source folder path (when #22=2) None, # 25: skip for API False, # 26: random image False, # 27: force upscale 0.6, # 28: face detection threshold 2, # 29: max faces to detect (0=unlimited) ], } } def has_reactor(self) -> bool: """检查 SD WebUI 是否安装了 ReActor 扩展""" try: resp = requests.get(f"{self.sd_url}/sdapi/v1/scripts", timeout=5) scripts = resp.json() all_scripts = scripts.get("txt2img", []) + scripts.get("img2img", []) return any("reactor" in s.lower() for s in all_scripts) except Exception: return False def check_connection(self) -> tuple[bool, str]: """检查 SD 服务是否可用""" try: resp = requests.get(f"{self.sd_url}/sdapi/v1/sd-models", timeout=5) if resp.status_code == 200: count = len(resp.json()) return True, f"SD 已连接,{count} 个模型可用" return False, f"SD 返回异常状态: {resp.status_code}" except requests.exceptions.ConnectionError: return False, "SD WebUI 未启动或端口错误" except Exception as e: return False, f"SD 连接失败: {e}" def get_models(self) -> list[str]: """获取 SD 模型列表""" resp = requests.get(f"{self.sd_url}/sdapi/v1/sd-models", timeout=5) resp.raise_for_status() return [m["title"] for m in resp.json()] def switch_model(self, model_name: str): """切换 SD 模型""" try: requests.post( f"{self.sd_url}/sdapi/v1/options", json={"sd_model_checkpoint": model_name}, timeout=60, ) except Exception as e: logger.warning("模型切换失败: %s", e) def txt2img( self, prompt: str, negative_prompt: str = DEFAULT_NEGATIVE, model: str = None, steps: int = None, cfg_scale: float = None, width: int = None, height: int = None, batch_size: int = None, seed: int = -1, sampler_name: str = None, scheduler: str = None, face_image: Image.Image = None, quality_mode: str = None, ) -> list[Image.Image]: """文生图(参数针对 JuggernautXL 优化) Args: face_image: 头像 PIL Image,传入后自动启用 ReActor 换脸 quality_mode: 预设模式名,如 '快速 (约30秒)' / '标准 (约1分钟)' / '精细 (约2-3分钟)' 传入后自动应用预设参数,其余参数可覆盖 """ if model: self.switch_model(model) # 加载预设作为基底,再用显式参数覆盖 preset = get_sd_preset(quality_mode) if quality_mode else get_sd_preset("标准 (约1分钟)") payload = { "prompt": prompt, "negative_prompt": negative_prompt, "steps": steps if steps is not None else preset["steps"], "cfg_scale": cfg_scale if cfg_scale is not None else preset["cfg_scale"], "width": width if width is not None else preset["width"], "height": height if height is not None else preset["height"], "batch_size": batch_size if batch_size is not None else preset["batch_size"], "seed": seed, "sampler_name": sampler_name if sampler_name is not None else preset["sampler_name"], "scheduler": scheduler if scheduler is not None else preset["scheduler"], } logger.info("SD 生成参数: steps=%s, cfg=%.1f, %dx%d, sampler=%s", payload['steps'], payload['cfg_scale'], payload['width'], payload['height'], payload['sampler_name']) # 如果提供了头像,通过 ReActor 换脸 if face_image is not None: payload["alwayson_scripts"] = self._build_reactor_args(face_image) logger.info("🎭 ReActor 换脸已启用") resp = requests.post( f"{self.sd_url}/sdapi/v1/txt2img", json=payload, timeout=SD_TIMEOUT, ) resp.raise_for_status() images = [] for img_b64 in resp.json().get("images", []): img = Image.open(io.BytesIO(base64.b64decode(img_b64))) images.append(img) return images def img2img( self, init_image: Image.Image, prompt: str, negative_prompt: str = DEFAULT_NEGATIVE, denoising_strength: float = 0.5, steps: int = 30, cfg_scale: float = 5.0, sampler_name: str = "DPM++ 2M", scheduler: str = "Karras", ) -> list[Image.Image]: """图生图(参数针对 JuggernautXL 优化)""" # 将 PIL Image 转为 base64 buf = io.BytesIO() init_image.save(buf, format="PNG") init_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") payload = { "init_images": [init_b64], "prompt": prompt, "negative_prompt": negative_prompt, "denoising_strength": denoising_strength, "steps": steps, "cfg_scale": cfg_scale, "width": init_image.width, "height": init_image.height, "sampler_name": sampler_name, "scheduler": scheduler, } resp = requests.post( f"{self.sd_url}/sdapi/v1/img2img", json=payload, timeout=SD_TIMEOUT, ) resp.raise_for_status() images = [] for img_b64 in resp.json().get("images", []): img = Image.open(io.BytesIO(base64.b64decode(img_b64))) images.append(img) return images def get_lora_models(self) -> list[str]: """获取可用的 LoRA 模型列表""" try: resp = requests.get(f"{self.sd_url}/sdapi/v1/loras", timeout=5) resp.raise_for_status() return [lora["name"] for lora in resp.json()] except Exception: return []