cpp-mcp/test/test_mcp_tools_extended.cpp

357 lines
12 KiB
C++
Raw Normal View History

2025-03-12 22:45:17 +08:00
/**
* @file test_mcp_tools_extended.cpp
* @brief MCP
*
* MCP2024-11-05
*/
#include "mcp_tool.h"
#include "mcp_server.h"
#include "mcp_client.h"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <string>
#include <thread>
#include <future>
#include <chrono>
// 测试类,用于设置服务器和客户端
class McpToolsExtendedTest : public ::testing::Test {
protected:
void SetUp() override {
// 创建服务器
server = std::make_unique<mcp::server>("localhost", 8098);
server->set_server_info("TestServer", "2024-11-05");
// 设置服务器能力
mcp::json capabilities = {
{"tools", {{"listChanged", true}}}
};
server->set_capabilities(capabilities);
// 注册计算器工具
mcp::tool calculator = mcp::tool_builder("calculator")
.with_description("计算器工具")
.with_string_param("operation", "操作类型 (add, subtract, multiply, divide)")
.with_number_param("a", "第一个操作数")
.with_number_param("b", "第二个操作数")
.build();
server->register_tool(calculator, [](const mcp::json& params) -> mcp::json {
std::string operation = params["operation"];
double a = params["a"];
double b = params["b"];
double result = 0;
if (operation == "add") {
result = a + b;
} else if (operation == "subtract") {
result = a - b;
} else if (operation == "multiply") {
result = a * b;
} else if (operation == "divide") {
if (b == 0) {
return {{"error", "除数不能为零"}};
}
result = a / b;
} else {
return {{"error", "未知操作: " + operation}};
}
return {{"result", result}};
});
// 注册文本处理工具
mcp::tool text_processor = mcp::tool_builder("text_processor")
.with_description("文本处理工具")
.with_string_param("text", "要处理的文本")
.with_string_param("operation", "操作类型 (uppercase, lowercase, reverse)")
.build();
server->register_tool(text_processor, [](const mcp::json& params) -> mcp::json {
std::string text = params["text"];
std::string operation = params["operation"];
std::string result;
if (operation == "uppercase") {
result = text;
std::transform(result.begin(), result.end(), result.begin(), ::toupper);
} else if (operation == "lowercase") {
result = text;
std::transform(result.begin(), result.end(), result.begin(), ::tolower);
} else if (operation == "reverse") {
result = text;
std::reverse(result.begin(), result.end());
} else {
return {{"error", "未知操作: " + operation}};
}
return {{"result", result}};
});
// 注册列表处理工具
mcp::tool list_processor = mcp::tool_builder("list_processor")
.with_description("列表处理工具")
.with_array_param("items", "要处理的项目列表", "string")
.with_string_param("operation", "操作类型 (sort, reverse, count)")
.build();
server->register_tool(list_processor, [](const mcp::json& params) -> mcp::json {
auto items = params["items"].get<std::vector<std::string>>();
std::string operation = params["operation"];
if (operation == "sort") {
std::sort(items.begin(), items.end());
return {{"result", items}};
} else if (operation == "reverse") {
std::reverse(items.begin(), items.end());
return {{"result", items}};
} else if (operation == "count") {
return {{"result", items.size()}};
} else {
return {{"error", "未知操作: " + operation}};
}
});
}
void TearDown() override {
// 停止服务器
if (server && server_thread.joinable()) {
server->stop();
server_thread.join();
}
}
// 启动服务器
void start_server() {
server_thread = std::thread([this]() {
server->start(false);
});
// 等待服务器启动
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
std::unique_ptr<mcp::server> server;
std::thread server_thread;
};
// 测试获取工具列表
TEST_F(McpToolsExtendedTest, GetToolsTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 获取工具列表
auto tools = client.get_tools();
// 验证工具列表
EXPECT_EQ(tools.size(), 3);
// 验证工具名称
std::vector<std::string> tool_names;
for (const auto& tool : tools) {
tool_names.push_back(tool.name);
}
EXPECT_THAT(tool_names, ::testing::UnorderedElementsAre("calculator", "text_processor", "list_processor"));
}
// 测试调用计算器工具
TEST_F(McpToolsExtendedTest, CalculatorToolTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 调用加法
mcp::json add_result = client.call_tool("calculator", {
{"operation", "add"},
{"a", 5},
{"b", 3}
});
EXPECT_EQ(add_result["result"], 8);
// 调用减法
mcp::json subtract_result = client.call_tool("calculator", {
{"operation", "subtract"},
{"a", 10},
{"b", 4}
});
EXPECT_EQ(subtract_result["result"], 6);
// 调用乘法
mcp::json multiply_result = client.call_tool("calculator", {
{"operation", "multiply"},
{"a", 6},
{"b", 7}
});
EXPECT_EQ(multiply_result["result"], 42);
// 调用除法
mcp::json divide_result = client.call_tool("calculator", {
{"operation", "divide"},
{"a", 20},
{"b", 5}
});
EXPECT_EQ(divide_result["result"], 4);
// 测试除以零
mcp::json divide_by_zero = client.call_tool("calculator", {
{"operation", "divide"},
{"a", 10},
{"b", 0}
});
EXPECT_TRUE(divide_by_zero.contains("error"));
}
// 测试调用文本处理工具
TEST_F(McpToolsExtendedTest, TextProcessorToolTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 测试转大写
mcp::json uppercase_result = client.call_tool("text_processor", {
{"text", "Hello World"},
{"operation", "uppercase"}
});
EXPECT_EQ(uppercase_result["result"], "HELLO WORLD");
// 测试转小写
mcp::json lowercase_result = client.call_tool("text_processor", {
{"text", "Hello World"},
{"operation", "lowercase"}
});
EXPECT_EQ(lowercase_result["result"], "hello world");
// 测试反转
mcp::json reverse_result = client.call_tool("text_processor", {
{"text", "Hello World"},
{"operation", "reverse"}
});
EXPECT_EQ(reverse_result["result"], "dlroW olleH");
}
// 测试调用列表处理工具
TEST_F(McpToolsExtendedTest, ListProcessorToolTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 准备测试数据
std::vector<std::string> items = {"banana", "apple", "orange", "grape"};
// 测试排序
mcp::json sort_result = client.call_tool("list_processor", {
{"items", items},
{"operation", "sort"}
});
std::vector<std::string> sorted_items = sort_result["result"];
EXPECT_THAT(sorted_items, ::testing::ElementsAre("apple", "banana", "grape", "orange"));
// 测试反转
mcp::json reverse_result = client.call_tool("list_processor", {
{"items", items},
{"operation", "reverse"}
});
std::vector<std::string> reversed_items = reverse_result["result"];
EXPECT_THAT(reversed_items, ::testing::ElementsAre("grape", "orange", "apple", "banana"));
// 测试计数
mcp::json count_result = client.call_tool("list_processor", {
{"items", items},
{"operation", "count"}
});
EXPECT_EQ(count_result["result"], 4);
}
// 测试工具参数验证
TEST_F(McpToolsExtendedTest, ToolParameterValidationTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 测试缺少必需参数
try {
client.call_tool("calculator", {
{"a", 5}
// 缺少 operation 和 b
});
FAIL() << "应该抛出异常,因为缺少必需参数";
} catch (const mcp::mcp_exception& e) {
EXPECT_EQ(e.code(), mcp::error_code::invalid_params);
}
// 测试参数类型错误
try {
client.call_tool("calculator", {
{"operation", "add"},
{"a", "not_a_number"}, // 应该是数字
{"b", 3}
});
FAIL() << "应该抛出异常,因为参数类型错误";
} catch (const mcp::mcp_exception& e) {
EXPECT_EQ(e.code(), mcp::error_code::invalid_params);
}
}
// 测试工具注册和注销
TEST_F(McpToolsExtendedTest, ToolRegistrationAndUnregistrationTest) {
// 创建工具注册表
mcp::tool_registry& registry = mcp::tool_registry::instance();
// 创建一个测试工具
mcp::tool test_tool = mcp::tool_builder("test_tool")
.with_description("测试工具")
.with_string_param("input", "输入参数")
.build();
// 注册工具
registry.register_tool(test_tool, [](const mcp::json& params) -> mcp::json {
return {{"output", "处理结果: " + params["input"].get<std::string>()}};
});
// 验证工具已注册
auto registered_tool = registry.get_tool("test_tool");
ASSERT_NE(registered_tool, nullptr);
EXPECT_EQ(registered_tool->first.name, "test_tool");
EXPECT_EQ(registered_tool->first.description, "测试工具");
// 调用工具
mcp::json result = registry.call_tool("test_tool", {{"input", "测试输入"}});
EXPECT_EQ(result["output"], "处理结果: 测试输入");
// 注销工具
bool unregistered = registry.unregister_tool("test_tool");
EXPECT_TRUE(unregistered);
// 验证工具已注销
EXPECT_EQ(registry.get_tool("test_tool"), nullptr);
// 尝试调用已注销的工具
try {
registry.call_tool("test_tool", {{"input", "测试输入"}});
FAIL() << "应该抛出异常,因为工具已注销";
} catch (const mcp::mcp_exception& e) {
EXPECT_EQ(e.code(), mcp::error_code::method_not_found);
}
}