Exception while cleaning up session resource: mutex lock failed: Invalid argument

main
hkr04 2025-03-13 19:37:31 +08:00
parent 8e11b5dc6d
commit 9e7b728e17
3 changed files with 193 additions and 51 deletions

View File

@ -28,6 +28,7 @@
#include <chrono> #include <chrono>
#include <condition_variable> #include <condition_variable>
#include <future> #include <future>
#include <atomic>
namespace mcp { namespace mcp {
@ -104,16 +105,31 @@ public:
} }
void close() { void close() {
std::lock_guard<std::mutex> lk(m_); try {
if (!closed_) { std::lock_guard<std::mutex> lk(m_);
if (!closed_) {
closed_ = true;
cv_.notify_all();
}
} catch (const std::exception& e) {
// 如果获取锁失败,尝试设置 closed_ 标志
closed_ = true; closed_ = true;
cv_.notify_all(); try {
cv_.notify_all();
} catch (...) {
// 忽略通知失败的异常
}
} }
} }
bool is_closed() const { bool is_closed() const {
std::lock_guard<std::mutex> lk(m_); try {
return closed_; std::lock_guard<std::mutex> lk(m_);
return closed_;
} catch (const std::exception&) {
// 如果获取锁失败,假设已关闭
return true;
}
} }
private: private:
@ -122,7 +138,7 @@ private:
std::atomic<int> id_{0}; std::atomic<int> id_{0};
std::atomic<int> cid_{-1}; std::atomic<int> cid_{-1};
std::string message_; std::string message_;
bool closed_ = false; std::atomic<bool> closed_{false};
}; };
/** /**

View File

@ -99,11 +99,16 @@ void server::stop() {
} }
} catch (const std::exception& e) { } catch (const std::exception& e) {
LOG_ERROR("Exception while closing session ", session_id, ": ", e.what()); LOG_ERROR("Exception while closing session ", session_id, ": ", e.what());
} catch (...) {
LOG_ERROR("Unknown exception while closing session ", session_id);
} }
} }
// 给线程一些时间来处理关闭事件
std::this_thread::sleep_for(std::chrono::milliseconds(50));
// 清理剩余的线程 // 清理剩余的线程
{ try {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
for (auto& [session_id, thread] : sse_threads_) { for (auto& [session_id, thread] : sse_threads_) {
if (thread && thread->joinable()) { if (thread && thread->joinable()) {
@ -111,6 +116,8 @@ void server::stop() {
thread->detach(); thread->detach();
} catch (const std::exception& e) { } catch (const std::exception& e) {
LOG_ERROR("Exception while detaching session thread ", session_id, ": ", e.what()); LOG_ERROR("Exception while detaching session thread ", session_id, ": ", e.what());
} catch (...) {
LOG_ERROR("Unknown exception while detaching session thread ", session_id);
} }
} }
} }
@ -118,6 +125,10 @@ void server::stop() {
// 清空映射 // 清空映射
session_dispatchers_.clear(); session_dispatchers_.clear();
sse_threads_.clear(); sse_threads_.clear();
} catch (const std::exception& e) {
LOG_ERROR("Exception while cleaning up threads: ", e.what());
} catch (...) {
LOG_ERROR("Unknown exception while cleaning up threads");
} }
if (http_server_) { if (http_server_) {
@ -125,7 +136,15 @@ void server::stop() {
} }
if (server_thread_ && server_thread_->joinable()) { if (server_thread_ && server_thread_->joinable()) {
server_thread_->join(); try {
server_thread_->join();
} catch (const std::exception& e) {
LOG_ERROR("Exception while joining server thread: ", e.what());
server_thread_->detach();
} catch (...) {
LOG_ERROR("Unknown exception while joining server thread");
server_thread_->detach();
}
} }
LOG_INFO("MCP server stopped"); LOG_INFO("MCP server stopped");
@ -359,6 +378,8 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
} }
} catch (const std::exception& e) { } catch (const std::exception& e) {
LOG_ERROR("Exception while cleaning up session resources: ", session_id, ", ", e.what()); LOG_ERROR("Exception while cleaning up session resources: ", session_id, ", ", e.what());
} catch (...) {
LOG_ERROR("Unknown exception while cleaning up session resources: ", session_id);
} }
}); });

View File

@ -183,44 +183,27 @@ protected:
}; };
server_->set_capabilities(server_capabilities); server_->set_capabilities(server_capabilities);
// 注册初始化方法处理器,检查版本
server_->register_method("initialize", [this, server_capabilities](const json& params) -> json {
// 检查协议版本
std::string requested_version = params["protocolVersion"];
if (requested_version != MCP_VERSION) {
throw mcp_exception(error_code::invalid_params, "Unsupported protocol version");
}
return {
{"protocolVersion", MCP_VERSION},
{"capabilities", server_capabilities},
{"serverInfo", {
{"name", "TestServer"},
{"version", "1.0.0"}
}}
};
});
// 启动服务器(非阻塞模式) // 启动服务器(非阻塞模式)
server_->start(false); server_->start(false);
client_ = std::make_unique<client>("localhost", 8081);
} }
void TearDown() override { void TearDown() override {
// 清理测试环境 // 清理测试环境
server_->stop(); server_->stop();
server_.reset(); server_.reset();
client_.reset();
} }
std::unique_ptr<server> server_; std::unique_ptr<server> server_;
std::unique_ptr<client> client_;
}; };
// 测试支持的版本 // 测试支持的版本
TEST_F(VersioningTest, SupportedVersion) { TEST_F(VersioningTest, SupportedVersion) {
// 创建使用正确版本的客户端
client client_correct("localhost", 8081);
// 执行初始化 // 执行初始化
bool init_result = client_correct.initialize("TestClient", "1.0.0"); bool init_result = client_->initialize("TestClient", "1.0.0");
// 验证初始化结果 // 验证初始化结果
EXPECT_TRUE(init_result); EXPECT_TRUE(init_result);
@ -228,20 +211,76 @@ TEST_F(VersioningTest, SupportedVersion) {
// 测试不支持的版本 // 测试不支持的版本
TEST_F(VersioningTest, UnsupportedVersion) { TEST_F(VersioningTest, UnsupportedVersion) {
// Use httplib::Client to send a request with an unsupported version
// Note: Open SSE connection first
httplib::Client client("localhost", 8081);
auto sse_response = client.Get("/sse");
// EXPECT_EQ(sse_response->status, 200);
std::string msg_endpoint = sse_response->body;
json req = request::create("initialize", {{"protocolVersion", "0.0.1"}}).to_json();
auto res = client.Post(msg_endpoint.c_str(), req.dump(), "application/json");
EXPECT_EQ(res->status, 400);
try { try {
auto mcp_res = response::from_json(json::parse(res->body)); // 使用 httplib::Client 发送不支持的版本请求
EXPECT_EQ(mcp_res.error["code"], error_code::invalid_params); std::unique_ptr<httplib::Client> sse_client = std::make_unique<httplib::Client>("localhost", 8081);
std::unique_ptr<httplib::Client> http_client = std::make_unique<httplib::Client>("localhost", 8081);
// 打开 SSE 连接
std::promise<std::string> msg_endpoint_promise;
std::promise<std::string> sse_promise;
std::future<std::string> msg_endpoint = msg_endpoint_promise.get_future();
std::future<std::string> sse_response = sse_promise.get_future();
std::atomic<bool> sse_running{true};
bool msg_endpoint_received = false;
bool sse_response_received = false;
std::thread sse_thread([&]() {
sse_client->Get("/sse", [&](const char* data, size_t len) {
try {
std::string response(data, len);
size_t pos = response.find("data: ");
if (pos != std::string::npos) {
std::string data_content = response.substr(pos + 6);
data_content = data_content.substr(0, data_content.find("\n"));
if (!msg_endpoint_received && response.find("endpoint") != std::string::npos) {
msg_endpoint_promise.set_value(data_content);
msg_endpoint_received = true;
// GTEST_LOG_(INFO) << "Endpoint received: " << data_content;
} else if (!sse_response_received && response.find("message") != std::string::npos) {
sse_promise.set_value(data_content);
sse_response_received = true;
// GTEST_LOG_(INFO) << "Message received: " << data_content;
}
}
} catch (const std::exception& e) {
GTEST_LOG_(ERROR) << "SSE处理错误: " << e.what();
}
return sse_running.load();
});
});
// // 等待消息端点,设置超时
// auto endpoint_status = msg_endpoint.wait_for(std::chrono::milliseconds(100));
// EXPECT_EQ(endpoint_status, std::future_status::ready) << "获取消息端点超时";
std::string endpoint = msg_endpoint.get();
EXPECT_FALSE(endpoint.empty());
// GTEST_LOG_(INFO) << "Using endpoint: " << endpoint;
// 发送不支持的版本请求
json req = request::create("initialize", {{"protocolVersion", "0.0.1"}}).to_json();
auto res = http_client->Post(endpoint.c_str(), req.dump(), "application/json");
EXPECT_TRUE(res != nullptr);
EXPECT_EQ(res->status, 202);
// // 等待SSE响应设置超时
// auto sse_status = sse_response.wait_for(std::chrono::milliseconds(100));
// EXPECT_EQ(sse_status, std::future_status::ready) << "获取SSE响应超时";
auto mcp_res = json::parse(sse_response.get());
EXPECT_EQ(mcp_res["error"]["code"].get<int>(), static_cast<int>(error_code::invalid_params));
sse_running.store(false);
if (sse_thread.joinable()) {
sse_thread.join();
}
sse_client.reset();
http_client.reset();
} catch (const mcp_exception& e) { } catch (const mcp_exception& e) {
EXPECT_TRUE(false); EXPECT_TRUE(false);
} }
@ -254,11 +293,6 @@ protected:
// 设置测试环境 // 设置测试环境
server_ = std::make_unique<server>("localhost", 8082); server_ = std::make_unique<server>("localhost", 8082);
// 注册ping方法处理器
server_->register_method("ping", [](const json& params) -> json {
return json::object(); // 返回空对象
});
// 启动服务器(非阻塞模式) // 启动服务器(非阻塞模式)
server_->start(false); server_->start(false);
@ -279,13 +313,84 @@ protected:
// 测试Ping请求 // 测试Ping请求
TEST_F(PingTest, PingRequest) { TEST_F(PingTest, PingRequest) {
// 发送ping请求 client_->initialize("TestClient", "1.0.0");
bool ping_result = client_->ping(); bool ping_result = client_->ping();
// 验证ping结果
EXPECT_TRUE(ping_result); EXPECT_TRUE(ping_result);
} }
TEST_F(PingTest, DirectPing) {
try {
// 使用 httplib::Client 发送不支持的版本请求
std::unique_ptr<httplib::Client> sse_client = std::make_unique<httplib::Client>("localhost", 8082);
std::unique_ptr<httplib::Client> http_client = std::make_unique<httplib::Client>("localhost", 8082);
// 打开 SSE 连接
std::promise<std::string> msg_endpoint_promise;
std::promise<std::string> sse_promise;
std::future<std::string> msg_endpoint = msg_endpoint_promise.get_future();
std::future<std::string> sse_response = sse_promise.get_future();
std::atomic<bool> sse_running{true};
bool msg_endpoint_received = false;
bool sse_response_received = false;
std::thread sse_thread([&]() {
sse_client->Get("/sse", [&](const char* data, size_t len) {
try {
std::string response(data, len);
size_t pos = response.find("data: ");
if (pos != std::string::npos) {
std::string data_content = response.substr(pos + 6);
data_content = data_content.substr(0, data_content.find("\n"));
if (!msg_endpoint_received && response.find("endpoint") != std::string::npos) {
msg_endpoint_promise.set_value(data_content);
msg_endpoint_received = true;
// GTEST_LOG_(INFO) << "Endpoint received: " << data_content;
} else if (!sse_response_received && response.find("message") != std::string::npos) {
sse_promise.set_value(data_content);
sse_response_received = true;
// GTEST_LOG_(INFO) << "Message received: " << data_content;
}
}
} catch (const std::exception& e) {
GTEST_LOG_(ERROR) << "SSE处理错误: " << e.what();
}
return sse_running.load();
});
});
// // 等待消息端点,设置超时
// auto endpoint_status = msg_endpoint.wait_for(std::chrono::milliseconds(100));
// EXPECT_EQ(endpoint_status, std::future_status::ready) << "获取消息端点超时";
std::string endpoint = msg_endpoint.get();
EXPECT_FALSE(endpoint.empty());
// 即使没有建立SSE连接也可以发送ping请求
json ping_req = request::create("ping").to_json();
auto ping_res = http_client->Post(endpoint.c_str(), ping_req.dump(), "application/json");
EXPECT_TRUE(ping_res != nullptr);
EXPECT_EQ(ping_res->status / 100, 2);
// auto sse_status = sse_response.wait_for(std::chrono::milliseconds(100));
// EXPECT_EQ(sse_status, std::future_status::ready) << "获取SSE响应超时";
auto mcp_res = json::parse(sse_response.get());
EXPECT_EQ(mcp_res["result"], json::object());
sse_running.store(false);
if (sse_thread.joinable()) {
sse_thread.join();
}
sse_client.reset();
http_client.reset();
} catch (const mcp_exception& e) {
EXPECT_TRUE(false);
}
}
// 测试工具功能 // 测试工具功能
class ToolsTest : public ::testing::Test { class ToolsTest : public ::testing::Test {
protected: protected: