✨ feat(sd): 新增图片生成质量预设模式并优化换脸功能
- 新增三档生成质量预设【快速/标准/精细】,针对 SDXL 模型优化参数
- 新增 `SD_PRESETS` 配置字典和 `get_sd_preset` 工具函数
- 为 `generate_images` 函数和 `txt2img` 方法添加 `quality_mode` 参数支持
- 在 Gradio UI 中添加生成模式选择器,并实现参数联动预览
- 优化换脸头像处理逻辑,支持多种输入格式并增强日志记录
- 调整默认绘图参数以匹配预设,并更新相关函数调用
♻️ refactor(sd): 重构 ReActor 换脸 API 调用参数
- 更新 `_build_reactor_payload` 方法参数列表以匹配最新 API
- 将部分字符串参数(如日志级别、性别检测)调整为整数类型
- 优化参数默认值,如提高 CodeFormer 权重至 0.8
This commit is contained in:
parent
358b957f5d
commit
883082411a
60
main.py
60
main.py
@ -19,7 +19,7 @@ import matplotlib.pyplot as plt
|
||||
|
||||
from config_manager import ConfigManager, OUTPUT_DIR
|
||||
from llm_service import LLMService
|
||||
from sd_service import SDService, DEFAULT_NEGATIVE, FACE_IMAGE_PATH
|
||||
from sd_service import SDService, DEFAULT_NEGATIVE, FACE_IMAGE_PATH, SD_PRESET_NAMES, get_sd_preset
|
||||
from mcp_client import MCPClient, get_mcp_client
|
||||
|
||||
# ================= matplotlib 中文字体配置 =================
|
||||
@ -309,22 +309,37 @@ def generate_copy(model, topic, style):
|
||||
return "", "", "", "", f"❌ 生成失败: {e}"
|
||||
|
||||
|
||||
def generate_images(sd_url, prompt, neg_prompt, model, steps, cfg_scale, face_swap_on, face_img):
|
||||
"""生成图片(可选 ReActor 换脸)"""
|
||||
def generate_images(sd_url, prompt, neg_prompt, model, steps, cfg_scale, face_swap_on, face_img, quality_mode):
|
||||
"""生成图片(可选 ReActor 换脸,支持质量模式预设)"""
|
||||
if not model:
|
||||
return None, [], "❌ 未选择 SD 模型"
|
||||
try:
|
||||
svc = SDService(sd_url)
|
||||
# 判断是否启用换脸
|
||||
face_image = None
|
||||
if face_swap_on and face_img is not None:
|
||||
if face_swap_on:
|
||||
# Gradio 可能传 PIL.Image / numpy.ndarray / 文件路径 / None
|
||||
if face_img is not None:
|
||||
if isinstance(face_img, Image.Image):
|
||||
face_image = face_img
|
||||
elif isinstance(face_img, str) and os.path.isfile(face_img):
|
||||
face_image = Image.open(face_img).convert("RGB")
|
||||
if face_swap_on and face_image is None:
|
||||
# 尝试从默认路径加载
|
||||
else:
|
||||
# numpy array 等其他格式
|
||||
try:
|
||||
import numpy as np
|
||||
if isinstance(face_img, np.ndarray):
|
||||
face_image = Image.fromarray(face_img).convert("RGB")
|
||||
logger.info("头像从 numpy array 转换为 PIL Image")
|
||||
except Exception as e:
|
||||
logger.warning("头像格式转换失败 (%s): %s", type(face_img).__name__, e)
|
||||
# 如果 UI 没传有效头像,从本地文件加载
|
||||
if face_image is None:
|
||||
face_image = SDService.load_face_image()
|
||||
if face_image is not None:
|
||||
logger.info("换脸头像已就绪: %dx%d", face_image.width, face_image.height)
|
||||
else:
|
||||
logger.warning("换脸已启用但未找到有效头像")
|
||||
|
||||
images = svc.txt2img(
|
||||
prompt=prompt,
|
||||
@ -333,9 +348,11 @@ def generate_images(sd_url, prompt, neg_prompt, model, steps, cfg_scale, face_sw
|
||||
steps=int(steps),
|
||||
cfg_scale=float(cfg_scale),
|
||||
face_image=face_image,
|
||||
quality_mode=quality_mode,
|
||||
)
|
||||
preset = get_sd_preset(quality_mode)
|
||||
swap_hint = " (已换脸)" if face_image else ""
|
||||
return images, images, f"✅ 生成 {len(images)} 张图片{swap_hint}"
|
||||
return images, images, f"✅ 生成 {len(images)} 张图片{swap_hint} [{quality_mode}]"
|
||||
except Exception as e:
|
||||
logger.error("图片生成失败: %s", e)
|
||||
return None, [], f"❌ 绘图失败: {e}"
|
||||
@ -1976,7 +1993,9 @@ def auto_publish_once(topics_str, mcp_url, sd_url_val, sd_model_name, model, fac
|
||||
_auto_log_append("🎭 换脸已启用")
|
||||
else:
|
||||
_auto_log_append("⚠️ 换脸已启用但未找到头像,跳过换脸")
|
||||
images = sd_svc.txt2img(prompt=sd_prompt, model=sd_model_name, face_image=face_image)
|
||||
images = sd_svc.txt2img(prompt=sd_prompt, model=sd_model_name,
|
||||
face_image=face_image,
|
||||
quality_mode="快速 (约30秒)")
|
||||
if not images:
|
||||
_record_error()
|
||||
return "❌ 图片生成失败:没有返回图片"
|
||||
@ -2529,12 +2548,18 @@ with gr.Blocks(
|
||||
|
||||
gr.Markdown("---")
|
||||
gr.Markdown("### 🎨 绘图参数")
|
||||
with gr.Accordion("高级设置", open=False):
|
||||
quality_mode = gr.Radio(
|
||||
SD_PRESET_NAMES,
|
||||
label="生成模式",
|
||||
value="标准 (约1分钟)",
|
||||
info="快速≈30s 标准≈1min 精细≈2-3min (SDXL)",
|
||||
)
|
||||
with gr.Accordion("高级设置 (覆盖预设)", open=False):
|
||||
neg_prompt = gr.Textbox(
|
||||
label="反向提示词", value=DEFAULT_NEGATIVE, lines=2,
|
||||
)
|
||||
steps = gr.Slider(15, 50, value=25, step=1, label="步数")
|
||||
cfg_scale = gr.Slider(1, 15, value=7, step=0.5, label="CFG Scale")
|
||||
steps = gr.Slider(8, 50, value=20, step=1, label="步数")
|
||||
cfg_scale = gr.Slider(1, 15, value=5.5, step=0.5, label="CFG Scale")
|
||||
btn_gen_img = gr.Button("🎨 第二步:生成图片", variant="primary")
|
||||
|
||||
# 中栏:文案编辑
|
||||
@ -3093,10 +3118,21 @@ with gr.Blocks(
|
||||
outputs=[res_title, res_content, res_prompt, res_tags, status_bar],
|
||||
)
|
||||
|
||||
# 生成模式切换 → 同步更新步数/CFG预览
|
||||
def on_quality_mode_change(mode):
|
||||
p = get_sd_preset(mode)
|
||||
return p["steps"], p["cfg_scale"]
|
||||
|
||||
quality_mode.change(
|
||||
fn=on_quality_mode_change,
|
||||
inputs=[quality_mode],
|
||||
outputs=[steps, cfg_scale],
|
||||
)
|
||||
|
||||
btn_gen_img.click(
|
||||
fn=generate_images,
|
||||
inputs=[sd_url, res_prompt, neg_prompt, sd_model, steps, cfg_scale,
|
||||
face_swap_toggle, face_image_preview],
|
||||
face_swap_toggle, face_image_preview, quality_mode],
|
||||
outputs=[gallery, state_images, status_bar],
|
||||
)
|
||||
|
||||
|
||||
@ -16,6 +16,46 @@ SD_TIMEOUT = 1800 # 图片生成可能需要较长时间
|
||||
# 头像文件默认保存路径
|
||||
FACE_IMAGE_PATH = os.path.join(os.path.dirname(__file__), "my_face.png")
|
||||
|
||||
# ==================== 生成质量预设 ====================
|
||||
# 针对 JuggernautXL (SDXL) 优化的三档参数
|
||||
SD_PRESETS = {
|
||||
"快速 (约30秒)": {
|
||||
"steps": 12,
|
||||
"cfg_scale": 5.0,
|
||||
"width": 768,
|
||||
"height": 1024,
|
||||
"sampler_name": "Euler a",
|
||||
"scheduler": "Normal",
|
||||
"batch_size": 2,
|
||||
},
|
||||
"标准 (约1分钟)": {
|
||||
"steps": 20,
|
||||
"cfg_scale": 5.5,
|
||||
"width": 832,
|
||||
"height": 1216,
|
||||
"sampler_name": "DPM++ 2M",
|
||||
"scheduler": "Karras",
|
||||
"batch_size": 2,
|
||||
},
|
||||
"精细 (约2-3分钟)": {
|
||||
"steps": 35,
|
||||
"cfg_scale": 6.0,
|
||||
"width": 832,
|
||||
"height": 1216,
|
||||
"sampler_name": "DPM++ 2M SDE",
|
||||
"scheduler": "Karras",
|
||||
"batch_size": 2,
|
||||
},
|
||||
}
|
||||
|
||||
SD_PRESET_NAMES = list(SD_PRESETS.keys())
|
||||
|
||||
|
||||
def get_sd_preset(name: str) -> dict:
|
||||
"""获取生成预设参数,默认返回'标准'"""
|
||||
return SD_PRESETS.get(name, SD_PRESETS["标准 (约1分钟)"])
|
||||
|
||||
|
||||
# 默认反向提示词(针对 JuggernautXL / SDXL 优化,偏向东方审美)
|
||||
DEFAULT_NEGATIVE = (
|
||||
"nsfw, nudity, lowres, bad anatomy, bad hands, text, error, missing fingers, "
|
||||
@ -97,24 +137,23 @@ class SDService:
|
||||
1, # 10: upscaler visibility
|
||||
False, # 11: swap in source
|
||||
True, # 12: swap in generated
|
||||
"Minimum", # 13: log level
|
||||
"No", # 14: gender detection (source)
|
||||
"No", # 15: gender detection (target)
|
||||
1, # 13: console log level (0=min, 1=med, 2=max)
|
||||
0, # 14: gender detection source (0=No)
|
||||
0, # 15: gender detection target (0=No)
|
||||
False, # 16: save original
|
||||
0.6, # 17: CodeFormer weight (fidelity)
|
||||
True, # 18: source hash check
|
||||
0.8, # 17: CodeFormer weight (0=max effect, 1=min)
|
||||
False, # 18: source hash check
|
||||
False, # 19: target hash check
|
||||
"CUDA", # 20: execution provider
|
||||
True, # 21: face mask correction
|
||||
"Image(s)", # 22: select source type
|
||||
"None", # 23: face model name
|
||||
"", # 24: source folder
|
||||
None, # 25: multiple source images
|
||||
0, # 22: select source (0=Image, 1=FaceModel, 2=Folder)
|
||||
"", # 23: face model filename (when #22=1)
|
||||
"", # 24: source folder path (when #22=2)
|
||||
None, # 25: skip for API
|
||||
False, # 26: random image
|
||||
False, # 27: force upscale
|
||||
0.5, # 28: detection threshold
|
||||
0, # 29: max faces (0 = no limit)
|
||||
"tab_single", # 30: tab
|
||||
0.6, # 28: face detection threshold
|
||||
2, # 29: max faces to detect (0=unlimited)
|
||||
],
|
||||
}
|
||||
}
|
||||
@ -164,36 +203,44 @@ class SDService:
|
||||
prompt: str,
|
||||
negative_prompt: str = DEFAULT_NEGATIVE,
|
||||
model: str = None,
|
||||
steps: int = 30,
|
||||
cfg_scale: float = 5.0,
|
||||
width: int = 832,
|
||||
height: int = 1216,
|
||||
batch_size: int = 2,
|
||||
steps: int = None,
|
||||
cfg_scale: float = None,
|
||||
width: int = None,
|
||||
height: int = None,
|
||||
batch_size: int = None,
|
||||
seed: int = -1,
|
||||
sampler_name: str = "DPM++ 2M",
|
||||
scheduler: str = "Karras",
|
||||
sampler_name: str = None,
|
||||
scheduler: str = None,
|
||||
face_image: Image.Image = None,
|
||||
quality_mode: str = None,
|
||||
) -> list[Image.Image]:
|
||||
"""文生图(参数针对 JuggernautXL 优化)
|
||||
|
||||
Args:
|
||||
face_image: 头像 PIL Image,传入后自动启用 ReActor 换脸
|
||||
quality_mode: 预设模式名,如 '快速 (约30秒)' / '标准 (约1分钟)' / '精细 (约2-3分钟)'
|
||||
传入后自动应用预设参数,其余参数可覆盖
|
||||
"""
|
||||
if model:
|
||||
self.switch_model(model)
|
||||
|
||||
# 加载预设作为基底,再用显式参数覆盖
|
||||
preset = get_sd_preset(quality_mode) if quality_mode else get_sd_preset("标准 (约1分钟)")
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"steps": steps,
|
||||
"cfg_scale": cfg_scale,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"batch_size": batch_size,
|
||||
"steps": steps if steps is not None else preset["steps"],
|
||||
"cfg_scale": cfg_scale if cfg_scale is not None else preset["cfg_scale"],
|
||||
"width": width if width is not None else preset["width"],
|
||||
"height": height if height is not None else preset["height"],
|
||||
"batch_size": batch_size if batch_size is not None else preset["batch_size"],
|
||||
"seed": seed,
|
||||
"sampler_name": sampler_name,
|
||||
"scheduler": scheduler,
|
||||
"sampler_name": sampler_name if sampler_name is not None else preset["sampler_name"],
|
||||
"scheduler": scheduler if scheduler is not None else preset["scheduler"],
|
||||
}
|
||||
logger.info("SD 生成参数: steps=%s, cfg=%.1f, %dx%d, sampler=%s",
|
||||
payload['steps'], payload['cfg_scale'],
|
||||
payload['width'], payload['height'], payload['sampler_name'])
|
||||
|
||||
# 如果提供了头像,通过 ReActor 换脸
|
||||
if face_image is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user