cpp-mcp/test/mcp_test.cpp

518 lines
18 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/**
* @file mcp_test.cpp
* @brief 测试MCP框架的基本功能
*
* 本文件包含对MCP框架的消息格式、生命周期、版本控制、ping和工具功能的测试。
*/
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include "mcp_message.h"
#include "mcp_client.h"
#include "mcp_server.h"
#include "mcp_tool.h"
using namespace mcp;
using json = nlohmann::ordered_json;
// 测试消息格式
class MessageFormatTest : public ::testing::Test {
protected:
void SetUp() override {
// 设置测试环境
}
void TearDown() override {
// 清理测试环境
}
};
// 测试请求消息格式
TEST_F(MessageFormatTest, RequestMessageFormat) {
// 创建一个请求消息
request req = request::create("test_method", {{"key", "value"}});
// 转换为JSON
json req_json = req.to_json();
// 验证JSON格式是否符合规范
EXPECT_EQ(req_json["jsonrpc"], "2.0");
EXPECT_TRUE(req_json.contains("id"));
EXPECT_EQ(req_json["method"], "test_method");
EXPECT_EQ(req_json["params"]["key"], "value");
}
// 测试响应消息格式
TEST_F(MessageFormatTest, ResponseMessageFormat) {
// 创建一个成功响应
response res = response::create_success("test_id", {{"key", "value"}});
// 转换为JSON
json res_json = res.to_json();
// 验证JSON格式是否符合规范
EXPECT_EQ(res_json["jsonrpc"], "2.0");
EXPECT_EQ(res_json["id"], "test_id");
EXPECT_EQ(res_json["result"]["key"], "value");
EXPECT_FALSE(res_json.contains("error"));
}
// 测试错误响应消息格式
TEST_F(MessageFormatTest, ErrorResponseMessageFormat) {
// 创建一个错误响应
response res = response::create_error("test_id", error_code::invalid_params, "Invalid parameters", {{"details", "Missing required field"}});
// 转换为JSON
json res_json = res.to_json();
// 验证JSON格式是否符合规范
EXPECT_EQ(res_json["jsonrpc"], "2.0");
EXPECT_EQ(res_json["id"], "test_id");
EXPECT_FALSE(res_json.contains("result"));
EXPECT_EQ(res_json["error"]["code"], static_cast<int>(error_code::invalid_params));
EXPECT_EQ(res_json["error"]["message"], "Invalid parameters");
EXPECT_EQ(res_json["error"]["data"]["details"], "Missing required field");
}
// 测试通知消息格式
TEST_F(MessageFormatTest, NotificationMessageFormat) {
// 创建一个通知消息
request notification = request::create_notification("test_notification", {{"key", "value"}});
// 转换为JSON
json notification_json = notification.to_json();
// 验证JSON格式是否符合规范
EXPECT_EQ(notification_json["jsonrpc"], "2.0");
EXPECT_FALSE(notification_json.contains("id"));
EXPECT_EQ(notification_json["method"], "notifications/test_notification");
EXPECT_EQ(notification_json["params"]["key"], "value");
// 验证是否为通知消息
EXPECT_TRUE(notification.is_notification());
}
// 测试生命周期
class LifecycleTest : public ::testing::Test {
protected:
void SetUp() override {
// 设置测试环境
server_ = std::make_unique<server>("localhost", 8080);
server_->set_server_info("TestServer", "1.0.0");
// 设置服务器能力
json server_capabilities = {
{"logging", json::object()},
{"prompts", {{"listChanged", true}}},
{"resources", {{"subscribe", true}, {"listChanged", true}}},
{"tools", {{"listChanged", true}}}
};
server_->set_capabilities(server_capabilities);
// 注册初始化方法处理器
server_->register_method("initialize", [this, server_capabilities](const json& params) -> json {
// 验证初始化请求参数
EXPECT_EQ(params["protocolVersion"], MCP_VERSION);
EXPECT_TRUE(params.contains("capabilities"));
EXPECT_TRUE(params.contains("clientInfo"));
// 返回初始化响应
return {
{"protocolVersion", MCP_VERSION},
{"capabilities", server_capabilities},
{"serverInfo", {
{"name", "TestServer"},
{"version", "1.0.0"}
}}
};
});
// 启动服务器(非阻塞模式)
server_->start(false);
// 创建客户端
json client_capabilities = {
{"roots", {{"listChanged", true}}},
{"sampling", json::object()}
};
client_ = std::make_unique<client>("localhost", 8080);
client_->set_capabilities(client_capabilities);
}
void TearDown() override {
// 清理测试环境
server_->stop();
server_.reset();
client_.reset();
}
std::unique_ptr<server> server_;
std::unique_ptr<client> client_;
};
// 测试初始化流程
TEST_F(LifecycleTest, InitializeProcess) {
// 执行初始化
bool init_result = client_->initialize("TestClient", "1.0.0");
// 验证初始化结果
EXPECT_TRUE(init_result);
// 验证服务器能力
json server_capabilities = client_->get_server_capabilities();
EXPECT_TRUE(server_capabilities.contains("logging"));
EXPECT_TRUE(server_capabilities.contains("prompts"));
EXPECT_TRUE(server_capabilities.contains("resources"));
EXPECT_TRUE(server_capabilities.contains("tools"));
}
// 测试版本控制
class VersioningTest : public ::testing::Test {
protected:
void SetUp() override {
// 设置测试环境
server_ = std::make_unique<server>("localhost", 8081);
server_->set_server_info("TestServer", "1.0.0");
// 设置服务器能力
json server_capabilities = {
{"logging", json::object()},
{"prompts", {{"listChanged", true}}},
{"resources", {{"subscribe", true}, {"listChanged", true}}},
{"tools", {{"listChanged", true}}}
};
server_->set_capabilities(server_capabilities);
// 启动服务器(非阻塞模式)
server_->start(false);
client_ = std::make_unique<client>("localhost", 8081);
}
void TearDown() override {
// 清理测试环境
server_->stop();
server_.reset();
client_.reset();
}
std::unique_ptr<server> server_;
std::unique_ptr<client> client_;
};
// 测试支持的版本
TEST_F(VersioningTest, SupportedVersion) {
// 执行初始化
bool init_result = client_->initialize("TestClient", "1.0.0");
// 验证初始化结果
EXPECT_TRUE(init_result);
}
// 测试不支持的版本
TEST_F(VersioningTest, UnsupportedVersion) {
try {
// 使用 httplib::Client 发送不支持的版本请求
std::unique_ptr<httplib::Client> sse_client = std::make_unique<httplib::Client>("localhost", 8081);
std::unique_ptr<httplib::Client> http_client = std::make_unique<httplib::Client>("localhost", 8081);
// 打开 SSE 连接
std::promise<std::string> msg_endpoint_promise;
std::promise<std::string> sse_promise;
std::future<std::string> msg_endpoint = msg_endpoint_promise.get_future();
std::future<std::string> sse_response = sse_promise.get_future();
std::atomic<bool> sse_running{true};
bool msg_endpoint_received = false;
bool sse_response_received = false;
std::thread sse_thread([&]() {
sse_client->Get("/sse", [&](const char* data, size_t len) {
try {
std::string response(data, len);
size_t pos = response.find("data: ");
if (pos != std::string::npos) {
std::string data_content = response.substr(pos + 6);
data_content = data_content.substr(0, data_content.find("\n"));
if (!msg_endpoint_received && response.find("endpoint") != std::string::npos) {
msg_endpoint_promise.set_value(data_content);
msg_endpoint_received = true;
// GTEST_LOG_(INFO) << "Endpoint received: " << data_content;
} else if (!sse_response_received && response.find("message") != std::string::npos) {
sse_promise.set_value(data_content);
sse_response_received = true;
// GTEST_LOG_(INFO) << "Message received: " << data_content;
}
}
} catch (const std::exception& e) {
GTEST_LOG_(ERROR) << "SSE处理错误: " << e.what();
}
return sse_running.load();
});
});
// // 等待消息端点,设置超时
// auto endpoint_status = msg_endpoint.wait_for(std::chrono::milliseconds(100));
// EXPECT_EQ(endpoint_status, std::future_status::ready) << "获取消息端点超时";
std::string endpoint = msg_endpoint.get();
EXPECT_FALSE(endpoint.empty());
// GTEST_LOG_(INFO) << "Using endpoint: " << endpoint;
// 发送不支持的版本请求
json req = request::create("initialize", {{"protocolVersion", "0.0.1"}}).to_json();
auto res = http_client->Post(endpoint.c_str(), req.dump(), "application/json");
EXPECT_TRUE(res != nullptr);
EXPECT_EQ(res->status, 202);
// // 等待SSE响应设置超时
// auto sse_status = sse_response.wait_for(std::chrono::milliseconds(100));
// EXPECT_EQ(sse_status, std::future_status::ready) << "获取SSE响应超时";
auto mcp_res = json::parse(sse_response.get());
EXPECT_EQ(mcp_res["error"]["code"].get<int>(), static_cast<int>(error_code::invalid_params));
sse_running.store(false);
if (sse_thread.joinable()) {
sse_thread.join();
}
sse_client.reset();
http_client.reset();
} catch (const mcp_exception& e) {
EXPECT_TRUE(false);
}
}
// 测试Ping功能
class PingTest : public ::testing::Test {
protected:
void SetUp() override {
// 设置测试环境
server_ = std::make_unique<server>("localhost", 8082);
// 启动服务器(非阻塞模式)
server_->start(false);
// 创建客户端
client_ = std::make_unique<client>("localhost", 8082);
}
void TearDown() override {
// 清理测试环境
server_->stop();
server_.reset();
client_.reset();
}
std::unique_ptr<server> server_;
std::unique_ptr<client> client_;
};
// 测试Ping请求
TEST_F(PingTest, PingRequest) {
client_->initialize("TestClient", "1.0.0");
bool ping_result = client_->ping();
EXPECT_TRUE(ping_result);
}
TEST_F(PingTest, DirectPing) {
try {
// 使用 httplib::Client 发送不支持的版本请求
std::unique_ptr<httplib::Client> sse_client = std::make_unique<httplib::Client>("localhost", 8082);
std::unique_ptr<httplib::Client> http_client = std::make_unique<httplib::Client>("localhost", 8082);
// 打开 SSE 连接
std::promise<std::string> msg_endpoint_promise;
std::promise<std::string> sse_promise;
std::future<std::string> msg_endpoint = msg_endpoint_promise.get_future();
std::future<std::string> sse_response = sse_promise.get_future();
std::atomic<bool> sse_running{true};
bool msg_endpoint_received = false;
bool sse_response_received = false;
std::thread sse_thread([&]() {
sse_client->Get("/sse", [&](const char* data, size_t len) {
try {
std::string response(data, len);
size_t pos = response.find("data: ");
if (pos != std::string::npos) {
std::string data_content = response.substr(pos + 6);
data_content = data_content.substr(0, data_content.find("\n"));
if (!msg_endpoint_received && response.find("endpoint") != std::string::npos) {
msg_endpoint_promise.set_value(data_content);
msg_endpoint_received = true;
// GTEST_LOG_(INFO) << "Endpoint received: " << data_content;
} else if (!sse_response_received && response.find("message") != std::string::npos) {
sse_promise.set_value(data_content);
sse_response_received = true;
// GTEST_LOG_(INFO) << "Message received: " << data_content;
}
}
} catch (const std::exception& e) {
GTEST_LOG_(ERROR) << "SSE处理错误: " << e.what();
}
return sse_running.load();
});
});
// // 等待消息端点,设置超时
// auto endpoint_status = msg_endpoint.wait_for(std::chrono::milliseconds(100));
// EXPECT_EQ(endpoint_status, std::future_status::ready) << "获取消息端点超时";
std::string endpoint = msg_endpoint.get();
EXPECT_FALSE(endpoint.empty());
// 即使没有建立SSE连接也可以发送ping请求
json ping_req = request::create("ping").to_json();
auto ping_res = http_client->Post(endpoint.c_str(), ping_req.dump(), "application/json");
EXPECT_TRUE(ping_res != nullptr);
EXPECT_EQ(ping_res->status / 100, 2);
// auto sse_status = sse_response.wait_for(std::chrono::milliseconds(100));
// EXPECT_EQ(sse_status, std::future_status::ready) << "获取SSE响应超时";
auto mcp_res = json::parse(sse_response.get());
EXPECT_EQ(mcp_res["result"], json::object());
sse_running.store(false);
if (sse_thread.joinable()) {
sse_thread.join();
}
sse_client.reset();
http_client.reset();
} catch (const mcp_exception& e) {
EXPECT_TRUE(false);
}
}
// 测试工具功能
class ToolsTest : public ::testing::Test {
protected:
void SetUp() override {
// 设置测试环境
server_ = std::make_unique<server>("localhost", 8083);
// 创建一个测试工具
tool test_tool;
test_tool.name = "get_weather";
test_tool.description = "Get current weather information for a location";
test_tool.parameters_schema = {
{"type", "object"},
{"properties", {
{"location", {
{"type", "string"},
{"description", "City name or zip code"}
}}
}},
{"required", json::array({"location"})}
};
// 注册工具
server_->register_tool(test_tool, [](const json& params) -> json {
// 简单的工具实现
std::string location = params["location"];
return {
{"content", json::array({
{
{"type", "text"},
{"text", "Current weather in " + location + ":\nTemperature: 72°F\nConditions: Partly cloudy"}
}
})},
{"isError", false}
};
});
// 注册工具列表方法
server_->register_method("tools/list", [this](const json& params) -> json {
return {
{"tools", json::array({
{
{"name", "get_weather"},
{"description", "Get current weather information for a location"},
{"inputSchema", {
{"type", "object"},
{"properties", {
{"location", {
{"type", "string"},
{"description", "City name or zip code"}
}}
}},
{"required", json::array({"location"})}
}}
}
})},
{"nextCursor", nullptr}
};
});
// 注册工具调用方法
server_->register_method("tools/call", [this](const json& params) -> json {
// 验证参数
EXPECT_EQ(params["name"], "get_weather");
EXPECT_EQ(params["arguments"]["location"], "New York");
// 返回工具调用结果
return {
{"content", json::array({
{
{"type", "text"},
{"text", "Current weather in New York:\nTemperature: 72°F\nConditions: Partly cloudy"}
}
})},
{"isError", false}
};
});
// 启动服务器(非阻塞模式)
server_->start(false);
// 创建客户端
client_ = std::make_unique<client>("localhost", 8083);
client_->initialize("TestClient", "1.0.0");
}
void TearDown() override {
// 清理测试环境
server_->stop();
server_.reset();
client_.reset();
}
std::unique_ptr<server> server_;
std::unique_ptr<client> client_;
};
// 测试列出工具
TEST_F(ToolsTest, ListTools) {
// 调用列出工具方法
json tools_list = client_->send_request("tools/list").result;
// 验证工具列表
EXPECT_TRUE(tools_list.contains("tools"));
EXPECT_EQ(tools_list["tools"].size(), 1);
EXPECT_EQ(tools_list["tools"][0]["name"], "get_weather");
EXPECT_EQ(tools_list["tools"][0]["description"], "Get current weather information for a location");
}
// 测试调用工具
TEST_F(ToolsTest, CallTool) {
// 调用工具
json tool_result = client_->call_tool("get_weather", {{"location", "New York"}});
// 验证工具调用结果
EXPECT_TRUE(tool_result.contains("content"));
EXPECT_FALSE(tool_result["isError"]);
EXPECT_EQ(tool_result["content"][0]["type"], "text");
EXPECT_EQ(tool_result["content"][0]["text"], "Current weather in New York:\nTemperature: 72°F\nConditions: Partly cloudy");
}
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}