/** * @file mcp_test.cpp * @brief 测试MCP框架的基本功能 * * 本文件包含对MCP框架的消息格式、生命周期、版本控制、ping和工具功能的测试。 */ #include #include #include "mcp_message.h" #include "mcp_client.h" #include "mcp_server.h" #include "mcp_tool.h" using namespace mcp; using json = nlohmann::ordered_json; // 测试消息格式 class MessageFormatTest : public ::testing::Test { protected: void SetUp() override { // 设置测试环境 } void TearDown() override { // 清理测试环境 } }; // 测试请求消息格式 TEST_F(MessageFormatTest, RequestMessageFormat) { // 创建一个请求消息 request req = request::create("test_method", {{"key", "value"}}); // 转换为JSON json req_json = req.to_json(); // 验证JSON格式是否符合规范 EXPECT_EQ(req_json["jsonrpc"], "2.0"); EXPECT_TRUE(req_json.contains("id")); EXPECT_EQ(req_json["method"], "test_method"); EXPECT_EQ(req_json["params"]["key"], "value"); } // 测试响应消息格式 TEST_F(MessageFormatTest, ResponseMessageFormat) { // 创建一个成功响应 response res = response::create_success("test_id", {{"key", "value"}}); // 转换为JSON json res_json = res.to_json(); // 验证JSON格式是否符合规范 EXPECT_EQ(res_json["jsonrpc"], "2.0"); EXPECT_EQ(res_json["id"], "test_id"); EXPECT_EQ(res_json["result"]["key"], "value"); EXPECT_FALSE(res_json.contains("error")); } // 测试错误响应消息格式 TEST_F(MessageFormatTest, ErrorResponseMessageFormat) { // 创建一个错误响应 response res = response::create_error("test_id", error_code::invalid_params, "Invalid parameters", {{"details", "Missing required field"}}); // 转换为JSON json res_json = res.to_json(); // 验证JSON格式是否符合规范 EXPECT_EQ(res_json["jsonrpc"], "2.0"); EXPECT_EQ(res_json["id"], "test_id"); EXPECT_FALSE(res_json.contains("result")); EXPECT_EQ(res_json["error"]["code"], static_cast(error_code::invalid_params)); EXPECT_EQ(res_json["error"]["message"], "Invalid parameters"); EXPECT_EQ(res_json["error"]["data"]["details"], "Missing required field"); } // 测试通知消息格式 TEST_F(MessageFormatTest, NotificationMessageFormat) { // 创建一个通知消息 request notification = request::create_notification("test_notification", {{"key", "value"}}); // 转换为JSON json notification_json = notification.to_json(); // 验证JSON格式是否符合规范 EXPECT_EQ(notification_json["jsonrpc"], "2.0"); EXPECT_FALSE(notification_json.contains("id")); EXPECT_EQ(notification_json["method"], "notifications/test_notification"); EXPECT_EQ(notification_json["params"]["key"], "value"); // 验证是否为通知消息 EXPECT_TRUE(notification.is_notification()); } // 测试生命周期 class LifecycleTest : public ::testing::Test { protected: void SetUp() override { // 设置测试环境 server_ = std::make_unique("localhost", 8080); server_->set_server_info("TestServer", "1.0.0"); // 设置服务器能力 json server_capabilities = { {"logging", json::object()}, {"prompts", {{"listChanged", true}}}, {"resources", {{"subscribe", true}, {"listChanged", true}}}, {"tools", {{"listChanged", true}}} }; server_->set_capabilities(server_capabilities); // 注册初始化方法处理器 server_->register_method("initialize", [this, server_capabilities](const json& params) -> json { // 验证初始化请求参数 EXPECT_EQ(params["protocolVersion"], MCP_VERSION); EXPECT_TRUE(params.contains("capabilities")); EXPECT_TRUE(params.contains("clientInfo")); // 返回初始化响应 return { {"protocolVersion", MCP_VERSION}, {"capabilities", server_capabilities}, {"serverInfo", { {"name", "TestServer"}, {"version", "1.0.0"} }} }; }); // 启动服务器(非阻塞模式) server_->start(false); // 创建客户端 json client_capabilities = { {"roots", {{"listChanged", true}}}, {"sampling", json::object()} }; client_ = std::make_unique("localhost", 8080); client_->set_capabilities(client_capabilities); } void TearDown() override { // 清理测试环境 server_->stop(); server_.reset(); client_.reset(); } std::unique_ptr server_; std::unique_ptr client_; }; // 测试初始化流程 TEST_F(LifecycleTest, InitializeProcess) { // 执行初始化 bool init_result = client_->initialize("TestClient", "1.0.0"); // 验证初始化结果 EXPECT_TRUE(init_result); // 验证服务器能力 json server_capabilities = client_->get_server_capabilities(); EXPECT_TRUE(server_capabilities.contains("logging")); EXPECT_TRUE(server_capabilities.contains("prompts")); EXPECT_TRUE(server_capabilities.contains("resources")); EXPECT_TRUE(server_capabilities.contains("tools")); } // 测试版本控制 class VersioningTest : public ::testing::Test { protected: void SetUp() override { // 设置测试环境 server_ = std::make_unique("localhost", 8081); server_->set_server_info("TestServer", "1.0.0"); // 设置服务器能力 json server_capabilities = { {"logging", json::object()}, {"prompts", {{"listChanged", true}}}, {"resources", {{"subscribe", true}, {"listChanged", true}}}, {"tools", {{"listChanged", true}}} }; server_->set_capabilities(server_capabilities); // 注册初始化方法处理器,检查版本 server_->register_method("initialize", [this, server_capabilities](const json& params) -> json { // 检查协议版本 std::string requested_version = params["protocolVersion"]; if (requested_version != MCP_VERSION) { throw mcp_exception(error_code::invalid_params, "Unsupported protocol version"); } return { {"protocolVersion", MCP_VERSION}, {"capabilities", server_capabilities}, {"serverInfo", { {"name", "TestServer"}, {"version", "1.0.0"} }} }; }); // 启动服务器(非阻塞模式) server_->start(false); } void TearDown() override { // 清理测试环境 server_->stop(); server_.reset(); } std::unique_ptr server_; }; // 测试支持的版本 TEST_F(VersioningTest, SupportedVersion) { // 创建使用正确版本的客户端 client client_correct("localhost", 8081); // 执行初始化 bool init_result = client_correct.initialize("TestClient", "1.0.0"); // 验证初始化结果 EXPECT_TRUE(init_result); } // 测试不支持的版本 TEST_F(VersioningTest, UnsupportedVersion) { // Use httplib::Client to send a request with an unsupported version // Note: Open SSE connection first httplib::Client client("localhost", 8081); auto sse_response = client.Get("/sse"); // EXPECT_EQ(sse_response->status, 200); std::string msg_endpoint = sse_response->body; json req = request::create("initialize", {{"protocolVersion", "0.0.1"}}).to_json(); auto res = client.Post(msg_endpoint.c_str(), req.dump(), "application/json"); EXPECT_EQ(res->status, 400); try { auto mcp_res = response::from_json(json::parse(res->body)); EXPECT_EQ(mcp_res.error["code"], error_code::invalid_params); } catch (const mcp_exception& e) { EXPECT_TRUE(false); } } // 测试Ping功能 class PingTest : public ::testing::Test { protected: void SetUp() override { // 设置测试环境 server_ = std::make_unique("localhost", 8082); // 注册ping方法处理器 server_->register_method("ping", [](const json& params) -> json { return json::object(); // 返回空对象 }); // 启动服务器(非阻塞模式) server_->start(false); // 创建客户端 client_ = std::make_unique("localhost", 8082); } void TearDown() override { // 清理测试环境 server_->stop(); server_.reset(); client_.reset(); } std::unique_ptr server_; std::unique_ptr client_; }; // 测试Ping请求 TEST_F(PingTest, PingRequest) { // 发送ping请求 bool ping_result = client_->ping(); // 验证ping结果 EXPECT_TRUE(ping_result); } // 测试工具功能 class ToolsTest : public ::testing::Test { protected: void SetUp() override { // 设置测试环境 server_ = std::make_unique("localhost", 8083); // 创建一个测试工具 tool test_tool; test_tool.name = "get_weather"; test_tool.description = "Get current weather information for a location"; test_tool.parameters_schema = { {"type", "object"}, {"properties", { {"location", { {"type", "string"}, {"description", "City name or zip code"} }} }}, {"required", json::array({"location"})} }; // 注册工具 server_->register_tool(test_tool, [](const json& params) -> json { // 简单的工具实现 std::string location = params["location"]; return { {"content", json::array({ { {"type", "text"}, {"text", "Current weather in " + location + ":\nTemperature: 72°F\nConditions: Partly cloudy"} } })}, {"isError", false} }; }); // 注册工具列表方法 server_->register_method("tools/list", [this](const json& params) -> json { return { {"tools", json::array({ { {"name", "get_weather"}, {"description", "Get current weather information for a location"}, {"inputSchema", { {"type", "object"}, {"properties", { {"location", { {"type", "string"}, {"description", "City name or zip code"} }} }}, {"required", json::array({"location"})} }} } })}, {"nextCursor", nullptr} }; }); // 注册工具调用方法 server_->register_method("tools/call", [this](const json& params) -> json { // 验证参数 EXPECT_EQ(params["name"], "get_weather"); EXPECT_EQ(params["arguments"]["location"], "New York"); // 返回工具调用结果 return { {"content", json::array({ { {"type", "text"}, {"text", "Current weather in New York:\nTemperature: 72°F\nConditions: Partly cloudy"} } })}, {"isError", false} }; }); // 启动服务器(非阻塞模式) server_->start(false); // 创建客户端 client_ = std::make_unique("localhost", 8083); client_->initialize("TestClient", "1.0.0"); } void TearDown() override { // 清理测试环境 server_->stop(); server_.reset(); client_.reset(); } std::unique_ptr server_; std::unique_ptr client_; }; // 测试列出工具 TEST_F(ToolsTest, ListTools) { // 调用列出工具方法 json tools_list = client_->send_request("tools/list").result; // 验证工具列表 EXPECT_TRUE(tools_list.contains("tools")); EXPECT_EQ(tools_list["tools"].size(), 1); EXPECT_EQ(tools_list["tools"][0]["name"], "get_weather"); EXPECT_EQ(tools_list["tools"][0]["description"], "Get current weather information for a location"); } // 测试调用工具 TEST_F(ToolsTest, CallTool) { // 调用工具 json tool_result = client_->call_tool("get_weather", {{"location", "New York"}}); // 验证工具调用结果 EXPECT_TRUE(tool_result.contains("content")); EXPECT_FALSE(tool_result["isError"]); EXPECT_EQ(tool_result["content"][0]["type"], "text"); EXPECT_EQ(tool_result["content"][0]["text"], "Current weather in New York:\nTemperature: 72°F\nConditions: Partly cloudy"); } int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); }