- main.py: 4360 → 146 lines (96.6% reduction), entry layer only - services/: rate_limiter, autostart, persona, connection, profile, hotspot, content, engagement, scheduler, queue_ops (10 business modules) - ui/app.py: all Gradio UI code extracted into build_app(cfg, analytics) - Fix: with gr.Blocks() indented inside build_app function - Fix: cfg.all property (not get_all method) - Fix: STATUS_LABELS, get_persona_keywords, fetch_proactive_notes imports - Fix: queue_ops module-level set_publish_callback moved into configure() - Fix: pub_queue.format_*() wrapped as queue_format_table/calendar helpers - All 14 files syntax-verified, build_app() runtime-verified - 58/58 tasks complete"
255 lines
8.7 KiB
Python
255 lines
8.7 KiB
Python
"""
|
||
services/connection.py
|
||
LLM 提供商管理、SD 连接、MCP 连接、XHS 登录等服务函数
|
||
"""
|
||
import os
|
||
import re
|
||
import logging
|
||
|
||
import gradio as gr
|
||
|
||
from config_manager import ConfigManager
|
||
from llm_service import LLMService
|
||
from sd_service import SDService, get_model_profile_info
|
||
from mcp_client import get_mcp_client
|
||
|
||
logger = logging.getLogger("autobot")
|
||
cfg = ConfigManager()
|
||
|
||
def _get_llm_config() -> tuple[str, str, str]:
|
||
"""获取当前激活 LLM 的 (api_key, base_url, model)"""
|
||
p = cfg.get_active_llm()
|
||
if p:
|
||
return p["api_key"], p["base_url"], cfg.get("model", "")
|
||
return "", "", ""
|
||
|
||
|
||
def connect_llm(provider_name):
|
||
"""连接选中的 LLM 提供商并获取模型列表"""
|
||
if not provider_name:
|
||
return gr.update(choices=[], value=None), "⚠️ 请先选择或添加 LLM 提供商"
|
||
cfg.set_active_llm(provider_name)
|
||
p = cfg.get_active_llm()
|
||
if not p:
|
||
return gr.update(choices=[], value=None), "❌ 未找到该提供商配置"
|
||
try:
|
||
svc = LLMService(p["api_key"], p["base_url"])
|
||
models = svc.get_models()
|
||
if models:
|
||
return (
|
||
gr.update(choices=models, value=models[0]),
|
||
f"✅ 已连接「{provider_name}」,加载 {len(models)} 个模型",
|
||
)
|
||
else:
|
||
# API 无法获取模型列表,保留手动输入
|
||
current_model = cfg.get("model", "")
|
||
return (
|
||
gr.update(choices=[current_model] if current_model else [], value=current_model or None),
|
||
f"⚠️ 已连接「{provider_name}」,但未获取到模型列表,请手动输入模型名",
|
||
)
|
||
except Exception as e:
|
||
logger.error("LLM 连接失败: %s", e)
|
||
current_model = cfg.get("model", "")
|
||
return (
|
||
gr.update(choices=[current_model] if current_model else [], value=current_model or None),
|
||
f"❌ 连接「{provider_name}」失败: {e}",
|
||
)
|
||
|
||
|
||
def add_llm_provider(name, api_key, base_url):
|
||
"""添加新的 LLM 提供商"""
|
||
msg = cfg.add_llm_provider(name, api_key, base_url)
|
||
names = cfg.get_llm_provider_names()
|
||
active = cfg.get("active_llm", "")
|
||
return (
|
||
gr.update(choices=names, value=active),
|
||
msg,
|
||
)
|
||
|
||
|
||
def remove_llm_provider(provider_name):
|
||
"""删除 LLM 提供商"""
|
||
if not provider_name:
|
||
return gr.update(choices=cfg.get_llm_provider_names(), value=cfg.get("active_llm", "")), "⚠️ 请先选择要删除的提供商"
|
||
msg = cfg.remove_llm_provider(provider_name)
|
||
names = cfg.get_llm_provider_names()
|
||
active = cfg.get("active_llm", "")
|
||
return (
|
||
gr.update(choices=names, value=active),
|
||
msg,
|
||
)
|
||
|
||
|
||
def on_provider_selected(provider_name):
|
||
"""切换 LLM 提供商时更新显示信息"""
|
||
if not provider_name:
|
||
return "未选择提供商"
|
||
for p in cfg.get_llm_providers():
|
||
if p["name"] == provider_name:
|
||
cfg.set_active_llm(provider_name)
|
||
masked_key = p["api_key"][:8] + "***" if len(p["api_key"]) > 8 else "***"
|
||
return f"**{provider_name}** \nAPI Key: `{masked_key}` \nBase URL: `{p['base_url']}`"
|
||
return "未找到该提供商"
|
||
|
||
|
||
# ==================================================
|
||
# Tab 1: 内容创作
|
||
# ==================================================
|
||
|
||
|
||
def connect_sd(sd_url):
|
||
"""连接 SD 并获取模型列表"""
|
||
try:
|
||
svc = SDService(sd_url)
|
||
ok, msg = svc.check_connection()
|
||
if ok:
|
||
models = svc.get_models()
|
||
cfg.set("sd_url", sd_url)
|
||
first = models[0] if models else None
|
||
info = get_model_profile_info(first) if first else "未检测到模型"
|
||
return gr.update(choices=models, value=first), f"✅ {msg}", info
|
||
return gr.update(choices=[]), f"❌ {msg}", ""
|
||
except Exception as e:
|
||
logger.error("SD 连接失败: %s", e)
|
||
return gr.update(choices=[]), f"❌ SD 连接失败: {e}", ""
|
||
|
||
|
||
def on_sd_model_change(model_name):
|
||
"""SD 模型切换时显示模型档案信息"""
|
||
if not model_name:
|
||
return "未选择模型"
|
||
return get_model_profile_info(model_name)
|
||
|
||
|
||
def check_mcp_status(mcp_url):
|
||
"""检查 MCP 连接状态"""
|
||
try:
|
||
client = get_mcp_client(mcp_url)
|
||
ok, msg = client.check_connection()
|
||
if ok:
|
||
cfg.set("mcp_url", mcp_url)
|
||
return f"✅ MCP 服务正常 - {msg}"
|
||
return f"❌ {msg}"
|
||
except Exception as e:
|
||
return f"❌ MCP 连接失败: {e}"
|
||
|
||
|
||
# ==================================================
|
||
# 小红书账号登录
|
||
# ==================================================
|
||
|
||
|
||
def get_login_qrcode(mcp_url):
|
||
"""获取小红书登录二维码"""
|
||
try:
|
||
client = get_mcp_client(mcp_url)
|
||
result = client.get_login_qrcode()
|
||
if "error" in result:
|
||
return None, f"❌ 获取二维码失败: {result['error']}"
|
||
qr_image = result.get("qr_image")
|
||
msg = result.get("text", "")
|
||
if qr_image:
|
||
return qr_image, f"✅ 二维码已生成,请用小红书 App 扫码\n{msg}"
|
||
return None, f"⚠️ 未获取到二维码图片,MCP 返回:\n{msg}"
|
||
except Exception as e:
|
||
logger.error("获取登录二维码失败: %s", e)
|
||
return None, f"❌ 获取二维码失败: {e}"
|
||
|
||
|
||
def logout_xhs(mcp_url):
|
||
"""退出登录:清除 cookies 并重置本地 token"""
|
||
try:
|
||
client = get_mcp_client(mcp_url)
|
||
result = client.delete_cookies()
|
||
if "error" in result:
|
||
return f"❌ 退出失败: {result['error']}"
|
||
cfg.set("xsec_token", "")
|
||
client._reset()
|
||
return "✅ 已退出登录,可以重新扫码登录"
|
||
except Exception as e:
|
||
logger.error("退出登录失败: %s", e)
|
||
return f"❌ 退出失败: {e}"
|
||
|
||
|
||
def _auto_fetch_xsec_token(mcp_url) -> str:
|
||
"""从推荐列表自动获取一个有效的 xsec_token"""
|
||
try:
|
||
client = get_mcp_client(mcp_url)
|
||
entries = client.list_feeds_parsed()
|
||
for e in entries:
|
||
token = e.get("xsec_token", "")
|
||
if token:
|
||
return token
|
||
except Exception as e:
|
||
logger.warning("自动获取 xsec_token 失败: %s", e)
|
||
return ""
|
||
|
||
|
||
def check_login(mcp_url):
|
||
"""检查登录状态,登录成功后自动获取 xsec_token 并保存"""
|
||
try:
|
||
client = get_mcp_client(mcp_url)
|
||
result = client.check_login_status()
|
||
if "error" in result:
|
||
return f"❌ {result['error']}", gr.update(), gr.update()
|
||
text = result.get("text", "")
|
||
if "未登录" in text:
|
||
return f"🔴 {text}", gr.update(), gr.update()
|
||
|
||
# 登录成功 → 自动获取 xsec_token
|
||
token = _auto_fetch_xsec_token(mcp_url)
|
||
if token:
|
||
cfg.set("xsec_token", token)
|
||
logger.info("自动获取 xsec_token 成功")
|
||
return (
|
||
f"🟢 {text}\n\n✅ xsec_token 已自动获取并保存",
|
||
gr.update(value=cfg.get("my_user_id", "")),
|
||
gr.update(value=token),
|
||
)
|
||
return f"🟢 {text}\n\n⚠️ 自动获取 xsec_token 失败,请手动刷新", gr.update(), gr.update()
|
||
except Exception as e:
|
||
return f"❌ 检查登录状态失败: {e}", gr.update(), gr.update()
|
||
|
||
|
||
def save_my_user_id(user_id_input):
|
||
"""保存用户 ID (验证 24 位十六进制格式)"""
|
||
uid = (user_id_input or "").strip()
|
||
if not uid:
|
||
cfg.set("my_user_id", "")
|
||
return "⚠️ 已清除用户 ID"
|
||
if not re.match(r'^[0-9a-fA-F]{24}$', uid):
|
||
return (
|
||
"❌ 格式错误!用户 ID 应为 24 位十六进制字符串\n"
|
||
f"你输入的: `{uid}` ({len(uid)} 位)\n\n"
|
||
"💡 如果你输入的是小红书号 (纯数字如 18688457507),那不是 userId。"
|
||
)
|
||
cfg.set("my_user_id", uid)
|
||
return f"✅ 用户 ID 已保存: `{uid}`"
|
||
|
||
|
||
# ================= 头像/换脸管理 =================
|
||
|
||
def upload_face_image(img):
|
||
"""上传并保存头像图片"""
|
||
if img is None:
|
||
return None, "❌ 请上传头像图片"
|
||
try:
|
||
if isinstance(img, str) and os.path.isfile(img):
|
||
img = Image.open(img).convert("RGB")
|
||
elif not isinstance(img, Image.Image):
|
||
return None, "❌ 无法识别图片格式"
|
||
path = SDService.save_face_image(img)
|
||
return img, f"✅ 头像已保存至 {os.path.basename(path)}"
|
||
except Exception as e:
|
||
return None, f"❌ 保存失败: {e}"
|
||
|
||
|
||
def load_saved_face_image():
|
||
"""加载已保存的头像"""
|
||
img = SDService.load_face_image()
|
||
if img:
|
||
return img, "✅ 已加载保存的头像"
|
||
return None, "ℹ️ 尚未设置头像"
|
||
|
||
|