- 新增 GitHub Issue 模板(Bug 报告、功能请求)和 Pull Request 模板 - 新增 Code of Conduct(贡献者行为准则)和 Security Policy(安全政策) - 新增 CI 工作流(GitHub Actions),包含 ruff 代码检查和导入验证 - 新增开发依赖文件 requirements-dev.txt 📦 build(ci): 配置 GitHub Actions 持续集成 - 在 push 到 main 分支和 pull request 时自动触发 CI - 添加 lint 任务执行 ruff 代码风格检查 - 添加 import-check 任务验证核心服务模块导入 ♻️ refactor(structure): 重构项目目录结构 - 将根目录的 6 个服务模块迁移至 services/ 包 - 更新所有相关文件的导入语句(main.py、ui/、services/) - 根目录仅保留 main.py 作为唯一 Python 入口文件 🔧 chore(config): 调整配置和资源文件路径 - 将 config.json 移至 config/ 目录,更新相关引用 - 将个人头像图片移至 assets/faces/ 目录,更新 .gitignore - 更新 Dockerfile 和 docker-compose.yml 中的配置路径 📝 docs(readme): 完善 README 文档 - 添加项目状态徽章(Python 版本、License、CI) - 更新项目结构图反映实际目录布局 - 修正使用指南中的 Tab 名称和操作路径 - 替换 your-username 占位符为格式提示 🗑️ chore(cleanup): 清理冗余文件 - 删除旧版备份文件、测试脚本、临时记录和运行日志 - 删除散落的个人图片文件(已归档至 assets/faces/)
255 lines
8.7 KiB
Python
255 lines
8.7 KiB
Python
"""
|
||
services/connection.py
|
||
LLM 提供商管理、SD 连接、MCP 连接、XHS 登录等服务函数
|
||
"""
|
||
import os
|
||
import re
|
||
import logging
|
||
|
||
import gradio as gr
|
||
|
||
from .config_manager import ConfigManager
|
||
from .llm_service import LLMService
|
||
from .sd_service import SDService, get_model_profile_info
|
||
from .mcp_client import get_mcp_client
|
||
|
||
logger = logging.getLogger("autobot")
|
||
cfg = ConfigManager()
|
||
|
||
def _get_llm_config() -> tuple[str, str, str]:
|
||
"""获取当前激活 LLM 的 (api_key, base_url, model)"""
|
||
p = cfg.get_active_llm()
|
||
if p:
|
||
return p["api_key"], p["base_url"], cfg.get("model", "")
|
||
return "", "", ""
|
||
|
||
|
||
def connect_llm(provider_name):
|
||
"""连接选中的 LLM 提供商并获取模型列表"""
|
||
if not provider_name:
|
||
return gr.update(choices=[], value=None), "⚠️ 请先选择或添加 LLM 提供商"
|
||
cfg.set_active_llm(provider_name)
|
||
p = cfg.get_active_llm()
|
||
if not p:
|
||
return gr.update(choices=[], value=None), "❌ 未找到该提供商配置"
|
||
try:
|
||
svc = LLMService(p["api_key"], p["base_url"])
|
||
models = svc.get_models()
|
||
if models:
|
||
return (
|
||
gr.update(choices=models, value=models[0]),
|
||
f"✅ 已连接「{provider_name}」,加载 {len(models)} 个模型",
|
||
)
|
||
else:
|
||
# API 无法获取模型列表,保留手动输入
|
||
current_model = cfg.get("model", "")
|
||
return (
|
||
gr.update(choices=[current_model] if current_model else [], value=current_model or None),
|
||
f"⚠️ 已连接「{provider_name}」,但未获取到模型列表,请手动输入模型名",
|
||
)
|
||
except Exception as e:
|
||
logger.error("LLM 连接失败: %s", e)
|
||
current_model = cfg.get("model", "")
|
||
return (
|
||
gr.update(choices=[current_model] if current_model else [], value=current_model or None),
|
||
f"❌ 连接「{provider_name}」失败: {e}",
|
||
)
|
||
|
||
|
||
def add_llm_provider(name, api_key, base_url):
|
||
"""添加新的 LLM 提供商"""
|
||
msg = cfg.add_llm_provider(name, api_key, base_url)
|
||
names = cfg.get_llm_provider_names()
|
||
active = cfg.get("active_llm", "")
|
||
return (
|
||
gr.update(choices=names, value=active),
|
||
msg,
|
||
)
|
||
|
||
|
||
def remove_llm_provider(provider_name):
|
||
"""删除 LLM 提供商"""
|
||
if not provider_name:
|
||
return gr.update(choices=cfg.get_llm_provider_names(), value=cfg.get("active_llm", "")), "⚠️ 请先选择要删除的提供商"
|
||
msg = cfg.remove_llm_provider(provider_name)
|
||
names = cfg.get_llm_provider_names()
|
||
active = cfg.get("active_llm", "")
|
||
return (
|
||
gr.update(choices=names, value=active),
|
||
msg,
|
||
)
|
||
|
||
|
||
def on_provider_selected(provider_name):
|
||
"""切换 LLM 提供商时更新显示信息"""
|
||
if not provider_name:
|
||
return "未选择提供商"
|
||
for p in cfg.get_llm_providers():
|
||
if p["name"] == provider_name:
|
||
cfg.set_active_llm(provider_name)
|
||
masked_key = p["api_key"][:8] + "***" if len(p["api_key"]) > 8 else "***"
|
||
return f"**{provider_name}** \nAPI Key: `{masked_key}` \nBase URL: `{p['base_url']}`"
|
||
return "未找到该提供商"
|
||
|
||
|
||
# ==================================================
|
||
# Tab 1: 内容创作
|
||
# ==================================================
|
||
|
||
|
||
def connect_sd(sd_url):
|
||
"""连接 SD 并获取模型列表"""
|
||
try:
|
||
svc = SDService(sd_url)
|
||
ok, msg = svc.check_connection()
|
||
if ok:
|
||
models = svc.get_models()
|
||
cfg.set("sd_url", sd_url)
|
||
first = models[0] if models else None
|
||
info = get_model_profile_info(first) if first else "未检测到模型"
|
||
return gr.update(choices=models, value=first), f"✅ {msg}", info
|
||
return gr.update(choices=[]), f"❌ {msg}", ""
|
||
except Exception as e:
|
||
logger.error("SD 连接失败: %s", e)
|
||
return gr.update(choices=[]), f"❌ SD 连接失败: {e}", ""
|
||
|
||
|
||
def on_sd_model_change(model_name):
|
||
"""SD 模型切换时显示模型档案信息"""
|
||
if not model_name:
|
||
return "未选择模型"
|
||
return get_model_profile_info(model_name)
|
||
|
||
|
||
def check_mcp_status(mcp_url):
|
||
"""检查 MCP 连接状态"""
|
||
try:
|
||
client = get_mcp_client(mcp_url)
|
||
ok, msg = client.check_connection()
|
||
if ok:
|
||
cfg.set("mcp_url", mcp_url)
|
||
return f"✅ MCP 服务正常 - {msg}"
|
||
return f"❌ {msg}"
|
||
except Exception as e:
|
||
return f"❌ MCP 连接失败: {e}"
|
||
|
||
|
||
# ==================================================
|
||
# 小红书账号登录
|
||
# ==================================================
|
||
|
||
|
||
def get_login_qrcode(mcp_url):
|
||
"""获取小红书登录二维码"""
|
||
try:
|
||
client = get_mcp_client(mcp_url)
|
||
result = client.get_login_qrcode()
|
||
if "error" in result:
|
||
return None, f"❌ 获取二维码失败: {result['error']}"
|
||
qr_image = result.get("qr_image")
|
||
msg = result.get("text", "")
|
||
if qr_image:
|
||
return qr_image, f"✅ 二维码已生成,请用小红书 App 扫码\n{msg}"
|
||
return None, f"⚠️ 未获取到二维码图片,MCP 返回:\n{msg}"
|
||
except Exception as e:
|
||
logger.error("获取登录二维码失败: %s", e)
|
||
return None, f"❌ 获取二维码失败: {e}"
|
||
|
||
|
||
def logout_xhs(mcp_url):
|
||
"""退出登录:清除 cookies 并重置本地 token"""
|
||
try:
|
||
client = get_mcp_client(mcp_url)
|
||
result = client.delete_cookies()
|
||
if "error" in result:
|
||
return f"❌ 退出失败: {result['error']}"
|
||
cfg.set("xsec_token", "")
|
||
client._reset()
|
||
return "✅ 已退出登录,可以重新扫码登录"
|
||
except Exception as e:
|
||
logger.error("退出登录失败: %s", e)
|
||
return f"❌ 退出失败: {e}"
|
||
|
||
|
||
def _auto_fetch_xsec_token(mcp_url) -> str:
|
||
"""从推荐列表自动获取一个有效的 xsec_token"""
|
||
try:
|
||
client = get_mcp_client(mcp_url)
|
||
entries = client.list_feeds_parsed()
|
||
for e in entries:
|
||
token = e.get("xsec_token", "")
|
||
if token:
|
||
return token
|
||
except Exception as e:
|
||
logger.warning("自动获取 xsec_token 失败: %s", e)
|
||
return ""
|
||
|
||
|
||
def check_login(mcp_url):
|
||
"""检查登录状态,登录成功后自动获取 xsec_token 并保存"""
|
||
try:
|
||
client = get_mcp_client(mcp_url)
|
||
result = client.check_login_status()
|
||
if "error" in result:
|
||
return f"❌ {result['error']}", gr.update(), gr.update()
|
||
text = result.get("text", "")
|
||
if "未登录" in text:
|
||
return f"🔴 {text}", gr.update(), gr.update()
|
||
|
||
# 登录成功 → 自动获取 xsec_token
|
||
token = _auto_fetch_xsec_token(mcp_url)
|
||
if token:
|
||
cfg.set("xsec_token", token)
|
||
logger.info("自动获取 xsec_token 成功")
|
||
return (
|
||
f"🟢 {text}\n\n✅ xsec_token 已自动获取并保存",
|
||
gr.update(value=cfg.get("my_user_id", "")),
|
||
gr.update(value=token),
|
||
)
|
||
return f"🟢 {text}\n\n⚠️ 自动获取 xsec_token 失败,请手动刷新", gr.update(), gr.update()
|
||
except Exception as e:
|
||
return f"❌ 检查登录状态失败: {e}", gr.update(), gr.update()
|
||
|
||
|
||
def save_my_user_id(user_id_input):
|
||
"""保存用户 ID (验证 24 位十六进制格式)"""
|
||
uid = (user_id_input or "").strip()
|
||
if not uid:
|
||
cfg.set("my_user_id", "")
|
||
return "⚠️ 已清除用户 ID"
|
||
if not re.match(r'^[0-9a-fA-F]{24}$', uid):
|
||
return (
|
||
"❌ 格式错误!用户 ID 应为 24 位十六进制字符串\n"
|
||
f"你输入的: `{uid}` ({len(uid)} 位)\n\n"
|
||
"💡 如果你输入的是小红书号 (纯数字如 18688457507),那不是 userId。"
|
||
)
|
||
cfg.set("my_user_id", uid)
|
||
return f"✅ 用户 ID 已保存: `{uid}`"
|
||
|
||
|
||
# ================= 头像/换脸管理 =================
|
||
|
||
def upload_face_image(img):
|
||
"""上传并保存头像图片"""
|
||
if img is None:
|
||
return None, "❌ 请上传头像图片"
|
||
try:
|
||
if isinstance(img, str) and os.path.isfile(img):
|
||
img = Image.open(img).convert("RGB")
|
||
elif not isinstance(img, Image.Image):
|
||
return None, "❌ 无法识别图片格式"
|
||
path = SDService.save_face_image(img)
|
||
return img, f"✅ 头像已保存至 {os.path.basename(path)}"
|
||
except Exception as e:
|
||
return None, f"❌ 保存失败: {e}"
|
||
|
||
|
||
def load_saved_face_image():
|
||
"""加载已保存的头像"""
|
||
img = SDService.load_face_image()
|
||
if img:
|
||
return img, "✅ 已加载保存的头像"
|
||
return None, "ℹ️ 尚未设置头像"
|
||
|
||
|