""" services/connection.py LLM 提供商管理、SD 连接、MCP 连接、XHS 登录等服务函数 """ import os import re import logging import gradio as gr from config_manager import ConfigManager from llm_service import LLMService from sd_service import SDService, get_model_profile_info from mcp_client import get_mcp_client logger = logging.getLogger("autobot") cfg = ConfigManager() def _get_llm_config() -> tuple[str, str, str]: """获取当前激活 LLM 的 (api_key, base_url, model)""" p = cfg.get_active_llm() if p: return p["api_key"], p["base_url"], cfg.get("model", "") return "", "", "" def connect_llm(provider_name): """连接选中的 LLM 提供商并获取模型列表""" if not provider_name: return gr.update(choices=[], value=None), "⚠️ 请先选择或添加 LLM 提供商" cfg.set_active_llm(provider_name) p = cfg.get_active_llm() if not p: return gr.update(choices=[], value=None), "❌ 未找到该提供商配置" try: svc = LLMService(p["api_key"], p["base_url"]) models = svc.get_models() if models: return ( gr.update(choices=models, value=models[0]), f"✅ 已连接「{provider_name}」,加载 {len(models)} 个模型", ) else: # API 无法获取模型列表,保留手动输入 current_model = cfg.get("model", "") return ( gr.update(choices=[current_model] if current_model else [], value=current_model or None), f"⚠️ 已连接「{provider_name}」,但未获取到模型列表,请手动输入模型名", ) except Exception as e: logger.error("LLM 连接失败: %s", e) current_model = cfg.get("model", "") return ( gr.update(choices=[current_model] if current_model else [], value=current_model or None), f"❌ 连接「{provider_name}」失败: {e}", ) def add_llm_provider(name, api_key, base_url): """添加新的 LLM 提供商""" msg = cfg.add_llm_provider(name, api_key, base_url) names = cfg.get_llm_provider_names() active = cfg.get("active_llm", "") return ( gr.update(choices=names, value=active), msg, ) def remove_llm_provider(provider_name): """删除 LLM 提供商""" if not provider_name: return gr.update(choices=cfg.get_llm_provider_names(), value=cfg.get("active_llm", "")), "⚠️ 请先选择要删除的提供商" msg = cfg.remove_llm_provider(provider_name) names = cfg.get_llm_provider_names() active = cfg.get("active_llm", "") return ( gr.update(choices=names, value=active), msg, ) def on_provider_selected(provider_name): """切换 LLM 提供商时更新显示信息""" if not provider_name: return "未选择提供商" for p in cfg.get_llm_providers(): if p["name"] == provider_name: cfg.set_active_llm(provider_name) masked_key = p["api_key"][:8] + "***" if len(p["api_key"]) > 8 else "***" return f"**{provider_name}** \nAPI Key: `{masked_key}` \nBase URL: `{p['base_url']}`" return "未找到该提供商" # ================================================== # Tab 1: 内容创作 # ================================================== def connect_sd(sd_url): """连接 SD 并获取模型列表""" try: svc = SDService(sd_url) ok, msg = svc.check_connection() if ok: models = svc.get_models() cfg.set("sd_url", sd_url) first = models[0] if models else None info = get_model_profile_info(first) if first else "未检测到模型" return gr.update(choices=models, value=first), f"✅ {msg}", info return gr.update(choices=[]), f"❌ {msg}", "" except Exception as e: logger.error("SD 连接失败: %s", e) return gr.update(choices=[]), f"❌ SD 连接失败: {e}", "" def on_sd_model_change(model_name): """SD 模型切换时显示模型档案信息""" if not model_name: return "未选择模型" return get_model_profile_info(model_name) def check_mcp_status(mcp_url): """检查 MCP 连接状态""" try: client = get_mcp_client(mcp_url) ok, msg = client.check_connection() if ok: cfg.set("mcp_url", mcp_url) return f"✅ MCP 服务正常 - {msg}" return f"❌ {msg}" except Exception as e: return f"❌ MCP 连接失败: {e}" # ================================================== # 小红书账号登录 # ================================================== def get_login_qrcode(mcp_url): """获取小红书登录二维码""" try: client = get_mcp_client(mcp_url) result = client.get_login_qrcode() if "error" in result: return None, f"❌ 获取二维码失败: {result['error']}" qr_image = result.get("qr_image") msg = result.get("text", "") if qr_image: return qr_image, f"✅ 二维码已生成,请用小红书 App 扫码\n{msg}" return None, f"⚠️ 未获取到二维码图片,MCP 返回:\n{msg}" except Exception as e: logger.error("获取登录二维码失败: %s", e) return None, f"❌ 获取二维码失败: {e}" def logout_xhs(mcp_url): """退出登录:清除 cookies 并重置本地 token""" try: client = get_mcp_client(mcp_url) result = client.delete_cookies() if "error" in result: return f"❌ 退出失败: {result['error']}" cfg.set("xsec_token", "") client._reset() return "✅ 已退出登录,可以重新扫码登录" except Exception as e: logger.error("退出登录失败: %s", e) return f"❌ 退出失败: {e}" def _auto_fetch_xsec_token(mcp_url) -> str: """从推荐列表自动获取一个有效的 xsec_token""" try: client = get_mcp_client(mcp_url) entries = client.list_feeds_parsed() for e in entries: token = e.get("xsec_token", "") if token: return token except Exception as e: logger.warning("自动获取 xsec_token 失败: %s", e) return "" def check_login(mcp_url): """检查登录状态,登录成功后自动获取 xsec_token 并保存""" try: client = get_mcp_client(mcp_url) result = client.check_login_status() if "error" in result: return f"❌ {result['error']}", gr.update(), gr.update() text = result.get("text", "") if "未登录" in text: return f"🔴 {text}", gr.update(), gr.update() # 登录成功 → 自动获取 xsec_token token = _auto_fetch_xsec_token(mcp_url) if token: cfg.set("xsec_token", token) logger.info("自动获取 xsec_token 成功") return ( f"🟢 {text}\n\n✅ xsec_token 已自动获取并保存", gr.update(value=cfg.get("my_user_id", "")), gr.update(value=token), ) return f"🟢 {text}\n\n⚠️ 自动获取 xsec_token 失败,请手动刷新", gr.update(), gr.update() except Exception as e: return f"❌ 检查登录状态失败: {e}", gr.update(), gr.update() def save_my_user_id(user_id_input): """保存用户 ID (验证 24 位十六进制格式)""" uid = (user_id_input or "").strip() if not uid: cfg.set("my_user_id", "") return "⚠️ 已清除用户 ID" if not re.match(r'^[0-9a-fA-F]{24}$', uid): return ( "❌ 格式错误!用户 ID 应为 24 位十六进制字符串\n" f"你输入的: `{uid}` ({len(uid)} 位)\n\n" "💡 如果你输入的是小红书号 (纯数字如 18688457507),那不是 userId。" ) cfg.set("my_user_id", uid) return f"✅ 用户 ID 已保存: `{uid}`" # ================= 头像/换脸管理 ================= def upload_face_image(img): """上传并保存头像图片""" if img is None: return None, "❌ 请上传头像图片" try: if isinstance(img, str) and os.path.isfile(img): img = Image.open(img).convert("RGB") elif not isinstance(img, Image.Image): return None, "❌ 无法识别图片格式" path = SDService.save_face_image(img) return img, f"✅ 头像已保存至 {os.path.basename(path)}" except Exception as e: return None, f"❌ 保存失败: {e}" def load_saved_face_image(): """加载已保存的头像""" img = SDService.load_face_image() if img: return img, "✅ 已加载保存的头像" return None, "ℹ️ 尚未设置头像"