mcp_server: explicitly catch session_id for method/tool/notification/auth handlers; add session cleanup handlers

main
hkr04 2025-04-07 11:40:58 +08:00
parent 21dc5cb144
commit 88237f2eaa
6 changed files with 131 additions and 135 deletions

View File

@ -18,7 +18,7 @@
#include <algorithm> #include <algorithm>
// Tool handler for getting current time // Tool handler for getting current time
mcp::json get_time_handler(const mcp::json& params) { mcp::json get_time_handler(const mcp::json& params, const std::string& /* session_id */) {
auto now = std::chrono::system_clock::now(); auto now = std::chrono::system_clock::now();
auto time_t_now = std::chrono::system_clock::to_time_t(now); auto time_t_now = std::chrono::system_clock::to_time_t(now);
@ -37,7 +37,7 @@ mcp::json get_time_handler(const mcp::json& params) {
} }
// Echo tool handler // Echo tool handler
mcp::json echo_handler(const mcp::json& params) { mcp::json echo_handler(const mcp::json& params, const std::string& /* session_id */) {
mcp::json result = params; mcp::json result = params;
if (params.contains("text")) { if (params.contains("text")) {
@ -63,7 +63,7 @@ mcp::json echo_handler(const mcp::json& params) {
} }
// Calculator tool handler // Calculator tool handler
mcp::json calculator_handler(const mcp::json& params) { mcp::json calculator_handler(const mcp::json& params, const std::string& /* session_id */) {
if (!params.contains("operation")) { if (!params.contains("operation")) {
throw mcp::mcp_exception(mcp::error_code::invalid_params, "Missing 'operation' parameter"); throw mcp::mcp_exception(mcp::error_code::invalid_params, "Missing 'operation' parameter");
} }
@ -107,7 +107,7 @@ mcp::json calculator_handler(const mcp::json& params) {
} }
// Custom API endpoint handler // Custom API endpoint handler
mcp::json hello_handler(const mcp::json& params) { mcp::json hello_handler(const mcp::json& params, const std::string& /* session_id */) {
std::string name = params.contains("name") ? params["name"].get<std::string>() : "World"; std::string name = params.contains("name") ? params["name"].get<std::string>() : "World";
return { return {
{ {

View File

@ -33,6 +33,12 @@
namespace mcp { namespace mcp {
using method_handler = std::function<json(const json&, const std::string&)>;
using tool_handler = method_handler;
using notification_handler = std::function<void(const json&, const std::string&)>;
using auth_handler = std::function<bool(const std::string&, const std::string&)>;
using session_cleanup_handler = std::function<void(const std::string&)>;
class event_dispatcher { class event_dispatcher {
public: public:
event_dispatcher() { event_dispatcher() {
@ -223,14 +229,14 @@ public:
* @param method The method name * @param method The method name
* @param handler The function to call when the method is invoked * @param handler The function to call when the method is invoked
*/ */
void register_method(const std::string& method, std::function<json(const json&)> handler); void register_method(const std::string& method, method_handler handler);
/** /**
* @brief Register a notification handler * @brief Register a notification handler
* @param method The notification method name * @param method The notification method name
* @param handler The function to call when the notification is received * @param handler The function to call when the notification is received
*/ */
void register_notification(const std::string& method, std::function<void(const json&)> handler); void register_notification(const std::string& method, notification_handler handler);
/** /**
* @brief Register a resource * @brief Register a resource
@ -245,6 +251,13 @@ public:
* @param handler The function to call when the tool is invoked * @param handler The function to call when the tool is invoked
*/ */
void register_tool(const tool& tool, tool_handler handler); void register_tool(const tool& tool, tool_handler handler);
/**
* @brief Register a session cleanup handler
* @param key Tool or resource name to be cleaned up
* @param handler The function to call when the session is closed
*/
void register_session_cleanup(const std::string& key, session_cleanup_handler handler);
/** /**
* @brief Get the list of available tools * @brief Get the list of available tools
@ -255,19 +268,24 @@ public:
/** /**
* @brief Set authentication handler * @brief Set authentication handler
* @param handler Function that takes a token and returns true if valid * @param handler Function that takes a token and returns true if valid
* @note The handler should return true if the token is valid, otherwise false
* @note Not used in the current implementation
*/ */
void set_auth_handler(std::function<bool(const std::string&)> handler); void set_auth_handler(auth_handler handler);
/** /**
* @brief Send a request to a client * @brief Send a request (or notification) to a client
* @param session_id The session ID of the client * @param session_id The session ID of the client
* @param method The method to call * @param req The request to send
* @param params The parameters to pass
*
* This method will only send requests other than ping and logging
* after the client has sent the initialized notification.
*/ */
void send_request(const std::string& session_id, const std::string& method, const json& params = json::object()); void send_request(const std::string& session_id, const request& req);
/**
* @brief Set mount point for server
* @param path The path to mount the resource at
* @param root The root directory to mount
*/
void set_mount_point(const std::string& path, const std::string& root);
private: private:
std::string host_; std::string host_;
@ -296,10 +314,10 @@ private:
std::string msg_endpoint_; std::string msg_endpoint_;
// Method handlers // Method handlers
std::map<std::string, std::function<json(const json&)>> method_handlers_; std::map<std::string, method_handler> method_handlers_;
// Notification handlers // Notification handlers
std::map<std::string, std::function<void(const json&)>> notification_handlers_; std::map<std::string, notification_handler> notification_handlers_;
// Resources map (path -> resource) // Resources map (path -> resource)
std::map<std::string, std::shared_ptr<resource>> resources_; std::map<std::string, std::shared_ptr<resource>> resources_;
@ -308,7 +326,7 @@ private:
std::map<std::string, std::pair<tool, tool_handler>> tools_; std::map<std::string, std::pair<tool, tool_handler>> tools_;
// Authentication handler // Authentication handler
std::function<bool(const std::string&)> auth_handler_; auth_handler auth_handler_;
// Mutex for thread safety // Mutex for thread safety
mutable std::mutex mutex_; mutable std::mutex mutex_;
@ -327,6 +345,9 @@ private:
// Handle incoming JSON-RPC requests // Handle incoming JSON-RPC requests
void handle_jsonrpc(const httplib::Request& req, httplib::Response& res); void handle_jsonrpc(const httplib::Request& req, httplib::Response& res);
// Send a JSON-RPC message to a client
void send_jsonrpc(const std::string& session_id, const json& message);
// Process a JSON-RPC request // Process a JSON-RPC request
json process_request(const request& req, const std::string& session_id); json process_request(const request& req, const std::string& session_id);
@ -345,10 +366,10 @@ private:
// Auxiliary function to create an async handler from a regular handler // Auxiliary function to create an async handler from a regular handler
template<typename F> template<typename F>
std::function<std::future<json>(const json&)> make_async_handler(F&& handler) { std::function<std::future<json>(const json&, const std::string&)> make_async_handler(F&& handler) {
return [handler = std::forward<F>(handler)](const json& params) -> std::future<json> { return [handler = std::forward<F>(handler)](const json& params, const std::string& session_id) -> std::future<json> {
return std::async(std::launch::async, [handler, params]() -> json { return std::async(std::launch::async, [handler, params, session_id]() -> json {
return handler(params); return handler(params, session_id);
}); });
}; };
} }
@ -369,6 +390,12 @@ private:
// Session management and maintenance // Session management and maintenance
void check_inactive_sessions(); void check_inactive_sessions();
std::unique_ptr<std::thread> maintenance_thread_; std::unique_ptr<std::thread> maintenance_thread_;
// Session cleanup handler
std::map<std::string, session_cleanup_handler> session_cleanup_handler_;
// Close session
void close_session(const std::string& session_id);
}; };
} // namespace mcp } // namespace mcp

View File

@ -19,9 +19,6 @@
namespace mcp { namespace mcp {
// Tool handler function type
using tool_handler = std::function<json(const json&)>;
// MCP Tool definition // MCP Tool definition
struct tool { struct tool {
std::string name; std::string name;

View File

@ -134,15 +134,9 @@ void server::stop() {
session_initialized_.clear(); session_initialized_.clear();
} }
// Close all dispatchers outside the lock // Close all sessions
for (auto& dispatcher : dispatchers_to_close) { for (const auto& [session_id, _] : session_dispatchers_) {
if (dispatcher && !dispatcher->is_closed()) { close_session(session_id);
try {
dispatcher->close();
} catch (...) {
// Ignore exceptions
}
}
} }
// Give threads some time to handle close events // Give threads some time to handle close events
@ -239,12 +233,12 @@ void server::set_capabilities(const json& capabilities) {
capabilities_ = capabilities; capabilities_ = capabilities;
} }
void server::register_method(const std::string& method, std::function<json(const json&)> handler) { void server::register_method(const std::string& method, method_handler handler) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
method_handlers_[method] = handler; method_handlers_[method] = handler;
} }
void server::register_notification(const std::string& method, std::function<void(const json&)> handler) { void server::register_notification(const std::string& method, notification_handler handler) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
notification_handlers_[method] = handler; notification_handlers_[method] = handler;
} }
@ -255,7 +249,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
// Register methods for resource access // Register methods for resource access
if (method_handlers_.find("resources/read") == method_handlers_.end()) { if (method_handlers_.find("resources/read") == method_handlers_.end()) {
method_handlers_["resources/read"] = [this](const json& params) -> json { method_handlers_["resources/read"] = [this](const json& params, const std::string& session_id) -> json {
if (!params.contains("uri")) { if (!params.contains("uri")) {
throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter"); throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter");
} }
@ -276,7 +270,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
} }
if (method_handlers_.find("resources/list") == method_handlers_.end()) { if (method_handlers_.find("resources/list") == method_handlers_.end()) {
method_handlers_["resources/list"] = [this](const json& params) -> json { method_handlers_["resources/list"] = [this](const json& params, const std::string& session_id) -> json {
json resources = json::array(); json resources = json::array();
for (const auto& [uri, res] : resources_) { for (const auto& [uri, res] : resources_) {
@ -296,7 +290,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
} }
if (method_handlers_.find("resources/subscribe") == method_handlers_.end()) { if (method_handlers_.find("resources/subscribe") == method_handlers_.end()) {
method_handlers_["resources/subscribe"] = [this](const json& params) -> json { method_handlers_["resources/subscribe"] = [this](const json& params, const std::string& session_id) -> json {
if (!params.contains("uri")) { if (!params.contains("uri")) {
throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter"); throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter");
} }
@ -312,7 +306,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
} }
if (method_handlers_.find("resources/templates/list") == method_handlers_.end()) { if (method_handlers_.find("resources/templates/list") == method_handlers_.end()) {
method_handlers_["resources/templates/list"] = [this](const json& params) -> json { method_handlers_["resources/templates/list"] = [this](const json& params, const std::string& session_id) -> json {
return json::array(); return json::array();
}; };
} }
@ -324,7 +318,7 @@ void server::register_tool(const tool& tool, tool_handler handler) {
// Register methods for tool listing and calling // Register methods for tool listing and calling
if (method_handlers_.find("tools/list") == method_handlers_.end()) { if (method_handlers_.find("tools/list") == method_handlers_.end()) {
method_handlers_["tools/list"] = [this](const json& params) -> json { method_handlers_["tools/list"] = [this](const json& params, const std::string& session_id) -> json {
json tools_json = json::array(); json tools_json = json::array();
for (const auto& [name, tool_pair] : tools_) { for (const auto& [name, tool_pair] : tools_) {
tools_json.push_back(tool_pair.first.to_json()); tools_json.push_back(tool_pair.first.to_json());
@ -334,7 +328,7 @@ void server::register_tool(const tool& tool, tool_handler handler) {
} }
if (method_handlers_.find("tools/call") == method_handlers_.end()) { if (method_handlers_.find("tools/call") == method_handlers_.end()) {
method_handlers_["tools/call"] = [this](const json& params) -> json { method_handlers_["tools/call"] = [this](const json& params, const std::string& session_id) -> json {
if (!params.contains("name")) { if (!params.contains("name")) {
throw mcp_exception(error_code::invalid_params, "Missing 'name' parameter"); throw mcp_exception(error_code::invalid_params, "Missing 'name' parameter");
} }
@ -352,7 +346,7 @@ void server::register_tool(const tool& tool, tool_handler handler) {
}; };
try { try {
tool_result["content"] = it->second.second(tool_args); tool_result["content"] = it->second.second(tool_args, session_id);
} catch (...) { } catch (...) {
tool_result["isError"] = true; tool_result["isError"] = true;
} }
@ -362,6 +356,11 @@ void server::register_tool(const tool& tool, tool_handler handler) {
} }
} }
void server::register_session_cleanup(const std::string& key, session_cleanup_handler handler) {
std::lock_guard<std::mutex> lock(mutex_);
session_cleanup_handler_[key] = handler;
}
std::vector<tool> server::get_tools() const { std::vector<tool> server::get_tools() const {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
std::vector<tool> tools; std::vector<tool> tools;
@ -373,7 +372,7 @@ std::vector<tool> server::get_tools() const {
return tools; return tools;
} }
void server::set_auth_handler(std::function<bool(const std::string&)> handler) { void server::set_auth_handler(auth_handler handler) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auth_handler_ = handler; auth_handler_ = handler;
} }
@ -442,47 +441,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
LOG_ERROR("SSE session thread exception: ", session_id, ", ", e.what()); LOG_ERROR("SSE session thread exception: ", session_id, ", ", e.what());
} }
// Clean up resources safely close_session(session_id);
try {
// Copy resources to be processed
std::shared_ptr<event_dispatcher> dispatcher_to_close;
std::unique_ptr<std::thread> thread_to_release;
{
std::lock_guard<std::mutex> lock(mutex_);
// Get dispatcher pointer
auto dispatcher_it = session_dispatchers_.find(session_id);
if (dispatcher_it != session_dispatchers_.end()) {
dispatcher_to_close = dispatcher_it->second;
session_dispatchers_.erase(dispatcher_it);
}
// Get thread pointer
auto thread_it = sse_threads_.find(session_id);
if (thread_it != sse_threads_.end()) {
thread_to_release = std::move(thread_it->second);
sse_threads_.erase(thread_it);
}
// Clean up initialization status
session_initialized_.erase(session_id);
}
// Close dispatcher outside the lock
if (dispatcher_to_close && !dispatcher_to_close->is_closed()) {
dispatcher_to_close->close();
}
// Release thread resources
if (thread_to_release) {
thread_to_release.release();
}
} catch (const std::exception& e) {
LOG_WARNING("Exception while cleaning up session resources: ", session_id, ", ", e.what());
} catch (...) {
LOG_WARNING("Unknown exception while cleaning up session resources: ", session_id);
}
}); });
// Store thread // Store thread
@ -507,8 +466,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
if (!result) { if (!result) {
LOG_WARNING("Failed to wait for event, closing connection: ", session_id); LOG_WARNING("Failed to wait for event, closing connection: ", session_id);
// Close dispatcher directly, no need to lock close_session(session_id);
session_dispatcher->close();
return false; return false;
} }
@ -520,8 +478,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
} catch (const std::exception& e) { } catch (const std::exception& e) {
LOG_ERROR("SSE content provider exception: ", e.what()); LOG_ERROR("SSE content provider exception: ", e.what());
// Close dispatcher directly, no need to lock close_session(session_id);
session_dispatcher->close();
return false; return false;
} }
@ -673,7 +630,7 @@ json server::process_request(const request& req, const std::string& session_id)
} }
// Find registered method handler // Find registered method handler
std::function<json(const json&)> handler; method_handler handler;
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto it = method_handlers_.find(req.method); auto it = method_handlers_.find(req.method);
@ -685,8 +642,8 @@ json server::process_request(const request& req, const std::string& session_id)
if (handler) { if (handler) {
// Call handler // Call handler
LOG_INFO("Calling method handler: ", req.method); LOG_INFO("Calling method handler: ", req.method);
auto future = thread_pool_.enqueue([handler, params = req.params]() -> json { auto future = thread_pool_.enqueue([handler, params = req.params, session_id]() -> json {
return handler(params); return handler(params, session_id);
}); });
json result = future.get(); json result = future.get();
@ -791,25 +748,13 @@ json server::handle_initialize(const request& req, const std::string& session_id
return response::create_success(req.id, result).to_json(); return response::create_success(req.id, result).to_json();
} }
void server::send_request(const std::string& session_id, const std::string& method, const json& params) { void server::send_jsonrpc(const std::string& session_id, const json& message) {
// Check if session ID is valid // Check if session ID is valid
if (session_id.empty()) { if (session_id.empty()) {
LOG_WARNING("Cannot send request to empty session_id"); LOG_WARNING("Cannot send message to empty session_id");
return; return;
} }
// Check if the method is ping or logging
bool is_allowed_before_init = (method == "ping" || method == "logging");
// Check if client is initialized or if this is an allowed method
if (!is_allowed_before_init && !is_session_initialized(session_id)) {
LOG_WARNING("Cannot send ", method, " request to session ", session_id, " before it is initialized");
return;
}
// Create request
request req = request::create(method, params);
// Get session dispatcher // Get session dispatcher
std::shared_ptr<event_dispatcher> dispatcher; std::shared_ptr<event_dispatcher> dispatcher;
{ {
@ -828,16 +773,20 @@ void server::send_request(const std::string& session_id, const std::string& meth
return; return;
} }
// Send request // Send message
std::stringstream ss; std::stringstream ss;
ss << "event: message\r\ndata: " << req.to_json().dump() << "\r\n\r\n"; ss << "event: message\r\ndata: " << message.dump() << "\r\n\r\n";
bool result = dispatcher->send_event(ss.str()); bool result = dispatcher->send_event(ss.str());
if (!result) { if (!result) {
LOG_ERROR("Failed to send request to session: ", session_id); LOG_ERROR("Failed to send message to session: ", session_id);
} }
} }
void server::send_request(const std::string& session_id, const request& req) {
send_jsonrpc(session_id, req.to_json());
}
bool server::is_session_initialized(const std::string& session_id) const { bool server::is_session_initialized(const std::string& session_id) const {
// Check if session ID is valid // Check if session ID is valid
if (session_id.empty()) { if (session_id.empty()) {
@ -933,18 +882,59 @@ void server::check_inactive_sessions() {
for (const auto& session_id : sessions_to_close) { for (const auto& session_id : sessions_to_close) {
LOG_INFO("Closing inactive session: ", session_id); LOG_INFO("Closing inactive session: ", session_id);
close_session(session_id);
}
}
void server::set_mount_point(const std::string& path, const std::string& root) {
http_server_->set_mount_point(path.c_str(), root.c_str());
}
void server::close_session(const std::string& session_id) {
// Clean up resources safely
try {
for (const auto& [key, handler] : session_cleanup_handler_) {
handler(key);
}
// Copy resources to be processed
std::shared_ptr<event_dispatcher> dispatcher_to_close; std::shared_ptr<event_dispatcher> dispatcher_to_close;
std::unique_ptr<std::thread> thread_to_release;
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto disp_it = session_dispatchers_.find(session_id);
if (disp_it != session_dispatchers_.end()) { // Get dispatcher pointer
dispatcher_to_close = disp_it->second; auto dispatcher_it = session_dispatchers_.find(session_id);
if (dispatcher_it != session_dispatchers_.end()) {
dispatcher_to_close = dispatcher_it->second;
session_dispatchers_.erase(dispatcher_it);
} }
// Get thread pointer
auto thread_it = sse_threads_.find(session_id);
if (thread_it != sse_threads_.end()) {
thread_to_release = std::move(thread_it->second);
sse_threads_.erase(thread_it);
}
// Clean up initialization status
session_initialized_.erase(session_id);
} }
if (dispatcher_to_close) { // Close dispatcher outside the lock
if (dispatcher_to_close && !dispatcher_to_close->is_closed()) {
dispatcher_to_close->close(); dispatcher_to_close->close();
} }
// Release thread resources
if (thread_to_release) {
thread_to_release.release();
}
} catch (const std::exception& e) {
LOG_WARNING("Exception while cleaning up session resources: ", session_id, ", ", e.what());
} catch (...) {
LOG_WARNING("Unknown exception while cleaning up session resources: ", session_id);
} }
} }

View File

@ -820,7 +820,7 @@ json stdio_client::send_jsonrpc(const request& req) {
} }
// Wait for response, set timeout // Wait for response, set timeout
const auto timeout = std::chrono::seconds(30); const auto timeout = std::chrono::seconds(60);
auto status = response_future.wait_for(timeout); auto status = response_future.wait_for(timeout);
if (status == std::future_status::ready) { if (status == std::future_status::ready) {

View File

@ -109,24 +109,6 @@ public:
}; };
server_->set_capabilities(server_capabilities); server_->set_capabilities(server_capabilities);
// Register initialize method handler
server_->register_method("initialize", [server_capabilities](const json& params) -> json {
// Verify initialize request parameters
EXPECT_EQ(params["protocolVersion"], MCP_VERSION);
EXPECT_TRUE(params.contains("capabilities"));
EXPECT_TRUE(params.contains("clientInfo"));
// Return initialize response
return {
{"protocolVersion", MCP_VERSION},
{"capabilities", server_capabilities},
{"serverInfo", {
{"name", "TestServer"},
{"version", "1.0.0"}
}}
};
});
// Start server (non-blocking mode) // Start server (non-blocking mode)
server_->start(false); server_->start(false);
@ -545,7 +527,7 @@ public:
}; };
// Register tool // Register tool
server_->register_tool(test_tool, [](const json& params) -> json { server_->register_tool(test_tool, [](const json& params, const std::string& /* session_id */) -> json {
// Simple tool implementation // Simple tool implementation
std::string location = params["location"]; std::string location = params["location"];
return { return {
@ -560,7 +542,7 @@ public:
}); });
// Register tools list method // Register tools list method
server_->register_method("tools/list", [](const json& params) -> json { server_->register_method("tools/list", [](const json& params, const std::string& /* session_id */) -> json {
return { return {
{"tools", json::array({ {"tools", json::array({
{ {
@ -583,7 +565,7 @@ public:
}); });
// Register tools call method // Register tools call method
server_->register_method("tools/call", [](const json& params) -> json { server_->register_method("tools/call", [](const json& params, const std::string& /* session_id */) -> json {
// Verify parameters // Verify parameters
EXPECT_EQ(params["name"], "get_weather"); EXPECT_EQ(params["name"], "get_weather");
EXPECT_EQ(params["arguments"]["location"], "New York"); EXPECT_EQ(params["arguments"]["location"], "New York");