cpp-mcp/test/mcp_test.cpp

569 lines
20 KiB
C++
Raw Normal View History

2025-03-13 00:04:18 +08:00
/**
* @file mcp_test.cpp
* @brief MCP
*
* MCPping
*/
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include "mcp_message.h"
#include "mcp_client.h"
#include "mcp_server.h"
#include "mcp_tool.h"
using namespace mcp;
using json = nlohmann::ordered_json;
// 测试消息格式
class MessageFormatTest : public ::testing::Test {
protected:
void SetUp() override {
// 设置测试环境
}
void TearDown() override {
// 清理测试环境
}
};
// 测试请求消息格式
TEST_F(MessageFormatTest, RequestMessageFormat) {
// 创建一个请求消息
request req = request::create("test_method", {{"key", "value"}});
// 转换为JSON
json req_json = req.to_json();
// 验证JSON格式是否符合规范
EXPECT_EQ(req_json["jsonrpc"], "2.0");
EXPECT_TRUE(req_json.contains("id"));
EXPECT_EQ(req_json["method"], "test_method");
EXPECT_EQ(req_json["params"]["key"], "value");
}
// 测试响应消息格式
TEST_F(MessageFormatTest, ResponseMessageFormat) {
// 创建一个成功响应
response res = response::create_success("test_id", {{"key", "value"}});
// 转换为JSON
json res_json = res.to_json();
// 验证JSON格式是否符合规范
EXPECT_EQ(res_json["jsonrpc"], "2.0");
EXPECT_EQ(res_json["id"], "test_id");
EXPECT_EQ(res_json["result"]["key"], "value");
EXPECT_FALSE(res_json.contains("error"));
}
// 测试错误响应消息格式
TEST_F(MessageFormatTest, ErrorResponseMessageFormat) {
// 创建一个错误响应
response res = response::create_error("test_id", error_code::invalid_params, "Invalid parameters", {{"details", "Missing required field"}});
// 转换为JSON
json res_json = res.to_json();
// 验证JSON格式是否符合规范
EXPECT_EQ(res_json["jsonrpc"], "2.0");
EXPECT_EQ(res_json["id"], "test_id");
EXPECT_FALSE(res_json.contains("result"));
EXPECT_EQ(res_json["error"]["code"], static_cast<int>(error_code::invalid_params));
EXPECT_EQ(res_json["error"]["message"], "Invalid parameters");
EXPECT_EQ(res_json["error"]["data"]["details"], "Missing required field");
}
// 测试通知消息格式
TEST_F(MessageFormatTest, NotificationMessageFormat) {
// 创建一个通知消息
request notification = request::create_notification("test_notification", {{"key", "value"}});
// 转换为JSON
json notification_json = notification.to_json();
// 验证JSON格式是否符合规范
EXPECT_EQ(notification_json["jsonrpc"], "2.0");
EXPECT_FALSE(notification_json.contains("id"));
EXPECT_EQ(notification_json["method"], "notifications/test_notification");
EXPECT_EQ(notification_json["params"]["key"], "value");
// 验证是否为通知消息
EXPECT_TRUE(notification.is_notification());
}
// 测试生命周期
class LifecycleTest : public ::testing::Test {
protected:
void SetUp() override {
// 设置测试环境
server_ = std::make_unique<server>("localhost", 8080);
server_->set_server_info("TestServer", "1.0.0");
// 设置服务器能力
json server_capabilities = {
{"logging", json::object()},
{"prompts", {{"listChanged", true}}},
{"resources", {{"subscribe", true}, {"listChanged", true}}},
{"tools", {{"listChanged", true}}}
};
server_->set_capabilities(server_capabilities);
// 注册初始化方法处理器
server_->register_method("initialize", [this, server_capabilities](const json& params) -> json {
// 验证初始化请求参数
EXPECT_EQ(params["protocolVersion"], MCP_VERSION);
EXPECT_TRUE(params.contains("capabilities"));
EXPECT_TRUE(params.contains("clientInfo"));
// 返回初始化响应
return {
{"protocolVersion", MCP_VERSION},
{"capabilities", server_capabilities},
{"serverInfo", {
{"name", "TestServer"},
{"version", "1.0.0"}
}}
};
});
// 启动服务器(非阻塞模式)
server_->start(false);
// 创建客户端
json client_capabilities = {
{"roots", {{"listChanged", true}}},
{"sampling", json::object()}
};
client_ = std::make_unique<client>("localhost", 8080);
client_->set_capabilities(client_capabilities);
}
void TearDown() override {
// 清理测试环境
server_->stop();
server_.reset();
client_.reset();
}
std::unique_ptr<server> server_;
std::unique_ptr<client> client_;
};
// 测试初始化流程
TEST_F(LifecycleTest, InitializeProcess) {
// 执行初始化
bool init_result = client_->initialize("TestClient", "1.0.0");
// 验证初始化结果
EXPECT_TRUE(init_result);
// 验证服务器能力
json server_capabilities = client_->get_server_capabilities();
EXPECT_TRUE(server_capabilities.contains("logging"));
EXPECT_TRUE(server_capabilities.contains("prompts"));
EXPECT_TRUE(server_capabilities.contains("resources"));
EXPECT_TRUE(server_capabilities.contains("tools"));
}
// 测试版本控制
class VersioningTest : public ::testing::Test {
protected:
void SetUp() override {
// 设置测试环境
server_ = std::make_unique<server>("localhost", 8081);
server_->set_server_info("TestServer", "1.0.0");
// 设置服务器能力
json server_capabilities = {
{"logging", json::object()},
{"prompts", {{"listChanged", true}}},
{"resources", {{"subscribe", true}, {"listChanged", true}}},
{"tools", {{"listChanged", true}}}
};
server_->set_capabilities(server_capabilities);
// 启动服务器(非阻塞模式)
server_->start(false);
client_ = std::make_unique<client>("localhost", 8081);
2025-03-13 00:04:18 +08:00
}
void TearDown() override {
// 清理测试环境
server_->stop();
server_.reset();
client_.reset();
2025-03-13 00:04:18 +08:00
}
std::unique_ptr<server> server_;
std::unique_ptr<client> client_;
2025-03-13 00:04:18 +08:00
};
// 测试支持的版本
TEST_F(VersioningTest, SupportedVersion) {
// 执行初始化
bool init_result = client_->initialize("TestClient", "1.0.0");
2025-03-13 00:04:18 +08:00
// 验证初始化结果
EXPECT_TRUE(init_result);
}
// 测试不支持的版本
TEST_F(VersioningTest, UnsupportedVersion) {
try {
// 使用 httplib::Client 发送不支持的版本请求
std::unique_ptr<httplib::Client> sse_client = std::make_unique<httplib::Client>("localhost", 8081);
std::unique_ptr<httplib::Client> http_client = std::make_unique<httplib::Client>("localhost", 8081);
// 打开 SSE 连接
std::promise<std::string> msg_endpoint_promise;
std::promise<std::string> sse_promise;
std::future<std::string> msg_endpoint = msg_endpoint_promise.get_future();
std::future<std::string> sse_response = sse_promise.get_future();
2025-03-13 00:04:18 +08:00
std::atomic<bool> sse_running{true};
std::atomic<bool> msg_endpoint_received{false};
std::atomic<bool> sse_response_received{false};
2025-03-13 00:04:18 +08:00
std::thread sse_thread([&]() {
sse_client->Get("/sse", [&](const char* data, size_t len) {
try {
std::string response(data, len);
size_t pos = response.find("data: ");
if (pos != std::string::npos) {
std::string data_content = response.substr(pos + 6);
data_content = data_content.substr(0, data_content.find("\n"));
if (!msg_endpoint_received.load() && response.find("endpoint") != std::string::npos) {
msg_endpoint_received.store(true);
try {
msg_endpoint_promise.set_value(data_content);
} catch (...) {
// 忽略重复设置的异常
}
} else if (!sse_response_received.load() && response.find("message") != std::string::npos) {
sse_response_received.store(true);
try {
sse_promise.set_value(data_content);
} catch (...) {
// 忽略重复设置的异常
}
}
}
} catch (const std::exception& e) {
GTEST_LOG_(ERROR) << "SSE处理错误: " << e.what();
}
return sse_running.load();
});
});
std::string endpoint = msg_endpoint.get();
EXPECT_FALSE(endpoint.empty());
// 发送不支持的版本请求
json req = request::create("initialize", {{"protocolVersion", "0.0.1"}}).to_json();
auto res = http_client->Post(endpoint.c_str(), req.dump(), "application/json");
EXPECT_TRUE(res != nullptr);
EXPECT_EQ(res->status, 202);
auto mcp_res = json::parse(sse_response.get());
EXPECT_EQ(mcp_res["error"]["code"].get<int>(), static_cast<int>(error_code::invalid_params));
// 主动关闭所有连接
sse_running.store(false);
// 尝试中断SSE连接
try {
sse_client->Get("/sse", [](const char*, size_t) { return false; });
} catch (...) {
// 忽略任何异常
}
// 等待线程结束最多1秒
if (sse_thread.joinable()) {
std::thread detacher([](std::thread& t) {
try {
if (t.joinable()) {
t.join();
}
} catch (...) {
if (t.joinable()) {
t.detach();
}
}
}, std::ref(sse_thread));
detacher.detach();
}
// 清理资源
std::this_thread::sleep_for(std::chrono::milliseconds(100));
sse_client.reset();
http_client.reset();
// 添加延迟,确保资源完全释放
std::this_thread::sleep_for(std::chrono::milliseconds(100));
2025-03-14 15:53:58 +08:00
} catch (...) {
2025-03-13 00:04:18 +08:00
EXPECT_TRUE(false);
}
}
// 测试Ping功能
class PingTest : public ::testing::Test {
protected:
void SetUp() override {
// 设置测试环境
server_ = std::make_unique<server>("localhost", 8082);
// 启动服务器(非阻塞模式)
server_->start(false);
// 创建客户端
client_ = std::make_unique<client>("localhost", 8082);
}
void TearDown() override {
// 清理测试环境
server_->stop();
server_.reset();
client_.reset();
}
std::unique_ptr<server> server_;
std::unique_ptr<client> client_;
};
// 测试Ping请求
TEST_F(PingTest, PingRequest) {
2025-03-13 21:46:49 +08:00
std::this_thread::sleep_for(std::chrono::milliseconds(100));
client_->initialize("TestClient", "1.0.0");
2025-03-13 00:04:18 +08:00
bool ping_result = client_->ping();
EXPECT_TRUE(ping_result);
}
TEST_F(PingTest, DirectPing) {
try {
// 使用 httplib::Client 发送不支持的版本请求
std::unique_ptr<httplib::Client> sse_client = std::make_unique<httplib::Client>("localhost", 8082);
std::unique_ptr<httplib::Client> http_client = std::make_unique<httplib::Client>("localhost", 8082);
// 打开 SSE 连接
std::promise<std::string> msg_endpoint_promise;
std::promise<std::string> sse_promise;
std::future<std::string> msg_endpoint = msg_endpoint_promise.get_future();
std::future<std::string> sse_response = sse_promise.get_future();
std::atomic<bool> sse_running{true};
std::atomic<bool> msg_endpoint_received{false};
std::atomic<bool> sse_response_received{false};
std::thread sse_thread([&]() {
sse_client->Get("/sse", [&](const char* data, size_t len) {
try {
std::string response(data, len);
size_t pos = response.find("data: ");
if (pos != std::string::npos) {
std::string data_content = response.substr(pos + 6);
data_content = data_content.substr(0, data_content.find("\n"));
if (!msg_endpoint_received.load() && response.find("endpoint") != std::string::npos) {
msg_endpoint_received.store(true);
try {
msg_endpoint_promise.set_value(data_content);
} catch (...) {
// 忽略重复设置的异常
}
} else if (!sse_response_received.load() && response.find("message") != std::string::npos) {
sse_response_received.store(true);
try {
sse_promise.set_value(data_content);
} catch (...) {
// 忽略重复设置的异常
}
}
}
} catch (const std::exception& e) {
GTEST_LOG_(ERROR) << "SSE处理错误: " << e.what();
}
return sse_running.load();
});
});
std::string endpoint = msg_endpoint.get();
EXPECT_FALSE(endpoint.empty());
// 即使没有建立SSE连接也可以发送ping请求
json ping_req = request::create("ping").to_json();
auto ping_res = http_client->Post(endpoint.c_str(), ping_req.dump(), "application/json");
EXPECT_TRUE(ping_res != nullptr);
EXPECT_EQ(ping_res->status / 100, 2);
auto mcp_res = json::parse(sse_response.get());
EXPECT_EQ(mcp_res["result"], json::object());
// 主动关闭所有连接
sse_running.store(false);
// 尝试中断SSE连接
try {
sse_client->Get("/sse", [](const char*, size_t) { return false; });
} catch (...) {
// 忽略任何异常
}
// 等待线程结束最多1秒
if (sse_thread.joinable()) {
std::thread detacher([](std::thread& t) {
try {
if (t.joinable()) {
t.join();
}
} catch (...) {
if (t.joinable()) {
t.detach();
}
}
}, std::ref(sse_thread));
detacher.detach();
}
// 清理资源
std::this_thread::sleep_for(std::chrono::milliseconds(100));
sse_client.reset();
http_client.reset();
// 添加延迟,确保资源完全释放
std::this_thread::sleep_for(std::chrono::milliseconds(100));
2025-03-14 15:53:58 +08:00
} catch (...) {
EXPECT_TRUE(false);
}
}
2025-03-13 00:04:18 +08:00
// 测试工具功能
class ToolsTest : public ::testing::Test {
protected:
void SetUp() override {
// 设置测试环境
server_ = std::make_unique<server>("localhost", 8083);
// 创建一个测试工具
tool test_tool;
test_tool.name = "get_weather";
test_tool.description = "Get current weather information for a location";
test_tool.parameters_schema = {
{"type", "object"},
{"properties", {
{"location", {
{"type", "string"},
{"description", "City name or zip code"}
}}
}},
{"required", json::array({"location"})}
};
// 注册工具
server_->register_tool(test_tool, [](const json& params) -> json {
// 简单的工具实现
std::string location = params["location"];
return {
{"content", json::array({
{
{"type", "text"},
{"text", "Current weather in " + location + ":\nTemperature: 72°F\nConditions: Partly cloudy"}
}
})},
{"isError", false}
};
});
// 注册工具列表方法
server_->register_method("tools/list", [this](const json& params) -> json {
return {
{"tools", json::array({
{
{"name", "get_weather"},
{"description", "Get current weather information for a location"},
{"inputSchema", {
{"type", "object"},
{"properties", {
{"location", {
{"type", "string"},
{"description", "City name or zip code"}
}}
}},
{"required", json::array({"location"})}
}}
}
})},
{"nextCursor", nullptr}
};
});
// 注册工具调用方法
server_->register_method("tools/call", [this](const json& params) -> json {
// 验证参数
EXPECT_EQ(params["name"], "get_weather");
EXPECT_EQ(params["arguments"]["location"], "New York");
// 返回工具调用结果
return {
{"content", json::array({
{
{"type", "text"},
{"text", "Current weather in New York:\nTemperature: 72°F\nConditions: Partly cloudy"}
}
})},
{"isError", false}
};
});
// 启动服务器(非阻塞模式)
server_->start(false);
// 创建客户端
client_ = std::make_unique<client>("localhost", 8083);
client_->initialize("TestClient", "1.0.0");
}
void TearDown() override {
// 清理测试环境
server_->stop();
server_.reset();
client_.reset();
}
std::unique_ptr<server> server_;
std::unique_ptr<client> client_;
};
// 测试列出工具
TEST_F(ToolsTest, ListTools) {
2025-03-13 21:46:49 +08:00
std::this_thread::sleep_for(std::chrono::milliseconds(100));
2025-03-13 00:04:18 +08:00
// 调用列出工具方法
json tools_list = client_->send_request("tools/list").result;
// 验证工具列表
EXPECT_TRUE(tools_list.contains("tools"));
EXPECT_EQ(tools_list["tools"].size(), 1);
EXPECT_EQ(tools_list["tools"][0]["name"], "get_weather");
EXPECT_EQ(tools_list["tools"][0]["description"], "Get current weather information for a location");
}
// 测试调用工具
TEST_F(ToolsTest, CallTool) {
// 调用工具
json tool_result = client_->call_tool("get_weather", {{"location", "New York"}});
// 验证工具调用结果
EXPECT_TRUE(tool_result.contains("content"));
EXPECT_FALSE(tool_result["isError"]);
EXPECT_EQ(tool_result["content"][0]["type"], "text");
EXPECT_EQ(tool_result["content"][0]["text"], "Current weather in New York:\nTemperature: 72°F\nConditions: Partly cloudy");
}
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}