mcp_server: explicitly catch session_id for method/tool/notification/auth handlers; add session cleanup handlers
parent
21dc5cb144
commit
88237f2eaa
|
@ -18,7 +18,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
// Tool handler for getting current time
|
// Tool handler for getting current time
|
||||||
mcp::json get_time_handler(const mcp::json& params) {
|
mcp::json get_time_handler(const mcp::json& params, const std::string& /* session_id */) {
|
||||||
auto now = std::chrono::system_clock::now();
|
auto now = std::chrono::system_clock::now();
|
||||||
auto time_t_now = std::chrono::system_clock::to_time_t(now);
|
auto time_t_now = std::chrono::system_clock::to_time_t(now);
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ mcp::json get_time_handler(const mcp::json& params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Echo tool handler
|
// Echo tool handler
|
||||||
mcp::json echo_handler(const mcp::json& params) {
|
mcp::json echo_handler(const mcp::json& params, const std::string& /* session_id */) {
|
||||||
mcp::json result = params;
|
mcp::json result = params;
|
||||||
|
|
||||||
if (params.contains("text")) {
|
if (params.contains("text")) {
|
||||||
|
@ -63,7 +63,7 @@ mcp::json echo_handler(const mcp::json& params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculator tool handler
|
// Calculator tool handler
|
||||||
mcp::json calculator_handler(const mcp::json& params) {
|
mcp::json calculator_handler(const mcp::json& params, const std::string& /* session_id */) {
|
||||||
if (!params.contains("operation")) {
|
if (!params.contains("operation")) {
|
||||||
throw mcp::mcp_exception(mcp::error_code::invalid_params, "Missing 'operation' parameter");
|
throw mcp::mcp_exception(mcp::error_code::invalid_params, "Missing 'operation' parameter");
|
||||||
}
|
}
|
||||||
|
@ -107,7 +107,7 @@ mcp::json calculator_handler(const mcp::json& params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Custom API endpoint handler
|
// Custom API endpoint handler
|
||||||
mcp::json hello_handler(const mcp::json& params) {
|
mcp::json hello_handler(const mcp::json& params, const std::string& /* session_id */) {
|
||||||
std::string name = params.contains("name") ? params["name"].get<std::string>() : "World";
|
std::string name = params.contains("name") ? params["name"].get<std::string>() : "World";
|
||||||
return {
|
return {
|
||||||
{
|
{
|
||||||
|
|
|
@ -33,6 +33,12 @@
|
||||||
|
|
||||||
namespace mcp {
|
namespace mcp {
|
||||||
|
|
||||||
|
using method_handler = std::function<json(const json&, const std::string&)>;
|
||||||
|
using tool_handler = method_handler;
|
||||||
|
using notification_handler = std::function<void(const json&, const std::string&)>;
|
||||||
|
using auth_handler = std::function<bool(const std::string&, const std::string&)>;
|
||||||
|
using session_cleanup_handler = std::function<void(const std::string&)>;
|
||||||
|
|
||||||
class event_dispatcher {
|
class event_dispatcher {
|
||||||
public:
|
public:
|
||||||
event_dispatcher() {
|
event_dispatcher() {
|
||||||
|
@ -223,14 +229,14 @@ public:
|
||||||
* @param method The method name
|
* @param method The method name
|
||||||
* @param handler The function to call when the method is invoked
|
* @param handler The function to call when the method is invoked
|
||||||
*/
|
*/
|
||||||
void register_method(const std::string& method, std::function<json(const json&)> handler);
|
void register_method(const std::string& method, method_handler handler);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Register a notification handler
|
* @brief Register a notification handler
|
||||||
* @param method The notification method name
|
* @param method The notification method name
|
||||||
* @param handler The function to call when the notification is received
|
* @param handler The function to call when the notification is received
|
||||||
*/
|
*/
|
||||||
void register_notification(const std::string& method, std::function<void(const json&)> handler);
|
void register_notification(const std::string& method, notification_handler handler);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Register a resource
|
* @brief Register a resource
|
||||||
|
@ -245,6 +251,13 @@ public:
|
||||||
* @param handler The function to call when the tool is invoked
|
* @param handler The function to call when the tool is invoked
|
||||||
*/
|
*/
|
||||||
void register_tool(const tool& tool, tool_handler handler);
|
void register_tool(const tool& tool, tool_handler handler);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Register a session cleanup handler
|
||||||
|
* @param key Tool or resource name to be cleaned up
|
||||||
|
* @param handler The function to call when the session is closed
|
||||||
|
*/
|
||||||
|
void register_session_cleanup(const std::string& key, session_cleanup_handler handler);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get the list of available tools
|
* @brief Get the list of available tools
|
||||||
|
@ -255,19 +268,24 @@ public:
|
||||||
/**
|
/**
|
||||||
* @brief Set authentication handler
|
* @brief Set authentication handler
|
||||||
* @param handler Function that takes a token and returns true if valid
|
* @param handler Function that takes a token and returns true if valid
|
||||||
|
* @note The handler should return true if the token is valid, otherwise false
|
||||||
|
* @note Not used in the current implementation
|
||||||
*/
|
*/
|
||||||
void set_auth_handler(std::function<bool(const std::string&)> handler);
|
void set_auth_handler(auth_handler handler);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Send a request to a client
|
* @brief Send a request (or notification) to a client
|
||||||
* @param session_id The session ID of the client
|
* @param session_id The session ID of the client
|
||||||
* @param method The method to call
|
* @param req The request to send
|
||||||
* @param params The parameters to pass
|
|
||||||
*
|
|
||||||
* This method will only send requests other than ping and logging
|
|
||||||
* after the client has sent the initialized notification.
|
|
||||||
*/
|
*/
|
||||||
void send_request(const std::string& session_id, const std::string& method, const json& params = json::object());
|
void send_request(const std::string& session_id, const request& req);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Set mount point for server
|
||||||
|
* @param path The path to mount the resource at
|
||||||
|
* @param root The root directory to mount
|
||||||
|
*/
|
||||||
|
void set_mount_point(const std::string& path, const std::string& root);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string host_;
|
std::string host_;
|
||||||
|
@ -296,10 +314,10 @@ private:
|
||||||
std::string msg_endpoint_;
|
std::string msg_endpoint_;
|
||||||
|
|
||||||
// Method handlers
|
// Method handlers
|
||||||
std::map<std::string, std::function<json(const json&)>> method_handlers_;
|
std::map<std::string, method_handler> method_handlers_;
|
||||||
|
|
||||||
// Notification handlers
|
// Notification handlers
|
||||||
std::map<std::string, std::function<void(const json&)>> notification_handlers_;
|
std::map<std::string, notification_handler> notification_handlers_;
|
||||||
|
|
||||||
// Resources map (path -> resource)
|
// Resources map (path -> resource)
|
||||||
std::map<std::string, std::shared_ptr<resource>> resources_;
|
std::map<std::string, std::shared_ptr<resource>> resources_;
|
||||||
|
@ -308,7 +326,7 @@ private:
|
||||||
std::map<std::string, std::pair<tool, tool_handler>> tools_;
|
std::map<std::string, std::pair<tool, tool_handler>> tools_;
|
||||||
|
|
||||||
// Authentication handler
|
// Authentication handler
|
||||||
std::function<bool(const std::string&)> auth_handler_;
|
auth_handler auth_handler_;
|
||||||
|
|
||||||
// Mutex for thread safety
|
// Mutex for thread safety
|
||||||
mutable std::mutex mutex_;
|
mutable std::mutex mutex_;
|
||||||
|
@ -327,6 +345,9 @@ private:
|
||||||
|
|
||||||
// Handle incoming JSON-RPC requests
|
// Handle incoming JSON-RPC requests
|
||||||
void handle_jsonrpc(const httplib::Request& req, httplib::Response& res);
|
void handle_jsonrpc(const httplib::Request& req, httplib::Response& res);
|
||||||
|
|
||||||
|
// Send a JSON-RPC message to a client
|
||||||
|
void send_jsonrpc(const std::string& session_id, const json& message);
|
||||||
|
|
||||||
// Process a JSON-RPC request
|
// Process a JSON-RPC request
|
||||||
json process_request(const request& req, const std::string& session_id);
|
json process_request(const request& req, const std::string& session_id);
|
||||||
|
@ -345,10 +366,10 @@ private:
|
||||||
|
|
||||||
// Auxiliary function to create an async handler from a regular handler
|
// Auxiliary function to create an async handler from a regular handler
|
||||||
template<typename F>
|
template<typename F>
|
||||||
std::function<std::future<json>(const json&)> make_async_handler(F&& handler) {
|
std::function<std::future<json>(const json&, const std::string&)> make_async_handler(F&& handler) {
|
||||||
return [handler = std::forward<F>(handler)](const json& params) -> std::future<json> {
|
return [handler = std::forward<F>(handler)](const json& params, const std::string& session_id) -> std::future<json> {
|
||||||
return std::async(std::launch::async, [handler, params]() -> json {
|
return std::async(std::launch::async, [handler, params, session_id]() -> json {
|
||||||
return handler(params);
|
return handler(params, session_id);
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -369,6 +390,12 @@ private:
|
||||||
// Session management and maintenance
|
// Session management and maintenance
|
||||||
void check_inactive_sessions();
|
void check_inactive_sessions();
|
||||||
std::unique_ptr<std::thread> maintenance_thread_;
|
std::unique_ptr<std::thread> maintenance_thread_;
|
||||||
|
|
||||||
|
// Session cleanup handler
|
||||||
|
std::map<std::string, session_cleanup_handler> session_cleanup_handler_;
|
||||||
|
|
||||||
|
// Close session
|
||||||
|
void close_session(const std::string& session_id);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mcp
|
} // namespace mcp
|
||||||
|
|
|
@ -19,9 +19,6 @@
|
||||||
|
|
||||||
namespace mcp {
|
namespace mcp {
|
||||||
|
|
||||||
// Tool handler function type
|
|
||||||
using tool_handler = std::function<json(const json&)>;
|
|
||||||
|
|
||||||
// MCP Tool definition
|
// MCP Tool definition
|
||||||
struct tool {
|
struct tool {
|
||||||
std::string name;
|
std::string name;
|
||||||
|
|
|
@ -134,15 +134,9 @@ void server::stop() {
|
||||||
session_initialized_.clear();
|
session_initialized_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close all dispatchers outside the lock
|
// Close all sessions
|
||||||
for (auto& dispatcher : dispatchers_to_close) {
|
for (const auto& [session_id, _] : session_dispatchers_) {
|
||||||
if (dispatcher && !dispatcher->is_closed()) {
|
close_session(session_id);
|
||||||
try {
|
|
||||||
dispatcher->close();
|
|
||||||
} catch (...) {
|
|
||||||
// Ignore exceptions
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Give threads some time to handle close events
|
// Give threads some time to handle close events
|
||||||
|
@ -239,12 +233,12 @@ void server::set_capabilities(const json& capabilities) {
|
||||||
capabilities_ = capabilities;
|
capabilities_ = capabilities;
|
||||||
}
|
}
|
||||||
|
|
||||||
void server::register_method(const std::string& method, std::function<json(const json&)> handler) {
|
void server::register_method(const std::string& method, method_handler handler) {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
method_handlers_[method] = handler;
|
method_handlers_[method] = handler;
|
||||||
}
|
}
|
||||||
|
|
||||||
void server::register_notification(const std::string& method, std::function<void(const json&)> handler) {
|
void server::register_notification(const std::string& method, notification_handler handler) {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
notification_handlers_[method] = handler;
|
notification_handlers_[method] = handler;
|
||||||
}
|
}
|
||||||
|
@ -255,7 +249,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
|
||||||
|
|
||||||
// Register methods for resource access
|
// Register methods for resource access
|
||||||
if (method_handlers_.find("resources/read") == method_handlers_.end()) {
|
if (method_handlers_.find("resources/read") == method_handlers_.end()) {
|
||||||
method_handlers_["resources/read"] = [this](const json& params) -> json {
|
method_handlers_["resources/read"] = [this](const json& params, const std::string& session_id) -> json {
|
||||||
if (!params.contains("uri")) {
|
if (!params.contains("uri")) {
|
||||||
throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter");
|
throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter");
|
||||||
}
|
}
|
||||||
|
@ -276,7 +270,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
|
||||||
}
|
}
|
||||||
|
|
||||||
if (method_handlers_.find("resources/list") == method_handlers_.end()) {
|
if (method_handlers_.find("resources/list") == method_handlers_.end()) {
|
||||||
method_handlers_["resources/list"] = [this](const json& params) -> json {
|
method_handlers_["resources/list"] = [this](const json& params, const std::string& session_id) -> json {
|
||||||
json resources = json::array();
|
json resources = json::array();
|
||||||
|
|
||||||
for (const auto& [uri, res] : resources_) {
|
for (const auto& [uri, res] : resources_) {
|
||||||
|
@ -296,7 +290,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
|
||||||
}
|
}
|
||||||
|
|
||||||
if (method_handlers_.find("resources/subscribe") == method_handlers_.end()) {
|
if (method_handlers_.find("resources/subscribe") == method_handlers_.end()) {
|
||||||
method_handlers_["resources/subscribe"] = [this](const json& params) -> json {
|
method_handlers_["resources/subscribe"] = [this](const json& params, const std::string& session_id) -> json {
|
||||||
if (!params.contains("uri")) {
|
if (!params.contains("uri")) {
|
||||||
throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter");
|
throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter");
|
||||||
}
|
}
|
||||||
|
@ -312,7 +306,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
|
||||||
}
|
}
|
||||||
|
|
||||||
if (method_handlers_.find("resources/templates/list") == method_handlers_.end()) {
|
if (method_handlers_.find("resources/templates/list") == method_handlers_.end()) {
|
||||||
method_handlers_["resources/templates/list"] = [this](const json& params) -> json {
|
method_handlers_["resources/templates/list"] = [this](const json& params, const std::string& session_id) -> json {
|
||||||
return json::array();
|
return json::array();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -324,7 +318,7 @@ void server::register_tool(const tool& tool, tool_handler handler) {
|
||||||
|
|
||||||
// Register methods for tool listing and calling
|
// Register methods for tool listing and calling
|
||||||
if (method_handlers_.find("tools/list") == method_handlers_.end()) {
|
if (method_handlers_.find("tools/list") == method_handlers_.end()) {
|
||||||
method_handlers_["tools/list"] = [this](const json& params) -> json {
|
method_handlers_["tools/list"] = [this](const json& params, const std::string& session_id) -> json {
|
||||||
json tools_json = json::array();
|
json tools_json = json::array();
|
||||||
for (const auto& [name, tool_pair] : tools_) {
|
for (const auto& [name, tool_pair] : tools_) {
|
||||||
tools_json.push_back(tool_pair.first.to_json());
|
tools_json.push_back(tool_pair.first.to_json());
|
||||||
|
@ -334,7 +328,7 @@ void server::register_tool(const tool& tool, tool_handler handler) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (method_handlers_.find("tools/call") == method_handlers_.end()) {
|
if (method_handlers_.find("tools/call") == method_handlers_.end()) {
|
||||||
method_handlers_["tools/call"] = [this](const json& params) -> json {
|
method_handlers_["tools/call"] = [this](const json& params, const std::string& session_id) -> json {
|
||||||
if (!params.contains("name")) {
|
if (!params.contains("name")) {
|
||||||
throw mcp_exception(error_code::invalid_params, "Missing 'name' parameter");
|
throw mcp_exception(error_code::invalid_params, "Missing 'name' parameter");
|
||||||
}
|
}
|
||||||
|
@ -352,7 +346,7 @@ void server::register_tool(const tool& tool, tool_handler handler) {
|
||||||
};
|
};
|
||||||
|
|
||||||
try {
|
try {
|
||||||
tool_result["content"] = it->second.second(tool_args);
|
tool_result["content"] = it->second.second(tool_args, session_id);
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
tool_result["isError"] = true;
|
tool_result["isError"] = true;
|
||||||
}
|
}
|
||||||
|
@ -362,6 +356,11 @@ void server::register_tool(const tool& tool, tool_handler handler) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void server::register_session_cleanup(const std::string& key, session_cleanup_handler handler) {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
session_cleanup_handler_[key] = handler;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<tool> server::get_tools() const {
|
std::vector<tool> server::get_tools() const {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
std::vector<tool> tools;
|
std::vector<tool> tools;
|
||||||
|
@ -373,7 +372,7 @@ std::vector<tool> server::get_tools() const {
|
||||||
return tools;
|
return tools;
|
||||||
}
|
}
|
||||||
|
|
||||||
void server::set_auth_handler(std::function<bool(const std::string&)> handler) {
|
void server::set_auth_handler(auth_handler handler) {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
auth_handler_ = handler;
|
auth_handler_ = handler;
|
||||||
}
|
}
|
||||||
|
@ -442,47 +441,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
||||||
LOG_ERROR("SSE session thread exception: ", session_id, ", ", e.what());
|
LOG_ERROR("SSE session thread exception: ", session_id, ", ", e.what());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up resources safely
|
close_session(session_id);
|
||||||
try {
|
|
||||||
// Copy resources to be processed
|
|
||||||
std::shared_ptr<event_dispatcher> dispatcher_to_close;
|
|
||||||
std::unique_ptr<std::thread> thread_to_release;
|
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
|
||||||
|
|
||||||
// Get dispatcher pointer
|
|
||||||
auto dispatcher_it = session_dispatchers_.find(session_id);
|
|
||||||
if (dispatcher_it != session_dispatchers_.end()) {
|
|
||||||
dispatcher_to_close = dispatcher_it->second;
|
|
||||||
session_dispatchers_.erase(dispatcher_it);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get thread pointer
|
|
||||||
auto thread_it = sse_threads_.find(session_id);
|
|
||||||
if (thread_it != sse_threads_.end()) {
|
|
||||||
thread_to_release = std::move(thread_it->second);
|
|
||||||
sse_threads_.erase(thread_it);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up initialization status
|
|
||||||
session_initialized_.erase(session_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close dispatcher outside the lock
|
|
||||||
if (dispatcher_to_close && !dispatcher_to_close->is_closed()) {
|
|
||||||
dispatcher_to_close->close();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release thread resources
|
|
||||||
if (thread_to_release) {
|
|
||||||
thread_to_release.release();
|
|
||||||
}
|
|
||||||
} catch (const std::exception& e) {
|
|
||||||
LOG_WARNING("Exception while cleaning up session resources: ", session_id, ", ", e.what());
|
|
||||||
} catch (...) {
|
|
||||||
LOG_WARNING("Unknown exception while cleaning up session resources: ", session_id);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Store thread
|
// Store thread
|
||||||
|
@ -507,8 +466,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
||||||
if (!result) {
|
if (!result) {
|
||||||
LOG_WARNING("Failed to wait for event, closing connection: ", session_id);
|
LOG_WARNING("Failed to wait for event, closing connection: ", session_id);
|
||||||
|
|
||||||
// Close dispatcher directly, no need to lock
|
close_session(session_id);
|
||||||
session_dispatcher->close();
|
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -520,8 +478,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
LOG_ERROR("SSE content provider exception: ", e.what());
|
LOG_ERROR("SSE content provider exception: ", e.what());
|
||||||
|
|
||||||
// Close dispatcher directly, no need to lock
|
close_session(session_id);
|
||||||
session_dispatcher->close();
|
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -673,7 +630,7 @@ json server::process_request(const request& req, const std::string& session_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find registered method handler
|
// Find registered method handler
|
||||||
std::function<json(const json&)> handler;
|
method_handler handler;
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
auto it = method_handlers_.find(req.method);
|
auto it = method_handlers_.find(req.method);
|
||||||
|
@ -685,8 +642,8 @@ json server::process_request(const request& req, const std::string& session_id)
|
||||||
if (handler) {
|
if (handler) {
|
||||||
// Call handler
|
// Call handler
|
||||||
LOG_INFO("Calling method handler: ", req.method);
|
LOG_INFO("Calling method handler: ", req.method);
|
||||||
auto future = thread_pool_.enqueue([handler, params = req.params]() -> json {
|
auto future = thread_pool_.enqueue([handler, params = req.params, session_id]() -> json {
|
||||||
return handler(params);
|
return handler(params, session_id);
|
||||||
});
|
});
|
||||||
json result = future.get();
|
json result = future.get();
|
||||||
|
|
||||||
|
@ -791,25 +748,13 @@ json server::handle_initialize(const request& req, const std::string& session_id
|
||||||
return response::create_success(req.id, result).to_json();
|
return response::create_success(req.id, result).to_json();
|
||||||
}
|
}
|
||||||
|
|
||||||
void server::send_request(const std::string& session_id, const std::string& method, const json& params) {
|
void server::send_jsonrpc(const std::string& session_id, const json& message) {
|
||||||
// Check if session ID is valid
|
// Check if session ID is valid
|
||||||
if (session_id.empty()) {
|
if (session_id.empty()) {
|
||||||
LOG_WARNING("Cannot send request to empty session_id");
|
LOG_WARNING("Cannot send message to empty session_id");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the method is ping or logging
|
|
||||||
bool is_allowed_before_init = (method == "ping" || method == "logging");
|
|
||||||
|
|
||||||
// Check if client is initialized or if this is an allowed method
|
|
||||||
if (!is_allowed_before_init && !is_session_initialized(session_id)) {
|
|
||||||
LOG_WARNING("Cannot send ", method, " request to session ", session_id, " before it is initialized");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create request
|
|
||||||
request req = request::create(method, params);
|
|
||||||
|
|
||||||
// Get session dispatcher
|
// Get session dispatcher
|
||||||
std::shared_ptr<event_dispatcher> dispatcher;
|
std::shared_ptr<event_dispatcher> dispatcher;
|
||||||
{
|
{
|
||||||
|
@ -828,16 +773,20 @@ void server::send_request(const std::string& session_id, const std::string& meth
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send request
|
// Send message
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "event: message\r\ndata: " << req.to_json().dump() << "\r\n\r\n";
|
ss << "event: message\r\ndata: " << message.dump() << "\r\n\r\n";
|
||||||
bool result = dispatcher->send_event(ss.str());
|
bool result = dispatcher->send_event(ss.str());
|
||||||
|
|
||||||
if (!result) {
|
if (!result) {
|
||||||
LOG_ERROR("Failed to send request to session: ", session_id);
|
LOG_ERROR("Failed to send message to session: ", session_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void server::send_request(const std::string& session_id, const request& req) {
|
||||||
|
send_jsonrpc(session_id, req.to_json());
|
||||||
|
}
|
||||||
|
|
||||||
bool server::is_session_initialized(const std::string& session_id) const {
|
bool server::is_session_initialized(const std::string& session_id) const {
|
||||||
// Check if session ID is valid
|
// Check if session ID is valid
|
||||||
if (session_id.empty()) {
|
if (session_id.empty()) {
|
||||||
|
@ -933,18 +882,59 @@ void server::check_inactive_sessions() {
|
||||||
for (const auto& session_id : sessions_to_close) {
|
for (const auto& session_id : sessions_to_close) {
|
||||||
LOG_INFO("Closing inactive session: ", session_id);
|
LOG_INFO("Closing inactive session: ", session_id);
|
||||||
|
|
||||||
|
close_session(session_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void server::set_mount_point(const std::string& path, const std::string& root) {
|
||||||
|
http_server_->set_mount_point(path.c_str(), root.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
void server::close_session(const std::string& session_id) {
|
||||||
|
// Clean up resources safely
|
||||||
|
try {
|
||||||
|
for (const auto& [key, handler] : session_cleanup_handler_) {
|
||||||
|
handler(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy resources to be processed
|
||||||
std::shared_ptr<event_dispatcher> dispatcher_to_close;
|
std::shared_ptr<event_dispatcher> dispatcher_to_close;
|
||||||
|
std::unique_ptr<std::thread> thread_to_release;
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
auto disp_it = session_dispatchers_.find(session_id);
|
|
||||||
if (disp_it != session_dispatchers_.end()) {
|
// Get dispatcher pointer
|
||||||
dispatcher_to_close = disp_it->second;
|
auto dispatcher_it = session_dispatchers_.find(session_id);
|
||||||
|
if (dispatcher_it != session_dispatchers_.end()) {
|
||||||
|
dispatcher_to_close = dispatcher_it->second;
|
||||||
|
session_dispatchers_.erase(dispatcher_it);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get thread pointer
|
||||||
|
auto thread_it = sse_threads_.find(session_id);
|
||||||
|
if (thread_it != sse_threads_.end()) {
|
||||||
|
thread_to_release = std::move(thread_it->second);
|
||||||
|
sse_threads_.erase(thread_it);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up initialization status
|
||||||
|
session_initialized_.erase(session_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dispatcher_to_close) {
|
// Close dispatcher outside the lock
|
||||||
|
if (dispatcher_to_close && !dispatcher_to_close->is_closed()) {
|
||||||
dispatcher_to_close->close();
|
dispatcher_to_close->close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Release thread resources
|
||||||
|
if (thread_to_release) {
|
||||||
|
thread_to_release.release();
|
||||||
|
}
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
LOG_WARNING("Exception while cleaning up session resources: ", session_id, ", ", e.what());
|
||||||
|
} catch (...) {
|
||||||
|
LOG_WARNING("Unknown exception while cleaning up session resources: ", session_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -820,7 +820,7 @@ json stdio_client::send_jsonrpc(const request& req) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for response, set timeout
|
// Wait for response, set timeout
|
||||||
const auto timeout = std::chrono::seconds(30);
|
const auto timeout = std::chrono::seconds(60);
|
||||||
auto status = response_future.wait_for(timeout);
|
auto status = response_future.wait_for(timeout);
|
||||||
|
|
||||||
if (status == std::future_status::ready) {
|
if (status == std::future_status::ready) {
|
||||||
|
|
|
@ -109,24 +109,6 @@ public:
|
||||||
};
|
};
|
||||||
server_->set_capabilities(server_capabilities);
|
server_->set_capabilities(server_capabilities);
|
||||||
|
|
||||||
// Register initialize method handler
|
|
||||||
server_->register_method("initialize", [server_capabilities](const json& params) -> json {
|
|
||||||
// Verify initialize request parameters
|
|
||||||
EXPECT_EQ(params["protocolVersion"], MCP_VERSION);
|
|
||||||
EXPECT_TRUE(params.contains("capabilities"));
|
|
||||||
EXPECT_TRUE(params.contains("clientInfo"));
|
|
||||||
|
|
||||||
// Return initialize response
|
|
||||||
return {
|
|
||||||
{"protocolVersion", MCP_VERSION},
|
|
||||||
{"capabilities", server_capabilities},
|
|
||||||
{"serverInfo", {
|
|
||||||
{"name", "TestServer"},
|
|
||||||
{"version", "1.0.0"}
|
|
||||||
}}
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
// Start server (non-blocking mode)
|
// Start server (non-blocking mode)
|
||||||
server_->start(false);
|
server_->start(false);
|
||||||
|
|
||||||
|
@ -545,7 +527,7 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
// Register tool
|
// Register tool
|
||||||
server_->register_tool(test_tool, [](const json& params) -> json {
|
server_->register_tool(test_tool, [](const json& params, const std::string& /* session_id */) -> json {
|
||||||
// Simple tool implementation
|
// Simple tool implementation
|
||||||
std::string location = params["location"];
|
std::string location = params["location"];
|
||||||
return {
|
return {
|
||||||
|
@ -560,7 +542,7 @@ public:
|
||||||
});
|
});
|
||||||
|
|
||||||
// Register tools list method
|
// Register tools list method
|
||||||
server_->register_method("tools/list", [](const json& params) -> json {
|
server_->register_method("tools/list", [](const json& params, const std::string& /* session_id */) -> json {
|
||||||
return {
|
return {
|
||||||
{"tools", json::array({
|
{"tools", json::array({
|
||||||
{
|
{
|
||||||
|
@ -583,7 +565,7 @@ public:
|
||||||
});
|
});
|
||||||
|
|
||||||
// Register tools call method
|
// Register tools call method
|
||||||
server_->register_method("tools/call", [](const json& params) -> json {
|
server_->register_method("tools/call", [](const json& params, const std::string& /* session_id */) -> json {
|
||||||
// Verify parameters
|
// Verify parameters
|
||||||
EXPECT_EQ(params["name"], "get_weather");
|
EXPECT_EQ(params["name"], "get_weather");
|
||||||
EXPECT_EQ(params["arguments"]["location"], "New York");
|
EXPECT_EQ(params["arguments"]["location"], "New York");
|
||||||
|
|
Loading…
Reference in New Issue