From 0d359874b9aef6f39fe3dbcd65654b6d680d0943 Mon Sep 17 00:00:00 2001 From: hkr04 Date: Tue, 11 Mar 2025 23:29:38 +0800 Subject: [PATCH] SSE connection (WIP) --- include/mcp_client.h | 16 ++++- include/mcp_server.h | 52 ++++++++++++++++- src/CMakeLists.txt | 4 ++ src/mcp_client.cpp | 116 +++++++++++++++++++++++++++++++++--- src/mcp_server.cpp | 136 ++++++++++++++++++++++++++++++++++++++----- 5 files changed, 300 insertions(+), 24 deletions(-) diff --git a/include/mcp_client.h b/include/mcp_client.h index 1820c06..b14a9b0 100644 --- a/include/mcp_client.h +++ b/include/mcp_client.h @@ -21,6 +21,7 @@ #include #include #include +#include namespace mcp { @@ -38,14 +39,14 @@ public: * @param host The server host (e.g., "localhost", "example.com") * @param port The server port */ - client(const std::string& host, int port = 8080, const json& capabilities = json::object()); + client(const std::string& host, int port = 8080, const json& capabilities = json::object(), const std::string& sse_endpoint = "/sse"); /** * @brief Constructor * @param base_url The base URL of the server (e.g., "http://localhost:8080") * @param capabilities The capabilities of the client */ - client(const std::string& base_url, const json& capabilities = json::object()); + client(const std::string& base_url, const json& capabilities = json::object(), const std::string& sse_endpoint = "/sse"); /** * @brief Destructor @@ -168,6 +169,8 @@ private: std::string base_url_; std::string host_; int port_; + std::string sse_endpoint_; + std::string msg_endpoint_; std::string auth_token_; int timeout_seconds_ = 30; json capabilities_; @@ -181,10 +184,19 @@ private: // Mutex for thread safety mutable std::mutex mutex_; + + // SSE connection + std::unique_ptr sse_thread_; + + // SSE连接状态 + std::atomic sse_running_{false}; // Initialize the client void init_client(const std::string& host, int port); void init_client(const std::string& base_url); + void open_sse_connection(); + void close_sse_connection(); + bool parse_sse_data(const char* data, size_t length); // Send a JSON-RPC request and get the response json send_jsonrpc(const request& req); diff --git a/include/mcp_server.h b/include/mcp_server.h index d9a530e..2c15528 100644 --- a/include/mcp_server.h +++ b/include/mcp_server.h @@ -23,9 +23,38 @@ #include #include #include +#include +#include + namespace mcp { +class event_dispatcher { +public: + event_dispatcher() = default; + + void wait_event(httplib::DataSink* sink) { + std::unique_lock lk(m_); + int id = id_; + cv_.wait(lk, [&] { return cid_ == id; }); + sink->write(message_.data(), message_.size()); + } + + void send_event(const std::string& message) { + std::lock_guard lk(m_); + cid_ = id_++; + message_ = message; + cv_.notify_all(); + } + +private: + std::mutex m_; + std::condition_variable cv_; + std::atomic id_{0}; + std::atomic cid_{-1}; + std::string message_; +}; + /** * @class server * @brief Main MCP server class @@ -40,7 +69,9 @@ public: * @param host The host to bind to (e.g., "localhost", "0.0.0.0") * @param port The port to listen on */ - server(const std::string& host = "localhost", int port = 8080); + server(const std::string& host = "localhost", int port = 8080, + const std::string& sse_endpoint = "/sse", + const std::string& msg_endpoint_prefix = "/message/"); /** * @brief Destructor @@ -141,6 +172,19 @@ private: // Server thread (for non-blocking mode) std::unique_ptr server_thread_; + + // SSE thread + std::map> sse_threads_; + + // Event dispatcher for server-sent events + event_dispatcher sse_dispatcher_; + + // Session-specific event dispatchers + std::map> session_dispatchers_; + + // Server-sent events endpoint + std::string sse_endpoint_; + std::string msg_endpoint_prefix_; // Method handlers std::map> method_handlers_; @@ -165,6 +209,9 @@ private: // Map to track client initialization status (client_address -> initialized) std::map client_initialized_; + + // Handle SSE requests + void handle_sse(const httplib::Request& req, httplib::Response& res); // Handle incoming JSON-RPC requests void handle_jsonrpc(const httplib::Request& req, httplib::Response& res); @@ -180,6 +227,9 @@ private: // Set client initialization status void set_client_initialized(const std::string& client_address, bool initialized); + + // Generate a random session ID + std::string generate_session_id() const; }; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b860160..ab7f9da 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -11,6 +11,10 @@ add_library(${TARGET} STATIC ../include/mcp_server.h mcp_tool.cpp ../include/mcp_tool.h + mcp_sse_server.cpp + ../include/mcp_sse_server.h + mcp_sse_client.cpp + ../include/mcp_sse_client.h ) target_link_libraries(${TARGET} PUBLIC diff --git a/src/mcp_client.cpp b/src/mcp_client.cpp index cea46bc..06839b3 100644 --- a/src/mcp_client.cpp +++ b/src/mcp_client.cpp @@ -11,19 +11,22 @@ namespace mcp { -client::client(const std::string& host, int port, const json& capabilities) - : host_(host), port_(port), capabilities_(capabilities) { +client::client(const std::string& host, int port, const json& capabilities, const std::string& sse_endpoint) + : host_(host), port_(port), capabilities_(capabilities), sse_endpoint_(sse_endpoint) { init_client(host, port); } -client::client(const std::string& base_url, const json& capabilities) - : base_url_(base_url), capabilities_(capabilities) { +client::client(const std::string& base_url, const json& capabilities, const std::string& sse_endpoint) + : base_url_(base_url), capabilities_(capabilities), sse_endpoint_(sse_endpoint) { init_client(base_url); } client::~client() { - // httplib::Client will be automatically destroyed + // 关闭SSE连接 + close_sse_connection(); + + // httplib::Client将自动销毁 } void client::init_client(const std::string& host, int port) { @@ -58,6 +61,8 @@ bool client::initialize(const std::string& client_name, const std::string& clien }); try { + open_sse_connection(); + // Send the request json result = send_jsonrpc(req); @@ -70,7 +75,9 @@ bool client::initialize(const std::string& client_name, const std::string& clien return true; } catch (const std::exception& e) { - // Initialization failed + // 初始化失败,关闭SSE连接 + std::cerr << "初始化失败: " << e.what() << std::endl; + close_sse_connection(); return false; } } @@ -200,6 +207,101 @@ json client::list_resource_templates() { return send_request("resources/templates/list").result; } +void client::open_sse_connection() { + // 设置SSE连接状态为运行中 + sse_running_ = true; + + // 创建并启动SSE线程 + sse_thread_ = std::make_unique([this]() { + int retry_count = 0; + const int max_retries = 5; + const int retry_delay_base = 1000; // 毫秒 + + while (sse_running_) { + try { + // 尝试建立SSE连接 + auto res = http_client_->Get(sse_endpoint_.c_str(), + [this](const char *data, size_t data_length) { + // 解析SSE数据 + if (!parse_sse_data(data, data_length)) { + return false; // 解析失败,关闭连接 + } + return sse_running_.load(); // 如果sse_running_为false,关闭连接 + }); + + // 检查连接是否成功 + if (!res) { + throw std::runtime_error("SSE连接失败: " + std::to_string(static_cast(res.error()))); + } + + // 连接成功后重置重试计数 + retry_count = 0; + } catch (const std::exception& e) { + // 记录错误 + std::cerr << "SSE连接错误: " << e.what() << std::endl; + + // 如果已达到最大重试次数,停止尝试 + if (++retry_count > max_retries) { + std::cerr << "达到最大重试次数,停止SSE连接尝试" << std::endl; + break; + } + + // 指数退避重试 + int delay = retry_delay_base * (1 << (retry_count - 1)); // 2^(retry_count-1) * base_delay + std::this_thread::sleep_for(std::chrono::milliseconds(delay)); + } + } + }); +} + +// 新增方法:解析SSE数据 +bool client::parse_sse_data(const char* data, size_t length) { + try { + std::string sse_data(data, length); + + // 查找"data:"标记 + auto data_pos = sse_data.find("data: "); + if (data_pos == std::string::npos) { + return true; // 不是数据事件,可能是注释或心跳,继续保持连接 + } + + // 查找数据行结束位置 + auto newline_pos = sse_data.find("\n", data_pos); + if (newline_pos == std::string::npos) { + newline_pos = sse_data.length(); // 如果没有换行符,使用整个字符串 + } + + // 提取数据内容 + std::string data_content = sse_data.substr(data_pos + 6, newline_pos - (data_pos + 6)); + + // 检查是否是心跳事件 + if (sse_data.find("event: heartbeat") != std::string::npos) { + // 心跳事件,不需要处理数据 + return true; + } + + // 更新消息端点 + { + std::lock_guard lock(mutex_); + msg_endpoint_ = data_content; + } + + return true; + } catch (const std::exception& e) { + std::cerr << "解析SSE数据错误: " << e.what() << std::endl; + return false; + } +} + +// 新增方法:关闭SSE连接 +void client::close_sse_connection() { + sse_running_ = false; + + if (sse_thread_ && sse_thread_->joinable()) { + sse_thread_->join(); + } +} + json client::send_jsonrpc(const request& req) { std::lock_guard lock(mutex_); @@ -217,7 +319,7 @@ json client::send_jsonrpc(const request& req) { } // Send the request - auto result = http_client_->Post("/jsonrpc", headers, req_body, "application/json"); + auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json"); if (!result) { // Error occurred diff --git a/src/mcp_server.cpp b/src/mcp_server.cpp index 3b6d39c..4475211 100644 --- a/src/mcp_server.cpp +++ b/src/mcp_server.cpp @@ -10,8 +10,8 @@ namespace mcp { -server::server(const std::string& host, int port) - : host_(host), port_(port), name_("MCP Server"), version_(MCP_VERSION) { +server::server(const std::string& host, int port, const std::string& sse_endpoint, const std::string& msg_endpoint_prefix) + : host_(host), port_(port), sse_endpoint_(sse_endpoint), msg_endpoint_prefix_(msg_endpoint_prefix), name_("MCP Server"), version_(MCP_VERSION) { http_server_ = std::make_unique(); @@ -34,8 +34,33 @@ bool server::start(bool blocking) { } // Set up JSON-RPC endpoint - http_server_->Post("/jsonrpc", [this](const httplib::Request& req, httplib::Response& res) { - this->handle_jsonrpc(req, res); + http_server_->Post(msg_endpoint_prefix_.c_str(), [this](const httplib::Request& req, httplib::Response& res) { + // 从URL参数中获取session_id + if (auto session_id = req.get_param_value("session_id"); !session_id.empty()) { + // 检查session是否存在 + std::lock_guard lock(mutex_); + if (session_dispatchers_.find(session_id) != session_dispatchers_.end()) { + // session存在,处理JSON-RPC请求 + handle_jsonrpc(req, res); + return; + } + } + + // session不存在,返回错误 + json error_response = { + {"jsonrpc", "2.0"}, + {"error", { + {"code", static_cast(error_code::invalid_request)}, + {"message", "Invalid or missing session ID. Initialize first to get a session ID"} + }}, + {"id", nullptr} + }; + res.set_content(error_response.dump(), "application/json"); + }); + + // Set up SSE endpoint + http_server_->Get(sse_endpoint_.c_str(), [this](const httplib::Request& req, httplib::Response& res) { + this->handle_sse(req, res); }); // Start the server @@ -74,6 +99,18 @@ void server::stop() { server_thread_->join(); } + // 清理所有SSE线程 + { + std::lock_guard lock(mutex_); + for (auto& [session_id, thread] : sse_threads_) { + if (thread && thread->joinable()) { + thread->join(); + } + } + sse_threads_.clear(); + session_dispatchers_.clear(); + } + running_ = false; } @@ -239,6 +276,71 @@ void server::set_auth_handler(std::function handler) { auth_handler_ = handler; } +void server::handle_sse(const httplib::Request& req, httplib::Response& res) { + // 生成会话ID并创建会话URI + std::string session_id = generate_session_id(); + std::string session_uri = msg_endpoint_prefix_ + "?session_id=" + session_id; + + // 创建一个共享的事件分发器,确保其生命周期 + auto session_dispatcher = std::make_shared(); + + { + std::lock_guard lock(mutex_); + // 存储会话信息 + session_dispatchers_[session_id] = session_dispatcher; + } + + // 创建并启动会话线程,使用共享指针和值捕获而不是引用捕获 + sse_threads_[session_id] = std::make_unique([this, session_id, session_uri, session_dispatcher]() { + try { + // 等待一段时间 + std::this_thread::sleep_for(std::chrono::seconds(1)); + + std::stringstream ss; + ss << "data: " << session_uri << "\n\n"; + + // 使用会话特定的分发器发送事件 + session_dispatcher->send_event(ss.str()); + + // 设置定期心跳 + while (running_) { + std::this_thread::sleep_for(std::chrono::seconds(30)); + + // 发送心跳事件 + std::stringstream heartbeat; + heartbeat << "event: heartbeat\ndata: {}\n\n"; + session_dispatcher->send_event(heartbeat.str()); + } + } catch (const std::exception& e) { + // 记录错误 + std::cerr << "SSE thread error for session " << session_id << ": " << e.what() << std::endl; + } + + // 线程结束时清理资源 + { + std::lock_guard lock(mutex_); + session_dispatchers_.erase(session_id); + } + }); + + // 不再使用detach,而是在server析构函数中管理线程生命周期 + + // 设置分块内容提供者 + res.set_chunked_content_provider("text/event-stream", [session_dispatcher](size_t /* offset */, httplib::DataSink& sink) { + // 使用会话特定的分发器等待事件 + session_dispatcher->wait_event(&sink); + return true; + }); + + // 注册会话特定的JSON-RPC端点 + { + std::lock_guard lock(mutex_); + http_server_->Post(session_uri, [this](const httplib::Request& req, httplib::Response& res) { + handle_jsonrpc(req, res); + }); + } +} + void server::handle_jsonrpc(const httplib::Request& req, httplib::Response& res) { // Set response headers res.set_header("Content-Type", "application/json"); @@ -330,16 +432,6 @@ json server::process_request(const request& req, const std::string& client_addre // The receiver MUST respond promptly with an empty response return response::create_success(req.id, {}).to_json(); } - - // Check if client is initialized - if (!is_client_initialized(client_address)) { - // Client not initialized - return response::create_error( - req.id, - error_code::invalid_request, - "Client not initialized" - ).to_json(); - } // Look for registered method handler std::lock_guard lock(mutex_); @@ -467,4 +559,20 @@ void server::set_client_initialized(const std::string& client_address, bool init client_initialized_[client_address] = initialized; } +// Generate a random session ID +std::string server::generate_session_id() const { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, 15); + + std::stringstream ss; + ss << std::hex; + + for (int i = 0; i < 32; ++i) { + ss << dis(gen); + } + + return ss.str(); +} + } // namespace mcp \ No newline at end of file