mcp_server: explicitly catch session_id for method/tool/notification/auth handlers; add session cleanup handlers

main
hkr04 2025-04-07 11:40:58 +08:00
parent 21dc5cb144
commit 88237f2eaa
6 changed files with 131 additions and 135 deletions

View File

@ -18,7 +18,7 @@
#include <algorithm>
// Tool handler for getting current time
mcp::json get_time_handler(const mcp::json& params) {
mcp::json get_time_handler(const mcp::json& params, const std::string& /* session_id */) {
auto now = std::chrono::system_clock::now();
auto time_t_now = std::chrono::system_clock::to_time_t(now);
@ -37,7 +37,7 @@ mcp::json get_time_handler(const mcp::json& params) {
}
// Echo tool handler
mcp::json echo_handler(const mcp::json& params) {
mcp::json echo_handler(const mcp::json& params, const std::string& /* session_id */) {
mcp::json result = params;
if (params.contains("text")) {
@ -63,7 +63,7 @@ mcp::json echo_handler(const mcp::json& params) {
}
// Calculator tool handler
mcp::json calculator_handler(const mcp::json& params) {
mcp::json calculator_handler(const mcp::json& params, const std::string& /* session_id */) {
if (!params.contains("operation")) {
throw mcp::mcp_exception(mcp::error_code::invalid_params, "Missing 'operation' parameter");
}
@ -107,7 +107,7 @@ mcp::json calculator_handler(const mcp::json& params) {
}
// Custom API endpoint handler
mcp::json hello_handler(const mcp::json& params) {
mcp::json hello_handler(const mcp::json& params, const std::string& /* session_id */) {
std::string name = params.contains("name") ? params["name"].get<std::string>() : "World";
return {
{

View File

@ -33,6 +33,12 @@
namespace mcp {
using method_handler = std::function<json(const json&, const std::string&)>;
using tool_handler = method_handler;
using notification_handler = std::function<void(const json&, const std::string&)>;
using auth_handler = std::function<bool(const std::string&, const std::string&)>;
using session_cleanup_handler = std::function<void(const std::string&)>;
class event_dispatcher {
public:
event_dispatcher() {
@ -223,14 +229,14 @@ public:
* @param method The method name
* @param handler The function to call when the method is invoked
*/
void register_method(const std::string& method, std::function<json(const json&)> handler);
void register_method(const std::string& method, method_handler handler);
/**
* @brief Register a notification handler
* @param method The notification method name
* @param handler The function to call when the notification is received
*/
void register_notification(const std::string& method, std::function<void(const json&)> handler);
void register_notification(const std::string& method, notification_handler handler);
/**
* @brief Register a resource
@ -246,6 +252,13 @@ public:
*/
void register_tool(const tool& tool, tool_handler handler);
/**
* @brief Register a session cleanup handler
* @param key Tool or resource name to be cleaned up
* @param handler The function to call when the session is closed
*/
void register_session_cleanup(const std::string& key, session_cleanup_handler handler);
/**
* @brief Get the list of available tools
* @return JSON array of available tools
@ -255,19 +268,24 @@ public:
/**
* @brief Set authentication handler
* @param handler Function that takes a token and returns true if valid
* @note The handler should return true if the token is valid, otherwise false
* @note Not used in the current implementation
*/
void set_auth_handler(std::function<bool(const std::string&)> handler);
void set_auth_handler(auth_handler handler);
/**
* @brief Send a request to a client
* @brief Send a request (or notification) to a client
* @param session_id The session ID of the client
* @param method The method to call
* @param params The parameters to pass
*
* This method will only send requests other than ping and logging
* after the client has sent the initialized notification.
* @param req The request to send
*/
void send_request(const std::string& session_id, const std::string& method, const json& params = json::object());
void send_request(const std::string& session_id, const request& req);
/**
* @brief Set mount point for server
* @param path The path to mount the resource at
* @param root The root directory to mount
*/
void set_mount_point(const std::string& path, const std::string& root);
private:
std::string host_;
@ -296,10 +314,10 @@ private:
std::string msg_endpoint_;
// Method handlers
std::map<std::string, std::function<json(const json&)>> method_handlers_;
std::map<std::string, method_handler> method_handlers_;
// Notification handlers
std::map<std::string, std::function<void(const json&)>> notification_handlers_;
std::map<std::string, notification_handler> notification_handlers_;
// Resources map (path -> resource)
std::map<std::string, std::shared_ptr<resource>> resources_;
@ -308,7 +326,7 @@ private:
std::map<std::string, std::pair<tool, tool_handler>> tools_;
// Authentication handler
std::function<bool(const std::string&)> auth_handler_;
auth_handler auth_handler_;
// Mutex for thread safety
mutable std::mutex mutex_;
@ -328,6 +346,9 @@ private:
// Handle incoming JSON-RPC requests
void handle_jsonrpc(const httplib::Request& req, httplib::Response& res);
// Send a JSON-RPC message to a client
void send_jsonrpc(const std::string& session_id, const json& message);
// Process a JSON-RPC request
json process_request(const request& req, const std::string& session_id);
@ -345,10 +366,10 @@ private:
// Auxiliary function to create an async handler from a regular handler
template<typename F>
std::function<std::future<json>(const json&)> make_async_handler(F&& handler) {
return [handler = std::forward<F>(handler)](const json& params) -> std::future<json> {
return std::async(std::launch::async, [handler, params]() -> json {
return handler(params);
std::function<std::future<json>(const json&, const std::string&)> make_async_handler(F&& handler) {
return [handler = std::forward<F>(handler)](const json& params, const std::string& session_id) -> std::future<json> {
return std::async(std::launch::async, [handler, params, session_id]() -> json {
return handler(params, session_id);
});
};
}
@ -369,6 +390,12 @@ private:
// Session management and maintenance
void check_inactive_sessions();
std::unique_ptr<std::thread> maintenance_thread_;
// Session cleanup handler
std::map<std::string, session_cleanup_handler> session_cleanup_handler_;
// Close session
void close_session(const std::string& session_id);
};
} // namespace mcp

View File

@ -19,9 +19,6 @@
namespace mcp {
// Tool handler function type
using tool_handler = std::function<json(const json&)>;
// MCP Tool definition
struct tool {
std::string name;

View File

@ -134,15 +134,9 @@ void server::stop() {
session_initialized_.clear();
}
// Close all dispatchers outside the lock
for (auto& dispatcher : dispatchers_to_close) {
if (dispatcher && !dispatcher->is_closed()) {
try {
dispatcher->close();
} catch (...) {
// Ignore exceptions
}
}
// Close all sessions
for (const auto& [session_id, _] : session_dispatchers_) {
close_session(session_id);
}
// Give threads some time to handle close events
@ -239,12 +233,12 @@ void server::set_capabilities(const json& capabilities) {
capabilities_ = capabilities;
}
void server::register_method(const std::string& method, std::function<json(const json&)> handler) {
void server::register_method(const std::string& method, method_handler handler) {
std::lock_guard<std::mutex> lock(mutex_);
method_handlers_[method] = handler;
}
void server::register_notification(const std::string& method, std::function<void(const json&)> handler) {
void server::register_notification(const std::string& method, notification_handler handler) {
std::lock_guard<std::mutex> lock(mutex_);
notification_handlers_[method] = handler;
}
@ -255,7 +249,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
// Register methods for resource access
if (method_handlers_.find("resources/read") == method_handlers_.end()) {
method_handlers_["resources/read"] = [this](const json& params) -> json {
method_handlers_["resources/read"] = [this](const json& params, const std::string& session_id) -> json {
if (!params.contains("uri")) {
throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter");
}
@ -276,7 +270,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
}
if (method_handlers_.find("resources/list") == method_handlers_.end()) {
method_handlers_["resources/list"] = [this](const json& params) -> json {
method_handlers_["resources/list"] = [this](const json& params, const std::string& session_id) -> json {
json resources = json::array();
for (const auto& [uri, res] : resources_) {
@ -296,7 +290,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
}
if (method_handlers_.find("resources/subscribe") == method_handlers_.end()) {
method_handlers_["resources/subscribe"] = [this](const json& params) -> json {
method_handlers_["resources/subscribe"] = [this](const json& params, const std::string& session_id) -> json {
if (!params.contains("uri")) {
throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter");
}
@ -312,7 +306,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
}
if (method_handlers_.find("resources/templates/list") == method_handlers_.end()) {
method_handlers_["resources/templates/list"] = [this](const json& params) -> json {
method_handlers_["resources/templates/list"] = [this](const json& params, const std::string& session_id) -> json {
return json::array();
};
}
@ -324,7 +318,7 @@ void server::register_tool(const tool& tool, tool_handler handler) {
// Register methods for tool listing and calling
if (method_handlers_.find("tools/list") == method_handlers_.end()) {
method_handlers_["tools/list"] = [this](const json& params) -> json {
method_handlers_["tools/list"] = [this](const json& params, const std::string& session_id) -> json {
json tools_json = json::array();
for (const auto& [name, tool_pair] : tools_) {
tools_json.push_back(tool_pair.first.to_json());
@ -334,7 +328,7 @@ void server::register_tool(const tool& tool, tool_handler handler) {
}
if (method_handlers_.find("tools/call") == method_handlers_.end()) {
method_handlers_["tools/call"] = [this](const json& params) -> json {
method_handlers_["tools/call"] = [this](const json& params, const std::string& session_id) -> json {
if (!params.contains("name")) {
throw mcp_exception(error_code::invalid_params, "Missing 'name' parameter");
}
@ -352,7 +346,7 @@ void server::register_tool(const tool& tool, tool_handler handler) {
};
try {
tool_result["content"] = it->second.second(tool_args);
tool_result["content"] = it->second.second(tool_args, session_id);
} catch (...) {
tool_result["isError"] = true;
}
@ -362,6 +356,11 @@ void server::register_tool(const tool& tool, tool_handler handler) {
}
}
void server::register_session_cleanup(const std::string& key, session_cleanup_handler handler) {
std::lock_guard<std::mutex> lock(mutex_);
session_cleanup_handler_[key] = handler;
}
std::vector<tool> server::get_tools() const {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<tool> tools;
@ -373,7 +372,7 @@ std::vector<tool> server::get_tools() const {
return tools;
}
void server::set_auth_handler(std::function<bool(const std::string&)> handler) {
void server::set_auth_handler(auth_handler handler) {
std::lock_guard<std::mutex> lock(mutex_);
auth_handler_ = handler;
}
@ -442,47 +441,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
LOG_ERROR("SSE session thread exception: ", session_id, ", ", e.what());
}
// Clean up resources safely
try {
// Copy resources to be processed
std::shared_ptr<event_dispatcher> dispatcher_to_close;
std::unique_ptr<std::thread> thread_to_release;
{
std::lock_guard<std::mutex> lock(mutex_);
// Get dispatcher pointer
auto dispatcher_it = session_dispatchers_.find(session_id);
if (dispatcher_it != session_dispatchers_.end()) {
dispatcher_to_close = dispatcher_it->second;
session_dispatchers_.erase(dispatcher_it);
}
// Get thread pointer
auto thread_it = sse_threads_.find(session_id);
if (thread_it != sse_threads_.end()) {
thread_to_release = std::move(thread_it->second);
sse_threads_.erase(thread_it);
}
// Clean up initialization status
session_initialized_.erase(session_id);
}
// Close dispatcher outside the lock
if (dispatcher_to_close && !dispatcher_to_close->is_closed()) {
dispatcher_to_close->close();
}
// Release thread resources
if (thread_to_release) {
thread_to_release.release();
}
} catch (const std::exception& e) {
LOG_WARNING("Exception while cleaning up session resources: ", session_id, ", ", e.what());
} catch (...) {
LOG_WARNING("Unknown exception while cleaning up session resources: ", session_id);
}
close_session(session_id);
});
// Store thread
@ -507,8 +466,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
if (!result) {
LOG_WARNING("Failed to wait for event, closing connection: ", session_id);
// Close dispatcher directly, no need to lock
session_dispatcher->close();
close_session(session_id);
return false;
}
@ -520,8 +478,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
} catch (const std::exception& e) {
LOG_ERROR("SSE content provider exception: ", e.what());
// Close dispatcher directly, no need to lock
session_dispatcher->close();
close_session(session_id);
return false;
}
@ -673,7 +630,7 @@ json server::process_request(const request& req, const std::string& session_id)
}
// Find registered method handler
std::function<json(const json&)> handler;
method_handler handler;
{
std::lock_guard<std::mutex> lock(mutex_);
auto it = method_handlers_.find(req.method);
@ -685,8 +642,8 @@ json server::process_request(const request& req, const std::string& session_id)
if (handler) {
// Call handler
LOG_INFO("Calling method handler: ", req.method);
auto future = thread_pool_.enqueue([handler, params = req.params]() -> json {
return handler(params);
auto future = thread_pool_.enqueue([handler, params = req.params, session_id]() -> json {
return handler(params, session_id);
});
json result = future.get();
@ -791,25 +748,13 @@ json server::handle_initialize(const request& req, const std::string& session_id
return response::create_success(req.id, result).to_json();
}
void server::send_request(const std::string& session_id, const std::string& method, const json& params) {
void server::send_jsonrpc(const std::string& session_id, const json& message) {
// Check if session ID is valid
if (session_id.empty()) {
LOG_WARNING("Cannot send request to empty session_id");
LOG_WARNING("Cannot send message to empty session_id");
return;
}
// Check if the method is ping or logging
bool is_allowed_before_init = (method == "ping" || method == "logging");
// Check if client is initialized or if this is an allowed method
if (!is_allowed_before_init && !is_session_initialized(session_id)) {
LOG_WARNING("Cannot send ", method, " request to session ", session_id, " before it is initialized");
return;
}
// Create request
request req = request::create(method, params);
// Get session dispatcher
std::shared_ptr<event_dispatcher> dispatcher;
{
@ -828,16 +773,20 @@ void server::send_request(const std::string& session_id, const std::string& meth
return;
}
// Send request
// Send message
std::stringstream ss;
ss << "event: message\r\ndata: " << req.to_json().dump() << "\r\n\r\n";
ss << "event: message\r\ndata: " << message.dump() << "\r\n\r\n";
bool result = dispatcher->send_event(ss.str());
if (!result) {
LOG_ERROR("Failed to send request to session: ", session_id);
LOG_ERROR("Failed to send message to session: ", session_id);
}
}
void server::send_request(const std::string& session_id, const request& req) {
send_jsonrpc(session_id, req.to_json());
}
bool server::is_session_initialized(const std::string& session_id) const {
// Check if session ID is valid
if (session_id.empty()) {
@ -933,18 +882,59 @@ void server::check_inactive_sessions() {
for (const auto& session_id : sessions_to_close) {
LOG_INFO("Closing inactive session: ", session_id);
std::shared_ptr<event_dispatcher> dispatcher_to_close;
{
std::lock_guard<std::mutex> lock(mutex_);
auto disp_it = session_dispatchers_.find(session_id);
if (disp_it != session_dispatchers_.end()) {
dispatcher_to_close = disp_it->second;
}
close_session(session_id);
}
}
void server::set_mount_point(const std::string& path, const std::string& root) {
http_server_->set_mount_point(path.c_str(), root.c_str());
}
void server::close_session(const std::string& session_id) {
// Clean up resources safely
try {
for (const auto& [key, handler] : session_cleanup_handler_) {
handler(key);
}
if (dispatcher_to_close) {
// Copy resources to be processed
std::shared_ptr<event_dispatcher> dispatcher_to_close;
std::unique_ptr<std::thread> thread_to_release;
{
std::lock_guard<std::mutex> lock(mutex_);
// Get dispatcher pointer
auto dispatcher_it = session_dispatchers_.find(session_id);
if (dispatcher_it != session_dispatchers_.end()) {
dispatcher_to_close = dispatcher_it->second;
session_dispatchers_.erase(dispatcher_it);
}
// Get thread pointer
auto thread_it = sse_threads_.find(session_id);
if (thread_it != sse_threads_.end()) {
thread_to_release = std::move(thread_it->second);
sse_threads_.erase(thread_it);
}
// Clean up initialization status
session_initialized_.erase(session_id);
}
// Close dispatcher outside the lock
if (dispatcher_to_close && !dispatcher_to_close->is_closed()) {
dispatcher_to_close->close();
}
// Release thread resources
if (thread_to_release) {
thread_to_release.release();
}
} catch (const std::exception& e) {
LOG_WARNING("Exception while cleaning up session resources: ", session_id, ", ", e.what());
} catch (...) {
LOG_WARNING("Unknown exception while cleaning up session resources: ", session_id);
}
}

View File

@ -820,7 +820,7 @@ json stdio_client::send_jsonrpc(const request& req) {
}
// Wait for response, set timeout
const auto timeout = std::chrono::seconds(30);
const auto timeout = std::chrono::seconds(60);
auto status = response_future.wait_for(timeout);
if (status == std::future_status::ready) {

View File

@ -109,24 +109,6 @@ public:
};
server_->set_capabilities(server_capabilities);
// Register initialize method handler
server_->register_method("initialize", [server_capabilities](const json& params) -> json {
// Verify initialize request parameters
EXPECT_EQ(params["protocolVersion"], MCP_VERSION);
EXPECT_TRUE(params.contains("capabilities"));
EXPECT_TRUE(params.contains("clientInfo"));
// Return initialize response
return {
{"protocolVersion", MCP_VERSION},
{"capabilities", server_capabilities},
{"serverInfo", {
{"name", "TestServer"},
{"version", "1.0.0"}
}}
};
});
// Start server (non-blocking mode)
server_->start(false);
@ -545,7 +527,7 @@ public:
};
// Register tool
server_->register_tool(test_tool, [](const json& params) -> json {
server_->register_tool(test_tool, [](const json& params, const std::string& /* session_id */) -> json {
// Simple tool implementation
std::string location = params["location"];
return {
@ -560,7 +542,7 @@ public:
});
// Register tools list method
server_->register_method("tools/list", [](const json& params) -> json {
server_->register_method("tools/list", [](const json& params, const std::string& /* session_id */) -> json {
return {
{"tools", json::array({
{
@ -583,7 +565,7 @@ public:
});
// Register tools call method
server_->register_method("tools/call", [](const json& params) -> json {
server_->register_method("tools/call", [](const json& params, const std::string& /* session_id */) -> json {
// Verify parameters
EXPECT_EQ(params["name"], "get_weather");
EXPECT_EQ(params["arguments"]["location"], "New York");