From 88237f2eaae1dc89b32d2a693ac2bd15fb6ad269 Mon Sep 17 00:00:00 2001 From: hkr04 Date: Mon, 7 Apr 2025 11:40:58 +0800 Subject: [PATCH] mcp_server: explicitly catch session_id for method/tool/notification/auth handlers; add session cleanup handlers --- examples/server_example.cpp | 8 +- include/mcp_server.h | 61 +++++++++---- include/mcp_tool.h | 3 - src/mcp_server.cpp | 168 +++++++++++++++++------------------- src/mcp_stdio_client.cpp | 2 +- test/mcp_test.cpp | 24 +----- 6 files changed, 131 insertions(+), 135 deletions(-) diff --git a/examples/server_example.cpp b/examples/server_example.cpp index 801b17c..0c09bfb 100644 --- a/examples/server_example.cpp +++ b/examples/server_example.cpp @@ -18,7 +18,7 @@ #include // Tool handler for getting current time -mcp::json get_time_handler(const mcp::json& params) { +mcp::json get_time_handler(const mcp::json& params, const std::string& /* session_id */) { auto now = std::chrono::system_clock::now(); auto time_t_now = std::chrono::system_clock::to_time_t(now); @@ -37,7 +37,7 @@ mcp::json get_time_handler(const mcp::json& params) { } // Echo tool handler -mcp::json echo_handler(const mcp::json& params) { +mcp::json echo_handler(const mcp::json& params, const std::string& /* session_id */) { mcp::json result = params; if (params.contains("text")) { @@ -63,7 +63,7 @@ mcp::json echo_handler(const mcp::json& params) { } // Calculator tool handler -mcp::json calculator_handler(const mcp::json& params) { +mcp::json calculator_handler(const mcp::json& params, const std::string& /* session_id */) { if (!params.contains("operation")) { throw mcp::mcp_exception(mcp::error_code::invalid_params, "Missing 'operation' parameter"); } @@ -107,7 +107,7 @@ mcp::json calculator_handler(const mcp::json& params) { } // Custom API endpoint handler -mcp::json hello_handler(const mcp::json& params) { +mcp::json hello_handler(const mcp::json& params, const std::string& /* session_id */) { std::string name = params.contains("name") ? params["name"].get() : "World"; return { { diff --git a/include/mcp_server.h b/include/mcp_server.h index 65bfe30..e89a6c4 100644 --- a/include/mcp_server.h +++ b/include/mcp_server.h @@ -33,6 +33,12 @@ namespace mcp { +using method_handler = std::function; +using tool_handler = method_handler; +using notification_handler = std::function; +using auth_handler = std::function; +using session_cleanup_handler = std::function; + class event_dispatcher { public: event_dispatcher() { @@ -223,14 +229,14 @@ public: * @param method The method name * @param handler The function to call when the method is invoked */ - void register_method(const std::string& method, std::function handler); + void register_method(const std::string& method, method_handler handler); /** * @brief Register a notification handler * @param method The notification method name * @param handler The function to call when the notification is received */ - void register_notification(const std::string& method, std::function handler); + void register_notification(const std::string& method, notification_handler handler); /** * @brief Register a resource @@ -245,6 +251,13 @@ public: * @param handler The function to call when the tool is invoked */ void register_tool(const tool& tool, tool_handler handler); + + /** + * @brief Register a session cleanup handler + * @param key Tool or resource name to be cleaned up + * @param handler The function to call when the session is closed + */ + void register_session_cleanup(const std::string& key, session_cleanup_handler handler); /** * @brief Get the list of available tools @@ -255,19 +268,24 @@ public: /** * @brief Set authentication handler * @param handler Function that takes a token and returns true if valid + * @note The handler should return true if the token is valid, otherwise false + * @note Not used in the current implementation */ - void set_auth_handler(std::function handler); + void set_auth_handler(auth_handler handler); /** - * @brief Send a request to a client + * @brief Send a request (or notification) to a client * @param session_id The session ID of the client - * @param method The method to call - * @param params The parameters to pass - * - * This method will only send requests other than ping and logging - * after the client has sent the initialized notification. + * @param req The request to send */ - void send_request(const std::string& session_id, const std::string& method, const json& params = json::object()); + void send_request(const std::string& session_id, const request& req); + + /** + * @brief Set mount point for server + * @param path The path to mount the resource at + * @param root The root directory to mount + */ + void set_mount_point(const std::string& path, const std::string& root); private: std::string host_; @@ -296,10 +314,10 @@ private: std::string msg_endpoint_; // Method handlers - std::map> method_handlers_; + std::map method_handlers_; // Notification handlers - std::map> notification_handlers_; + std::map notification_handlers_; // Resources map (path -> resource) std::map> resources_; @@ -308,7 +326,7 @@ private: std::map> tools_; // Authentication handler - std::function auth_handler_; + auth_handler auth_handler_; // Mutex for thread safety mutable std::mutex mutex_; @@ -327,6 +345,9 @@ private: // Handle incoming JSON-RPC requests void handle_jsonrpc(const httplib::Request& req, httplib::Response& res); + + // Send a JSON-RPC message to a client + void send_jsonrpc(const std::string& session_id, const json& message); // Process a JSON-RPC request json process_request(const request& req, const std::string& session_id); @@ -345,10 +366,10 @@ private: // Auxiliary function to create an async handler from a regular handler template - std::function(const json&)> make_async_handler(F&& handler) { - return [handler = std::forward(handler)](const json& params) -> std::future { - return std::async(std::launch::async, [handler, params]() -> json { - return handler(params); + std::function(const json&, const std::string&)> make_async_handler(F&& handler) { + return [handler = std::forward(handler)](const json& params, const std::string& session_id) -> std::future { + return std::async(std::launch::async, [handler, params, session_id]() -> json { + return handler(params, session_id); }); }; } @@ -369,6 +390,12 @@ private: // Session management and maintenance void check_inactive_sessions(); std::unique_ptr maintenance_thread_; + + // Session cleanup handler + std::map session_cleanup_handler_; + + // Close session + void close_session(const std::string& session_id); }; } // namespace mcp diff --git a/include/mcp_tool.h b/include/mcp_tool.h index 3bcbdde..ff44634 100644 --- a/include/mcp_tool.h +++ b/include/mcp_tool.h @@ -19,9 +19,6 @@ namespace mcp { -// Tool handler function type -using tool_handler = std::function; - // MCP Tool definition struct tool { std::string name; diff --git a/src/mcp_server.cpp b/src/mcp_server.cpp index 9b6b377..a7b801b 100644 --- a/src/mcp_server.cpp +++ b/src/mcp_server.cpp @@ -134,15 +134,9 @@ void server::stop() { session_initialized_.clear(); } - // Close all dispatchers outside the lock - for (auto& dispatcher : dispatchers_to_close) { - if (dispatcher && !dispatcher->is_closed()) { - try { - dispatcher->close(); - } catch (...) { - // Ignore exceptions - } - } + // Close all sessions + for (const auto& [session_id, _] : session_dispatchers_) { + close_session(session_id); } // Give threads some time to handle close events @@ -239,12 +233,12 @@ void server::set_capabilities(const json& capabilities) { capabilities_ = capabilities; } -void server::register_method(const std::string& method, std::function handler) { +void server::register_method(const std::string& method, method_handler handler) { std::lock_guard lock(mutex_); method_handlers_[method] = handler; } -void server::register_notification(const std::string& method, std::function handler) { +void server::register_notification(const std::string& method, notification_handler handler) { std::lock_guard lock(mutex_); notification_handlers_[method] = handler; } @@ -255,7 +249,7 @@ void server::register_resource(const std::string& path, std::shared_ptr json { + method_handlers_["resources/read"] = [this](const json& params, const std::string& session_id) -> json { if (!params.contains("uri")) { throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter"); } @@ -276,7 +270,7 @@ void server::register_resource(const std::string& path, std::shared_ptr json { + method_handlers_["resources/list"] = [this](const json& params, const std::string& session_id) -> json { json resources = json::array(); for (const auto& [uri, res] : resources_) { @@ -296,7 +290,7 @@ void server::register_resource(const std::string& path, std::shared_ptr json { + method_handlers_["resources/subscribe"] = [this](const json& params, const std::string& session_id) -> json { if (!params.contains("uri")) { throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter"); } @@ -312,7 +306,7 @@ void server::register_resource(const std::string& path, std::shared_ptr json { + method_handlers_["resources/templates/list"] = [this](const json& params, const std::string& session_id) -> json { return json::array(); }; } @@ -324,7 +318,7 @@ void server::register_tool(const tool& tool, tool_handler handler) { // Register methods for tool listing and calling if (method_handlers_.find("tools/list") == method_handlers_.end()) { - method_handlers_["tools/list"] = [this](const json& params) -> json { + method_handlers_["tools/list"] = [this](const json& params, const std::string& session_id) -> json { json tools_json = json::array(); for (const auto& [name, tool_pair] : tools_) { tools_json.push_back(tool_pair.first.to_json()); @@ -334,7 +328,7 @@ void server::register_tool(const tool& tool, tool_handler handler) { } if (method_handlers_.find("tools/call") == method_handlers_.end()) { - method_handlers_["tools/call"] = [this](const json& params) -> json { + method_handlers_["tools/call"] = [this](const json& params, const std::string& session_id) -> json { if (!params.contains("name")) { throw mcp_exception(error_code::invalid_params, "Missing 'name' parameter"); } @@ -352,7 +346,7 @@ void server::register_tool(const tool& tool, tool_handler handler) { }; try { - tool_result["content"] = it->second.second(tool_args); + tool_result["content"] = it->second.second(tool_args, session_id); } catch (...) { tool_result["isError"] = true; } @@ -362,6 +356,11 @@ void server::register_tool(const tool& tool, tool_handler handler) { } } +void server::register_session_cleanup(const std::string& key, session_cleanup_handler handler) { + std::lock_guard lock(mutex_); + session_cleanup_handler_[key] = handler; +} + std::vector server::get_tools() const { std::lock_guard lock(mutex_); std::vector tools; @@ -373,7 +372,7 @@ std::vector server::get_tools() const { return tools; } -void server::set_auth_handler(std::function handler) { +void server::set_auth_handler(auth_handler handler) { std::lock_guard lock(mutex_); auth_handler_ = handler; } @@ -442,47 +441,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) { LOG_ERROR("SSE session thread exception: ", session_id, ", ", e.what()); } - // Clean up resources safely - try { - // Copy resources to be processed - std::shared_ptr dispatcher_to_close; - std::unique_ptr thread_to_release; - - { - std::lock_guard lock(mutex_); - - // Get dispatcher pointer - auto dispatcher_it = session_dispatchers_.find(session_id); - if (dispatcher_it != session_dispatchers_.end()) { - dispatcher_to_close = dispatcher_it->second; - session_dispatchers_.erase(dispatcher_it); - } - - // Get thread pointer - auto thread_it = sse_threads_.find(session_id); - if (thread_it != sse_threads_.end()) { - thread_to_release = std::move(thread_it->second); - sse_threads_.erase(thread_it); - } - - // Clean up initialization status - session_initialized_.erase(session_id); - } - - // Close dispatcher outside the lock - if (dispatcher_to_close && !dispatcher_to_close->is_closed()) { - dispatcher_to_close->close(); - } - - // Release thread resources - if (thread_to_release) { - thread_to_release.release(); - } - } catch (const std::exception& e) { - LOG_WARNING("Exception while cleaning up session resources: ", session_id, ", ", e.what()); - } catch (...) { - LOG_WARNING("Unknown exception while cleaning up session resources: ", session_id); - } + close_session(session_id); }); // Store thread @@ -507,8 +466,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) { if (!result) { LOG_WARNING("Failed to wait for event, closing connection: ", session_id); - // Close dispatcher directly, no need to lock - session_dispatcher->close(); + close_session(session_id); return false; } @@ -520,8 +478,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) { } catch (const std::exception& e) { LOG_ERROR("SSE content provider exception: ", e.what()); - // Close dispatcher directly, no need to lock - session_dispatcher->close(); + close_session(session_id); return false; } @@ -673,7 +630,7 @@ json server::process_request(const request& req, const std::string& session_id) } // Find registered method handler - std::function handler; + method_handler handler; { std::lock_guard lock(mutex_); auto it = method_handlers_.find(req.method); @@ -685,8 +642,8 @@ json server::process_request(const request& req, const std::string& session_id) if (handler) { // Call handler LOG_INFO("Calling method handler: ", req.method); - auto future = thread_pool_.enqueue([handler, params = req.params]() -> json { - return handler(params); + auto future = thread_pool_.enqueue([handler, params = req.params, session_id]() -> json { + return handler(params, session_id); }); json result = future.get(); @@ -791,25 +748,13 @@ json server::handle_initialize(const request& req, const std::string& session_id return response::create_success(req.id, result).to_json(); } -void server::send_request(const std::string& session_id, const std::string& method, const json& params) { +void server::send_jsonrpc(const std::string& session_id, const json& message) { // Check if session ID is valid if (session_id.empty()) { - LOG_WARNING("Cannot send request to empty session_id"); + LOG_WARNING("Cannot send message to empty session_id"); return; } - - // Check if the method is ping or logging - bool is_allowed_before_init = (method == "ping" || method == "logging"); - - // Check if client is initialized or if this is an allowed method - if (!is_allowed_before_init && !is_session_initialized(session_id)) { - LOG_WARNING("Cannot send ", method, " request to session ", session_id, " before it is initialized"); - return; - } - - // Create request - request req = request::create(method, params); - + // Get session dispatcher std::shared_ptr dispatcher; { @@ -828,16 +773,20 @@ void server::send_request(const std::string& session_id, const std::string& meth return; } - // Send request + // Send message std::stringstream ss; - ss << "event: message\r\ndata: " << req.to_json().dump() << "\r\n\r\n"; + ss << "event: message\r\ndata: " << message.dump() << "\r\n\r\n"; bool result = dispatcher->send_event(ss.str()); if (!result) { - LOG_ERROR("Failed to send request to session: ", session_id); + LOG_ERROR("Failed to send message to session: ", session_id); } } +void server::send_request(const std::string& session_id, const request& req) { + send_jsonrpc(session_id, req.to_json()); +} + bool server::is_session_initialized(const std::string& session_id) const { // Check if session ID is valid if (session_id.empty()) { @@ -933,18 +882,59 @@ void server::check_inactive_sessions() { for (const auto& session_id : sessions_to_close) { LOG_INFO("Closing inactive session: ", session_id); + close_session(session_id); + } +} + +void server::set_mount_point(const std::string& path, const std::string& root) { + http_server_->set_mount_point(path.c_str(), root.c_str()); +} + +void server::close_session(const std::string& session_id) { + // Clean up resources safely + try { + for (const auto& [key, handler] : session_cleanup_handler_) { + handler(key); + } + + // Copy resources to be processed std::shared_ptr dispatcher_to_close; + std::unique_ptr thread_to_release; + { std::lock_guard lock(mutex_); - auto disp_it = session_dispatchers_.find(session_id); - if (disp_it != session_dispatchers_.end()) { - dispatcher_to_close = disp_it->second; + + // Get dispatcher pointer + auto dispatcher_it = session_dispatchers_.find(session_id); + if (dispatcher_it != session_dispatchers_.end()) { + dispatcher_to_close = dispatcher_it->second; + session_dispatchers_.erase(dispatcher_it); } + + // Get thread pointer + auto thread_it = sse_threads_.find(session_id); + if (thread_it != sse_threads_.end()) { + thread_to_release = std::move(thread_it->second); + sse_threads_.erase(thread_it); + } + + // Clean up initialization status + session_initialized_.erase(session_id); } - if (dispatcher_to_close) { + // Close dispatcher outside the lock + if (dispatcher_to_close && !dispatcher_to_close->is_closed()) { dispatcher_to_close->close(); } + + // Release thread resources + if (thread_to_release) { + thread_to_release.release(); + } + } catch (const std::exception& e) { + LOG_WARNING("Exception while cleaning up session resources: ", session_id, ", ", e.what()); + } catch (...) { + LOG_WARNING("Unknown exception while cleaning up session resources: ", session_id); } } diff --git a/src/mcp_stdio_client.cpp b/src/mcp_stdio_client.cpp index 27fa610..01c88aa 100644 --- a/src/mcp_stdio_client.cpp +++ b/src/mcp_stdio_client.cpp @@ -820,7 +820,7 @@ json stdio_client::send_jsonrpc(const request& req) { } // Wait for response, set timeout - const auto timeout = std::chrono::seconds(30); + const auto timeout = std::chrono::seconds(60); auto status = response_future.wait_for(timeout); if (status == std::future_status::ready) { diff --git a/test/mcp_test.cpp b/test/mcp_test.cpp index f9258f6..100599a 100644 --- a/test/mcp_test.cpp +++ b/test/mcp_test.cpp @@ -109,24 +109,6 @@ public: }; server_->set_capabilities(server_capabilities); - // Register initialize method handler - server_->register_method("initialize", [server_capabilities](const json& params) -> json { - // Verify initialize request parameters - EXPECT_EQ(params["protocolVersion"], MCP_VERSION); - EXPECT_TRUE(params.contains("capabilities")); - EXPECT_TRUE(params.contains("clientInfo")); - - // Return initialize response - return { - {"protocolVersion", MCP_VERSION}, - {"capabilities", server_capabilities}, - {"serverInfo", { - {"name", "TestServer"}, - {"version", "1.0.0"} - }} - }; - }); - // Start server (non-blocking mode) server_->start(false); @@ -545,7 +527,7 @@ public: }; // Register tool - server_->register_tool(test_tool, [](const json& params) -> json { + server_->register_tool(test_tool, [](const json& params, const std::string& /* session_id */) -> json { // Simple tool implementation std::string location = params["location"]; return { @@ -560,7 +542,7 @@ public: }); // Register tools list method - server_->register_method("tools/list", [](const json& params) -> json { + server_->register_method("tools/list", [](const json& params, const std::string& /* session_id */) -> json { return { {"tools", json::array({ { @@ -583,7 +565,7 @@ public: }); // Register tools call method - server_->register_method("tools/call", [](const json& params) -> json { + server_->register_method("tools/call", [](const json& params, const std::string& /* session_id */) -> json { // Verify parameters EXPECT_EQ(params["name"], "get_weather"); EXPECT_EQ(params["arguments"]["location"], "New York");