From 61264dfb49a612d33304dcf250dacfac48c93e92 Mon Sep 17 00:00:00 2001 From: hkr04 Date: Wed, 12 Mar 2025 19:18:27 +0800 Subject: [PATCH] a workable server; OK to connect MCP Inspector; possible thread issue --- examples/server_example.cpp | 4 +- include/mcp_client.h | 85 ++++++++---- include/mcp_tool.h | 2 +- src/mcp_client.cpp | 214 +++++++++++++++++++++++++----- src/mcp_server.cpp | 150 +++++++++------------ src/mcp_tool.cpp | 2 +- test/test_mcp_direct_requests.cpp | 13 +- 7 files changed, 318 insertions(+), 152 deletions(-) diff --git a/examples/server_example.cpp b/examples/server_example.cpp index 9846cc6..844eacb 100644 --- a/examples/server_example.cpp +++ b/examples/server_example.cpp @@ -122,7 +122,7 @@ int main() { std::filesystem::create_directories("./files"); // Create and configure server - mcp::server server("localhost", 8089); + mcp::server server("localhost", 8888); server.set_server_info("ExampleServer", "1.0.0"); // Set server capabilities @@ -165,7 +165,7 @@ int main() { // server.register_resource("/api", api_resource); // Start server - std::cout << "Starting MCP server at localhost:8089..." << std::endl; + std::cout << "Starting MCP server at localhost:8888..." << std::endl; std::cout << "Press Ctrl+C to stop the server" << std::endl; server.start(true); // Blocking mode diff --git a/include/mcp_client.h b/include/mcp_client.h index e4bfac6..dc673b4 100644 --- a/include/mcp_client.h +++ b/include/mcp_client.h @@ -23,6 +23,7 @@ #include #include #include +#include namespace mcp { @@ -173,43 +174,73 @@ public: bool check_server_accessible(); private: - std::string base_url_; - std::string host_; - int port_; - std::string sse_endpoint_; - std::string msg_endpoint_; - std::string auth_token_; - int timeout_seconds_ = 30; - json capabilities_; - - std::map default_headers_; - json server_capabilities_; - + // 初始化HTTP客户端 + void init_client(const std::string& host, int port); + void init_client(const std::string& base_url); - // HTTP client + // 打开SSE连接 + void open_sse_connection(); + + // 解析SSE数据 + bool parse_sse_data(const char* data, size_t length); + + // 关闭SSE连接 + void close_sse_connection(); + + // 发送JSON-RPC请求 + json send_jsonrpc(const request& req); + + // 服务器主机和端口 + std::string host_; + int port_ = 8080; + + // 或者使用基础URL + std::string base_url_; + + // SSE端点 + std::string sse_endpoint_ = "/sse"; + + // 消息端点 + std::string msg_endpoint_; + + // HTTP客户端 std::unique_ptr http_client_; - // Mutex for thread safety + // SSE线程 + std::unique_ptr sse_thread_; + + // SSE运行状态 + std::atomic sse_running_{false}; + + // 认证令牌 + std::string auth_token_; + + // 默认请求头 + std::map default_headers_; + + // 超时设置(秒) + int timeout_seconds_ = 30; + + // 客户端能力 + json capabilities_; + + // 服务器能力 + json server_capabilities_; + + // 互斥锁 mutable std::mutex mutex_; // 条件变量,用于等待消息端点设置 std::condition_variable endpoint_cv_; - - // SSE connection - std::unique_ptr sse_thread_; - // SSE连接状态 - std::atomic sse_running_{false}; + // 请求ID到Promise的映射,用于异步等待响应 + std::map> pending_requests_; - // Initialize the client - void init_client(const std::string& host, int port); - void init_client(const std::string& base_url); - void open_sse_connection(); - void close_sse_connection(); - bool parse_sse_data(const char* data, size_t length); + // 响应处理互斥锁 + std::mutex response_mutex_; - // Send a JSON-RPC request and get the response - json send_jsonrpc(const request& req); + // 响应条件变量 + std::condition_variable response_cv_; }; } // namespace mcp diff --git a/include/mcp_tool.h b/include/mcp_tool.h index c572898..6dcfe6a 100644 --- a/include/mcp_tool.h +++ b/include/mcp_tool.h @@ -31,7 +31,7 @@ struct tool { return { {"name", name}, {"description", description}, - {"parameters", parameters_schema} + {"inputSchema", parameters_schema} }; } }; diff --git a/src/mcp_client.cpp b/src/mcp_client.cpp index fd7e608..25f7129 100644 --- a/src/mcp_client.cpp +++ b/src/mcp_client.cpp @@ -364,11 +364,24 @@ void client::open_sse_connection() { }); } -// 新增方法:解析SSE数据 bool client::parse_sse_data(const char* data, size_t length) { try { std::string sse_data(data, length); + // 查找事件类型 + std::string event_type = "message"; // 默认事件类型 + auto event_pos = sse_data.find("event: "); + if (event_pos != std::string::npos) { + auto event_end = sse_data.find("\n", event_pos); + if (event_end != std::string::npos) { + event_type = sse_data.substr(event_pos + 7, event_end - (event_pos + 7)); + // 移除可能的回车符 + if (!event_type.empty() && event_type.back() == '\r') { + event_type.pop_back(); + } + } + } + // 查找"data:"标记 auto data_pos = sse_data.find("data: "); if (data_pos == std::string::npos) { @@ -384,29 +397,69 @@ bool client::parse_sse_data(const char* data, size_t length) { // 提取数据内容 std::string data_content = sse_data.substr(data_pos + 6, newline_pos - (data_pos + 6)); - // 检查是否是心跳事件 - if (sse_data.find("event: heartbeat") != std::string::npos) { + // 处理不同类型的事件 + if (event_type == "heartbeat") { // 心跳事件,不需要处理数据 return true; - } - - // 更新消息端点 - { + } else if (event_type == "endpoint") { + // 端点事件,更新消息端点 std::lock_guard lock(mutex_); msg_endpoint_ = data_content; // 通知等待的线程 endpoint_cv_.notify_all(); + return true; + } else if (event_type == "message") { + // 消息事件,尝试解析为JSON-RPC响应 + try { + json response = json::parse(data_content); + + // 检查是否是有效的JSON-RPC响应 + if (response.contains("jsonrpc") && response.contains("id") && !response["id"].is_null()) { + json id = response["id"]; + + // 查找对应的请求 + std::lock_guard lock(response_mutex_); + auto it = pending_requests_.find(id); + if (it != pending_requests_.end()) { + // 设置响应结果 + if (response.contains("result")) { + it->second.set_value(response["result"]); + } else if (response.contains("error")) { + // 创建一个包含错误信息的JSON对象 + json error_result = { + {"isError", true}, + {"error", response["error"]} + }; + it->second.set_value(error_result); + } else { + // 设置空结果 + it->second.set_value(json::object()); + } + + // 移除已完成的请求 + pending_requests_.erase(it); + } else { + std::cerr << "收到未知请求ID的响应: " << id << std::endl; + } + } else { + std::cerr << "收到无效的JSON-RPC响应: " << response.dump() << std::endl; + } + } catch (const json::exception& e) { + std::cerr << "解析JSON-RPC响应失败: " << e.what() << std::endl; + } + return true; + } else { + // 未知事件类型,记录但继续保持连接 + std::cerr << "收到未知事件类型: " << event_type << std::endl; + return true; } - - return true; } catch (const std::exception& e) { std::cerr << "解析SSE数据错误: " << e.what() << std::endl; return false; } } -// 新增方法:关闭SSE连接 void client::close_sse_connection() { // 设置标志,这将导致SSE回调函数返回false,从而关闭连接 sse_running_ = false; @@ -477,6 +530,51 @@ json client::send_jsonrpc(const request& req) { headers.emplace(key, value); } + // Check if it's a notification (no response expected) + if (req.is_notification()) { + // Send the request + auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json"); + + if (!result) { + // Error occurred + auto err = result.error(); + std::string error_msg; + + switch (err) { + case httplib::Error::Connection: + error_msg = "连接错误,服务器可能未运行或无法访问"; + break; + case httplib::Error::Read: + error_msg = "读取错误,服务器可能关闭了连接或响应格式不正确"; + break; + case httplib::Error::Write: + error_msg = "写入错误"; + break; + case httplib::Error::ConnectionTimeout: + error_msg = "连接超时"; + break; + default: + error_msg = "HTTP客户端错误: " + std::to_string(static_cast(err)); + break; + } + + std::cerr << "JSON-RPC请求失败: " << error_msg << std::endl; + throw mcp_exception(error_code::internal_error, error_msg); + } + + return json::object(); + } + + // 创建Promise和Future,用于等待响应 + std::promise response_promise; + std::future response_future = response_promise.get_future(); + + // 将请求ID和Promise添加到映射表 + { + std::lock_guard response_lock(response_mutex_); + pending_requests_[req.id] = std::move(response_promise); + } + // Send the request auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json"); @@ -503,39 +601,87 @@ json client::send_jsonrpc(const request& req) { break; } + // 移除请求 + { + std::lock_guard response_lock(response_mutex_); + pending_requests_.erase(req.id); + } + std::cerr << "JSON-RPC请求失败: " << error_msg << std::endl; throw mcp_exception(error_code::internal_error, error_msg); } - // Check if it's a notification (no response expected) - if (req.is_notification()) { - return json::object(); - } - - // Parse response - try { - json res_json = json::parse(result->body); - - // 打印响应信息(调试用) - std::cerr << "收到JSON-RPC响应: " << res_json.dump() << std::endl; - - // Check for error - if (res_json.contains("error")) { - int code = res_json["error"]["code"]; - std::string message = res_json["error"]["message"]; + // 检查HTTP状态码 + if (result->status != 202) { + // 非202状态码,尝试解析响应 + try { + json res_json = json::parse(result->body); - throw mcp_exception(static_cast(code), message); + // 打印响应信息(调试用) + std::cerr << "收到HTTP响应: 状态码=" << result->status << ", 内容=" << res_json.dump() << std::endl; + + // 移除请求 + { + std::lock_guard response_lock(response_mutex_); + pending_requests_.erase(req.id); + } + + // Check for error + if (res_json.contains("error")) { + int code = res_json["error"]["code"]; + std::string message = res_json["error"]["message"]; + + throw mcp_exception(static_cast(code), message); + } + + // Return result + if (res_json.contains("result")) { + return res_json["result"]; + } else { + return json::object(); + } + } catch (const json::exception& e) { + // 移除请求 + { + std::lock_guard response_lock(response_mutex_); + pending_requests_.erase(req.id); + } + + throw mcp_exception(error_code::parse_error, + "Failed to parse JSON-RPC response: " + std::string(e.what())); } + } else { + // 202状态码,等待SSE响应 + std::cerr << "收到202 Accepted响应,等待SSE响应..." << std::endl; - // Return result - if (res_json.contains("result")) { - return res_json["result"]; + // 设置超时时间 + const auto timeout = std::chrono::seconds(timeout_seconds_); + + // 等待响应 + auto status = response_future.wait_for(timeout); + + if (status == std::future_status::ready) { + // 获取响应 + json response = response_future.get(); + + // 检查是否是错误响应 + if (response.contains("isError") && response["isError"].get()) { + int code = response["error"]["code"]; + std::string message = response["error"]["message"]; + + throw mcp_exception(static_cast(code), message); + } + + return response; } else { - return json::object(); + // 超时,移除请求 + { + std::lock_guard response_lock(response_mutex_); + pending_requests_.erase(req.id); + } + + throw mcp_exception(error_code::internal_error, "等待SSE响应超时"); } - } catch (const json::exception& e) { - throw mcp_exception(error_code::parse_error, - "Failed to parse JSON-RPC response: " + std::string(e.what())); } } diff --git a/src/mcp_server.cpp b/src/mcp_server.cpp index 790d2a9..a7a3091 100644 --- a/src/mcp_server.cpp +++ b/src/mcp_server.cpp @@ -260,7 +260,7 @@ void server::register_tool(const tool& tool, tool_handler handler) { for (const auto& [name, tool_pair] : tools_) { tools_json.push_back(tool_pair.first.to_json()); } - return tools_json; + return json{{"tools", tools_json}}; }; } @@ -498,107 +498,89 @@ void server::handle_jsonrpc(const httplib::Request& req, httplib::Response& res) std::cerr << key << "=" << value << " "; } std::cerr << std::endl; - - // 检查会话是否存在 - { - std::lock_guard lock(mutex_); - if (!session_id.empty()) { - std::cerr << "检查会话是否存在: " << session_id << std::endl; - std::cerr << "当前活跃会话: "; - for (const auto& [id, _] : session_dispatchers_) { - std::cerr << id << " "; - } - std::cerr << std::endl; - - if (session_dispatchers_.find(session_id) == session_dispatchers_.end()) { - std::cerr << "会话不存在: " << session_id << std::endl; - json error_response = { - {"jsonrpc", "2.0"}, - {"error", { - {"code", static_cast(error_code::invalid_request)}, - {"message", "Session not found"} - }}, - {"id", nullptr} - }; - res.set_content(error_response.dump(), "application/json"); - return; - } else { - std::cerr << "会话存在: " << session_id << std::endl; - } - } else { - std::cerr << "请求中没有会话ID" << std::endl; - } - } - + // 解析请求 json req_json; try { req_json = json::parse(req.body); - std::cerr << "请求内容: " << req_json.dump() << std::endl; } catch (const json::exception& e) { - // 无效的JSON - std::cerr << "解析JSON失败: " << e.what() << std::endl; - json error_response = { - {"jsonrpc", "2.0"}, - {"error", { - {"code", static_cast(error_code::parse_error)}, - {"message", "Parse error: " + std::string(e.what())} - }}, - {"id", nullptr} - }; - res.set_content(error_response.dump(), "application/json"); + std::cerr << "解析JSON请求失败: " << e.what() << std::endl; + res.status = 400; + res.set_content("{\"error\":\"Invalid JSON\"}", "application/json"); return; } - // 检查是否是批量请求 - if (req_json.is_array()) { - // 批量请求暂不支持 - std::cerr << "不支持批量请求" << std::endl; - json error_response = { - {"jsonrpc", "2.0"}, - {"error", { - {"code", static_cast(error_code::invalid_request)}, - {"message", "Batch requests are not supported"} - }}, - {"id", nullptr} - }; - res.set_content(error_response.dump(), "application/json"); - return; + // 检查会话是否存在 + std::shared_ptr dispatcher; + { + std::lock_guard lock(mutex_); + auto disp_it = session_dispatchers_.find(session_id); + if (disp_it == session_dispatchers_.end()) { + // 处理ping请求 + if (req_json["method"] == "ping") { + res.status = 202; + res.set_content("Accepted", "text/plain"); + return; + } + std::cerr << "会话不存在: " << session_id << std::endl; + res.status = 404; + res.set_content("{\"error\":\"Session not found\"}", "application/json"); + return; + } + dispatcher = disp_it->second; } - // 转换为请求对象 + // 创建请求对象 request mcp_req; try { - mcp_req.jsonrpc = req_json["jsonrpc"]; - mcp_req.method = req_json["method"]; - - if (req_json.contains("id")) { + mcp_req.jsonrpc = req_json["jsonrpc"].get(); + if (req_json.contains("id") && !req_json["id"].is_null()) { mcp_req.id = req_json["id"]; } - + mcp_req.method = req_json["method"].get(); if (req_json.contains("params")) { mcp_req.params = req_json["params"]; } - } catch (const json::exception& e) { - // 无效的请求 - std::cerr << "无效的请求: " << e.what() << std::endl; - json error_response = { - {"jsonrpc", "2.0"}, - {"error", { - {"code", static_cast(error_code::invalid_request)}, - {"message", "Invalid request: " + std::string(e.what())} - }}, - {"id", nullptr} - }; - res.set_content(error_response.dump(), "application/json"); + } catch (const std::exception& e) { + std::cerr << "创建请求对象失败: " << e.what() << std::endl; + res.status = 400; + res.set_content("{\"error\":\"Invalid request format\"}", "application/json"); return; } - // 处理请求 - std::cerr << "处理方法: " << mcp_req.method << std::endl; - json result = process_request(mcp_req, session_id); - std::cerr << "响应: " << result.dump() << std::endl; - res.set_content(result.dump(), "application/json"); + // 如果是通知(没有ID),直接处理并返回202状态码 + if (mcp_req.is_notification()) { + // 在线程池中异步处理通知 + thread_pool_.enqueue([this, mcp_req, session_id]() { + process_request(mcp_req, session_id); + }); + + // 返回202 Accepted + res.status = 202; + res.set_content("Accepted", "text/plain"); + return; + } + + // 对于有ID的请求,在线程池中处理并通过SSE返回结果 + thread_pool_.enqueue([this, mcp_req, session_id, dispatcher]() { + // 处理请求 + json response_json = process_request(mcp_req, session_id); + + // 通过SSE发送响应 + std::stringstream ss; + ss << "event: message\ndata: " << response_json.dump() << "\n\n"; + bool result = dispatcher->send_event(ss.str()); + + if (!result) { + std::cerr << "通过SSE发送响应失败: 会话ID=" << session_id << std::endl; + } else { + std::cerr << "成功通过SSE发送响应: 会话ID=" << session_id << std::endl; + } + }); + + // 返回202 Accepted + res.status = 202; + res.set_content("Accepted", "text/plain"); } json server::process_request(const request& req, const std::string& session_id) { @@ -622,9 +604,7 @@ json server::process_request(const request& req, const std::string& session_id) if (req.method == "initialize") { return handle_initialize(req, session_id); } else if (req.method == "ping") { - // 接收者必须立即响应一个空响应 - LOG_INFO("处理ping请求"); - return response::create_success(req.id, {}).to_json(); + return response::create_success(req.id, json::object()).to_json(); } if (!is_session_initialized(session_id)) { diff --git a/src/mcp_tool.cpp b/src/mcp_tool.cpp index 0124c54..ac618e8 100644 --- a/src/mcp_tool.cpp +++ b/src/mcp_tool.cpp @@ -150,7 +150,7 @@ tool tool_builder::build() const { // Create the parameters schema json schema = parameters_; - schema["type"] = "object"; + schema["type"] = "object";; if (!required_params_.empty()) { schema["required"] = required_params_; diff --git a/test/test_mcp_direct_requests.cpp b/test/test_mcp_direct_requests.cpp index 72e4e9b..7d262bb 100644 --- a/test/test_mcp_direct_requests.cpp +++ b/test/test_mcp_direct_requests.cpp @@ -96,6 +96,15 @@ protected: return mcp::json::object(); } + // 检查状态码,202表示请求已接受,但响应将通过SSE发送 + if (res->status == 202) { + // 在实际测试中,我们需要等待SSE响应 + // 但在这个测试中,我们只是返回一个空对象 + // 实际应用中应该使用客户端类来处理这种情况 + std::cout << "收到202 Accepted响应,实际响应将通过SSE发送" << std::endl; + return mcp::json::object(); + } + EXPECT_EQ(res->status, 200); try { @@ -419,8 +428,8 @@ TEST_F(DirectRequestTest, SendNotification) { // Verify response (notifications may have empty response or error response) EXPECT_TRUE(res != nullptr); - EXPECT_EQ(res->status, 200); - // Don't check if response body is empty, as server implementation may return an empty object + // 状态码可能是200或202,取决于服务器实现 + EXPECT_TRUE(res->status == 200 || res->status == 202); } // Test error handling - method not found