xhs_factory/services/connection.py
zhoujie 2ba87c8f6e
Some checks failed
CI / Lint (ruff) (push) Has been cancelled
CI / Import Check (push) Has been cancelled
📝 docs(project): 添加开源社区标准文档与 CI 工作流
- 新增 GitHub Issue 模板(Bug 报告、功能请求)和 Pull Request 模板
- 新增 Code of Conduct(贡献者行为准则)和 Security Policy(安全政策)
- 新增 CI 工作流(GitHub Actions),包含 ruff 代码检查和导入验证
- 新增开发依赖文件 requirements-dev.txt

📦 build(ci): 配置 GitHub Actions 持续集成

- 在 push 到 main 分支和 pull request 时自动触发 CI
- 添加 lint 任务执行 ruff 代码风格检查
- 添加 import-check 任务验证核心服务模块导入

♻️ refactor(structure): 重构项目目录结构

- 将根目录的 6 个服务模块迁移至 services/ 包
- 更新所有相关文件的导入语句(main.py、ui/、services/)
- 根目录仅保留 main.py 作为唯一 Python 入口文件

🔧 chore(config): 调整配置和资源文件路径

- 将 config.json 移至 config/ 目录,更新相关引用
- 将个人头像图片移至 assets/faces/ 目录,更新 .gitignore
- 更新 Dockerfile 和 docker-compose.yml 中的配置路径

📝 docs(readme): 完善 README 文档

- 添加项目状态徽章(Python 版本、License、CI)
- 更新项目结构图反映实际目录布局
- 修正使用指南中的 Tab 名称和操作路径
- 替换 your-username 占位符为格式提示

🗑️ chore(cleanup): 清理冗余文件

- 删除旧版备份文件、测试脚本、临时记录和运行日志
- 删除散落的个人图片文件(已归档至 assets/faces/)
2026-02-27 22:12:39 +08:00

