/** * @file mcp_client.cpp * @brief Implementation of the MCP client * * This file implements the client-side functionality for the Model Context Protocol. * Follows the 2024-11-05 basic protocol specification. */ #include "mcp_client.h" #include "base64.hpp" namespace mcp { client::client(const std::string& host, int port, const json& capabilities, const std::string& sse_endpoint) : host_(host), port_(port), capabilities_(capabilities), sse_endpoint_(sse_endpoint) { init_client(host, port); } client::client(const std::string& base_url, const json& capabilities, const std::string& sse_endpoint) : base_url_(base_url), capabilities_(capabilities), sse_endpoint_(sse_endpoint) { init_client(base_url); } client::~client() { // 关闭SSE连接 close_sse_connection(); // httplib::Client将自动销毁 } void client::init_client(const std::string& host, int port) { // Create the HTTP client http_client_ = std::make_unique(host.c_str(), port); // Set timeout http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_read_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0); } void client::init_client(const std::string& base_url) { // Create the HTTP client http_client_ = std::make_unique(base_url.c_str()); // Set timeout http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_read_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0); } bool client::initialize(const std::string& client_name, const std::string& client_version) { std::cerr << "开始初始化MCP客户端..." << std::endl; // 检查服务器是否可访问 if (!check_server_accessible()) { std::cerr << "服务器不可访问,初始化失败" << std::endl; return false; } // Create initialization request request req = request::create("initialize", { {"protocolVersion", MCP_VERSION}, {"capabilities", capabilities_}, {"clientInfo", { {"name", client_name}, {"version", client_version} }} }); try { // 打开SSE连接 std::cerr << "正在打开SSE连接..." << std::endl; open_sse_connection(); // 等待SSE连接建立并获取消息端点 // 使用条件变量和超时机制 const auto timeout = std::chrono::milliseconds(5000); // 5秒超时 { std::unique_lock lock(mutex_); // 检查初始状态 if (!msg_endpoint_.empty()) { std::cerr << "消息端点已经设置: " << msg_endpoint_ << std::endl; } else { std::cerr << "等待条件变量..." << std::endl; } bool success = endpoint_cv_.wait_for(lock, timeout, [this]() { if (!sse_running_) { std::cerr << "SSE连接已关闭,停止等待" << std::endl; return true; } if (!msg_endpoint_.empty()) { std::cerr << "消息端点已设置,停止等待" << std::endl; return true; } return false; }); // 检查等待结果 if (!success) { std::cerr << "条件变量等待超时" << std::endl; } // 如果SSE连接已关闭或等待超时,抛出异常 if (!sse_running_) { throw std::runtime_error("SSE连接已关闭,未能获取消息端点"); } if (msg_endpoint_.empty()) { throw std::runtime_error("等待SSE连接超时,未能获取消息端点"); } std::cerr << "成功获取消息端点: " << msg_endpoint_ << std::endl; } // 发送初始化请求 json result = send_jsonrpc(req); // 存储服务器能力 server_capabilities_ = result["capabilities"]; // 发送已初始化通知 request notification = request::create_notification("initialized"); send_jsonrpc(notification); return true; } catch (const std::exception& e) { // 初始化失败,关闭SSE连接 std::cerr << "初始化失败: " << e.what() << std::endl; close_sse_connection(); return false; } } bool client::ping() { // Create ping request request req = request::create("ping", {}); try { // Send the request json result = send_jsonrpc(req); // The receiver MUST respond promptly with an empty response if (result.empty()) { return true; } else { return false; } } catch (const std::exception& e) { // Ping failed return false; } } void client::set_auth_token(const std::string& token) { std::lock_guard lock(mutex_); auth_token_ = token; // Add to default headers set_header("Authorization", "Bearer " + auth_token_); } void client::set_header(const std::string& key, const std::string& value) { std::lock_guard lock(mutex_); default_headers_[key] = value; } void client::set_timeout(int timeout_seconds) { std::lock_guard lock(mutex_); timeout_seconds_ = timeout_seconds; // Update the client's timeout http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_read_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0); } void client::set_capabilities(const json& capabilities) { std::lock_guard lock(mutex_); capabilities_ = capabilities; } response client::send_request(const std::string& method, const json& params) { request req = request::create(method, params); json result = send_jsonrpc(req); response res; res.jsonrpc = "2.0"; res.id = req.id; res.result = result; return res; } void client::send_notification(const std::string& method, const json& params) { request req = request::create_notification(method, params); send_jsonrpc(req); } json client::get_server_capabilities() { return server_capabilities_; } json client::call_tool(const std::string& tool_name, const json& arguments) { return send_request("tools/call", { {"name", tool_name}, {"arguments", arguments} }).result; } std::vector client::get_tools() { json tools_json = send_request("tools/list", {}).result; std::vector tools; if (tools_json.is_array()) { for (const auto& tool_json : tools_json) { tool t; t.name = tool_json["name"]; t.description = tool_json["description"]; if (tool_json.contains("inputSchema")) { t.parameters_schema = tool_json["inputSchema"]; } tools.push_back(t); } } return tools; } json client::get_capabilities() { return capabilities_; } json client::list_resources(const std::string& cursor) { json params = json::object(); if (!cursor.empty()) { params["cursor"] = cursor; } return send_request("resources/list", params).result; } json client::read_resource(const std::string& resource_uri) { return send_request("resources/read", { {"uri", resource_uri} }).result; } json client::subscribe_to_resource(const std::string& resource_uri) { return send_request("resources/subscribe", { {"uri", resource_uri} }).result; } json client::list_resource_templates() { return send_request("resources/templates/list").result; } void client::open_sse_connection() { // 设置SSE连接状态为运行中 sse_running_ = true; // 清空消息端点 { std::lock_guard lock(mutex_); msg_endpoint_.clear(); // 通知等待的线程(虽然消息端点为空,但可以让等待的线程检查sse_running_状态) endpoint_cv_.notify_all(); } // 打印连接信息(调试用) std::string connection_info; if (!base_url_.empty()) { connection_info = "Base URL: " + base_url_ + ", SSE Endpoint: " + sse_endpoint_; } else { connection_info = "Host: " + host_ + ", Port: " + std::to_string(port_) + ", SSE Endpoint: " + sse_endpoint_; } std::cerr << "尝试建立SSE连接: " << connection_info << std::endl; // 创建并启动SSE线程 sse_thread_ = std::make_unique([this]() { int retry_count = 0; const int max_retries = 5; const int retry_delay_base = 1000; // 毫秒 while (sse_running_) { try { // 尝试建立SSE连接 std::cerr << "SSE线程: 尝试连接到 " << sse_endpoint_ << std::endl; auto res = http_client_->Get(sse_endpoint_.c_str(), [this](const char *data, size_t data_length) { // 解析SSE数据 std::cerr << "SSE线程: 收到数据,长度: " << data_length << std::endl; if (!parse_sse_data(data, data_length)) { std::cerr << "SSE线程: 解析数据失败" << std::endl; return false; // 解析失败,关闭连接 } return sse_running_.load(); // 如果sse_running_为false,关闭连接 }); // 检查连接是否成功 if (!res) { std::string error_msg = "SSE连接失败: "; error_msg += "错误代码: " + std::to_string(static_cast(res.error())); // 添加更详细的错误信息 switch (res.error()) { case httplib::Error::Connection: error_msg += " (连接错误,服务器可能未运行或无法访问)"; break; case httplib::Error::Read: error_msg += " (读取错误,服务器可能关闭了连接或响应格式不正确)"; break; case httplib::Error::Write: error_msg += " (写入错误)"; break; case httplib::Error::ConnectionTimeout: error_msg += " (连接超时)"; break; case httplib::Error::Canceled: error_msg += " (请求被取消)"; break; default: error_msg += " (未知错误)"; break; } throw std::runtime_error(error_msg); } // 连接成功后重置重试计数 retry_count = 0; std::cerr << "SSE线程: 连接成功" << std::endl; } catch (const std::exception& e) { // 记录错误 std::cerr << "SSE连接错误: " << e.what() << std::endl; // 如果已达到最大重试次数,停止尝试 if (++retry_count > max_retries) { std::cerr << "达到最大重试次数,停止SSE连接尝试" << std::endl; break; } // 指数退避重试 int delay = retry_delay_base * (1 << (retry_count - 1)); // 2^(retry_count-1) * base_delay std::cerr << "将在 " << delay << " 毫秒后重试 (尝试 " << retry_count << "/" << max_retries << ")" << std::endl; std::this_thread::sleep_for(std::chrono::milliseconds(delay)); } } std::cerr << "SSE线程: 退出" << std::endl; }); } // 新增方法:解析SSE数据 bool client::parse_sse_data(const char* data, size_t length) { try { std::string sse_data(data, length); // 查找"data:"标记 auto data_pos = sse_data.find("data: "); if (data_pos == std::string::npos) { return true; // 不是数据事件,可能是注释或心跳,继续保持连接 } // 查找数据行结束位置 auto newline_pos = sse_data.find("\n", data_pos); if (newline_pos == std::string::npos) { newline_pos = sse_data.length(); // 如果没有换行符,使用整个字符串 } // 提取数据内容 std::string data_content = sse_data.substr(data_pos + 6, newline_pos - (data_pos + 6)); // 检查是否是心跳事件 if (sse_data.find("event: heartbeat") != std::string::npos) { // 心跳事件,不需要处理数据 return true; } // 更新消息端点 { std::lock_guard lock(mutex_); msg_endpoint_ = data_content; // 通知等待的线程 endpoint_cv_.notify_all(); } return true; } catch (const std::exception& e) { std::cerr << "解析SSE数据错误: " << e.what() << std::endl; return false; } } // 新增方法:关闭SSE连接 void client::close_sse_connection() { sse_running_ = false; if (sse_thread_ && sse_thread_->joinable()) { sse_thread_->join(); } // 清空消息端点 { std::lock_guard lock(mutex_); msg_endpoint_.clear(); // 通知等待的线程(虽然消息端点为空,但可以让等待的线程检查sse_running_状态) endpoint_cv_.notify_all(); } } json client::send_jsonrpc(const request& req) { std::lock_guard lock(mutex_); // 检查消息端点是否已设置 if (msg_endpoint_.empty()) { throw mcp_exception(error_code::internal_error, "消息端点未设置,SSE连接可能未建立"); } // 打印请求信息(调试用) std::cerr << "发送JSON-RPC请求: 方法=" << req.method << ", 端点=" << msg_endpoint_ << std::endl; // Convert request to JSON json req_json = req.to_json(); std::string req_body = req_json.dump(); // Prepare headers httplib::Headers headers; headers.emplace("Content-Type", "application/json"); // Add default headers for (const auto& [key, value] : default_headers_) { headers.emplace(key, value); } // Send the request auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json"); if (!result) { // Error occurred auto err = result.error(); std::string error_msg; switch (err) { case httplib::Error::Connection: error_msg = "连接错误,服务器可能未运行或无法访问"; break; case httplib::Error::Read: error_msg = "读取错误,服务器可能关闭了连接或响应格式不正确"; break; case httplib::Error::Write: error_msg = "写入错误"; break; case httplib::Error::ConnectionTimeout: error_msg = "连接超时"; break; default: error_msg = "HTTP客户端错误: " + std::to_string(static_cast(err)); break; } std::cerr << "JSON-RPC请求失败: " << error_msg << std::endl; throw mcp_exception(error_code::internal_error, error_msg); } // Check if it's a notification (no response expected) if (req.is_notification()) { return json::object(); } // Parse response try { json res_json = json::parse(result->body); // 打印响应信息(调试用) std::cerr << "收到JSON-RPC响应: " << res_json.dump() << std::endl; // Check for error if (res_json.contains("error")) { int code = res_json["error"]["code"]; std::string message = res_json["error"]["message"]; throw mcp_exception(static_cast(code), message); } // Return result if (res_json.contains("result")) { return res_json["result"]; } else { return json::object(); } } catch (const json::exception& e) { throw mcp_exception(error_code::parse_error, "Failed to parse JSON-RPC response: " + std::string(e.what())); } } bool client::check_server_accessible() { std::cerr << "检查服务器是否可访问..." << std::endl; try { // 尝试发送一个简单的GET请求到服务器 auto res = http_client_->Get("/"); if (res) { std::cerr << "服务器可访问,状态码: " << res->status << std::endl; return true; } else { std::string error_msg = "服务器不可访问,错误代码: " + std::to_string(static_cast(res.error())); // 添加更详细的错误信息 switch (res.error()) { case httplib::Error::Connection: error_msg += " (连接错误,服务器可能未运行或无法访问)"; break; case httplib::Error::Read: error_msg += " (读取错误)"; break; case httplib::Error::Write: error_msg += " (写入错误)"; break; case httplib::Error::ConnectionTimeout: error_msg += " (连接超时)"; break; default: error_msg += " (未知错误)"; break; } std::cerr << error_msg << std::endl; return false; } } catch (const std::exception& e) { std::cerr << "检查服务器可访问性时发生异常: " << e.what() << std::endl; return false; } } } // namespace mcp