cpp-mcp/test/test_mcp_tools_extended.cpp

357 lines
12 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/**
* @file test_mcp_tools_extended.cpp
* @brief 测试MCP工具相关功能的扩展测试
*
* 本文件包含对MCP工具模块的扩展单元测试基于规范2024-11-05。
*/
#include "mcp_tool.h"
#include "mcp_server.h"
#include "mcp_client.h"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <string>
#include <thread>
#include <future>
#include <chrono>
// 测试类,用于设置服务器和客户端
class McpToolsExtendedTest : public ::testing::Test {
protected:
void SetUp() override {
// 创建服务器
server = std::make_unique<mcp::server>("localhost", 8098);
server->set_server_info("TestServer", "2024-11-05");
// 设置服务器能力
mcp::json capabilities = {
{"tools", {{"listChanged", true}}}
};
server->set_capabilities(capabilities);
// 注册计算器工具
mcp::tool calculator = mcp::tool_builder("calculator")
.with_description("计算器工具")
.with_string_param("operation", "操作类型 (add, subtract, multiply, divide)")
.with_number_param("a", "第一个操作数")
.with_number_param("b", "第二个操作数")
.build();
server->register_tool(calculator, [](const mcp::json& params) -> mcp::json {
std::string operation = params["operation"];
double a = params["a"];
double b = params["b"];
double result = 0;
if (operation == "add") {
result = a + b;
} else if (operation == "subtract") {
result = a - b;
} else if (operation == "multiply") {
result = a * b;
} else if (operation == "divide") {
if (b == 0) {
return {{"error", "除数不能为零"}};
}
result = a / b;
} else {
return {{"error", "未知操作: " + operation}};
}
return {{"result", result}};
});
// 注册文本处理工具
mcp::tool text_processor = mcp::tool_builder("text_processor")
.with_description("文本处理工具")
.with_string_param("text", "要处理的文本")
.with_string_param("operation", "操作类型 (uppercase, lowercase, reverse)")
.build();
server->register_tool(text_processor, [](const mcp::json& params) -> mcp::json {
std::string text = params["text"];
std::string operation = params["operation"];
std::string result;
if (operation == "uppercase") {
result = text;
std::transform(result.begin(), result.end(), result.begin(), ::toupper);
} else if (operation == "lowercase") {
result = text;
std::transform(result.begin(), result.end(), result.begin(), ::tolower);
} else if (operation == "reverse") {
result = text;
std::reverse(result.begin(), result.end());
} else {
return {{"error", "未知操作: " + operation}};
}
return {{"result", result}};
});
// 注册列表处理工具
mcp::tool list_processor = mcp::tool_builder("list_processor")
.with_description("列表处理工具")
.with_array_param("items", "要处理的项目列表", "string")
.with_string_param("operation", "操作类型 (sort, reverse, count)")
.build();
server->register_tool(list_processor, [](const mcp::json& params) -> mcp::json {
auto items = params["items"].get<std::vector<std::string>>();
std::string operation = params["operation"];
if (operation == "sort") {
std::sort(items.begin(), items.end());
return {{"result", items}};
} else if (operation == "reverse") {
std::reverse(items.begin(), items.end());
return {{"result", items}};
} else if (operation == "count") {
return {{"result", items.size()}};
} else {
return {{"error", "未知操作: " + operation}};
}
});
}
void TearDown() override {
// 停止服务器
if (server && server_thread.joinable()) {
server->stop();
server_thread.join();
}
}
// 启动服务器
void start_server() {
server_thread = std::thread([this]() {
server->start(false);
});
// 等待服务器启动
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
std::unique_ptr<mcp::server> server;
std::thread server_thread;
};
// 测试获取工具列表
TEST_F(McpToolsExtendedTest, GetToolsTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 获取工具列表
auto tools = client.get_tools();
// 验证工具列表
EXPECT_EQ(tools.size(), 3);
// 验证工具名称
std::vector<std::string> tool_names;
for (const auto& tool : tools) {
tool_names.push_back(tool.name);
}
EXPECT_THAT(tool_names, ::testing::UnorderedElementsAre("calculator", "text_processor", "list_processor"));
}
// 测试调用计算器工具
TEST_F(McpToolsExtendedTest, CalculatorToolTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 调用加法
mcp::json add_result = client.call_tool("calculator", {
{"operation", "add"},
{"a", 5},
{"b", 3}
});
EXPECT_EQ(add_result["result"], 8);
// 调用减法
mcp::json subtract_result = client.call_tool("calculator", {
{"operation", "subtract"},
{"a", 10},
{"b", 4}
});
EXPECT_EQ(subtract_result["result"], 6);
// 调用乘法
mcp::json multiply_result = client.call_tool("calculator", {
{"operation", "multiply"},
{"a", 6},
{"b", 7}
});
EXPECT_EQ(multiply_result["result"], 42);
// 调用除法
mcp::json divide_result = client.call_tool("calculator", {
{"operation", "divide"},
{"a", 20},
{"b", 5}
});
EXPECT_EQ(divide_result["result"], 4);
// 测试除以零
mcp::json divide_by_zero = client.call_tool("calculator", {
{"operation", "divide"},
{"a", 10},
{"b", 0}
});
EXPECT_TRUE(divide_by_zero.contains("error"));
}
// 测试调用文本处理工具
TEST_F(McpToolsExtendedTest, TextProcessorToolTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 测试转大写
mcp::json uppercase_result = client.call_tool("text_processor", {
{"text", "Hello World"},
{"operation", "uppercase"}
});
EXPECT_EQ(uppercase_result["result"], "HELLO WORLD");
// 测试转小写
mcp::json lowercase_result = client.call_tool("text_processor", {
{"text", "Hello World"},
{"operation", "lowercase"}
});
EXPECT_EQ(lowercase_result["result"], "hello world");
// 测试反转
mcp::json reverse_result = client.call_tool("text_processor", {
{"text", "Hello World"},
{"operation", "reverse"}
});
EXPECT_EQ(reverse_result["result"], "dlroW olleH");
}
// 测试调用列表处理工具
TEST_F(McpToolsExtendedTest, ListProcessorToolTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 准备测试数据
std::vector<std::string> items = {"banana", "apple", "orange", "grape"};
// 测试排序
mcp::json sort_result = client.call_tool("list_processor", {
{"items", items},
{"operation", "sort"}
});
std::vector<std::string> sorted_items = sort_result["result"];
EXPECT_THAT(sorted_items, ::testing::ElementsAre("apple", "banana", "grape", "orange"));
// 测试反转
mcp::json reverse_result = client.call_tool("list_processor", {
{"items", items},
{"operation", "reverse"}
});
std::vector<std::string> reversed_items = reverse_result["result"];
EXPECT_THAT(reversed_items, ::testing::ElementsAre("grape", "orange", "apple", "banana"));
// 测试计数
mcp::json count_result = client.call_tool("list_processor", {
{"items", items},
{"operation", "count"}
});
EXPECT_EQ(count_result["result"], 4);
}
// 测试工具参数验证
TEST_F(McpToolsExtendedTest, ToolParameterValidationTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 测试缺少必需参数
try {
client.call_tool("calculator", {
{"a", 5}
// 缺少 operation 和 b
});
FAIL() << "应该抛出异常,因为缺少必需参数";
} catch (const mcp::mcp_exception& e) {
EXPECT_EQ(e.code(), mcp::error_code::invalid_params);
}
// 测试参数类型错误
try {
client.call_tool("calculator", {
{"operation", "add"},
{"a", "not_a_number"}, // 应该是数字
{"b", 3}
});
FAIL() << "应该抛出异常,因为参数类型错误";
} catch (const mcp::mcp_exception& e) {
EXPECT_EQ(e.code(), mcp::error_code::invalid_params);
}
}
// 测试工具注册和注销
TEST_F(McpToolsExtendedTest, ToolRegistrationAndUnregistrationTest) {
// 创建工具注册表
mcp::tool_registry& registry = mcp::tool_registry::instance();
// 创建一个测试工具
mcp::tool test_tool = mcp::tool_builder("test_tool")
.with_description("测试工具")
.with_string_param("input", "输入参数")
.build();
// 注册工具
registry.register_tool(test_tool, [](const mcp::json& params) -> mcp::json {
return {{"output", "处理结果: " + params["input"].get<std::string>()}};
});
// 验证工具已注册
auto registered_tool = registry.get_tool("test_tool");
ASSERT_NE(registered_tool, nullptr);
EXPECT_EQ(registered_tool->first.name, "test_tool");
EXPECT_EQ(registered_tool->first.description, "测试工具");
// 调用工具
mcp::json result = registry.call_tool("test_tool", {{"input", "测试输入"}});
EXPECT_EQ(result["output"], "处理结果: 测试输入");
// 注销工具
bool unregistered = registry.unregister_tool("test_tool");
EXPECT_TRUE(unregistered);
// 验证工具已注销
EXPECT_EQ(registry.get_tool("test_tool"), nullptr);
// 尝试调用已注销的工具
try {
registry.call_tool("test_tool", {{"input", "测试输入"}});
FAIL() << "应该抛出异常,因为工具已注销";
} catch (const mcp::mcp_exception& e) {
EXPECT_EQ(e.code(), mcp::error_code::method_not_found);
}
}