Exception while cleaning up session resource: mutex lock failed: Invalid argument
parent
8e11b5dc6d
commit
9e7b728e17
|
@ -28,6 +28,7 @@
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
#include <future>
|
#include <future>
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
|
|
||||||
namespace mcp {
|
namespace mcp {
|
||||||
|
@ -104,16 +105,31 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
void close() {
|
void close() {
|
||||||
std::lock_guard<std::mutex> lk(m_);
|
try {
|
||||||
if (!closed_) {
|
std::lock_guard<std::mutex> lk(m_);
|
||||||
|
if (!closed_) {
|
||||||
|
closed_ = true;
|
||||||
|
cv_.notify_all();
|
||||||
|
}
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
// 如果获取锁失败,尝试设置 closed_ 标志
|
||||||
closed_ = true;
|
closed_ = true;
|
||||||
cv_.notify_all();
|
try {
|
||||||
|
cv_.notify_all();
|
||||||
|
} catch (...) {
|
||||||
|
// 忽略通知失败的异常
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_closed() const {
|
bool is_closed() const {
|
||||||
std::lock_guard<std::mutex> lk(m_);
|
try {
|
||||||
return closed_;
|
std::lock_guard<std::mutex> lk(m_);
|
||||||
|
return closed_;
|
||||||
|
} catch (const std::exception&) {
|
||||||
|
// 如果获取锁失败,假设已关闭
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -122,7 +138,7 @@ private:
|
||||||
std::atomic<int> id_{0};
|
std::atomic<int> id_{0};
|
||||||
std::atomic<int> cid_{-1};
|
std::atomic<int> cid_{-1};
|
||||||
std::string message_;
|
std::string message_;
|
||||||
bool closed_ = false;
|
std::atomic<bool> closed_{false};
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -99,11 +99,16 @@ void server::stop() {
|
||||||
}
|
}
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
LOG_ERROR("Exception while closing session ", session_id, ": ", e.what());
|
LOG_ERROR("Exception while closing session ", session_id, ": ", e.what());
|
||||||
|
} catch (...) {
|
||||||
|
LOG_ERROR("Unknown exception while closing session ", session_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 给线程一些时间来处理关闭事件
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||||
|
|
||||||
// 清理剩余的线程
|
// 清理剩余的线程
|
||||||
{
|
try {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
for (auto& [session_id, thread] : sse_threads_) {
|
for (auto& [session_id, thread] : sse_threads_) {
|
||||||
if (thread && thread->joinable()) {
|
if (thread && thread->joinable()) {
|
||||||
|
@ -111,6 +116,8 @@ void server::stop() {
|
||||||
thread->detach();
|
thread->detach();
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
LOG_ERROR("Exception while detaching session thread ", session_id, ": ", e.what());
|
LOG_ERROR("Exception while detaching session thread ", session_id, ": ", e.what());
|
||||||
|
} catch (...) {
|
||||||
|
LOG_ERROR("Unknown exception while detaching session thread ", session_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -118,6 +125,10 @@ void server::stop() {
|
||||||
// 清空映射
|
// 清空映射
|
||||||
session_dispatchers_.clear();
|
session_dispatchers_.clear();
|
||||||
sse_threads_.clear();
|
sse_threads_.clear();
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
LOG_ERROR("Exception while cleaning up threads: ", e.what());
|
||||||
|
} catch (...) {
|
||||||
|
LOG_ERROR("Unknown exception while cleaning up threads");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (http_server_) {
|
if (http_server_) {
|
||||||
|
@ -125,7 +136,15 @@ void server::stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (server_thread_ && server_thread_->joinable()) {
|
if (server_thread_ && server_thread_->joinable()) {
|
||||||
server_thread_->join();
|
try {
|
||||||
|
server_thread_->join();
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
LOG_ERROR("Exception while joining server thread: ", e.what());
|
||||||
|
server_thread_->detach();
|
||||||
|
} catch (...) {
|
||||||
|
LOG_ERROR("Unknown exception while joining server thread");
|
||||||
|
server_thread_->detach();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_INFO("MCP server stopped");
|
LOG_INFO("MCP server stopped");
|
||||||
|
@ -359,6 +378,8 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
||||||
}
|
}
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
LOG_ERROR("Exception while cleaning up session resources: ", session_id, ", ", e.what());
|
LOG_ERROR("Exception while cleaning up session resources: ", session_id, ", ", e.what());
|
||||||
|
} catch (...) {
|
||||||
|
LOG_ERROR("Unknown exception while cleaning up session resources: ", session_id);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -183,44 +183,27 @@ protected:
|
||||||
};
|
};
|
||||||
server_->set_capabilities(server_capabilities);
|
server_->set_capabilities(server_capabilities);
|
||||||
|
|
||||||
// 注册初始化方法处理器,检查版本
|
|
||||||
server_->register_method("initialize", [this, server_capabilities](const json& params) -> json {
|
|
||||||
// 检查协议版本
|
|
||||||
std::string requested_version = params["protocolVersion"];
|
|
||||||
if (requested_version != MCP_VERSION) {
|
|
||||||
throw mcp_exception(error_code::invalid_params, "Unsupported protocol version");
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
{"protocolVersion", MCP_VERSION},
|
|
||||||
{"capabilities", server_capabilities},
|
|
||||||
{"serverInfo", {
|
|
||||||
{"name", "TestServer"},
|
|
||||||
{"version", "1.0.0"}
|
|
||||||
}}
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
// 启动服务器(非阻塞模式)
|
// 启动服务器(非阻塞模式)
|
||||||
server_->start(false);
|
server_->start(false);
|
||||||
|
|
||||||
|
client_ = std::make_unique<client>("localhost", 8081);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TearDown() override {
|
void TearDown() override {
|
||||||
// 清理测试环境
|
// 清理测试环境
|
||||||
server_->stop();
|
server_->stop();
|
||||||
server_.reset();
|
server_.reset();
|
||||||
|
client_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<server> server_;
|
std::unique_ptr<server> server_;
|
||||||
|
std::unique_ptr<client> client_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// 测试支持的版本
|
// 测试支持的版本
|
||||||
TEST_F(VersioningTest, SupportedVersion) {
|
TEST_F(VersioningTest, SupportedVersion) {
|
||||||
// 创建使用正确版本的客户端
|
|
||||||
client client_correct("localhost", 8081);
|
|
||||||
|
|
||||||
// 执行初始化
|
// 执行初始化
|
||||||
bool init_result = client_correct.initialize("TestClient", "1.0.0");
|
bool init_result = client_->initialize("TestClient", "1.0.0");
|
||||||
|
|
||||||
// 验证初始化结果
|
// 验证初始化结果
|
||||||
EXPECT_TRUE(init_result);
|
EXPECT_TRUE(init_result);
|
||||||
|
@ -228,20 +211,76 @@ TEST_F(VersioningTest, SupportedVersion) {
|
||||||
|
|
||||||
// 测试不支持的版本
|
// 测试不支持的版本
|
||||||
TEST_F(VersioningTest, UnsupportedVersion) {
|
TEST_F(VersioningTest, UnsupportedVersion) {
|
||||||
// Use httplib::Client to send a request with an unsupported version
|
|
||||||
// Note: Open SSE connection first
|
|
||||||
httplib::Client client("localhost", 8081);
|
|
||||||
auto sse_response = client.Get("/sse");
|
|
||||||
// EXPECT_EQ(sse_response->status, 200);
|
|
||||||
|
|
||||||
std::string msg_endpoint = sse_response->body;
|
|
||||||
|
|
||||||
json req = request::create("initialize", {{"protocolVersion", "0.0.1"}}).to_json();
|
|
||||||
auto res = client.Post(msg_endpoint.c_str(), req.dump(), "application/json");
|
|
||||||
EXPECT_EQ(res->status, 400);
|
|
||||||
try {
|
try {
|
||||||
auto mcp_res = response::from_json(json::parse(res->body));
|
// 使用 httplib::Client 发送不支持的版本请求
|
||||||
EXPECT_EQ(mcp_res.error["code"], error_code::invalid_params);
|
std::unique_ptr<httplib::Client> sse_client = std::make_unique<httplib::Client>("localhost", 8081);
|
||||||
|
std::unique_ptr<httplib::Client> http_client = std::make_unique<httplib::Client>("localhost", 8081);
|
||||||
|
|
||||||
|
// 打开 SSE 连接
|
||||||
|
std::promise<std::string> msg_endpoint_promise;
|
||||||
|
std::promise<std::string> sse_promise;
|
||||||
|
std::future<std::string> msg_endpoint = msg_endpoint_promise.get_future();
|
||||||
|
std::future<std::string> sse_response = sse_promise.get_future();
|
||||||
|
|
||||||
|
std::atomic<bool> sse_running{true};
|
||||||
|
bool msg_endpoint_received = false;
|
||||||
|
bool sse_response_received = false;
|
||||||
|
|
||||||
|
std::thread sse_thread([&]() {
|
||||||
|
sse_client->Get("/sse", [&](const char* data, size_t len) {
|
||||||
|
try {
|
||||||
|
std::string response(data, len);
|
||||||
|
size_t pos = response.find("data: ");
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
std::string data_content = response.substr(pos + 6);
|
||||||
|
data_content = data_content.substr(0, data_content.find("\n"));
|
||||||
|
|
||||||
|
if (!msg_endpoint_received && response.find("endpoint") != std::string::npos) {
|
||||||
|
msg_endpoint_promise.set_value(data_content);
|
||||||
|
msg_endpoint_received = true;
|
||||||
|
// GTEST_LOG_(INFO) << "Endpoint received: " << data_content;
|
||||||
|
} else if (!sse_response_received && response.find("message") != std::string::npos) {
|
||||||
|
sse_promise.set_value(data_content);
|
||||||
|
sse_response_received = true;
|
||||||
|
// GTEST_LOG_(INFO) << "Message received: " << data_content;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
GTEST_LOG_(ERROR) << "SSE处理错误: " << e.what();
|
||||||
|
}
|
||||||
|
return sse_running.load();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// // 等待消息端点,设置超时
|
||||||
|
// auto endpoint_status = msg_endpoint.wait_for(std::chrono::milliseconds(100));
|
||||||
|
// EXPECT_EQ(endpoint_status, std::future_status::ready) << "获取消息端点超时";
|
||||||
|
|
||||||
|
std::string endpoint = msg_endpoint.get();
|
||||||
|
EXPECT_FALSE(endpoint.empty());
|
||||||
|
// GTEST_LOG_(INFO) << "Using endpoint: " << endpoint;
|
||||||
|
|
||||||
|
// 发送不支持的版本请求
|
||||||
|
json req = request::create("initialize", {{"protocolVersion", "0.0.1"}}).to_json();
|
||||||
|
auto res = http_client->Post(endpoint.c_str(), req.dump(), "application/json");
|
||||||
|
|
||||||
|
EXPECT_TRUE(res != nullptr);
|
||||||
|
EXPECT_EQ(res->status, 202);
|
||||||
|
|
||||||
|
// // 等待SSE响应,设置超时
|
||||||
|
// auto sse_status = sse_response.wait_for(std::chrono::milliseconds(100));
|
||||||
|
// EXPECT_EQ(sse_status, std::future_status::ready) << "获取SSE响应超时";
|
||||||
|
|
||||||
|
auto mcp_res = json::parse(sse_response.get());
|
||||||
|
EXPECT_EQ(mcp_res["error"]["code"].get<int>(), static_cast<int>(error_code::invalid_params));
|
||||||
|
|
||||||
|
sse_running.store(false);
|
||||||
|
if (sse_thread.joinable()) {
|
||||||
|
sse_thread.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
sse_client.reset();
|
||||||
|
http_client.reset();
|
||||||
} catch (const mcp_exception& e) {
|
} catch (const mcp_exception& e) {
|
||||||
EXPECT_TRUE(false);
|
EXPECT_TRUE(false);
|
||||||
}
|
}
|
||||||
|
@ -254,11 +293,6 @@ protected:
|
||||||
// 设置测试环境
|
// 设置测试环境
|
||||||
server_ = std::make_unique<server>("localhost", 8082);
|
server_ = std::make_unique<server>("localhost", 8082);
|
||||||
|
|
||||||
// 注册ping方法处理器
|
|
||||||
server_->register_method("ping", [](const json& params) -> json {
|
|
||||||
return json::object(); // 返回空对象
|
|
||||||
});
|
|
||||||
|
|
||||||
// 启动服务器(非阻塞模式)
|
// 启动服务器(非阻塞模式)
|
||||||
server_->start(false);
|
server_->start(false);
|
||||||
|
|
||||||
|
@ -279,13 +313,84 @@ protected:
|
||||||
|
|
||||||
// 测试Ping请求
|
// 测试Ping请求
|
||||||
TEST_F(PingTest, PingRequest) {
|
TEST_F(PingTest, PingRequest) {
|
||||||
// 发送ping请求
|
client_->initialize("TestClient", "1.0.0");
|
||||||
bool ping_result = client_->ping();
|
bool ping_result = client_->ping();
|
||||||
|
|
||||||
// 验证ping结果
|
|
||||||
EXPECT_TRUE(ping_result);
|
EXPECT_TRUE(ping_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(PingTest, DirectPing) {
|
||||||
|
try {
|
||||||
|
// 使用 httplib::Client 发送不支持的版本请求
|
||||||
|
std::unique_ptr<httplib::Client> sse_client = std::make_unique<httplib::Client>("localhost", 8082);
|
||||||
|
std::unique_ptr<httplib::Client> http_client = std::make_unique<httplib::Client>("localhost", 8082);
|
||||||
|
|
||||||
|
// 打开 SSE 连接
|
||||||
|
std::promise<std::string> msg_endpoint_promise;
|
||||||
|
std::promise<std::string> sse_promise;
|
||||||
|
std::future<std::string> msg_endpoint = msg_endpoint_promise.get_future();
|
||||||
|
std::future<std::string> sse_response = sse_promise.get_future();
|
||||||
|
|
||||||
|
std::atomic<bool> sse_running{true};
|
||||||
|
bool msg_endpoint_received = false;
|
||||||
|
bool sse_response_received = false;
|
||||||
|
|
||||||
|
std::thread sse_thread([&]() {
|
||||||
|
sse_client->Get("/sse", [&](const char* data, size_t len) {
|
||||||
|
try {
|
||||||
|
std::string response(data, len);
|
||||||
|
size_t pos = response.find("data: ");
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
std::string data_content = response.substr(pos + 6);
|
||||||
|
data_content = data_content.substr(0, data_content.find("\n"));
|
||||||
|
|
||||||
|
if (!msg_endpoint_received && response.find("endpoint") != std::string::npos) {
|
||||||
|
msg_endpoint_promise.set_value(data_content);
|
||||||
|
msg_endpoint_received = true;
|
||||||
|
// GTEST_LOG_(INFO) << "Endpoint received: " << data_content;
|
||||||
|
} else if (!sse_response_received && response.find("message") != std::string::npos) {
|
||||||
|
sse_promise.set_value(data_content);
|
||||||
|
sse_response_received = true;
|
||||||
|
// GTEST_LOG_(INFO) << "Message received: " << data_content;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
GTEST_LOG_(ERROR) << "SSE处理错误: " << e.what();
|
||||||
|
}
|
||||||
|
return sse_running.load();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// // 等待消息端点,设置超时
|
||||||
|
// auto endpoint_status = msg_endpoint.wait_for(std::chrono::milliseconds(100));
|
||||||
|
// EXPECT_EQ(endpoint_status, std::future_status::ready) << "获取消息端点超时";
|
||||||
|
|
||||||
|
std::string endpoint = msg_endpoint.get();
|
||||||
|
EXPECT_FALSE(endpoint.empty());
|
||||||
|
|
||||||
|
// 即使没有建立SSE连接,也可以发送ping请求
|
||||||
|
json ping_req = request::create("ping").to_json();
|
||||||
|
auto ping_res = http_client->Post(endpoint.c_str(), ping_req.dump(), "application/json");
|
||||||
|
EXPECT_TRUE(ping_res != nullptr);
|
||||||
|
EXPECT_EQ(ping_res->status / 100, 2);
|
||||||
|
|
||||||
|
// auto sse_status = sse_response.wait_for(std::chrono::milliseconds(100));
|
||||||
|
// EXPECT_EQ(sse_status, std::future_status::ready) << "获取SSE响应超时";
|
||||||
|
|
||||||
|
auto mcp_res = json::parse(sse_response.get());
|
||||||
|
EXPECT_EQ(mcp_res["result"], json::object());
|
||||||
|
|
||||||
|
sse_running.store(false);
|
||||||
|
if (sse_thread.joinable()) {
|
||||||
|
sse_thread.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
sse_client.reset();
|
||||||
|
http_client.reset();
|
||||||
|
} catch (const mcp_exception& e) {
|
||||||
|
EXPECT_TRUE(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 测试工具功能
|
// 测试工具功能
|
||||||
class ToolsTest : public ::testing::Test {
|
class ToolsTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
|
|
Loading…
Reference in New Issue