From 0814f1a6a787b07d714a9d64881ce73aa6dcce48 Mon Sep 17 00:00:00 2001 From: hkr04 Date: Wed, 12 Mar 2025 19:43:34 +0800 Subject: [PATCH] a workable client --- examples/client_example.cpp | 2 +- include/mcp_client.h | 3 ++ include/mcp_tool.h | 2 +- src/mcp_client.cpp | 94 +++++++++++++++++++++++++++++++------ test/test_client.cpp | 61 ++++++++++++++++++++++++ 5 files changed, 145 insertions(+), 17 deletions(-) create mode 100644 test/test_client.cpp diff --git a/examples/client_example.cpp b/examples/client_example.cpp index 62fbc77..a747245 100644 --- a/examples/client_example.cpp +++ b/examples/client_example.cpp @@ -12,7 +12,7 @@ int main() { // Create a client - mcp::client client("localhost", 8089); + mcp::client client("localhost", 8888); // Set capabilites mcp::json capabilities = { diff --git a/include/mcp_client.h b/include/mcp_client.h index dc673b4..cf69ad5 100644 --- a/include/mcp_client.h +++ b/include/mcp_client.h @@ -206,6 +206,9 @@ private: // HTTP客户端 std::unique_ptr http_client_; + // SSE专用HTTP客户端 + std::unique_ptr sse_client_; + // SSE线程 std::unique_ptr sse_thread_; diff --git a/include/mcp_tool.h b/include/mcp_tool.h index 6dcfe6a..2824e16 100644 --- a/include/mcp_tool.h +++ b/include/mcp_tool.h @@ -31,7 +31,7 @@ struct tool { return { {"name", name}, {"description", description}, - {"inputSchema", parameters_schema} + {"inputSchema", parameters_schema} // You may need 'parameters' instead of 'inputSchema' for OAI format }; } }; diff --git a/src/mcp_client.cpp b/src/mcp_client.cpp index 25f7129..f6f8ba5 100644 --- a/src/mcp_client.cpp +++ b/src/mcp_client.cpp @@ -30,23 +30,36 @@ client::~client() { } void client::init_client(const std::string& host, int port) { - // Create the HTTP client + // 创建两个独立的HTTP客户端实例 + // 一个用于SSE连接,一个用于JSON-RPC请求 http_client_ = std::make_unique(host.c_str(), port); + sse_client_ = std::make_unique(host.c_str(), port); - // Set timeout + // 设置超时 http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_read_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0); + + // SSE客户端需要更长的超时时间,因为它会保持长连接 + sse_client_->set_connection_timeout(timeout_seconds_ * 2, 0); + sse_client_->set_write_timeout(timeout_seconds_, 0); } void client::init_client(const std::string& base_url) { - // Create the HTTP client + // 创建两个独立的HTTP客户端实例 + // 一个用于SSE连接,一个用于JSON-RPC请求 http_client_ = std::make_unique(base_url.c_str()); + sse_client_ = std::make_unique(base_url.c_str()); - // Set timeout + // 设置超时 http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_read_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0); + + // SSE客户端需要更长的超时时间,因为它会保持长连接 + sse_client_->set_connection_timeout(timeout_seconds_ * 2, 0); + sse_client_->set_read_timeout(0, 0); // 无限读取超时,适合SSE长连接 + sse_client_->set_write_timeout(timeout_seconds_, 0); } bool client::initialize(const std::string& client_name, const std::string& client_version) { @@ -159,23 +172,37 @@ void client::set_auth_token(const std::string& token) { std::lock_guard lock(mutex_); auth_token_ = token; - // Add to default headers + // 为两个客户端都添加默认头 set_header("Authorization", "Bearer " + auth_token_); } void client::set_header(const std::string& key, const std::string& value) { std::lock_guard lock(mutex_); default_headers_[key] = value; + + // 确保两个客户端实例都设置了相同的头 + if (http_client_) { + http_client_->set_default_headers({{key, value}}); + } + if (sse_client_) { + sse_client_->set_default_headers({{key, value}}); + } } void client::set_timeout(int timeout_seconds) { std::lock_guard lock(mutex_); timeout_seconds_ = timeout_seconds; - // Update the client's timeout - http_client_->set_connection_timeout(timeout_seconds_, 0); - http_client_->set_read_timeout(timeout_seconds_, 0); - http_client_->set_write_timeout(timeout_seconds_, 0); + // 更新两个客户端的超时设置 + if (http_client_) { + http_client_->set_connection_timeout(timeout_seconds_, 0); + http_client_->set_write_timeout(timeout_seconds_, 0); + } + + if (sse_client_) { + sse_client_->set_connection_timeout(timeout_seconds_ * 2, 0); + sse_client_->set_write_timeout(timeout_seconds_, 0); + } } void client::set_capabilities(const json& capabilities) { @@ -293,7 +320,8 @@ void client::open_sse_connection() { // 尝试建立SSE连接 std::cerr << "SSE线程: 尝试连接到 " << sse_endpoint_ << std::endl; - auto res = http_client_->Get(sse_endpoint_.c_str(), + // 使用专用的SSE客户端实例 + auto res = sse_client_->Get(sse_endpoint_.c_str(), [this](const char *data, size_t data_length) { // 解析SSE数据 std::cerr << "SSE线程: 收到数据,长度: " << data_length << std::endl; @@ -330,7 +358,13 @@ void client::open_sse_connection() { error_msg += " (连接超时)"; break; case httplib::Error::Canceled: - error_msg += " (请求被取消)"; + // 如果是由于sse_running_=false导致的取消,这是正常的关闭过程 + if (!sse_running_) { + std::cerr << "SSE连接已被主动关闭 (请求被取消)" << std::endl; + return; // 直接返回,不再重试 + } else { + error_msg += " (请求被取消)"; + } break; default: error_msg += " (未知错误)"; @@ -347,6 +381,12 @@ void client::open_sse_connection() { // 记录错误 std::cerr << "SSE连接错误: " << e.what() << std::endl; + // 如果sse_running_为false,说明是主动关闭,不需要重试 + if (!sse_running_) { + std::cerr << "SSE连接已被主动关闭,不再重试" << std::endl; + break; + } + // 如果已达到最大重试次数,停止尝试 if (++retry_count > max_retries) { std::cerr << "达到最大重试次数,停止SSE连接尝试" << std::endl; @@ -356,7 +396,18 @@ void client::open_sse_connection() { // 指数退避重试 int delay = retry_delay_base * (1 << (retry_count - 1)); // 2^(retry_count-1) * base_delay std::cerr << "将在 " << delay << " 毫秒后重试 (尝试 " << retry_count << "/" << max_retries << ")" << std::endl; - std::this_thread::sleep_for(std::chrono::milliseconds(delay)); + + // 在等待期间定期检查sse_running_状态 + const int check_interval = 100; // 每100毫秒检查一次 + for (int waited = 0; waited < delay && sse_running_; waited += check_interval) { + std::this_thread::sleep_for(std::chrono::milliseconds(check_interval)); + } + + // 如果在等待期间sse_running_变为false,直接退出 + if (!sse_running_) { + std::cerr << "在等待重试期间检测到SSE连接已被主动关闭,不再重试" << std::endl; + break; + } } } @@ -461,6 +512,14 @@ bool client::parse_sse_data(const char* data, size_t length) { } void client::close_sse_connection() { + // 检查是否已经关闭 + if (!sse_running_) { + std::cerr << "SSE连接已经关闭,无需再次关闭" << std::endl; + return; + } + + std::cerr << "正在主动关闭SSE连接(正常退出流程)..." << std::endl; + // 设置标志,这将导致SSE回调函数返回false,从而关闭连接 sse_running_ = false; @@ -473,12 +532,15 @@ void client::close_sse_connection() { auto timeout = std::chrono::seconds(5); auto start = std::chrono::steady_clock::now(); + std::cerr << "等待SSE线程结束..." << std::endl; + // 尝试在超时前等待线程结束 while (sse_thread_->joinable() && std::chrono::steady_clock::now() - start < timeout) { try { // 尝试立即加入线程 sse_thread_->join(); + std::cerr << "SSE线程已成功结束" << std::endl; break; // 如果成功加入,跳出循环 } catch (const std::exception& e) { std::cerr << "等待SSE线程时出错: " << e.what() << std::endl; @@ -492,6 +554,8 @@ void client::close_sse_connection() { std::cerr << "警告: SSE线程未能在超时时间内结束,分离线程" << std::endl; sse_thread_->detach(); } + } else { + std::cerr << "SSE线程不存在或已经结束" << std::endl; } // 清空消息端点 @@ -503,7 +567,7 @@ void client::close_sse_connection() { endpoint_cv_.notify_all(); } - std::cerr << "SSE连接已关闭" << std::endl; + std::cerr << "SSE连接已成功关闭(正常退出流程)" << std::endl; } json client::send_jsonrpc(const request& req) { @@ -532,7 +596,7 @@ json client::send_jsonrpc(const request& req) { // Check if it's a notification (no response expected) if (req.is_notification()) { - // Send the request + // 使用主HTTP客户端发送请求 auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json"); if (!result) { @@ -575,7 +639,7 @@ json client::send_jsonrpc(const request& req) { pending_requests_[req.id] = std::move(response_promise); } - // Send the request + // 使用主HTTP客户端发送请求 auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json"); if (!result) { diff --git a/test/test_client.cpp b/test/test_client.cpp new file mode 100644 index 0000000..70f173d --- /dev/null +++ b/test/test_client.cpp @@ -0,0 +1,61 @@ +/** + * @file test_client.cpp + * @brief 测试MCP客户端 + */ + +#include "mcp_client.h" +#include +#include +#include + +int main() { + // 创建客户端 + mcp::client client("localhost", 8080); + + // 设置超时 + client.set_timeout(30); + + // 初始化客户端 + std::cout << "正在初始化客户端..." << std::endl; + bool success = client.initialize("TestClient", "1.0.0"); + + if (!success) { + std::cerr << "初始化失败" << std::endl; + return 1; + } + + std::cout << "初始化成功" << std::endl; + + // 获取服务器能力 + std::cout << "服务器能力: " << client.get_server_capabilities().dump(2) << std::endl; + + // 获取可用工具 + std::cout << "正在获取可用工具..." << std::endl; + auto tools = client.get_tools(); + std::cout << "可用工具数量: " << tools.size() << std::endl; + + for (const auto& tool : tools) { + std::cout << "工具: " << tool.name << " - " << tool.description << std::endl; + } + + // 发送ping请求 + std::cout << "正在发送ping请求..." << std::endl; + bool ping_result = client.ping(); + std::cout << "Ping结果: " << (ping_result ? "成功" : "失败") << std::endl; + + // 列出资源 + std::cout << "正在列出资源..." << std::endl; + auto resources = client.list_resources(); + std::cout << "资源: " << resources.dump(2) << std::endl; + + // 测试多个并发请求 + std::cout << "测试并发请求..." << std::endl; + for (int i = 0; i < 5; i++) { + std::cout << "请求 " << i << "..." << std::endl; + auto response = client.send_request("ping"); + std::cout << "响应 " << i << ": " << response.result.dump() << std::endl; + } + + std::cout << "测试完成" << std::endl; + return 0; +} \ No newline at end of file