255 lines
8.7 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
services/connection.py
LLM 提供商管理、SD 连接、MCP 连接、XHS 登录等服务函数
"""
import os
import re
import logging
import gradio as gr
from .config_manager import ConfigManager
from .llm_service import LLMService
from .sd_service import SDService, get_model_profile_info
from .mcp_client import get_mcp_client
logger = logging.getLogger("autobot")
cfg = ConfigManager()
def _get_llm_config() -> tuple[str, str, str]:
"""获取当前激活 LLM 的 (api_key, base_url, model)"""
p = cfg.get_active_llm()
if p:
return p["api_key"], p["base_url"], cfg.get("model", "")
return "", "", ""
def connect_llm(provider_name):
"""连接选中的 LLM 提供商并获取模型列表"""
if not provider_name:
return gr.update(choices=[], value=None), "⚠️ 请先选择或添加 LLM 提供商"
cfg.set_active_llm(provider_name)
p = cfg.get_active_llm()
if not p:
return gr.update(choices=[], value=None), "❌ 未找到该提供商配置"
try:
svc = LLMService(p["api_key"], p["base_url"])
models = svc.get_models()
if models:
return (
gr.update(choices=models, value=models[0]),
f"✅ 已连接「{provider_name}」,加载 {len(models)} 个模型",
)
else:
# API 无法获取模型列表,保留手动输入
current_model = cfg.get("model", "")
return (
gr.update(choices=[current_model] if current_model else [], value=current_model or None),
f"⚠️ 已连接「{provider_name}」,但未获取到模型列表,请手动输入模型名",
)
except Exception as e:
logger.error("LLM 连接失败: %s", e)
current_model = cfg.get("model", "")
return (
gr.update(choices=[current_model] if current_model else [], value=current_model or None),
f"❌ 连接「{provider_name}」失败: {e}",
)
def add_llm_provider(name, api_key, base_url):
"""添加新的 LLM 提供商"""
msg = cfg.add_llm_provider(name, api_key, base_url)
names = cfg.get_llm_provider_names()
active = cfg.get("active_llm", "")
return (
gr.update(choices=names, value=active),
msg,
)
def remove_llm_provider(provider_name):
"""删除 LLM 提供商"""
if not provider_name:
return gr.update(choices=cfg.get_llm_provider_names(), value=cfg.get("active_llm", "")), "⚠️ 请先选择要删除的提供商"
msg = cfg.remove_llm_provider(provider_name)
names = cfg.get_llm_provider_names()
active = cfg.get("active_llm", "")
return (
gr.update(choices=names, value=active),
msg,
)
def on_provider_selected(provider_name):
"""切换 LLM 提供商时更新显示信息"""
if not provider_name:
return "未选择提供商"
for p in cfg.get_llm_providers():
if p["name"] == provider_name:
cfg.set_active_llm(provider_name)
masked_key = p["api_key"][:8] + "***" if len(p["api_key"]) > 8 else "***"
return f"**{provider_name}** \nAPI Key: `{masked_key}` \nBase URL: `{p['base_url']}`"
return "未找到该提供商"
# ==================================================
# Tab 1: 内容创作
# ==================================================
def connect_sd(sd_url):
"""连接 SD 并获取模型列表"""
try:
svc = SDService(sd_url)
ok, msg = svc.check_connection()
if ok:
models = svc.get_models()
cfg.set("sd_url", sd_url)
first = models[0] if models else None
info = get_model_profile_info(first) if first else "未检测到模型"
return gr.update(choices=models, value=first), f"{msg}", info
return gr.update(choices=[]), f"{msg}", ""
except Exception as e:
logger.error("SD 连接失败: %s", e)
return gr.update(choices=[]), f"❌ SD 连接失败: {e}", ""
def on_sd_model_change(model_name):
"""SD 模型切换时显示模型档案信息"""
if not model_name:
return "未选择模型"
return get_model_profile_info(model_name)
def check_mcp_status(mcp_url):
"""检查 MCP 连接状态"""
try:
client = get_mcp_client(mcp_url)
ok, msg = client.check_connection()
if ok:
cfg.set("mcp_url", mcp_url)
return f"✅ MCP 服务正常 - {msg}"
return f"{msg}"
except Exception as e:
return f"❌ MCP 连接失败: {e}"
# ==================================================
# 小红书账号登录
# ==================================================
def get_login_qrcode(mcp_url):
"""获取小红书登录二维码"""
try:
client = get_mcp_client(mcp_url)
result = client.get_login_qrcode()
if "error" in result:
return None, f"❌ 获取二维码失败: {result['error']}"
qr_image = result.get("qr_image")
msg = result.get("text", "")
if qr_image:
return qr_image, f"✅ 二维码已生成,请用小红书 App 扫码\n{msg}"
return None, f"⚠️ 未获取到二维码图片MCP 返回:\n{msg}"
except Exception as e:
logger.error("获取登录二维码失败: %s", e)
return None, f"❌ 获取二维码失败: {e}"
def logout_xhs(mcp_url):
"""退出登录:清除 cookies 并重置本地 token"""
try:
client = get_mcp_client(mcp_url)
result = client.delete_cookies()
if "error" in result:
return f"❌ 退出失败: {result['error']}"
cfg.set("xsec_token", "")
client._reset()
return "✅ 已退出登录,可以重新扫码登录"
except Exception as e:
logger.error("退出登录失败: %s", e)
return f"❌ 退出失败: {e}"
def _auto_fetch_xsec_token(mcp_url) -> str:
"""从推荐列表自动获取一个有效的 xsec_token"""
try:
client = get_mcp_client(mcp_url)
entries = client.list_feeds_parsed()
for e in entries:
token = e.get("xsec_token", "")
if token:
return token
except Exception as e:
logger.warning("自动获取 xsec_token 失败: %s", e)
return ""
def check_login(mcp_url):
"""检查登录状态,登录成功后自动获取 xsec_token 并保存"""
try:
client = get_mcp_client(mcp_url)
result = client.check_login_status()
if "error" in result:
return f"{result['error']}", gr.update(), gr.update()
text = result.get("text", "")
if "未登录" in text:
return f"🔴 {text}", gr.update(), gr.update()
# 登录成功 → 自动获取 xsec_token
token = _auto_fetch_xsec_token(mcp_url)
if token:
cfg.set("xsec_token", token)
logger.info("自动获取 xsec_token 成功")
return (
f"🟢 {text}\n\n✅ xsec_token 已自动获取并保存",
gr.update(value=cfg.get("my_user_id", "")),
gr.update(value=token),
)
return f"🟢 {text}\n\n⚠️ 自动获取 xsec_token 失败,请手动刷新", gr.update(), gr.update()
except Exception as e:
return f"❌ 检查登录状态失败: {e}", gr.update(), gr.update()
def save_my_user_id(user_id_input):
"""保存用户 ID (验证 24 位十六进制格式)"""
uid = (user_id_input or "").strip()
if not uid:
cfg.set("my_user_id", "")
return "⚠️ 已清除用户 ID"
if not re.match(r'^[0-9a-fA-F]{24}$', uid):
return (
"❌ 格式错误!用户 ID 应为 24 位十六进制字符串\n"
f"你输入的: `{uid}` ({len(uid)} 位)\n\n"
"💡 如果你输入的是小红书号 (纯数字如 18688457507),那不是 userId。"
)
cfg.set("my_user_id", uid)
return f"✅ 用户 ID 已保存: `{uid}`"
# ================= 头像/换脸管理 =================
def upload_face_image(img):
"""上传并保存头像图片"""
if img is None:
return None, "❌ 请上传头像图片"
try:
if isinstance(img, str) and os.path.isfile(img):
img = Image.open(img).convert("RGB")
elif not isinstance(img, Image.Image):
return None, "❌ 无法识别图片格式"
path = SDService.save_face_image(img)
return img, f"✅ 头像已保存至 {os.path.basename(path)}"
except Exception as e:
return None, f"❌ 保存失败: {e}"
def load_saved_face_image():
"""加载已保存的头像"""
img = SDService.load_face_image()
if img:
return img, "✅ 已加载保存的头像"
return None, " 尚未设置头像"