/** * @file mcp_client.cpp * @brief Implementation of the MCP client * * This file implements the client-side functionality for the Model Context Protocol. * Follows the 2024-11-05 basic protocol specification. */ #include "mcp_client.h" #include "base64.hpp" namespace mcp { client::client(const std::string& host, int port, const json& capabilities, const std::string& sse_endpoint) : host_(host), port_(port), capabilities_(capabilities), sse_endpoint_(sse_endpoint) { init_client(host, port); } client::client(const std::string& base_url, const json& capabilities, const std::string& sse_endpoint) : base_url_(base_url), capabilities_(capabilities), sse_endpoint_(sse_endpoint) { init_client(base_url); } client::~client() { // 关闭SSE连接 close_sse_connection(); // httplib::Client将自动销毁 } void client::init_client(const std::string& host, int port) { // Create the HTTP client http_client_ = std::make_unique(host.c_str(), port); // Set timeout http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_read_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0); } void client::init_client(const std::string& base_url) { // Create the HTTP client http_client_ = std::make_unique(base_url.c_str()); // Set timeout http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_read_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0); } bool client::initialize(const std::string& client_name, const std::string& client_version) { // Create initialization request request req = request::create("initialize", { {"protocolVersion", MCP_VERSION}, {"capabilities", capabilities_}, {"clientInfo", { {"name", client_name}, {"version", client_version} }} }); try { open_sse_connection(); // Send the request json result = send_jsonrpc(req); // Store server capabilities server_capabilities_ = result["capabilities"]; // Send initialized notification request notification = request::create_notification("initialized"); send_jsonrpc(notification); return true; } catch (const std::exception& e) { // 初始化失败,关闭SSE连接 std::cerr << "初始化失败: " << e.what() << std::endl; close_sse_connection(); return false; } } bool client::ping() { // Create ping request request req = request::create("ping", {}); try { // Send the request json result = send_jsonrpc(req); // The receiver MUST respond promptly with an empty response if (result.empty()) { return true; } else { return false; } } catch (const std::exception& e) { // Ping failed return false; } } void client::set_auth_token(const std::string& token) { std::lock_guard lock(mutex_); auth_token_ = token; // Add to default headers set_header("Authorization", "Bearer " + auth_token_); } void client::set_header(const std::string& key, const std::string& value) { std::lock_guard lock(mutex_); default_headers_[key] = value; } void client::set_timeout(int timeout_seconds) { std::lock_guard lock(mutex_); timeout_seconds_ = timeout_seconds; // Update the client's timeout http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_read_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0); } void client::set_capabilities(const json& capabilities) { std::lock_guard lock(mutex_); capabilities_ = capabilities; } response client::send_request(const std::string& method, const json& params) { request req = request::create(method, params); json result = send_jsonrpc(req); response res; res.jsonrpc = "2.0"; res.id = req.id; res.result = result; return res; } void client::send_notification(const std::string& method, const json& params) { request req = request::create_notification(method, params); send_jsonrpc(req); } json client::get_server_capabilities() { return server_capabilities_; } json client::call_tool(const std::string& tool_name, const json& arguments) { return send_request("tools/call", { {"name", tool_name}, {"arguments", arguments} }).result; } std::vector client::get_tools() { json tools_json = send_request("tools/list", {}).result; std::vector tools; if (tools_json.is_array()) { for (const auto& tool_json : tools_json) { tool t; t.name = tool_json["name"]; t.description = tool_json["description"]; if (tool_json.contains("inputSchema")) { t.parameters_schema = tool_json["inputSchema"]; } tools.push_back(t); } } return tools; } json client::get_capabilities() { return capabilities_; } json client::list_resources(const std::string& cursor) { json params = json::object(); if (!cursor.empty()) { params["cursor"] = cursor; } return send_request("resources/list", params).result; } json client::read_resource(const std::string& resource_uri) { return send_request("resources/read", { {"uri", resource_uri} }).result; } json client::subscribe_to_resource(const std::string& resource_uri) { return send_request("resources/subscribe", { {"uri", resource_uri} }).result; } json client::list_resource_templates() { return send_request("resources/templates/list").result; } void client::open_sse_connection() { // 设置SSE连接状态为运行中 sse_running_ = true; // 创建并启动SSE线程 sse_thread_ = std::make_unique([this]() { int retry_count = 0; const int max_retries = 5; const int retry_delay_base = 1000; // 毫秒 while (sse_running_) { try { // 尝试建立SSE连接 auto res = http_client_->Get(sse_endpoint_.c_str(), [this](const char *data, size_t data_length) { // 解析SSE数据 if (!parse_sse_data(data, data_length)) { return false; // 解析失败,关闭连接 } return sse_running_.load(); // 如果sse_running_为false,关闭连接 }); // 检查连接是否成功 if (!res) { throw std::runtime_error("SSE连接失败: " + std::to_string(static_cast(res.error()))); } // 连接成功后重置重试计数 retry_count = 0; } catch (const std::exception& e) { // 记录错误 std::cerr << "SSE连接错误: " << e.what() << std::endl; // 如果已达到最大重试次数,停止尝试 if (++retry_count > max_retries) { std::cerr << "达到最大重试次数,停止SSE连接尝试" << std::endl; break; } // 指数退避重试 int delay = retry_delay_base * (1 << (retry_count - 1)); // 2^(retry_count-1) * base_delay std::this_thread::sleep_for(std::chrono::milliseconds(delay)); } } }); } // 新增方法:解析SSE数据 bool client::parse_sse_data(const char* data, size_t length) { try { std::string sse_data(data, length); // 查找"data:"标记 auto data_pos = sse_data.find("data: "); if (data_pos == std::string::npos) { return true; // 不是数据事件,可能是注释或心跳,继续保持连接 } // 查找数据行结束位置 auto newline_pos = sse_data.find("\n", data_pos); if (newline_pos == std::string::npos) { newline_pos = sse_data.length(); // 如果没有换行符,使用整个字符串 } // 提取数据内容 std::string data_content = sse_data.substr(data_pos + 6, newline_pos - (data_pos + 6)); // 检查是否是心跳事件 if (sse_data.find("event: heartbeat") != std::string::npos) { // 心跳事件,不需要处理数据 return true; } // 更新消息端点 { std::lock_guard lock(mutex_); msg_endpoint_ = data_content; } return true; } catch (const std::exception& e) { std::cerr << "解析SSE数据错误: " << e.what() << std::endl; return false; } } // 新增方法:关闭SSE连接 void client::close_sse_connection() { sse_running_ = false; if (sse_thread_ && sse_thread_->joinable()) { sse_thread_->join(); } } json client::send_jsonrpc(const request& req) { std::lock_guard lock(mutex_); // Convert request to JSON json req_json = req.to_json(); std::string req_body = req_json.dump(); // Prepare headers httplib::Headers headers; headers.emplace("Content-Type", "application/json"); // Add default headers for (const auto& [key, value] : default_headers_) { headers.emplace(key, value); } // Send the request auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json"); if (!result) { // Error occurred auto err = result.error(); switch (err) { case httplib::Error::Connection: throw mcp_exception(error_code::server_error_start, "Connection error"); case httplib::Error::Read: throw mcp_exception(error_code::internal_error, "Read error"); case httplib::Error::Write: throw mcp_exception(error_code::internal_error, "Write error"); case httplib::Error::ConnectionTimeout: throw mcp_exception(error_code::server_error_start, "Timeout error"); default: throw mcp_exception(error_code::internal_error, "HTTP client error: " + std::to_string(static_cast(err))); } } // Check if it's a notification (no response expected) if (req.is_notification()) { return json::object(); } // Parse response try { json res_json = json::parse(result->body); // Check for error if (res_json.contains("error")) { int code = res_json["error"]["code"]; std::string message = res_json["error"]["message"]; throw mcp_exception(static_cast(code), message); } // Return result if (res_json.contains("result")) { return res_json["result"]; } else { return json::object(); } } catch (const json::exception& e) { throw mcp_exception(error_code::parse_error, "Failed to parse JSON-RPC response: " + std::string(e.what())); } } } // namespace mcp