diff --git a/include/mcp_server.h b/include/mcp_server.h index 4a03d23..ed5a430 100644 --- a/include/mcp_server.h +++ b/include/mcp_server.h @@ -28,6 +28,7 @@ #include #include #include +#include namespace mcp { @@ -104,16 +105,31 @@ public: } void close() { - std::lock_guard lk(m_); - if (!closed_) { + try { + std::lock_guard lk(m_); + if (!closed_) { + closed_ = true; + cv_.notify_all(); + } + } catch (const std::exception& e) { + // 如果获取锁失败,尝试设置 closed_ 标志 closed_ = true; - cv_.notify_all(); + try { + cv_.notify_all(); + } catch (...) { + // 忽略通知失败的异常 + } } } bool is_closed() const { - std::lock_guard lk(m_); - return closed_; + try { + std::lock_guard lk(m_); + return closed_; + } catch (const std::exception&) { + // 如果获取锁失败,假设已关闭 + return true; + } } private: @@ -122,7 +138,7 @@ private: std::atomic id_{0}; std::atomic cid_{-1}; std::string message_; - bool closed_ = false; + std::atomic closed_{false}; }; /** diff --git a/src/mcp_server.cpp b/src/mcp_server.cpp index 71213d8..20b2601 100644 --- a/src/mcp_server.cpp +++ b/src/mcp_server.cpp @@ -99,11 +99,16 @@ void server::stop() { } } catch (const std::exception& e) { LOG_ERROR("Exception while closing session ", session_id, ": ", e.what()); + } catch (...) { + LOG_ERROR("Unknown exception while closing session ", session_id); } } + // 给线程一些时间来处理关闭事件 + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + // 清理剩余的线程 - { + try { std::lock_guard lock(mutex_); for (auto& [session_id, thread] : sse_threads_) { if (thread && thread->joinable()) { @@ -111,6 +116,8 @@ void server::stop() { thread->detach(); } catch (const std::exception& e) { LOG_ERROR("Exception while detaching session thread ", session_id, ": ", e.what()); + } catch (...) { + LOG_ERROR("Unknown exception while detaching session thread ", session_id); } } } @@ -118,6 +125,10 @@ void server::stop() { // 清空映射 session_dispatchers_.clear(); sse_threads_.clear(); + } catch (const std::exception& e) { + LOG_ERROR("Exception while cleaning up threads: ", e.what()); + } catch (...) { + LOG_ERROR("Unknown exception while cleaning up threads"); } if (http_server_) { @@ -125,7 +136,15 @@ void server::stop() { } if (server_thread_ && server_thread_->joinable()) { - server_thread_->join(); + try { + server_thread_->join(); + } catch (const std::exception& e) { + LOG_ERROR("Exception while joining server thread: ", e.what()); + server_thread_->detach(); + } catch (...) { + LOG_ERROR("Unknown exception while joining server thread"); + server_thread_->detach(); + } } LOG_INFO("MCP server stopped"); @@ -359,6 +378,8 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) { } } catch (const std::exception& e) { LOG_ERROR("Exception while cleaning up session resources: ", session_id, ", ", e.what()); + } catch (...) { + LOG_ERROR("Unknown exception while cleaning up session resources: ", session_id); } }); diff --git a/test/mcp_test.cpp b/test/mcp_test.cpp index 1248601..44c2554 100644 --- a/test/mcp_test.cpp +++ b/test/mcp_test.cpp @@ -183,44 +183,27 @@ protected: }; server_->set_capabilities(server_capabilities); - // 注册初始化方法处理器,检查版本 - server_->register_method("initialize", [this, server_capabilities](const json& params) -> json { - // 检查协议版本 - std::string requested_version = params["protocolVersion"]; - if (requested_version != MCP_VERSION) { - throw mcp_exception(error_code::invalid_params, "Unsupported protocol version"); - } - - return { - {"protocolVersion", MCP_VERSION}, - {"capabilities", server_capabilities}, - {"serverInfo", { - {"name", "TestServer"}, - {"version", "1.0.0"} - }} - }; - }); - // 启动服务器(非阻塞模式) server_->start(false); + + client_ = std::make_unique("localhost", 8081); } void TearDown() override { // 清理测试环境 server_->stop(); server_.reset(); + client_.reset(); } std::unique_ptr server_; + std::unique_ptr client_; }; // 测试支持的版本 TEST_F(VersioningTest, SupportedVersion) { - // 创建使用正确版本的客户端 - client client_correct("localhost", 8081); - // 执行初始化 - bool init_result = client_correct.initialize("TestClient", "1.0.0"); + bool init_result = client_->initialize("TestClient", "1.0.0"); // 验证初始化结果 EXPECT_TRUE(init_result); @@ -228,20 +211,76 @@ TEST_F(VersioningTest, SupportedVersion) { // 测试不支持的版本 TEST_F(VersioningTest, UnsupportedVersion) { - // Use httplib::Client to send a request with an unsupported version - // Note: Open SSE connection first - httplib::Client client("localhost", 8081); - auto sse_response = client.Get("/sse"); - // EXPECT_EQ(sse_response->status, 200); - - std::string msg_endpoint = sse_response->body; - - json req = request::create("initialize", {{"protocolVersion", "0.0.1"}}).to_json(); - auto res = client.Post(msg_endpoint.c_str(), req.dump(), "application/json"); - EXPECT_EQ(res->status, 400); try { - auto mcp_res = response::from_json(json::parse(res->body)); - EXPECT_EQ(mcp_res.error["code"], error_code::invalid_params); + // 使用 httplib::Client 发送不支持的版本请求 + std::unique_ptr sse_client = std::make_unique("localhost", 8081); + std::unique_ptr http_client = std::make_unique("localhost", 8081); + + // 打开 SSE 连接 + std::promise msg_endpoint_promise; + std::promise sse_promise; + std::future msg_endpoint = msg_endpoint_promise.get_future(); + std::future sse_response = sse_promise.get_future(); + + std::atomic sse_running{true}; + bool msg_endpoint_received = false; + bool sse_response_received = false; + + std::thread sse_thread([&]() { + sse_client->Get("/sse", [&](const char* data, size_t len) { + try { + std::string response(data, len); + size_t pos = response.find("data: "); + if (pos != std::string::npos) { + std::string data_content = response.substr(pos + 6); + data_content = data_content.substr(0, data_content.find("\n")); + + if (!msg_endpoint_received && response.find("endpoint") != std::string::npos) { + msg_endpoint_promise.set_value(data_content); + msg_endpoint_received = true; + // GTEST_LOG_(INFO) << "Endpoint received: " << data_content; + } else if (!sse_response_received && response.find("message") != std::string::npos) { + sse_promise.set_value(data_content); + sse_response_received = true; + // GTEST_LOG_(INFO) << "Message received: " << data_content; + } + } + } catch (const std::exception& e) { + GTEST_LOG_(ERROR) << "SSE处理错误: " << e.what(); + } + return sse_running.load(); + }); + }); + + // // 等待消息端点,设置超时 + // auto endpoint_status = msg_endpoint.wait_for(std::chrono::milliseconds(100)); + // EXPECT_EQ(endpoint_status, std::future_status::ready) << "获取消息端点超时"; + + std::string endpoint = msg_endpoint.get(); + EXPECT_FALSE(endpoint.empty()); + // GTEST_LOG_(INFO) << "Using endpoint: " << endpoint; + + // 发送不支持的版本请求 + json req = request::create("initialize", {{"protocolVersion", "0.0.1"}}).to_json(); + auto res = http_client->Post(endpoint.c_str(), req.dump(), "application/json"); + + EXPECT_TRUE(res != nullptr); + EXPECT_EQ(res->status, 202); + + // // 等待SSE响应,设置超时 + // auto sse_status = sse_response.wait_for(std::chrono::milliseconds(100)); + // EXPECT_EQ(sse_status, std::future_status::ready) << "获取SSE响应超时"; + + auto mcp_res = json::parse(sse_response.get()); + EXPECT_EQ(mcp_res["error"]["code"].get(), static_cast(error_code::invalid_params)); + + sse_running.store(false); + if (sse_thread.joinable()) { + sse_thread.join(); + } + + sse_client.reset(); + http_client.reset(); } catch (const mcp_exception& e) { EXPECT_TRUE(false); } @@ -254,11 +293,6 @@ protected: // 设置测试环境 server_ = std::make_unique("localhost", 8082); - // 注册ping方法处理器 - server_->register_method("ping", [](const json& params) -> json { - return json::object(); // 返回空对象 - }); - // 启动服务器(非阻塞模式) server_->start(false); @@ -279,13 +313,84 @@ protected: // 测试Ping请求 TEST_F(PingTest, PingRequest) { - // 发送ping请求 + client_->initialize("TestClient", "1.0.0"); bool ping_result = client_->ping(); - - // 验证ping结果 EXPECT_TRUE(ping_result); } +TEST_F(PingTest, DirectPing) { + try { + // 使用 httplib::Client 发送不支持的版本请求 + std::unique_ptr sse_client = std::make_unique("localhost", 8082); + std::unique_ptr http_client = std::make_unique("localhost", 8082); + + // 打开 SSE 连接 + std::promise msg_endpoint_promise; + std::promise sse_promise; + std::future msg_endpoint = msg_endpoint_promise.get_future(); + std::future sse_response = sse_promise.get_future(); + + std::atomic sse_running{true}; + bool msg_endpoint_received = false; + bool sse_response_received = false; + + std::thread sse_thread([&]() { + sse_client->Get("/sse", [&](const char* data, size_t len) { + try { + std::string response(data, len); + size_t pos = response.find("data: "); + if (pos != std::string::npos) { + std::string data_content = response.substr(pos + 6); + data_content = data_content.substr(0, data_content.find("\n")); + + if (!msg_endpoint_received && response.find("endpoint") != std::string::npos) { + msg_endpoint_promise.set_value(data_content); + msg_endpoint_received = true; + // GTEST_LOG_(INFO) << "Endpoint received: " << data_content; + } else if (!sse_response_received && response.find("message") != std::string::npos) { + sse_promise.set_value(data_content); + sse_response_received = true; + // GTEST_LOG_(INFO) << "Message received: " << data_content; + } + } + } catch (const std::exception& e) { + GTEST_LOG_(ERROR) << "SSE处理错误: " << e.what(); + } + return sse_running.load(); + }); + }); + + // // 等待消息端点,设置超时 + // auto endpoint_status = msg_endpoint.wait_for(std::chrono::milliseconds(100)); + // EXPECT_EQ(endpoint_status, std::future_status::ready) << "获取消息端点超时"; + + std::string endpoint = msg_endpoint.get(); + EXPECT_FALSE(endpoint.empty()); + + // 即使没有建立SSE连接,也可以发送ping请求 + json ping_req = request::create("ping").to_json(); + auto ping_res = http_client->Post(endpoint.c_str(), ping_req.dump(), "application/json"); + EXPECT_TRUE(ping_res != nullptr); + EXPECT_EQ(ping_res->status / 100, 2); + + // auto sse_status = sse_response.wait_for(std::chrono::milliseconds(100)); + // EXPECT_EQ(sse_status, std::future_status::ready) << "获取SSE响应超时"; + + auto mcp_res = json::parse(sse_response.get()); + EXPECT_EQ(mcp_res["result"], json::object()); + + sse_running.store(false); + if (sse_thread.joinable()) { + sse_thread.join(); + } + + sse_client.reset(); + http_client.reset(); + } catch (const mcp_exception& e) { + EXPECT_TRUE(false); + } +} + // 测试工具功能 class ToolsTest : public ::testing::Test { protected: