ignore exception temporarily; update server a bit
parent
9e7b728e17
commit
28b8fd5376
|
@ -35,14 +35,17 @@ namespace mcp {
|
||||||
|
|
||||||
class event_dispatcher {
|
class event_dispatcher {
|
||||||
public:
|
public:
|
||||||
event_dispatcher() = default;
|
// 使用较小的初始消息缓冲区
|
||||||
|
event_dispatcher() {
|
||||||
|
message_.reserve(128); // 预分配较小的缓冲区空间
|
||||||
|
}
|
||||||
|
|
||||||
~event_dispatcher() {
|
~event_dispatcher() {
|
||||||
close();
|
close();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool wait_event(httplib::DataSink* sink, const std::chrono::milliseconds& timeout = std::chrono::milliseconds(30000)) {
|
bool wait_event(httplib::DataSink* sink, const std::chrono::milliseconds& timeout = std::chrono::milliseconds(10000)) {
|
||||||
if (!sink) {
|
if (!sink || closed_.load(std::memory_order_acquire)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,86 +53,99 @@ public:
|
||||||
{
|
{
|
||||||
std::unique_lock<std::mutex> lk(m_);
|
std::unique_lock<std::mutex> lk(m_);
|
||||||
|
|
||||||
// 如果连接已关闭,返回false
|
if (closed_.load(std::memory_order_acquire)) {
|
||||||
if (closed_) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
int id = id_;
|
int id = id_.load(std::memory_order_relaxed);
|
||||||
|
|
||||||
// 使用超时等待
|
|
||||||
bool result = cv_.wait_for(lk, timeout, [&] {
|
bool result = cv_.wait_for(lk, timeout, [&] {
|
||||||
return cid_ == id || closed_;
|
return cid_.load(std::memory_order_relaxed) == id || closed_.load(std::memory_order_acquire);
|
||||||
});
|
});
|
||||||
|
|
||||||
// 如果连接已关闭或等待超时,返回false
|
if (closed_.load(std::memory_order_acquire)) {
|
||||||
if (closed_) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!result) {
|
if (!result) {
|
||||||
std::cerr << "等待事件超时" << std::endl;
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 复制消息,避免在锁外访问共享数据
|
// 仅当有新消息时才复制
|
||||||
message_copy = message_;
|
if (!message_.empty()) {
|
||||||
|
message_copy.swap(message_);
|
||||||
|
} else {
|
||||||
|
return true; // 没有消息但是条件已满足
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 写入数据 - 在锁外进行,避免长时间持有锁
|
|
||||||
try {
|
try {
|
||||||
bool write_result = sink->write(message_copy.data(), message_copy.size());
|
if (!message_copy.empty()) {
|
||||||
if (!write_result) {
|
if (!sink->write(message_copy.data(), message_copy.size())) {
|
||||||
close();
|
close();
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::exception& e) {
|
} catch (...) {
|
||||||
close();
|
close();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool send_event(const std::string& message) {
|
bool send_event(const std::string& message) {
|
||||||
std::lock_guard<std::mutex> lk(m_);
|
if (closed_.load(std::memory_order_acquire) || message.empty()) {
|
||||||
|
|
||||||
// 如果连接已关闭,返回失败
|
|
||||||
if (closed_) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
cid_ = id_++;
|
try {
|
||||||
message_ = message;
|
std::lock_guard<std::mutex> lk(m_);
|
||||||
cv_.notify_all();
|
|
||||||
return true;
|
if (closed_.load(std::memory_order_acquire)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 高效设置消息并分配适当空间
|
||||||
|
if (message.size() > message_.capacity()) {
|
||||||
|
message_.reserve(message.size() + 64); // 预分配额外空间避免频繁再分配
|
||||||
|
}
|
||||||
|
message_ = message;
|
||||||
|
|
||||||
|
cid_.store(id_.fetch_add(1, std::memory_order_relaxed), std::memory_order_relaxed);
|
||||||
|
cv_.notify_one(); // 只通知一个等待线程,减少竞争
|
||||||
|
return true;
|
||||||
|
} catch (...) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void close() {
|
void close() {
|
||||||
|
bool was_closed = closed_.exchange(true, std::memory_order_release);
|
||||||
|
if (was_closed) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
std::lock_guard<std::mutex> lk(m_);
|
cv_.notify_all();
|
||||||
if (!closed_) {
|
} catch (...) {
|
||||||
closed_ = true;
|
// 忽略异常
|
||||||
cv_.notify_all();
|
|
||||||
}
|
|
||||||
} catch (const std::exception& e) {
|
|
||||||
// 如果获取锁失败,尝试设置 closed_ 标志
|
|
||||||
closed_ = true;
|
|
||||||
try {
|
|
||||||
cv_.notify_all();
|
|
||||||
} catch (...) {
|
|
||||||
// 忽略通知失败的异常
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_closed() const {
|
bool is_closed() const {
|
||||||
try {
|
return closed_.load(std::memory_order_acquire);
|
||||||
std::lock_guard<std::mutex> lk(m_);
|
}
|
||||||
return closed_;
|
|
||||||
} catch (const std::exception&) {
|
// 获取最后活动时间
|
||||||
// 如果获取锁失败,假设已关闭
|
std::chrono::steady_clock::time_point last_activity() const {
|
||||||
return true;
|
std::lock_guard<std::mutex> lk(m_);
|
||||||
}
|
return last_activity_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新活动时间(发送或接收消息时)
|
||||||
|
void update_activity() {
|
||||||
|
std::lock_guard<std::mutex> lk(m_);
|
||||||
|
last_activity_ = std::chrono::steady_clock::now();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -139,6 +155,7 @@ private:
|
||||||
std::atomic<int> cid_{-1};
|
std::atomic<int> cid_{-1};
|
||||||
std::string message_;
|
std::string message_;
|
||||||
std::atomic<bool> closed_{false};
|
std::atomic<bool> closed_{false};
|
||||||
|
std::chrono::steady_clock::time_point last_activity_{std::chrono::steady_clock::now()};
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -342,6 +359,10 @@ private:
|
||||||
auto_lock get_lock() const {
|
auto_lock get_lock() const {
|
||||||
return auto_lock(mutex_);
|
return auto_lock(mutex_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 会话管理与维护
|
||||||
|
void check_inactive_sessions();
|
||||||
|
std::unique_ptr<std::thread> maintenance_thread_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mcp
|
} // namespace mcp
|
||||||
|
|
|
@ -47,6 +47,25 @@ bool server::start(bool blocking) {
|
||||||
LOG_INFO(req.remote_addr, ":", req.remote_port, " - \"GET ", req.path, " HTTP/1.1\" ", res.status);
|
LOG_INFO(req.remote_addr, ":", req.remote_port, " - \"GET ", req.path, " HTTP/1.1\" ", res.status);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// 启动资源检查线程(优化:只在非阻塞模式下启动)
|
||||||
|
if (!blocking) {
|
||||||
|
maintenance_thread_ = std::make_unique<std::thread>([this]() {
|
||||||
|
while (running_) {
|
||||||
|
// 每60秒检查一次不活跃的会话
|
||||||
|
std::this_thread::sleep_for(std::chrono::seconds(60));
|
||||||
|
if (running_) {
|
||||||
|
try {
|
||||||
|
check_inactive_sessions();
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
LOG_ERROR("Exception in maintenance thread: ", e.what());
|
||||||
|
} catch (...) {
|
||||||
|
LOG_ERROR("Unknown exception in maintenance thread");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// 启动服务器
|
// 启动服务器
|
||||||
if (blocking) {
|
if (blocking) {
|
||||||
running_ = true;
|
running_ = true;
|
||||||
|
@ -80,71 +99,127 @@ void server::stop() {
|
||||||
LOG_INFO("Stopping MCP server...");
|
LOG_INFO("Stopping MCP server...");
|
||||||
running_ = false;
|
running_ = false;
|
||||||
|
|
||||||
// 关闭所有SSE连接
|
// 关闭维护线程
|
||||||
std::vector<std::string> session_ids;
|
if (maintenance_thread_ && maintenance_thread_->joinable()) {
|
||||||
|
try {
|
||||||
|
maintenance_thread_->join();
|
||||||
|
} catch (...) {
|
||||||
|
maintenance_thread_->detach();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 复制所有分发器和线程,避免长时间持有锁
|
||||||
|
std::vector<std::shared_ptr<event_dispatcher>> dispatchers_to_close;
|
||||||
|
std::vector<std::unique_ptr<std::thread>> threads_to_join;
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
for (const auto& [session_id, _] : session_dispatchers_) {
|
|
||||||
session_ids.push_back(session_id);
|
// 复制所有分发器
|
||||||
|
dispatchers_to_close.reserve(session_dispatchers_.size());
|
||||||
|
for (const auto& [_, dispatcher] : session_dispatchers_) {
|
||||||
|
dispatchers_to_close.push_back(dispatcher);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
// 复制所有线程
|
||||||
// 关闭每个会话的分发器
|
threads_to_join.reserve(sse_threads_.size());
|
||||||
for (const auto& session_id : session_ids) {
|
for (auto& [_, thread] : sse_threads_) {
|
||||||
try {
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
|
||||||
auto it = session_dispatchers_.find(session_id);
|
|
||||||
if (it != session_dispatchers_.end()) {
|
|
||||||
it->second->close();
|
|
||||||
}
|
|
||||||
} catch (const std::exception& e) {
|
|
||||||
LOG_ERROR("Exception while closing session ", session_id, ": ", e.what());
|
|
||||||
} catch (...) {
|
|
||||||
LOG_ERROR("Unknown exception while closing session ", session_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 给线程一些时间来处理关闭事件
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
|
||||||
|
|
||||||
// 清理剩余的线程
|
|
||||||
try {
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
|
||||||
for (auto& [session_id, thread] : sse_threads_) {
|
|
||||||
if (thread && thread->joinable()) {
|
if (thread && thread->joinable()) {
|
||||||
try {
|
threads_to_join.push_back(std::move(thread));
|
||||||
thread->detach();
|
|
||||||
} catch (const std::exception& e) {
|
|
||||||
LOG_ERROR("Exception while detaching session thread ", session_id, ": ", e.what());
|
|
||||||
} catch (...) {
|
|
||||||
LOG_ERROR("Unknown exception while detaching session thread ", session_id);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清空映射
|
// 清空映射表
|
||||||
session_dispatchers_.clear();
|
session_dispatchers_.clear();
|
||||||
sse_threads_.clear();
|
sse_threads_.clear();
|
||||||
} catch (const std::exception& e) {
|
session_initialized_.clear();
|
||||||
LOG_ERROR("Exception while cleaning up threads: ", e.what());
|
|
||||||
} catch (...) {
|
|
||||||
LOG_ERROR("Unknown exception while cleaning up threads");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (http_server_) {
|
// 在锁外关闭所有分发器
|
||||||
http_server_->stop();
|
for (auto& dispatcher : dispatchers_to_close) {
|
||||||
|
if (dispatcher && !dispatcher->is_closed()) {
|
||||||
|
try {
|
||||||
|
dispatcher->close();
|
||||||
|
} catch (...) {
|
||||||
|
// 忽略异常
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 给线程一些时间处理关闭事件
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(300));
|
||||||
|
|
||||||
|
// 在锁外等待线程结束(有超时限制)
|
||||||
|
const auto timeout_point = std::chrono::steady_clock::now() + std::chrono::seconds(2);
|
||||||
|
|
||||||
|
for (auto& thread : threads_to_join) {
|
||||||
|
if (!thread || !thread->joinable()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (std::chrono::steady_clock::now() >= timeout_point) {
|
||||||
|
// 如果已经超时,detach剩余线程
|
||||||
|
LOG_WARNING("Thread join timeout reached, detaching remaining threads");
|
||||||
|
thread->detach();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试使用超时的join
|
||||||
|
bool joined = false;
|
||||||
|
try {
|
||||||
|
// 创建future和promise,用于实现thread join的超时处理
|
||||||
|
std::promise<void> thread_done;
|
||||||
|
auto future = thread_done.get_future();
|
||||||
|
|
||||||
|
// 在另一个线程中尝试join
|
||||||
|
std::thread join_helper([&thread, &thread_done]() {
|
||||||
|
try {
|
||||||
|
thread->join();
|
||||||
|
thread_done.set_value();
|
||||||
|
} catch (...) {
|
||||||
|
try {
|
||||||
|
thread_done.set_exception(std::current_exception());
|
||||||
|
} catch (...) {}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 等待join完成或超时
|
||||||
|
if (future.wait_for(std::chrono::milliseconds(100)) == std::future_status::ready) {
|
||||||
|
future.get(); // 获取可能的异常
|
||||||
|
joined = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理join_helper线程
|
||||||
|
if (join_helper.joinable()) {
|
||||||
|
if (joined) {
|
||||||
|
join_helper.join();
|
||||||
|
} else {
|
||||||
|
join_helper.detach();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (...) {
|
||||||
|
joined = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果join失败,则detach
|
||||||
|
if (!joined) {
|
||||||
|
try {
|
||||||
|
thread->detach();
|
||||||
|
} catch (...) {
|
||||||
|
// 忽略异常
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (server_thread_ && server_thread_->joinable()) {
|
if (server_thread_ && server_thread_->joinable()) {
|
||||||
|
http_server_->stop();
|
||||||
try {
|
try {
|
||||||
server_thread_->join();
|
server_thread_->join();
|
||||||
} catch (const std::exception& e) {
|
|
||||||
LOG_ERROR("Exception while joining server thread: ", e.what());
|
|
||||||
server_thread_->detach();
|
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
LOG_ERROR("Unknown exception while joining server thread");
|
|
||||||
server_thread_->detach();
|
server_thread_->detach();
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
http_server_->stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_INFO("MCP server stopped");
|
LOG_INFO("MCP server stopped");
|
||||||
|
@ -317,6 +392,9 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
||||||
// 创建会话特定的事件分发器
|
// 创建会话特定的事件分发器
|
||||||
auto session_dispatcher = std::make_shared<event_dispatcher>();
|
auto session_dispatcher = std::make_shared<event_dispatcher>();
|
||||||
|
|
||||||
|
// 初始化活动时间
|
||||||
|
session_dispatcher->update_activity();
|
||||||
|
|
||||||
// 添加会话分发器到映射表
|
// 添加会话分发器到映射表
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
@ -332,6 +410,9 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
||||||
ss << "event: endpoint\ndata: " << session_uri << "\n\n";
|
ss << "event: endpoint\ndata: " << session_uri << "\n\n";
|
||||||
session_dispatcher->send_event(ss.str());
|
session_dispatcher->send_event(ss.str());
|
||||||
|
|
||||||
|
// 更新活动时间(发送消息后)
|
||||||
|
session_dispatcher->update_activity();
|
||||||
|
|
||||||
// 定期发送心跳,检测连接状态
|
// 定期发送心跳,检测连接状态
|
||||||
int heartbeat_count = 0;
|
int heartbeat_count = 0;
|
||||||
while (running_ && !session_dispatcher->is_closed()) {
|
while (running_ && !session_dispatcher->is_closed()) {
|
||||||
|
@ -350,6 +431,9 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
||||||
LOG_WARNING("Failed to send heartbeat, client may have closed connection: ", session_id);
|
LOG_WARNING("Failed to send heartbeat, client may have closed connection: ", session_id);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 更新活动时间(心跳成功)
|
||||||
|
session_dispatcher->update_activity();
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
LOG_ERROR("Failed to send heartbeat: ", e.what());
|
LOG_ERROR("Failed to send heartbeat: ", e.what());
|
||||||
break;
|
break;
|
||||||
|
@ -361,20 +445,39 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
||||||
|
|
||||||
// 安全地清理资源
|
// 安全地清理资源
|
||||||
try {
|
try {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
// 先复制需要处理的资源指针
|
||||||
|
std::shared_ptr<event_dispatcher> dispatcher_to_close;
|
||||||
|
std::unique_ptr<std::thread> thread_to_release;
|
||||||
|
|
||||||
auto dispatcher_it = session_dispatchers_.find(session_id);
|
{
|
||||||
if (dispatcher_it != session_dispatchers_.end()) {
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
if (!dispatcher_it->second->is_closed()) {
|
|
||||||
dispatcher_it->second->close();
|
// 获取dispatcher指针
|
||||||
|
auto dispatcher_it = session_dispatchers_.find(session_id);
|
||||||
|
if (dispatcher_it != session_dispatchers_.end()) {
|
||||||
|
dispatcher_to_close = dispatcher_it->second;
|
||||||
|
session_dispatchers_.erase(dispatcher_it);
|
||||||
}
|
}
|
||||||
session_dispatchers_.erase(dispatcher_it);
|
|
||||||
|
// 获取线程指针
|
||||||
|
auto thread_it = sse_threads_.find(session_id);
|
||||||
|
if (thread_it != sse_threads_.end()) {
|
||||||
|
thread_to_release = std::move(thread_it->second);
|
||||||
|
sse_threads_.erase(thread_it);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理初始化状态
|
||||||
|
session_initialized_.erase(session_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto thread_it = sse_threads_.find(session_id);
|
// 在锁外关闭dispatcher
|
||||||
if (thread_it != sse_threads_.end()) {
|
if (dispatcher_to_close && !dispatcher_to_close->is_closed()) {
|
||||||
thread_it->second.release();
|
dispatcher_to_close->close();
|
||||||
sse_threads_.erase(thread_it);
|
}
|
||||||
|
|
||||||
|
// 释放线程资源
|
||||||
|
if (thread_to_release) {
|
||||||
|
thread_to_release.release();
|
||||||
}
|
}
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
LOG_ERROR("Exception while cleaning up session resources: ", session_id, ", ", e.what());
|
LOG_ERROR("Exception while cleaning up session resources: ", session_id, ", ", e.what());
|
||||||
|
@ -392,46 +495,34 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
||||||
// 设置分块内容提供者
|
// 设置分块内容提供者
|
||||||
res.set_chunked_content_provider("text/event-stream", [this, session_id, session_dispatcher](size_t /* offset */, httplib::DataSink& sink) {
|
res.set_chunked_content_provider("text/event-stream", [this, session_id, session_dispatcher](size_t /* offset */, httplib::DataSink& sink) {
|
||||||
try {
|
try {
|
||||||
// 检查会话是否已关闭
|
// 检查会话是否已关闭 - 直接从分发器获取状态,减少锁冲突
|
||||||
{
|
if (session_dispatcher->is_closed()) {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
return false;
|
||||||
auto it = session_dispatchers_.find(session_id);
|
|
||||||
if (it == session_dispatchers_.end() || it->second->is_closed()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 更新活动时间(接收到请求)
|
||||||
|
session_dispatcher->update_activity();
|
||||||
|
|
||||||
// 等待事件
|
// 等待事件
|
||||||
bool result = session_dispatcher->wait_event(&sink);
|
bool result = session_dispatcher->wait_event(&sink);
|
||||||
if (!result) {
|
if (!result) {
|
||||||
LOG_WARNING("Failed to wait for event, closing connection: ", session_id);
|
LOG_WARNING("Failed to wait for event, closing connection: ", session_id);
|
||||||
|
|
||||||
// 关闭会话分发器,但不清理资源
|
// 直接关闭分发器,无需加锁
|
||||||
{
|
session_dispatcher->close();
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
|
||||||
auto it = session_dispatchers_.find(session_id);
|
|
||||||
if (it != session_dispatchers_.end()) {
|
|
||||||
it->second->close();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 更新活动时间(成功接收消息)
|
||||||
|
session_dispatcher->update_activity();
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
LOG_ERROR("SSE content provider exception: ", e.what());
|
LOG_ERROR("SSE content provider exception: ", e.what());
|
||||||
|
|
||||||
// 关闭会话分发器,但不清理资源
|
// 直接关闭分发器,无需加锁
|
||||||
try {
|
session_dispatcher->close();
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
|
||||||
auto it = session_dispatchers_.find(session_id);
|
|
||||||
if (it != session_dispatchers_.end()) {
|
|
||||||
it->second->close();
|
|
||||||
}
|
|
||||||
} catch (const std::exception& e2) {
|
|
||||||
LOG_ERROR("Exception while closing session dispatcher: ", e2.what());
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -455,6 +546,22 @@ void server::handle_jsonrpc(const httplib::Request& req, httplib::Response& res)
|
||||||
auto it = req.params.find("session_id");
|
auto it = req.params.find("session_id");
|
||||||
std::string session_id = it != req.params.end() ? it->second : "";
|
std::string session_id = it != req.params.end() ? it->second : "";
|
||||||
|
|
||||||
|
// 更新会话活动时间
|
||||||
|
if (!session_id.empty()) {
|
||||||
|
std::shared_ptr<event_dispatcher> dispatcher;
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
auto disp_it = session_dispatchers_.find(session_id);
|
||||||
|
if (disp_it != session_dispatchers_.end()) {
|
||||||
|
dispatcher = disp_it->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dispatcher) {
|
||||||
|
dispatcher->update_activity();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 解析请求
|
// 解析请求
|
||||||
json req_json;
|
json req_json;
|
||||||
try {
|
try {
|
||||||
|
@ -686,6 +793,12 @@ json server::handle_initialize(const request& req, const std::string& session_id
|
||||||
}
|
}
|
||||||
|
|
||||||
void server::send_request(const std::string& session_id, const std::string& method, const json& params) {
|
void server::send_request(const std::string& session_id, const std::string& method, const json& params) {
|
||||||
|
// 检查会话ID是否有效
|
||||||
|
if (session_id.empty()) {
|
||||||
|
LOG_WARNING("Cannot send request to empty session_id");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Check if the method is ping or logging
|
// Check if the method is ping or logging
|
||||||
bool is_allowed_before_init = (method == "ping" || method == "logging");
|
bool is_allowed_before_init = (method == "ping" || method == "logging");
|
||||||
|
|
||||||
|
@ -710,6 +823,12 @@ void server::send_request(const std::string& session_id, const std::string& meth
|
||||||
dispatcher = it->second;
|
dispatcher = it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 确认dispatcher仍然有效
|
||||||
|
if (!dispatcher || dispatcher->is_closed()) {
|
||||||
|
LOG_WARNING("Cannot send to closed session: ", session_id);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// 发送请求
|
// 发送请求
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "event: message\ndata: " << req.to_json().dump() << "\n\n";
|
ss << "event: message\ndata: " << req.to_json().dump() << "\n\n";
|
||||||
|
@ -721,14 +840,40 @@ void server::send_request(const std::string& session_id, const std::string& meth
|
||||||
}
|
}
|
||||||
|
|
||||||
bool server::is_session_initialized(const std::string& session_id) const {
|
bool server::is_session_initialized(const std::string& session_id) const {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
// 检查会话ID是否有效
|
||||||
auto it = session_initialized_.find(session_id);
|
if (session_id.empty()) {
|
||||||
return (it != session_initialized_.end() && it->second);
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
auto it = session_initialized_.find(session_id);
|
||||||
|
return (it != session_initialized_.end() && it->second);
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
LOG_ERROR("Exception checking if session is initialized: ", e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void server::set_session_initialized(const std::string& session_id, bool initialized) {
|
void server::set_session_initialized(const std::string& session_id, bool initialized) {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
// 检查会话ID是否有效
|
||||||
session_initialized_[session_id] = initialized;
|
if (session_id.empty()) {
|
||||||
|
LOG_WARNING("Cannot set initialization state for empty session_id");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
// 检查会话是否仍然存在
|
||||||
|
auto it = session_dispatchers_.find(session_id);
|
||||||
|
if (it == session_dispatchers_.end()) {
|
||||||
|
LOG_WARNING("Cannot set initialization state for non-existent session: ", session_id);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
session_initialized_[session_id] = initialized;
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
LOG_ERROR("Exception setting session initialization state: ", e.what());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string server::generate_session_id() const {
|
std::string server::generate_session_id() const {
|
||||||
|
@ -767,4 +912,41 @@ std::string server::generate_session_id() const {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void server::check_inactive_sessions() {
|
||||||
|
if (!running_) return;
|
||||||
|
|
||||||
|
const auto now = std::chrono::steady_clock::now();
|
||||||
|
const auto timeout = std::chrono::minutes(3); // 3分钟不活跃则关闭
|
||||||
|
|
||||||
|
std::vector<std::string> sessions_to_close;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
for (const auto& [session_id, dispatcher] : session_dispatchers_) {
|
||||||
|
if (now - dispatcher->last_activity() > timeout) {
|
||||||
|
// 超过闲置时间限制
|
||||||
|
sessions_to_close.push_back(session_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭不活跃的会话
|
||||||
|
for (const auto& session_id : sessions_to_close) {
|
||||||
|
LOG_INFO("Closing inactive session: ", session_id);
|
||||||
|
|
||||||
|
std::shared_ptr<event_dispatcher> dispatcher_to_close;
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
auto disp_it = session_dispatchers_.find(session_id);
|
||||||
|
if (disp_it != session_dispatchers_.end()) {
|
||||||
|
dispatcher_to_close = disp_it->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dispatcher_to_close) {
|
||||||
|
dispatcher_to_close->close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mcp
|
} // namespace mcp
|
|
@ -223,8 +223,8 @@ TEST_F(VersioningTest, UnsupportedVersion) {
|
||||||
std::future<std::string> sse_response = sse_promise.get_future();
|
std::future<std::string> sse_response = sse_promise.get_future();
|
||||||
|
|
||||||
std::atomic<bool> sse_running{true};
|
std::atomic<bool> sse_running{true};
|
||||||
bool msg_endpoint_received = false;
|
std::atomic<bool> msg_endpoint_received{false};
|
||||||
bool sse_response_received = false;
|
std::atomic<bool> sse_response_received{false};
|
||||||
|
|
||||||
std::thread sse_thread([&]() {
|
std::thread sse_thread([&]() {
|
||||||
sse_client->Get("/sse", [&](const char* data, size_t len) {
|
sse_client->Get("/sse", [&](const char* data, size_t len) {
|
||||||
|
@ -235,14 +235,20 @@ TEST_F(VersioningTest, UnsupportedVersion) {
|
||||||
std::string data_content = response.substr(pos + 6);
|
std::string data_content = response.substr(pos + 6);
|
||||||
data_content = data_content.substr(0, data_content.find("\n"));
|
data_content = data_content.substr(0, data_content.find("\n"));
|
||||||
|
|
||||||
if (!msg_endpoint_received && response.find("endpoint") != std::string::npos) {
|
if (!msg_endpoint_received.load() && response.find("endpoint") != std::string::npos) {
|
||||||
msg_endpoint_promise.set_value(data_content);
|
msg_endpoint_received.store(true);
|
||||||
msg_endpoint_received = true;
|
try {
|
||||||
// GTEST_LOG_(INFO) << "Endpoint received: " << data_content;
|
msg_endpoint_promise.set_value(data_content);
|
||||||
} else if (!sse_response_received && response.find("message") != std::string::npos) {
|
} catch (...) {
|
||||||
sse_promise.set_value(data_content);
|
// 忽略重复设置的异常
|
||||||
sse_response_received = true;
|
}
|
||||||
// GTEST_LOG_(INFO) << "Message received: " << data_content;
|
} else if (!sse_response_received.load() && response.find("message") != std::string::npos) {
|
||||||
|
sse_response_received.store(true);
|
||||||
|
try {
|
||||||
|
sse_promise.set_value(data_content);
|
||||||
|
} catch (...) {
|
||||||
|
// 忽略重复设置的异常
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
|
@ -251,14 +257,9 @@ TEST_F(VersioningTest, UnsupportedVersion) {
|
||||||
return sse_running.load();
|
return sse_running.load();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
// // 等待消息端点,设置超时
|
|
||||||
// auto endpoint_status = msg_endpoint.wait_for(std::chrono::milliseconds(100));
|
|
||||||
// EXPECT_EQ(endpoint_status, std::future_status::ready) << "获取消息端点超时";
|
|
||||||
|
|
||||||
std::string endpoint = msg_endpoint.get();
|
std::string endpoint = msg_endpoint.get();
|
||||||
EXPECT_FALSE(endpoint.empty());
|
EXPECT_FALSE(endpoint.empty());
|
||||||
// GTEST_LOG_(INFO) << "Using endpoint: " << endpoint;
|
|
||||||
|
|
||||||
// 发送不支持的版本请求
|
// 发送不支持的版本请求
|
||||||
json req = request::create("initialize", {{"protocolVersion", "0.0.1"}}).to_json();
|
json req = request::create("initialize", {{"protocolVersion", "0.0.1"}}).to_json();
|
||||||
|
@ -267,20 +268,42 @@ TEST_F(VersioningTest, UnsupportedVersion) {
|
||||||
EXPECT_TRUE(res != nullptr);
|
EXPECT_TRUE(res != nullptr);
|
||||||
EXPECT_EQ(res->status, 202);
|
EXPECT_EQ(res->status, 202);
|
||||||
|
|
||||||
// // 等待SSE响应,设置超时
|
|
||||||
// auto sse_status = sse_response.wait_for(std::chrono::milliseconds(100));
|
|
||||||
// EXPECT_EQ(sse_status, std::future_status::ready) << "获取SSE响应超时";
|
|
||||||
|
|
||||||
auto mcp_res = json::parse(sse_response.get());
|
auto mcp_res = json::parse(sse_response.get());
|
||||||
EXPECT_EQ(mcp_res["error"]["code"].get<int>(), static_cast<int>(error_code::invalid_params));
|
EXPECT_EQ(mcp_res["error"]["code"].get<int>(), static_cast<int>(error_code::invalid_params));
|
||||||
|
|
||||||
|
// 主动关闭所有连接
|
||||||
sse_running.store(false);
|
sse_running.store(false);
|
||||||
|
|
||||||
|
// 尝试中断SSE连接
|
||||||
|
try {
|
||||||
|
sse_client->Get("/sse", [](const char*, size_t) { return false; });
|
||||||
|
} catch (...) {
|
||||||
|
// 忽略任何异常
|
||||||
|
}
|
||||||
|
|
||||||
|
// 等待线程结束(最多1秒)
|
||||||
if (sse_thread.joinable()) {
|
if (sse_thread.joinable()) {
|
||||||
sse_thread.join();
|
std::thread detacher([](std::thread& t) {
|
||||||
|
try {
|
||||||
|
if (t.joinable()) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
} catch (...) {
|
||||||
|
if (t.joinable()) {
|
||||||
|
t.detach();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, std::ref(sse_thread));
|
||||||
|
detacher.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 清理资源
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
sse_client.reset();
|
sse_client.reset();
|
||||||
http_client.reset();
|
http_client.reset();
|
||||||
|
|
||||||
|
// 添加延迟,确保资源完全释放
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
} catch (const mcp_exception& e) {
|
} catch (const mcp_exception& e) {
|
||||||
EXPECT_TRUE(false);
|
EXPECT_TRUE(false);
|
||||||
}
|
}
|
||||||
|
@ -331,8 +354,8 @@ TEST_F(PingTest, DirectPing) {
|
||||||
std::future<std::string> sse_response = sse_promise.get_future();
|
std::future<std::string> sse_response = sse_promise.get_future();
|
||||||
|
|
||||||
std::atomic<bool> sse_running{true};
|
std::atomic<bool> sse_running{true};
|
||||||
bool msg_endpoint_received = false;
|
std::atomic<bool> msg_endpoint_received{false};
|
||||||
bool sse_response_received = false;
|
std::atomic<bool> sse_response_received{false};
|
||||||
|
|
||||||
std::thread sse_thread([&]() {
|
std::thread sse_thread([&]() {
|
||||||
sse_client->Get("/sse", [&](const char* data, size_t len) {
|
sse_client->Get("/sse", [&](const char* data, size_t len) {
|
||||||
|
@ -343,14 +366,20 @@ TEST_F(PingTest, DirectPing) {
|
||||||
std::string data_content = response.substr(pos + 6);
|
std::string data_content = response.substr(pos + 6);
|
||||||
data_content = data_content.substr(0, data_content.find("\n"));
|
data_content = data_content.substr(0, data_content.find("\n"));
|
||||||
|
|
||||||
if (!msg_endpoint_received && response.find("endpoint") != std::string::npos) {
|
if (!msg_endpoint_received.load() && response.find("endpoint") != std::string::npos) {
|
||||||
msg_endpoint_promise.set_value(data_content);
|
msg_endpoint_received.store(true);
|
||||||
msg_endpoint_received = true;
|
try {
|
||||||
// GTEST_LOG_(INFO) << "Endpoint received: " << data_content;
|
msg_endpoint_promise.set_value(data_content);
|
||||||
} else if (!sse_response_received && response.find("message") != std::string::npos) {
|
} catch (...) {
|
||||||
sse_promise.set_value(data_content);
|
// 忽略重复设置的异常
|
||||||
sse_response_received = true;
|
}
|
||||||
// GTEST_LOG_(INFO) << "Message received: " << data_content;
|
} else if (!sse_response_received.load() && response.find("message") != std::string::npos) {
|
||||||
|
sse_response_received.store(true);
|
||||||
|
try {
|
||||||
|
sse_promise.set_value(data_content);
|
||||||
|
} catch (...) {
|
||||||
|
// 忽略重复设置的异常
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
|
@ -360,10 +389,6 @@ TEST_F(PingTest, DirectPing) {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
// // 等待消息端点,设置超时
|
|
||||||
// auto endpoint_status = msg_endpoint.wait_for(std::chrono::milliseconds(100));
|
|
||||||
// EXPECT_EQ(endpoint_status, std::future_status::ready) << "获取消息端点超时";
|
|
||||||
|
|
||||||
std::string endpoint = msg_endpoint.get();
|
std::string endpoint = msg_endpoint.get();
|
||||||
EXPECT_FALSE(endpoint.empty());
|
EXPECT_FALSE(endpoint.empty());
|
||||||
|
|
||||||
|
@ -373,19 +398,42 @@ TEST_F(PingTest, DirectPing) {
|
||||||
EXPECT_TRUE(ping_res != nullptr);
|
EXPECT_TRUE(ping_res != nullptr);
|
||||||
EXPECT_EQ(ping_res->status / 100, 2);
|
EXPECT_EQ(ping_res->status / 100, 2);
|
||||||
|
|
||||||
// auto sse_status = sse_response.wait_for(std::chrono::milliseconds(100));
|
|
||||||
// EXPECT_EQ(sse_status, std::future_status::ready) << "获取SSE响应超时";
|
|
||||||
|
|
||||||
auto mcp_res = json::parse(sse_response.get());
|
auto mcp_res = json::parse(sse_response.get());
|
||||||
EXPECT_EQ(mcp_res["result"], json::object());
|
EXPECT_EQ(mcp_res["result"], json::object());
|
||||||
|
|
||||||
|
// 主动关闭所有连接
|
||||||
sse_running.store(false);
|
sse_running.store(false);
|
||||||
|
|
||||||
|
// 尝试中断SSE连接
|
||||||
|
try {
|
||||||
|
sse_client->Get("/sse", [](const char*, size_t) { return false; });
|
||||||
|
} catch (...) {
|
||||||
|
// 忽略任何异常
|
||||||
|
}
|
||||||
|
|
||||||
|
// 等待线程结束(最多1秒)
|
||||||
if (sse_thread.joinable()) {
|
if (sse_thread.joinable()) {
|
||||||
sse_thread.join();
|
std::thread detacher([](std::thread& t) {
|
||||||
|
try {
|
||||||
|
if (t.joinable()) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
} catch (...) {
|
||||||
|
if (t.joinable()) {
|
||||||
|
t.detach();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, std::ref(sse_thread));
|
||||||
|
detacher.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 清理资源
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
sse_client.reset();
|
sse_client.reset();
|
||||||
http_client.reset();
|
http_client.reset();
|
||||||
|
|
||||||
|
// 添加延迟,确保资源完全释放
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
} catch (const mcp_exception& e) {
|
} catch (const mcp_exception& e) {
|
||||||
EXPECT_TRUE(false);
|
EXPECT_TRUE(false);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue