413 lines
13 KiB
C++
413 lines
13 KiB
C++
|
/**
|
||
|
* @file mcp_test.cpp
|
||
|
* @brief 测试MCP框架的基本功能
|
||
|
*
|
||
|
* 本文件包含对MCP框架的消息格式、生命周期、版本控制、ping和工具功能的测试。
|
||
|
*/
|
||
|
|
||
|
#include <gtest/gtest.h>
|
||
|
#include <gmock/gmock.h>
|
||
|
#include "mcp_message.h"
|
||
|
#include "mcp_client.h"
|
||
|
#include "mcp_server.h"
|
||
|
#include "mcp_tool.h"
|
||
|
|
||
|
using namespace mcp;
|
||
|
using json = nlohmann::ordered_json;
|
||
|
|
||
|
// 测试消息格式
|
||
|
class MessageFormatTest : public ::testing::Test {
|
||
|
protected:
|
||
|
void SetUp() override {
|
||
|
// 设置测试环境
|
||
|
}
|
||
|
|
||
|
void TearDown() override {
|
||
|
// 清理测试环境
|
||
|
}
|
||
|
};
|
||
|
|
||
|
// 测试请求消息格式
|
||
|
TEST_F(MessageFormatTest, RequestMessageFormat) {
|
||
|
// 创建一个请求消息
|
||
|
request req = request::create("test_method", {{"key", "value"}});
|
||
|
|
||
|
// 转换为JSON
|
||
|
json req_json = req.to_json();
|
||
|
|
||
|
// 验证JSON格式是否符合规范
|
||
|
EXPECT_EQ(req_json["jsonrpc"], "2.0");
|
||
|
EXPECT_TRUE(req_json.contains("id"));
|
||
|
EXPECT_EQ(req_json["method"], "test_method");
|
||
|
EXPECT_EQ(req_json["params"]["key"], "value");
|
||
|
}
|
||
|
|
||
|
// 测试响应消息格式
|
||
|
TEST_F(MessageFormatTest, ResponseMessageFormat) {
|
||
|
// 创建一个成功响应
|
||
|
response res = response::create_success("test_id", {{"key", "value"}});
|
||
|
|
||
|
// 转换为JSON
|
||
|
json res_json = res.to_json();
|
||
|
|
||
|
// 验证JSON格式是否符合规范
|
||
|
EXPECT_EQ(res_json["jsonrpc"], "2.0");
|
||
|
EXPECT_EQ(res_json["id"], "test_id");
|
||
|
EXPECT_EQ(res_json["result"]["key"], "value");
|
||
|
EXPECT_FALSE(res_json.contains("error"));
|
||
|
}
|
||
|
|
||
|
// 测试错误响应消息格式
|
||
|
TEST_F(MessageFormatTest, ErrorResponseMessageFormat) {
|
||
|
// 创建一个错误响应
|
||
|
response res = response::create_error("test_id", error_code::invalid_params, "Invalid parameters", {{"details", "Missing required field"}});
|
||
|
|
||
|
// 转换为JSON
|
||
|
json res_json = res.to_json();
|
||
|
|
||
|
// 验证JSON格式是否符合规范
|
||
|
EXPECT_EQ(res_json["jsonrpc"], "2.0");
|
||
|
EXPECT_EQ(res_json["id"], "test_id");
|
||
|
EXPECT_FALSE(res_json.contains("result"));
|
||
|
EXPECT_EQ(res_json["error"]["code"], static_cast<int>(error_code::invalid_params));
|
||
|
EXPECT_EQ(res_json["error"]["message"], "Invalid parameters");
|
||
|
EXPECT_EQ(res_json["error"]["data"]["details"], "Missing required field");
|
||
|
}
|
||
|
|
||
|
// 测试通知消息格式
|
||
|
TEST_F(MessageFormatTest, NotificationMessageFormat) {
|
||
|
// 创建一个通知消息
|
||
|
request notification = request::create_notification("test_notification", {{"key", "value"}});
|
||
|
|
||
|
// 转换为JSON
|
||
|
json notification_json = notification.to_json();
|
||
|
|
||
|
// 验证JSON格式是否符合规范
|
||
|
EXPECT_EQ(notification_json["jsonrpc"], "2.0");
|
||
|
EXPECT_FALSE(notification_json.contains("id"));
|
||
|
EXPECT_EQ(notification_json["method"], "notifications/test_notification");
|
||
|
EXPECT_EQ(notification_json["params"]["key"], "value");
|
||
|
|
||
|
// 验证是否为通知消息
|
||
|
EXPECT_TRUE(notification.is_notification());
|
||
|
}
|
||
|
|
||
|
// 测试生命周期
|
||
|
class LifecycleTest : public ::testing::Test {
|
||
|
protected:
|
||
|
void SetUp() override {
|
||
|
// 设置测试环境
|
||
|
server_ = std::make_unique<server>("localhost", 8080);
|
||
|
server_->set_server_info("TestServer", "1.0.0");
|
||
|
|
||
|
// 设置服务器能力
|
||
|
json server_capabilities = {
|
||
|
{"logging", json::object()},
|
||
|
{"prompts", {{"listChanged", true}}},
|
||
|
{"resources", {{"subscribe", true}, {"listChanged", true}}},
|
||
|
{"tools", {{"listChanged", true}}}
|
||
|
};
|
||
|
server_->set_capabilities(server_capabilities);
|
||
|
|
||
|
// 注册初始化方法处理器
|
||
|
server_->register_method("initialize", [this, server_capabilities](const json& params) -> json {
|
||
|
// 验证初始化请求参数
|
||
|
EXPECT_EQ(params["protocolVersion"], MCP_VERSION);
|
||
|
EXPECT_TRUE(params.contains("capabilities"));
|
||
|
EXPECT_TRUE(params.contains("clientInfo"));
|
||
|
|
||
|
// 返回初始化响应
|
||
|
return {
|
||
|
{"protocolVersion", MCP_VERSION},
|
||
|
{"capabilities", server_capabilities},
|
||
|
{"serverInfo", {
|
||
|
{"name", "TestServer"},
|
||
|
{"version", "1.0.0"}
|
||
|
}}
|
||
|
};
|
||
|
});
|
||
|
|
||
|
// 启动服务器(非阻塞模式)
|
||
|
server_->start(false);
|
||
|
|
||
|
// 创建客户端
|
||
|
json client_capabilities = {
|
||
|
{"roots", {{"listChanged", true}}},
|
||
|
{"sampling", json::object()}
|
||
|
};
|
||
|
client_ = std::make_unique<client>("localhost", 8080);
|
||
|
client_->set_capabilities(client_capabilities);
|
||
|
}
|
||
|
|
||
|
void TearDown() override {
|
||
|
// 清理测试环境
|
||
|
server_->stop();
|
||
|
server_.reset();
|
||
|
client_.reset();
|
||
|
}
|
||
|
|
||
|
std::unique_ptr<server> server_;
|
||
|
std::unique_ptr<client> client_;
|
||
|
};
|
||
|
|
||
|
// 测试初始化流程
|
||
|
TEST_F(LifecycleTest, InitializeProcess) {
|
||
|
// 执行初始化
|
||
|
bool init_result = client_->initialize("TestClient", "1.0.0");
|
||
|
|
||
|
// 验证初始化结果
|
||
|
EXPECT_TRUE(init_result);
|
||
|
|
||
|
// 验证服务器能力
|
||
|
json server_capabilities = client_->get_server_capabilities();
|
||
|
EXPECT_TRUE(server_capabilities.contains("logging"));
|
||
|
EXPECT_TRUE(server_capabilities.contains("prompts"));
|
||
|
EXPECT_TRUE(server_capabilities.contains("resources"));
|
||
|
EXPECT_TRUE(server_capabilities.contains("tools"));
|
||
|
}
|
||
|
|
||
|
// 测试版本控制
|
||
|
class VersioningTest : public ::testing::Test {
|
||
|
protected:
|
||
|
void SetUp() override {
|
||
|
// 设置测试环境
|
||
|
server_ = std::make_unique<server>("localhost", 8081);
|
||
|
server_->set_server_info("TestServer", "1.0.0");
|
||
|
|
||
|
// 设置服务器能力
|
||
|
json server_capabilities = {
|
||
|
{"logging", json::object()},
|
||
|
{"prompts", {{"listChanged", true}}},
|
||
|
{"resources", {{"subscribe", true}, {"listChanged", true}}},
|
||
|
{"tools", {{"listChanged", true}}}
|
||
|
};
|
||
|
server_->set_capabilities(server_capabilities);
|
||
|
|
||
|
// 注册初始化方法处理器,检查版本
|
||
|
server_->register_method("initialize", [this, server_capabilities](const json& params) -> json {
|
||
|
// 检查协议版本
|
||
|
std::string requested_version = params["protocolVersion"];
|
||
|
if (requested_version != MCP_VERSION) {
|
||
|
throw mcp_exception(error_code::invalid_params, "Unsupported protocol version");
|
||
|
}
|
||
|
|
||
|
return {
|
||
|
{"protocolVersion", MCP_VERSION},
|
||
|
{"capabilities", server_capabilities},
|
||
|
{"serverInfo", {
|
||
|
{"name", "TestServer"},
|
||
|
{"version", "1.0.0"}
|
||
|
}}
|
||
|
};
|
||
|
});
|
||
|
|
||
|
// 启动服务器(非阻塞模式)
|
||
|
server_->start(false);
|
||
|
}
|
||
|
|
||
|
void TearDown() override {
|
||
|
// 清理测试环境
|
||
|
server_->stop();
|
||
|
server_.reset();
|
||
|
}
|
||
|
|
||
|
std::unique_ptr<server> server_;
|
||
|
};
|
||
|
|
||
|
// 测试支持的版本
|
||
|
TEST_F(VersioningTest, SupportedVersion) {
|
||
|
// 创建使用正确版本的客户端
|
||
|
client client_correct("localhost", 8081);
|
||
|
|
||
|
// 执行初始化
|
||
|
bool init_result = client_correct.initialize("TestClient", "1.0.0");
|
||
|
|
||
|
// 验证初始化结果
|
||
|
EXPECT_TRUE(init_result);
|
||
|
}
|
||
|
|
||
|
// 测试不支持的版本
|
||
|
TEST_F(VersioningTest, UnsupportedVersion) {
|
||
|
// Use httplib::Client to send a request with an unsupported version
|
||
|
// Note: Open SSE connection first
|
||
|
httplib::Client client("localhost", 8081);
|
||
|
auto sse_response = client.Get("/sse");
|
||
|
// EXPECT_EQ(sse_response->status, 200);
|
||
|
|
||
|
std::string msg_endpoint = sse_response->body;
|
||
|
|
||
|
json req = request::create("initialize", {{"protocolVersion", "0.0.1"}}).to_json();
|
||
|
auto res = client.Post(msg_endpoint.c_str(), req.dump(), "application/json");
|
||
|
EXPECT_EQ(res->status, 400);
|
||
|
try {
|
||
|
auto mcp_res = response::from_json(json::parse(res->body));
|
||
|
EXPECT_EQ(mcp_res.error["code"], error_code::invalid_params);
|
||
|
} catch (const mcp_exception& e) {
|
||
|
EXPECT_TRUE(false);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// 测试Ping功能
|
||
|
class PingTest : public ::testing::Test {
|
||
|
protected:
|
||
|
void SetUp() override {
|
||
|
// 设置测试环境
|
||
|
server_ = std::make_unique<server>("localhost", 8082);
|
||
|
|
||
|
// 注册ping方法处理器
|
||
|
server_->register_method("ping", [](const json& params) -> json {
|
||
|
return json::object(); // 返回空对象
|
||
|
});
|
||
|
|
||
|
// 启动服务器(非阻塞模式)
|
||
|
server_->start(false);
|
||
|
|
||
|
// 创建客户端
|
||
|
client_ = std::make_unique<client>("localhost", 8082);
|
||
|
}
|
||
|
|
||
|
void TearDown() override {
|
||
|
// 清理测试环境
|
||
|
server_->stop();
|
||
|
server_.reset();
|
||
|
client_.reset();
|
||
|
}
|
||
|
|
||
|
std::unique_ptr<server> server_;
|
||
|
std::unique_ptr<client> client_;
|
||
|
};
|
||
|
|
||
|
// 测试Ping请求
|
||
|
TEST_F(PingTest, PingRequest) {
|
||
|
// 发送ping请求
|
||
|
bool ping_result = client_->ping();
|
||
|
|
||
|
// 验证ping结果
|
||
|
EXPECT_TRUE(ping_result);
|
||
|
}
|
||
|
|
||
|
// 测试工具功能
|
||
|
class ToolsTest : public ::testing::Test {
|
||
|
protected:
|
||
|
void SetUp() override {
|
||
|
// 设置测试环境
|
||
|
server_ = std::make_unique<server>("localhost", 8083);
|
||
|
|
||
|
// 创建一个测试工具
|
||
|
tool test_tool;
|
||
|
test_tool.name = "get_weather";
|
||
|
test_tool.description = "Get current weather information for a location";
|
||
|
test_tool.parameters_schema = {
|
||
|
{"type", "object"},
|
||
|
{"properties", {
|
||
|
{"location", {
|
||
|
{"type", "string"},
|
||
|
{"description", "City name or zip code"}
|
||
|
}}
|
||
|
}},
|
||
|
{"required", json::array({"location"})}
|
||
|
};
|
||
|
|
||
|
// 注册工具
|
||
|
server_->register_tool(test_tool, [](const json& params) -> json {
|
||
|
// 简单的工具实现
|
||
|
std::string location = params["location"];
|
||
|
return {
|
||
|
{"content", json::array({
|
||
|
{
|
||
|
{"type", "text"},
|
||
|
{"text", "Current weather in " + location + ":\nTemperature: 72°F\nConditions: Partly cloudy"}
|
||
|
}
|
||
|
})},
|
||
|
{"isError", false}
|
||
|
};
|
||
|
});
|
||
|
|
||
|
// 注册工具列表方法
|
||
|
server_->register_method("tools/list", [this](const json& params) -> json {
|
||
|
return {
|
||
|
{"tools", json::array({
|
||
|
{
|
||
|
{"name", "get_weather"},
|
||
|
{"description", "Get current weather information for a location"},
|
||
|
{"inputSchema", {
|
||
|
{"type", "object"},
|
||
|
{"properties", {
|
||
|
{"location", {
|
||
|
{"type", "string"},
|
||
|
{"description", "City name or zip code"}
|
||
|
}}
|
||
|
}},
|
||
|
{"required", json::array({"location"})}
|
||
|
}}
|
||
|
}
|
||
|
})},
|
||
|
{"nextCursor", nullptr}
|
||
|
};
|
||
|
});
|
||
|
|
||
|
// 注册工具调用方法
|
||
|
server_->register_method("tools/call", [this](const json& params) -> json {
|
||
|
// 验证参数
|
||
|
EXPECT_EQ(params["name"], "get_weather");
|
||
|
EXPECT_EQ(params["arguments"]["location"], "New York");
|
||
|
|
||
|
// 返回工具调用结果
|
||
|
return {
|
||
|
{"content", json::array({
|
||
|
{
|
||
|
{"type", "text"},
|
||
|
{"text", "Current weather in New York:\nTemperature: 72°F\nConditions: Partly cloudy"}
|
||
|
}
|
||
|
})},
|
||
|
{"isError", false}
|
||
|
};
|
||
|
});
|
||
|
|
||
|
// 启动服务器(非阻塞模式)
|
||
|
server_->start(false);
|
||
|
|
||
|
// 创建客户端
|
||
|
client_ = std::make_unique<client>("localhost", 8083);
|
||
|
client_->initialize("TestClient", "1.0.0");
|
||
|
}
|
||
|
|
||
|
void TearDown() override {
|
||
|
// 清理测试环境
|
||
|
server_->stop();
|
||
|
server_.reset();
|
||
|
client_.reset();
|
||
|
}
|
||
|
|
||
|
std::unique_ptr<server> server_;
|
||
|
std::unique_ptr<client> client_;
|
||
|
};
|
||
|
|
||
|
// 测试列出工具
|
||
|
TEST_F(ToolsTest, ListTools) {
|
||
|
// 调用列出工具方法
|
||
|
json tools_list = client_->send_request("tools/list").result;
|
||
|
|
||
|
// 验证工具列表
|
||
|
EXPECT_TRUE(tools_list.contains("tools"));
|
||
|
EXPECT_EQ(tools_list["tools"].size(), 1);
|
||
|
EXPECT_EQ(tools_list["tools"][0]["name"], "get_weather");
|
||
|
EXPECT_EQ(tools_list["tools"][0]["description"], "Get current weather information for a location");
|
||
|
}
|
||
|
|
||
|
// 测试调用工具
|
||
|
TEST_F(ToolsTest, CallTool) {
|
||
|
// 调用工具
|
||
|
json tool_result = client_->call_tool("get_weather", {{"location", "New York"}});
|
||
|
|
||
|
// 验证工具调用结果
|
||
|
EXPECT_TRUE(tool_result.contains("content"));
|
||
|
EXPECT_FALSE(tool_result["isError"]);
|
||
|
EXPECT_EQ(tool_result["content"][0]["type"], "text");
|
||
|
EXPECT_EQ(tool_result["content"][0]["text"], "Current weather in New York:\nTemperature: 72°F\nConditions: Partly cloudy");
|
||
|
}
|
||
|
|
||
|
int main(int argc, char **argv) {
|
||
|
::testing::InitGoogleTest(&argc, argv);
|
||
|
return RUN_ALL_TESTS();
|
||
|
}
|