SSE connection (WIP)
parent
62eb1ad2e4
commit
0d359874b9
|
@ -21,6 +21,7 @@
|
|||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <functional>
|
||||
#include <atomic>
|
||||
|
||||
namespace mcp {
|
||||
|
||||
|
@ -38,14 +39,14 @@ public:
|
|||
* @param host The server host (e.g., "localhost", "example.com")
|
||||
* @param port The server port
|
||||
*/
|
||||
client(const std::string& host, int port = 8080, const json& capabilities = json::object());
|
||||
client(const std::string& host, int port = 8080, const json& capabilities = json::object(), const std::string& sse_endpoint = "/sse");
|
||||
|
||||
/**
|
||||
* @brief Constructor
|
||||
* @param base_url The base URL of the server (e.g., "http://localhost:8080")
|
||||
* @param capabilities The capabilities of the client
|
||||
*/
|
||||
client(const std::string& base_url, const json& capabilities = json::object());
|
||||
client(const std::string& base_url, const json& capabilities = json::object(), const std::string& sse_endpoint = "/sse");
|
||||
|
||||
/**
|
||||
* @brief Destructor
|
||||
|
@ -168,6 +169,8 @@ private:
|
|||
std::string base_url_;
|
||||
std::string host_;
|
||||
int port_;
|
||||
std::string sse_endpoint_;
|
||||
std::string msg_endpoint_;
|
||||
std::string auth_token_;
|
||||
int timeout_seconds_ = 30;
|
||||
json capabilities_;
|
||||
|
@ -181,10 +184,19 @@ private:
|
|||
|
||||
// Mutex for thread safety
|
||||
mutable std::mutex mutex_;
|
||||
|
||||
// SSE connection
|
||||
std::unique_ptr<std::thread> sse_thread_;
|
||||
|
||||
// SSE连接状态
|
||||
std::atomic<bool> sse_running_{false};
|
||||
|
||||
// Initialize the client
|
||||
void init_client(const std::string& host, int port);
|
||||
void init_client(const std::string& base_url);
|
||||
void open_sse_connection();
|
||||
void close_sse_connection();
|
||||
bool parse_sse_data(const char* data, size_t length);
|
||||
|
||||
// Send a JSON-RPC request and get the response
|
||||
json send_jsonrpc(const request& req);
|
||||
|
|
|
@ -23,9 +23,38 @@
|
|||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <functional>
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
|
||||
|
||||
namespace mcp {
|
||||
|
||||
class event_dispatcher {
|
||||
public:
|
||||
event_dispatcher() = default;
|
||||
|
||||
void wait_event(httplib::DataSink* sink) {
|
||||
std::unique_lock<std::mutex> lk(m_);
|
||||
int id = id_;
|
||||
cv_.wait(lk, [&] { return cid_ == id; });
|
||||
sink->write(message_.data(), message_.size());
|
||||
}
|
||||
|
||||
void send_event(const std::string& message) {
|
||||
std::lock_guard<std::mutex> lk(m_);
|
||||
cid_ = id_++;
|
||||
message_ = message;
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex m_;
|
||||
std::condition_variable cv_;
|
||||
std::atomic<int> id_{0};
|
||||
std::atomic<int> cid_{-1};
|
||||
std::string message_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @class server
|
||||
* @brief Main MCP server class
|
||||
|
@ -40,7 +69,9 @@ public:
|
|||
* @param host The host to bind to (e.g., "localhost", "0.0.0.0")
|
||||
* @param port The port to listen on
|
||||
*/
|
||||
server(const std::string& host = "localhost", int port = 8080);
|
||||
server(const std::string& host = "localhost", int port = 8080,
|
||||
const std::string& sse_endpoint = "/sse",
|
||||
const std::string& msg_endpoint_prefix = "/message/");
|
||||
|
||||
/**
|
||||
* @brief Destructor
|
||||
|
@ -141,6 +172,19 @@ private:
|
|||
|
||||
// Server thread (for non-blocking mode)
|
||||
std::unique_ptr<std::thread> server_thread_;
|
||||
|
||||
// SSE thread
|
||||
std::map<std::string, std::unique_ptr<std::thread>> sse_threads_;
|
||||
|
||||
// Event dispatcher for server-sent events
|
||||
event_dispatcher sse_dispatcher_;
|
||||
|
||||
// Session-specific event dispatchers
|
||||
std::map<std::string, std::shared_ptr<event_dispatcher>> session_dispatchers_;
|
||||
|
||||
// Server-sent events endpoint
|
||||
std::string sse_endpoint_;
|
||||
std::string msg_endpoint_prefix_;
|
||||
|
||||
// Method handlers
|
||||
std::map<std::string, std::function<json(const json&)>> method_handlers_;
|
||||
|
@ -165,6 +209,9 @@ private:
|
|||
|
||||
// Map to track client initialization status (client_address -> initialized)
|
||||
std::map<std::string, bool> client_initialized_;
|
||||
|
||||
// Handle SSE requests
|
||||
void handle_sse(const httplib::Request& req, httplib::Response& res);
|
||||
|
||||
// Handle incoming JSON-RPC requests
|
||||
void handle_jsonrpc(const httplib::Request& req, httplib::Response& res);
|
||||
|
@ -180,6 +227,9 @@ private:
|
|||
|
||||
// Set client initialization status
|
||||
void set_client_initialized(const std::string& client_address, bool initialized);
|
||||
|
||||
// Generate a random session ID
|
||||
std::string generate_session_id() const;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -11,6 +11,10 @@ add_library(${TARGET} STATIC
|
|||
../include/mcp_server.h
|
||||
mcp_tool.cpp
|
||||
../include/mcp_tool.h
|
||||
mcp_sse_server.cpp
|
||||
../include/mcp_sse_server.h
|
||||
mcp_sse_client.cpp
|
||||
../include/mcp_sse_client.h
|
||||
)
|
||||
|
||||
target_link_libraries(${TARGET} PUBLIC
|
||||
|
|
|
@ -11,19 +11,22 @@
|
|||
|
||||
namespace mcp {
|
||||
|
||||
client::client(const std::string& host, int port, const json& capabilities)
|
||||
: host_(host), port_(port), capabilities_(capabilities) {
|
||||
client::client(const std::string& host, int port, const json& capabilities, const std::string& sse_endpoint)
|
||||
: host_(host), port_(port), capabilities_(capabilities), sse_endpoint_(sse_endpoint) {
|
||||
|
||||
init_client(host, port);
|
||||
}
|
||||
|
||||
client::client(const std::string& base_url, const json& capabilities)
|
||||
: base_url_(base_url), capabilities_(capabilities) {
|
||||
client::client(const std::string& base_url, const json& capabilities, const std::string& sse_endpoint)
|
||||
: base_url_(base_url), capabilities_(capabilities), sse_endpoint_(sse_endpoint) {
|
||||
init_client(base_url);
|
||||
}
|
||||
|
||||
client::~client() {
|
||||
// httplib::Client will be automatically destroyed
|
||||
// 关闭SSE连接
|
||||
close_sse_connection();
|
||||
|
||||
// httplib::Client将自动销毁
|
||||
}
|
||||
|
||||
void client::init_client(const std::string& host, int port) {
|
||||
|
@ -58,6 +61,8 @@ bool client::initialize(const std::string& client_name, const std::string& clien
|
|||
});
|
||||
|
||||
try {
|
||||
open_sse_connection();
|
||||
|
||||
// Send the request
|
||||
json result = send_jsonrpc(req);
|
||||
|
||||
|
@ -70,7 +75,9 @@ bool client::initialize(const std::string& client_name, const std::string& clien
|
|||
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
// Initialization failed
|
||||
// 初始化失败,关闭SSE连接
|
||||
std::cerr << "初始化失败: " << e.what() << std::endl;
|
||||
close_sse_connection();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -200,6 +207,101 @@ json client::list_resource_templates() {
|
|||
return send_request("resources/templates/list").result;
|
||||
}
|
||||
|
||||
void client::open_sse_connection() {
|
||||
// 设置SSE连接状态为运行中
|
||||
sse_running_ = true;
|
||||
|
||||
// 创建并启动SSE线程
|
||||
sse_thread_ = std::make_unique<std::thread>([this]() {
|
||||
int retry_count = 0;
|
||||
const int max_retries = 5;
|
||||
const int retry_delay_base = 1000; // 毫秒
|
||||
|
||||
while (sse_running_) {
|
||||
try {
|
||||
// 尝试建立SSE连接
|
||||
auto res = http_client_->Get(sse_endpoint_.c_str(),
|
||||
[this](const char *data, size_t data_length) {
|
||||
// 解析SSE数据
|
||||
if (!parse_sse_data(data, data_length)) {
|
||||
return false; // 解析失败,关闭连接
|
||||
}
|
||||
return sse_running_.load(); // 如果sse_running_为false,关闭连接
|
||||
});
|
||||
|
||||
// 检查连接是否成功
|
||||
if (!res) {
|
||||
throw std::runtime_error("SSE连接失败: " + std::to_string(static_cast<int>(res.error())));
|
||||
}
|
||||
|
||||
// 连接成功后重置重试计数
|
||||
retry_count = 0;
|
||||
} catch (const std::exception& e) {
|
||||
// 记录错误
|
||||
std::cerr << "SSE连接错误: " << e.what() << std::endl;
|
||||
|
||||
// 如果已达到最大重试次数,停止尝试
|
||||
if (++retry_count > max_retries) {
|
||||
std::cerr << "达到最大重试次数,停止SSE连接尝试" << std::endl;
|
||||
break;
|
||||
}
|
||||
|
||||
// 指数退避重试
|
||||
int delay = retry_delay_base * (1 << (retry_count - 1)); // 2^(retry_count-1) * base_delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(delay));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 新增方法:解析SSE数据
|
||||
bool client::parse_sse_data(const char* data, size_t length) {
|
||||
try {
|
||||
std::string sse_data(data, length);
|
||||
|
||||
// 查找"data:"标记
|
||||
auto data_pos = sse_data.find("data: ");
|
||||
if (data_pos == std::string::npos) {
|
||||
return true; // 不是数据事件,可能是注释或心跳,继续保持连接
|
||||
}
|
||||
|
||||
// 查找数据行结束位置
|
||||
auto newline_pos = sse_data.find("\n", data_pos);
|
||||
if (newline_pos == std::string::npos) {
|
||||
newline_pos = sse_data.length(); // 如果没有换行符,使用整个字符串
|
||||
}
|
||||
|
||||
// 提取数据内容
|
||||
std::string data_content = sse_data.substr(data_pos + 6, newline_pos - (data_pos + 6));
|
||||
|
||||
// 检查是否是心跳事件
|
||||
if (sse_data.find("event: heartbeat") != std::string::npos) {
|
||||
// 心跳事件,不需要处理数据
|
||||
return true;
|
||||
}
|
||||
|
||||
// 更新消息端点
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
msg_endpoint_ = data_content;
|
||||
}
|
||||
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "解析SSE数据错误: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// 新增方法:关闭SSE连接
|
||||
void client::close_sse_connection() {
|
||||
sse_running_ = false;
|
||||
|
||||
if (sse_thread_ && sse_thread_->joinable()) {
|
||||
sse_thread_->join();
|
||||
}
|
||||
}
|
||||
|
||||
json client::send_jsonrpc(const request& req) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
|
@ -217,7 +319,7 @@ json client::send_jsonrpc(const request& req) {
|
|||
}
|
||||
|
||||
// Send the request
|
||||
auto result = http_client_->Post("/jsonrpc", headers, req_body, "application/json");
|
||||
auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json");
|
||||
|
||||
if (!result) {
|
||||
// Error occurred
|
||||
|
|
|
@ -10,8 +10,8 @@
|
|||
|
||||
namespace mcp {
|
||||
|
||||
server::server(const std::string& host, int port)
|
||||
: host_(host), port_(port), name_("MCP Server"), version_(MCP_VERSION) {
|
||||
server::server(const std::string& host, int port, const std::string& sse_endpoint, const std::string& msg_endpoint_prefix)
|
||||
: host_(host), port_(port), sse_endpoint_(sse_endpoint), msg_endpoint_prefix_(msg_endpoint_prefix), name_("MCP Server"), version_(MCP_VERSION) {
|
||||
|
||||
http_server_ = std::make_unique<httplib::Server>();
|
||||
|
||||
|
@ -34,8 +34,33 @@ bool server::start(bool blocking) {
|
|||
}
|
||||
|
||||
// Set up JSON-RPC endpoint
|
||||
http_server_->Post("/jsonrpc", [this](const httplib::Request& req, httplib::Response& res) {
|
||||
this->handle_jsonrpc(req, res);
|
||||
http_server_->Post(msg_endpoint_prefix_.c_str(), [this](const httplib::Request& req, httplib::Response& res) {
|
||||
// 从URL参数中获取session_id
|
||||
if (auto session_id = req.get_param_value("session_id"); !session_id.empty()) {
|
||||
// 检查session是否存在
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (session_dispatchers_.find(session_id) != session_dispatchers_.end()) {
|
||||
// session存在,处理JSON-RPC请求
|
||||
handle_jsonrpc(req, res);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// session不存在,返回错误
|
||||
json error_response = {
|
||||
{"jsonrpc", "2.0"},
|
||||
{"error", {
|
||||
{"code", static_cast<int>(error_code::invalid_request)},
|
||||
{"message", "Invalid or missing session ID. Initialize first to get a session ID"}
|
||||
}},
|
||||
{"id", nullptr}
|
||||
};
|
||||
res.set_content(error_response.dump(), "application/json");
|
||||
});
|
||||
|
||||
// Set up SSE endpoint
|
||||
http_server_->Get(sse_endpoint_.c_str(), [this](const httplib::Request& req, httplib::Response& res) {
|
||||
this->handle_sse(req, res);
|
||||
});
|
||||
|
||||
// Start the server
|
||||
|
@ -74,6 +99,18 @@ void server::stop() {
|
|||
server_thread_->join();
|
||||
}
|
||||
|
||||
// 清理所有SSE线程
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
for (auto& [session_id, thread] : sse_threads_) {
|
||||
if (thread && thread->joinable()) {
|
||||
thread->join();
|
||||
}
|
||||
}
|
||||
sse_threads_.clear();
|
||||
session_dispatchers_.clear();
|
||||
}
|
||||
|
||||
running_ = false;
|
||||
}
|
||||
|
||||
|
@ -239,6 +276,71 @@ void server::set_auth_handler(std::function<bool(const std::string&)> handler) {
|
|||
auth_handler_ = handler;
|
||||
}
|
||||
|
||||
void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
||||
// 生成会话ID并创建会话URI
|
||||
std::string session_id = generate_session_id();
|
||||
std::string session_uri = msg_endpoint_prefix_ + "?session_id=" + session_id;
|
||||
|
||||
// 创建一个共享的事件分发器,确保其生命周期
|
||||
auto session_dispatcher = std::make_shared<event_dispatcher>();
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
// 存储会话信息
|
||||
session_dispatchers_[session_id] = session_dispatcher;
|
||||
}
|
||||
|
||||
// 创建并启动会话线程,使用共享指针和值捕获而不是引用捕获
|
||||
sse_threads_[session_id] = std::make_unique<std::thread>([this, session_id, session_uri, session_dispatcher]() {
|
||||
try {
|
||||
// 等待一段时间
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1));
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "data: " << session_uri << "\n\n";
|
||||
|
||||
// 使用会话特定的分发器发送事件
|
||||
session_dispatcher->send_event(ss.str());
|
||||
|
||||
// 设置定期心跳
|
||||
while (running_) {
|
||||
std::this_thread::sleep_for(std::chrono::seconds(30));
|
||||
|
||||
// 发送心跳事件
|
||||
std::stringstream heartbeat;
|
||||
heartbeat << "event: heartbeat\ndata: {}\n\n";
|
||||
session_dispatcher->send_event(heartbeat.str());
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
// 记录错误
|
||||
std::cerr << "SSE thread error for session " << session_id << ": " << e.what() << std::endl;
|
||||
}
|
||||
|
||||
// 线程结束时清理资源
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
session_dispatchers_.erase(session_id);
|
||||
}
|
||||
});
|
||||
|
||||
// 不再使用detach,而是在server析构函数中管理线程生命周期
|
||||
|
||||
// 设置分块内容提供者
|
||||
res.set_chunked_content_provider("text/event-stream", [session_dispatcher](size_t /* offset */, httplib::DataSink& sink) {
|
||||
// 使用会话特定的分发器等待事件
|
||||
session_dispatcher->wait_event(&sink);
|
||||
return true;
|
||||
});
|
||||
|
||||
// 注册会话特定的JSON-RPC端点
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
http_server_->Post(session_uri, [this](const httplib::Request& req, httplib::Response& res) {
|
||||
handle_jsonrpc(req, res);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void server::handle_jsonrpc(const httplib::Request& req, httplib::Response& res) {
|
||||
// Set response headers
|
||||
res.set_header("Content-Type", "application/json");
|
||||
|
@ -330,16 +432,6 @@ json server::process_request(const request& req, const std::string& client_addre
|
|||
// The receiver MUST respond promptly with an empty response
|
||||
return response::create_success(req.id, {}).to_json();
|
||||
}
|
||||
|
||||
// Check if client is initialized
|
||||
if (!is_client_initialized(client_address)) {
|
||||
// Client not initialized
|
||||
return response::create_error(
|
||||
req.id,
|
||||
error_code::invalid_request,
|
||||
"Client not initialized"
|
||||
).to_json();
|
||||
}
|
||||
|
||||
// Look for registered method handler
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
@ -467,4 +559,20 @@ void server::set_client_initialized(const std::string& client_address, bool init
|
|||
client_initialized_[client_address] = initialized;
|
||||
}
|
||||
|
||||
// Generate a random session ID
|
||||
std::string server::generate_session_id() const {
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> dis(0, 15);
|
||||
|
||||
std::stringstream ss;
|
||||
ss << std::hex;
|
||||
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
ss << dis(gen);
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // namespace mcp
|
Loading…
Reference in New Issue