Compare commits
3 Commits
3cd85f0b3b
...
1b54308d81
Author | SHA1 | Date |
---|---|---|
|
1b54308d81 | |
|
0814f1a6a7 | |
|
61264dfb49 |
|
@ -33,9 +33,5 @@ if(MCP_BUILD_TESTS)
|
|||
enable_testing()
|
||||
add_subdirectory(test)
|
||||
|
||||
# Add custom test target
|
||||
add_custom_target(run_tests
|
||||
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure
|
||||
COMMENT "Running MCP tests..."
|
||||
)
|
||||
# 注意:run_tests目标已在test/CMakeLists.txt中定义,此处不再重复定义
|
||||
endif()
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
|
||||
int main() {
|
||||
// Create a client
|
||||
mcp::client client("localhost", 8089);
|
||||
mcp::client client("localhost", 8888);
|
||||
|
||||
// Set capabilites
|
||||
mcp::json capabilities = {
|
||||
|
@ -53,6 +53,8 @@ int main() {
|
|||
std::cout << "- " << tool.name << ": " << tool.description << std::endl;
|
||||
}
|
||||
|
||||
// Get available resources
|
||||
|
||||
// Call the get_time tool
|
||||
std::cout << "\nCalling get_time tool..." << std::endl;
|
||||
mcp::json time_result = client.call_tool("get_time");
|
||||
|
@ -76,17 +78,6 @@ int main() {
|
|||
};
|
||||
mcp::json calc_result = client.call_tool("calculator", calc_params);
|
||||
std::cout << "10 + 5 = " << calc_result["content"][0]["text"].get<std::string>() << std::endl;
|
||||
|
||||
// Not implemented yet
|
||||
// // Access a resource
|
||||
// std::cout << "\nAccessing API resource..." << std::endl;
|
||||
// mcp::json api_params = {
|
||||
// {"endpoint", "hello"},
|
||||
// {"name", "MCP Client"}
|
||||
// };
|
||||
// mcp::json api_result = client.access_resource("/api", api_params);
|
||||
// std::cout << "API response: " << api_result["contents"][0]["text"].get<std::string>() << std::endl;
|
||||
|
||||
} catch (const mcp::mcp_exception& e) {
|
||||
std::cerr << "MCP error: " << e.what() << " (code: " << static_cast<int>(e.code()) << ")" << std::endl;
|
||||
return 1;
|
||||
|
|
|
@ -122,12 +122,13 @@ int main() {
|
|||
std::filesystem::create_directories("./files");
|
||||
|
||||
// Create and configure server
|
||||
mcp::server server("localhost", 8089);
|
||||
mcp::server server("localhost", 8888);
|
||||
server.set_server_info("ExampleServer", "1.0.0");
|
||||
|
||||
// Set server capabilities
|
||||
mcp::json capabilities = {
|
||||
{"tools", {{"listChanged", true}}}
|
||||
{"tools", {{"listChanged", true}}},
|
||||
{"resources", {{"subscribe", false}, {"listChanged", true}}}
|
||||
};
|
||||
server.set_capabilities(capabilities);
|
||||
|
||||
|
@ -154,18 +155,12 @@ int main() {
|
|||
server.register_tool(echo_tool, echo_handler);
|
||||
server.register_tool(calc_tool, calculator_handler);
|
||||
|
||||
// Not implemented yet
|
||||
// // Register resources
|
||||
// auto file_resource = std::make_shared<mcp::file_resource>("./files");
|
||||
// server.register_resource("/files", file_resource);
|
||||
|
||||
// auto api_resource = std::make_shared<mcp::api_resource>("API", "Custom API endpoints");
|
||||
// api_resource->register_handler("hello", hello_handler, "Say hello");
|
||||
|
||||
// server.register_resource("/api", api_resource);
|
||||
// Register resources
|
||||
auto file_resource = std::make_shared<mcp::file_resource>("./Makefile");
|
||||
server.register_resource("file://./Makefile", file_resource);
|
||||
|
||||
// Start server
|
||||
std::cout << "Starting MCP server at localhost:8089..." << std::endl;
|
||||
std::cout << "Starting MCP server at localhost:8888..." << std::endl;
|
||||
std::cout << "Press Ctrl+C to stop the server" << std::endl;
|
||||
server.start(true); // Blocking mode
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include "mcp_message.h"
|
||||
#include "mcp_tool.h"
|
||||
#include "mcp_logger.h"
|
||||
|
||||
// Include the HTTP library
|
||||
#include "httplib.h"
|
||||
|
@ -23,6 +24,7 @@
|
|||
#include <functional>
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <future>
|
||||
|
||||
namespace mcp {
|
||||
|
||||
|
@ -173,43 +175,76 @@ public:
|
|||
bool check_server_accessible();
|
||||
|
||||
private:
|
||||
std::string base_url_;
|
||||
// 初始化HTTP客户端
|
||||
void init_client(const std::string& host, int port);
|
||||
void init_client(const std::string& base_url);
|
||||
|
||||
// 打开SSE连接
|
||||
void open_sse_connection();
|
||||
|
||||
// 解析SSE数据
|
||||
bool parse_sse_data(const char* data, size_t length);
|
||||
|
||||
// 关闭SSE连接
|
||||
void close_sse_connection();
|
||||
|
||||
// 发送JSON-RPC请求
|
||||
json send_jsonrpc(const request& req);
|
||||
|
||||
// 服务器主机和端口
|
||||
std::string host_;
|
||||
int port_;
|
||||
std::string sse_endpoint_;
|
||||
int port_ = 8080;
|
||||
|
||||
// 或者使用基础URL
|
||||
std::string base_url_;
|
||||
|
||||
// SSE端点
|
||||
std::string sse_endpoint_ = "/sse";
|
||||
|
||||
// 消息端点
|
||||
std::string msg_endpoint_;
|
||||
std::string auth_token_;
|
||||
int timeout_seconds_ = 30;
|
||||
json capabilities_;
|
||||
|
||||
std::map<std::string, std::string> default_headers_;
|
||||
json server_capabilities_;
|
||||
|
||||
|
||||
// HTTP client
|
||||
// HTTP客户端
|
||||
std::unique_ptr<httplib::Client> http_client_;
|
||||
|
||||
// Mutex for thread safety
|
||||
// SSE专用HTTP客户端
|
||||
std::unique_ptr<httplib::Client> sse_client_;
|
||||
|
||||
// SSE线程
|
||||
std::unique_ptr<std::thread> sse_thread_;
|
||||
|
||||
// SSE运行状态
|
||||
std::atomic<bool> sse_running_{false};
|
||||
|
||||
// 认证令牌
|
||||
std::string auth_token_;
|
||||
|
||||
// 默认请求头
|
||||
std::map<std::string, std::string> default_headers_;
|
||||
|
||||
// 超时设置(秒)
|
||||
int timeout_seconds_ = 30;
|
||||
|
||||
// 客户端能力
|
||||
json capabilities_;
|
||||
|
||||
// 服务器能力
|
||||
json server_capabilities_;
|
||||
|
||||
// 互斥锁
|
||||
mutable std::mutex mutex_;
|
||||
|
||||
// 条件变量,用于等待消息端点设置
|
||||
std::condition_variable endpoint_cv_;
|
||||
|
||||
// SSE connection
|
||||
std::unique_ptr<std::thread> sse_thread_;
|
||||
// 请求ID到Promise的映射,用于异步等待响应
|
||||
std::map<json, std::promise<json>> pending_requests_;
|
||||
|
||||
// SSE连接状态
|
||||
std::atomic<bool> sse_running_{false};
|
||||
// 响应处理互斥锁
|
||||
std::mutex response_mutex_;
|
||||
|
||||
// Initialize the client
|
||||
void init_client(const std::string& host, int port);
|
||||
void init_client(const std::string& base_url);
|
||||
void open_sse_connection();
|
||||
void close_sse_connection();
|
||||
bool parse_sse_data(const char* data, size_t length);
|
||||
|
||||
// Send a JSON-RPC request and get the response
|
||||
json send_jsonrpc(const request& req);
|
||||
// 响应条件变量
|
||||
std::condition_variable response_cv_;
|
||||
};
|
||||
|
||||
} // namespace mcp
|
||||
|
|
|
@ -83,19 +83,19 @@ private:
|
|||
|
||||
ss << std::put_time(now_tm, "%Y-%m-%d %H:%M:%S") << " ";
|
||||
|
||||
// 添加日志级别
|
||||
// 添加日志级别和颜色
|
||||
switch (level) {
|
||||
case log_level::debug:
|
||||
ss << "[DEBUG] ";
|
||||
ss << "\033[36m[DEBUG]\033[0m "; // 青色
|
||||
break;
|
||||
case log_level::info:
|
||||
ss << "[INFO] ";
|
||||
ss << "\033[32m[INFO]\033[0m "; // 绿色
|
||||
break;
|
||||
case log_level::warning:
|
||||
ss << "[WARNING] ";
|
||||
ss << "\033[33m[WARNING]\033[0m "; // 黄色
|
||||
break;
|
||||
case log_level::error:
|
||||
ss << "[ERROR] ";
|
||||
ss << "\033[31m[ERROR]\033[0m "; // 红色
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
@ -79,13 +79,11 @@ public:
|
|||
try {
|
||||
bool write_result = sink->write(message_copy.data(), message_copy.size());
|
||||
if (!write_result) {
|
||||
std::cerr << "写入事件数据失败: 客户端可能已关闭连接" << std::endl;
|
||||
close();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "写入事件数据失败: " << e.what() << std::endl;
|
||||
close();
|
||||
return false;
|
||||
}
|
||||
|
@ -232,13 +230,6 @@ public:
|
|||
*/
|
||||
void send_request(const std::string& session_id, const std::string& method, const json& params = json::object());
|
||||
|
||||
/**
|
||||
* @brief 打印服务器状态
|
||||
*
|
||||
* 打印当前服务器的状态,包括活跃的会话、注册的方法等
|
||||
*/
|
||||
void print_status() const;
|
||||
|
||||
private:
|
||||
std::string host_;
|
||||
int port_;
|
||||
|
|
|
@ -31,7 +31,7 @@ struct tool {
|
|||
return {
|
||||
{"name", name},
|
||||
{"description", description},
|
||||
{"parameters", parameters_schema}
|
||||
{"inputSchema", parameters_schema} // You may need 'parameters' instead of 'inputSchema' for OAI format
|
||||
};
|
||||
}
|
||||
};
|
||||
|
|
|
@ -23,42 +23,41 @@ client::client(const std::string& base_url, const json& capabilities, const std:
|
|||
}
|
||||
|
||||
client::~client() {
|
||||
// 关闭SSE连接
|
||||
close_sse_connection();
|
||||
|
||||
// httplib::Client将自动销毁
|
||||
}
|
||||
|
||||
void client::init_client(const std::string& host, int port) {
|
||||
// Create the HTTP client
|
||||
http_client_ = std::make_unique<httplib::Client>(host.c_str(), port);
|
||||
sse_client_ = std::make_unique<httplib::Client>(host.c_str(), port);
|
||||
|
||||
// Set timeout
|
||||
http_client_->set_connection_timeout(timeout_seconds_, 0);
|
||||
http_client_->set_read_timeout(timeout_seconds_, 0);
|
||||
http_client_->set_write_timeout(timeout_seconds_, 0);
|
||||
|
||||
sse_client_->set_connection_timeout(timeout_seconds_ * 2, 0);
|
||||
sse_client_->set_write_timeout(timeout_seconds_, 0);
|
||||
}
|
||||
|
||||
void client::init_client(const std::string& base_url) {
|
||||
// Create the HTTP client
|
||||
http_client_ = std::make_unique<httplib::Client>(base_url.c_str());
|
||||
sse_client_ = std::make_unique<httplib::Client>(base_url.c_str());
|
||||
|
||||
// Set timeout
|
||||
http_client_->set_connection_timeout(timeout_seconds_, 0);
|
||||
http_client_->set_read_timeout(timeout_seconds_, 0);
|
||||
http_client_->set_write_timeout(timeout_seconds_, 0);
|
||||
|
||||
sse_client_->set_connection_timeout(timeout_seconds_ * 2, 0);
|
||||
sse_client_->set_read_timeout(0, 0);
|
||||
sse_client_->set_write_timeout(timeout_seconds_, 0);
|
||||
}
|
||||
|
||||
bool client::initialize(const std::string& client_name, const std::string& client_version) {
|
||||
std::cerr << "开始初始化MCP客户端..." << std::endl;
|
||||
LOG_INFO("Initializing MCP client...");
|
||||
|
||||
// 检查服务器是否可访问
|
||||
if (!check_server_accessible()) {
|
||||
std::cerr << "服务器不可访问,初始化失败" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Create initialization request
|
||||
request req = request::create("initialize", {
|
||||
{"protocolVersion", MCP_VERSION},
|
||||
{"capabilities", capabilities_},
|
||||
|
@ -69,88 +68,63 @@ bool client::initialize(const std::string& client_name, const std::string& clien
|
|||
});
|
||||
|
||||
try {
|
||||
// 打开SSE连接
|
||||
std::cerr << "正在打开SSE连接..." << std::endl;
|
||||
LOG_INFO("Opening SSE connection...");
|
||||
open_sse_connection();
|
||||
|
||||
// 等待SSE连接建立并获取消息端点
|
||||
// 使用条件变量和超时机制
|
||||
const auto timeout = std::chrono::milliseconds(5000); // 5秒超时
|
||||
const auto timeout = std::chrono::milliseconds(5000);
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
||||
// 检查初始状态
|
||||
if (!msg_endpoint_.empty()) {
|
||||
std::cerr << "消息端点已经设置: " << msg_endpoint_ << std::endl;
|
||||
} else {
|
||||
std::cerr << "等待条件变量..." << std::endl;
|
||||
}
|
||||
|
||||
bool success = endpoint_cv_.wait_for(lock, timeout, [this]() {
|
||||
if (!sse_running_) {
|
||||
std::cerr << "SSE连接已关闭,停止等待" << std::endl;
|
||||
LOG_WARNING("SSE connection closed, stopping wait");
|
||||
return true;
|
||||
}
|
||||
if (!msg_endpoint_.empty()) {
|
||||
std::cerr << "消息端点已设置,停止等待" << std::endl;
|
||||
LOG_INFO("Message endpoint set, stopping wait");
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
// 检查等待结果
|
||||
if (!success) {
|
||||
std::cerr << "条件变量等待超时" << std::endl;
|
||||
LOG_WARNING("Condition variable wait timed out");
|
||||
}
|
||||
|
||||
// 如果SSE连接已关闭或等待超时,抛出异常
|
||||
if (!sse_running_) {
|
||||
throw std::runtime_error("SSE连接已关闭,未能获取消息端点");
|
||||
throw std::runtime_error("SSE connection closed, failed to get message endpoint");
|
||||
}
|
||||
|
||||
if (msg_endpoint_.empty()) {
|
||||
throw std::runtime_error("等待SSE连接超时,未能获取消息端点");
|
||||
throw std::runtime_error("Timeout waiting for SSE connection, failed to get message endpoint");
|
||||
}
|
||||
|
||||
std::cerr << "成功获取消息端点: " << msg_endpoint_ << std::endl;
|
||||
LOG_INFO("Successfully got message endpoint: ", msg_endpoint_);
|
||||
}
|
||||
|
||||
// 发送初始化请求
|
||||
json result = send_jsonrpc(req);
|
||||
|
||||
// 存储服务器能力
|
||||
server_capabilities_ = result["capabilities"];
|
||||
|
||||
// 发送已初始化通知
|
||||
request notification = request::create_notification("initialized");
|
||||
send_jsonrpc(notification);
|
||||
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
// 初始化失败,关闭SSE连接
|
||||
std::cerr << "初始化失败: " << e.what() << std::endl;
|
||||
LOG_ERROR("Initialization failed: ", e.what());
|
||||
close_sse_connection();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool client::ping() {
|
||||
// Create ping request
|
||||
request req = request::create("ping", {});
|
||||
|
||||
try {
|
||||
// Send the request
|
||||
json result = send_jsonrpc(req);
|
||||
|
||||
// The receiver MUST respond promptly with an empty response
|
||||
if (result.empty()) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return result.empty();
|
||||
} catch (const std::exception& e) {
|
||||
// Ping failed
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -158,24 +132,34 @@ bool client::ping() {
|
|||
void client::set_auth_token(const std::string& token) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auth_token_ = token;
|
||||
|
||||
// Add to default headers
|
||||
set_header("Authorization", "Bearer " + auth_token_);
|
||||
}
|
||||
|
||||
void client::set_header(const std::string& key, const std::string& value) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
default_headers_[key] = value;
|
||||
|
||||
if (http_client_) {
|
||||
http_client_->set_default_headers({{key, value}});
|
||||
}
|
||||
if (sse_client_) {
|
||||
sse_client_->set_default_headers({{key, value}});
|
||||
}
|
||||
}
|
||||
|
||||
void client::set_timeout(int timeout_seconds) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
timeout_seconds_ = timeout_seconds;
|
||||
|
||||
// Update the client's timeout
|
||||
http_client_->set_connection_timeout(timeout_seconds_, 0);
|
||||
http_client_->set_read_timeout(timeout_seconds_, 0);
|
||||
http_client_->set_write_timeout(timeout_seconds_, 0);
|
||||
if (http_client_) {
|
||||
http_client_->set_connection_timeout(timeout_seconds_, 0);
|
||||
http_client_->set_write_timeout(timeout_seconds_, 0);
|
||||
}
|
||||
|
||||
if (sse_client_) {
|
||||
sse_client_->set_connection_timeout(timeout_seconds_ * 2, 0);
|
||||
sse_client_->set_write_timeout(timeout_seconds_, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void client::set_capabilities(const json& capabilities) {
|
||||
|
@ -212,21 +196,28 @@ json client::call_tool(const std::string& tool_name, const json& arguments) {
|
|||
}
|
||||
|
||||
std::vector<tool> client::get_tools() {
|
||||
json tools_json = send_request("tools/list", {}).result;
|
||||
json response_json = send_request("tools/list", {}).result;
|
||||
std::vector<tool> tools;
|
||||
|
||||
if (tools_json.is_array()) {
|
||||
for (const auto& tool_json : tools_json) {
|
||||
tool t;
|
||||
t.name = tool_json["name"];
|
||||
t.description = tool_json["description"];
|
||||
json tools_json;
|
||||
if (response_json.contains("tools") && response_json["tools"].is_array()) {
|
||||
tools_json = response_json["tools"];
|
||||
} else if (response_json.is_array()) {
|
||||
tools_json = response_json;
|
||||
} else {
|
||||
return tools;
|
||||
}
|
||||
|
||||
if (tool_json.contains("inputSchema")) {
|
||||
t.parameters_schema = tool_json["inputSchema"];
|
||||
}
|
||||
for (const auto& tool_json : tools_json) {
|
||||
tool t;
|
||||
t.name = tool_json["name"];
|
||||
t.description = tool_json["description"];
|
||||
|
||||
tools.push_back(t);
|
||||
if (tool_json.contains("inputSchema")) {
|
||||
t.parameters_schema = tool_json["inputSchema"];
|
||||
}
|
||||
|
||||
tools.push_back(t);
|
||||
}
|
||||
|
||||
return tools;
|
||||
|
@ -261,321 +252,334 @@ json client::list_resource_templates() {
|
|||
}
|
||||
|
||||
void client::open_sse_connection() {
|
||||
// 设置SSE连接状态为运行中
|
||||
sse_running_ = true;
|
||||
|
||||
// 清空消息端点
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
msg_endpoint_.clear();
|
||||
|
||||
// 通知等待的线程(虽然消息端点为空,但可以让等待的线程检查sse_running_状态)
|
||||
endpoint_cv_.notify_all();
|
||||
}
|
||||
|
||||
// 打印连接信息(调试用)
|
||||
std::string connection_info;
|
||||
if (!base_url_.empty()) {
|
||||
connection_info = "Base URL: " + base_url_ + ", SSE Endpoint: " + sse_endpoint_;
|
||||
} else {
|
||||
connection_info = "Host: " + host_ + ", Port: " + std::to_string(port_) + ", SSE Endpoint: " + sse_endpoint_;
|
||||
}
|
||||
std::cerr << "尝试建立SSE连接: " << connection_info << std::endl;
|
||||
LOG_INFO("Attempting to establish SSE connection: ", connection_info);
|
||||
|
||||
// 创建并启动SSE线程
|
||||
sse_thread_ = std::make_unique<std::thread>([this]() {
|
||||
int retry_count = 0;
|
||||
const int max_retries = 5;
|
||||
const int retry_delay_base = 1000; // 毫秒
|
||||
const int retry_delay_base = 1000;
|
||||
|
||||
while (sse_running_) {
|
||||
try {
|
||||
// 尝试建立SSE连接
|
||||
std::cerr << "SSE线程: 尝试连接到 " << sse_endpoint_ << std::endl;
|
||||
LOG_INFO("SSE thread: Attempting to connect to ", sse_endpoint_);
|
||||
|
||||
auto res = http_client_->Get(sse_endpoint_.c_str(),
|
||||
auto res = sse_client_->Get(sse_endpoint_.c_str(),
|
||||
[this](const char *data, size_t data_length) {
|
||||
// 解析SSE数据
|
||||
std::cerr << "SSE线程: 收到数据,长度: " << data_length << std::endl;
|
||||
if (!parse_sse_data(data, data_length)) {
|
||||
std::cerr << "SSE线程: 解析数据失败" << std::endl;
|
||||
return false; // 解析失败,关闭连接
|
||||
LOG_ERROR("SSE thread: Failed to parse data");
|
||||
return false;
|
||||
}
|
||||
|
||||
// 检查是否应该关闭连接
|
||||
bool should_continue = sse_running_.load();
|
||||
if (!should_continue) {
|
||||
std::cerr << "SSE线程: sse_running_为false,关闭连接" << std::endl;
|
||||
LOG_INFO("SSE thread: sse_running_ is false, closing connection");
|
||||
}
|
||||
return should_continue; // 如果sse_running_为false,关闭连接
|
||||
return should_continue;
|
||||
});
|
||||
|
||||
// 检查连接是否成功
|
||||
if (!res) {
|
||||
std::string error_msg = "SSE连接失败: ";
|
||||
error_msg += "错误代码: " + std::to_string(static_cast<int>(res.error()));
|
||||
|
||||
// 添加更详细的错误信息
|
||||
switch (res.error()) {
|
||||
case httplib::Error::Connection:
|
||||
error_msg += " (连接错误,服务器可能未运行或无法访问)";
|
||||
break;
|
||||
case httplib::Error::Read:
|
||||
error_msg += " (读取错误,服务器可能关闭了连接或响应格式不正确)";
|
||||
break;
|
||||
case httplib::Error::Write:
|
||||
error_msg += " (写入错误)";
|
||||
break;
|
||||
case httplib::Error::ConnectionTimeout:
|
||||
error_msg += " (连接超时)";
|
||||
break;
|
||||
case httplib::Error::Canceled:
|
||||
error_msg += " (请求被取消)";
|
||||
break;
|
||||
default:
|
||||
error_msg += " (未知错误)";
|
||||
break;
|
||||
}
|
||||
|
||||
std::string error_msg = "SSE connection failed: ";
|
||||
error_msg += httplib::to_string(res.error());
|
||||
throw std::runtime_error(error_msg);
|
||||
}
|
||||
|
||||
// 连接成功后重置重试计数
|
||||
retry_count = 0;
|
||||
std::cerr << "SSE线程: 连接成功" << std::endl;
|
||||
LOG_INFO("SSE thread: Connection successful");
|
||||
} catch (const std::exception& e) {
|
||||
// 记录错误
|
||||
std::cerr << "SSE连接错误: " << e.what() << std::endl;
|
||||
LOG_ERROR("SSE connection error: ", e.what());
|
||||
|
||||
// 如果已达到最大重试次数,停止尝试
|
||||
if (++retry_count > max_retries) {
|
||||
std::cerr << "达到最大重试次数,停止SSE连接尝试" << std::endl;
|
||||
if (!sse_running_) {
|
||||
LOG_INFO("SSE connection actively closed, no retry needed");
|
||||
break;
|
||||
}
|
||||
|
||||
// 指数退避重试
|
||||
int delay = retry_delay_base * (1 << (retry_count - 1)); // 2^(retry_count-1) * base_delay
|
||||
std::cerr << "将在 " << delay << " 毫秒后重试 (尝试 " << retry_count << "/" << max_retries << ")" << std::endl;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(delay));
|
||||
if (++retry_count > max_retries) {
|
||||
LOG_ERROR("Maximum retry count reached, stopping SSE connection attempts");
|
||||
break;
|
||||
}
|
||||
|
||||
int delay = retry_delay_base * (1 << (retry_count - 1));
|
||||
LOG_INFO("Will retry in ", delay, " ms (attempt ", retry_count, "/", max_retries, ")");
|
||||
|
||||
const int check_interval = 100;
|
||||
for (int waited = 0; waited < delay && sse_running_; waited += check_interval) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(check_interval));
|
||||
}
|
||||
|
||||
if (!sse_running_) {
|
||||
LOG_INFO("SSE connection actively closed during retry wait, stopping retry");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cerr << "SSE线程: 退出" << std::endl;
|
||||
LOG_INFO("SSE thread: Exiting");
|
||||
});
|
||||
}
|
||||
|
||||
// 新增方法:解析SSE数据
|
||||
bool client::parse_sse_data(const char* data, size_t length) {
|
||||
try {
|
||||
std::string sse_data(data, length);
|
||||
|
||||
// 查找"data:"标记
|
||||
std::string event_type = "message";
|
||||
auto event_pos = sse_data.find("event: ");
|
||||
if (event_pos != std::string::npos) {
|
||||
auto event_end = sse_data.find("\n", event_pos);
|
||||
if (event_end != std::string::npos) {
|
||||
event_type = sse_data.substr(event_pos + 7, event_end - (event_pos + 7));
|
||||
if (!event_type.empty() && event_type.back() == '\r') {
|
||||
event_type.pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto data_pos = sse_data.find("data: ");
|
||||
if (data_pos == std::string::npos) {
|
||||
return true; // 不是数据事件,可能是注释或心跳,继续保持连接
|
||||
}
|
||||
|
||||
// 查找数据行结束位置
|
||||
auto newline_pos = sse_data.find("\n", data_pos);
|
||||
if (newline_pos == std::string::npos) {
|
||||
newline_pos = sse_data.length(); // 如果没有换行符,使用整个字符串
|
||||
}
|
||||
|
||||
// 提取数据内容
|
||||
std::string data_content = sse_data.substr(data_pos + 6, newline_pos - (data_pos + 6));
|
||||
|
||||
// 检查是否是心跳事件
|
||||
if (sse_data.find("event: heartbeat") != std::string::npos) {
|
||||
// 心跳事件,不需要处理数据
|
||||
return true;
|
||||
}
|
||||
|
||||
// 更新消息端点
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
msg_endpoint_ = data_content;
|
||||
|
||||
// 通知等待的线程
|
||||
endpoint_cv_.notify_all();
|
||||
auto newline_pos = sse_data.find("\n", data_pos);
|
||||
if (newline_pos == std::string::npos) {
|
||||
newline_pos = sse_data.length();
|
||||
}
|
||||
|
||||
return true;
|
||||
std::string data_content = sse_data.substr(data_pos + 6, newline_pos - (data_pos + 6));
|
||||
|
||||
if (event_type == "heartbeat") {
|
||||
return true;
|
||||
} else if (event_type == "endpoint") {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
msg_endpoint_ = data_content;
|
||||
endpoint_cv_.notify_all();
|
||||
return true;
|
||||
} else if (event_type == "message") {
|
||||
try {
|
||||
json response = json::parse(data_content);
|
||||
|
||||
if (response.contains("jsonrpc") && response.contains("id") && !response["id"].is_null()) {
|
||||
json id = response["id"];
|
||||
|
||||
std::lock_guard<std::mutex> lock(response_mutex_);
|
||||
auto it = pending_requests_.find(id);
|
||||
if (it != pending_requests_.end()) {
|
||||
if (response.contains("result")) {
|
||||
it->second.set_value(response["result"]);
|
||||
} else if (response.contains("error")) {
|
||||
json error_result = {
|
||||
{"isError", true},
|
||||
{"error", response["error"]}
|
||||
};
|
||||
it->second.set_value(error_result);
|
||||
} else {
|
||||
it->second.set_value(json::object());
|
||||
}
|
||||
|
||||
pending_requests_.erase(it);
|
||||
} else {
|
||||
LOG_WARNING("Received response for unknown request ID: ", id);
|
||||
}
|
||||
} else {
|
||||
LOG_WARNING("Received invalid JSON-RPC response: ", response.dump());
|
||||
}
|
||||
} catch (const json::exception& e) {
|
||||
LOG_ERROR("Failed to parse JSON-RPC response: ", e.what());
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
LOG_WARNING("Received unknown event type: ", event_type);
|
||||
return true;
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "解析SSE数据错误: " << e.what() << std::endl;
|
||||
LOG_ERROR("Error parsing SSE data: ", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// 新增方法:关闭SSE连接
|
||||
void client::close_sse_connection() {
|
||||
// 设置标志,这将导致SSE回调函数返回false,从而关闭连接
|
||||
if (!sse_running_) {
|
||||
LOG_INFO("SSE connection already closed");
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_INFO("Actively closing SSE connection (normal exit flow)...");
|
||||
|
||||
sse_running_ = false;
|
||||
|
||||
// 给一些时间让回调函数返回false并关闭连接
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||
|
||||
// 等待SSE线程结束
|
||||
if (sse_thread_ && sse_thread_->joinable()) {
|
||||
// 设置一个合理的超时时间,例如5秒
|
||||
auto timeout = std::chrono::seconds(5);
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
|
||||
// 尝试在超时前等待线程结束
|
||||
LOG_INFO("Waiting for SSE thread to end...");
|
||||
|
||||
while (sse_thread_->joinable() &&
|
||||
std::chrono::steady_clock::now() - start < timeout) {
|
||||
try {
|
||||
// 尝试立即加入线程
|
||||
sse_thread_->join();
|
||||
break; // 如果成功加入,跳出循环
|
||||
LOG_INFO("SSE thread successfully ended");
|
||||
break;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "等待SSE线程时出错: " << e.what() << std::endl;
|
||||
// 短暂休眠,避免CPU占用过高
|
||||
LOG_ERROR("Error waiting for SSE thread: ", e.what());
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
}
|
||||
|
||||
// 如果线程仍然没有结束,记录警告并分离线程
|
||||
if (sse_thread_->joinable()) {
|
||||
std::cerr << "警告: SSE线程未能在超时时间内结束,分离线程" << std::endl;
|
||||
LOG_WARNING("SSE thread did not end within timeout, detaching thread");
|
||||
sse_thread_->detach();
|
||||
}
|
||||
}
|
||||
|
||||
// 清空消息端点
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
msg_endpoint_.clear();
|
||||
|
||||
// 通知等待的线程(虽然消息端点为空,但可以让等待的线程检查sse_running_状态)
|
||||
endpoint_cv_.notify_all();
|
||||
}
|
||||
|
||||
std::cerr << "SSE连接已关闭" << std::endl;
|
||||
LOG_INFO("SSE connection successfully closed (normal exit flow)");
|
||||
}
|
||||
|
||||
json client::send_jsonrpc(const request& req) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
// 检查消息端点是否已设置
|
||||
if (msg_endpoint_.empty()) {
|
||||
throw mcp_exception(error_code::internal_error, "消息端点未设置,SSE连接可能未建立");
|
||||
throw mcp_exception(error_code::internal_error, "Message endpoint not set, SSE connection may not be established");
|
||||
}
|
||||
|
||||
// 打印请求信息(调试用)
|
||||
std::cerr << "发送JSON-RPC请求: 方法=" << req.method << ", 端点=" << msg_endpoint_ << std::endl;
|
||||
|
||||
// Convert request to JSON
|
||||
json req_json = req.to_json();
|
||||
std::string req_body = req_json.dump();
|
||||
|
||||
// Prepare headers
|
||||
httplib::Headers headers;
|
||||
headers.emplace("Content-Type", "application/json");
|
||||
|
||||
// Add default headers
|
||||
for (const auto& [key, value] : default_headers_) {
|
||||
headers.emplace(key, value);
|
||||
}
|
||||
|
||||
// Send the request
|
||||
auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json");
|
||||
if (req.is_notification()) {
|
||||
auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json");
|
||||
|
||||
if (!result) {
|
||||
// Error occurred
|
||||
auto err = result.error();
|
||||
std::string error_msg;
|
||||
|
||||
switch (err) {
|
||||
case httplib::Error::Connection:
|
||||
error_msg = "连接错误,服务器可能未运行或无法访问";
|
||||
break;
|
||||
case httplib::Error::Read:
|
||||
error_msg = "读取错误,服务器可能关闭了连接或响应格式不正确";
|
||||
break;
|
||||
case httplib::Error::Write:
|
||||
error_msg = "写入错误";
|
||||
break;
|
||||
case httplib::Error::ConnectionTimeout:
|
||||
error_msg = "连接超时";
|
||||
break;
|
||||
default:
|
||||
error_msg = "HTTP客户端错误: " + std::to_string(static_cast<int>(err));
|
||||
break;
|
||||
if (!result) {
|
||||
auto err = result.error();
|
||||
std::string error_msg = httplib::to_string(err);
|
||||
LOG_ERROR("JSON-RPC request failed: ", error_msg);
|
||||
throw mcp_exception(error_code::internal_error, error_msg);
|
||||
}
|
||||
|
||||
std::cerr << "JSON-RPC请求失败: " << error_msg << std::endl;
|
||||
throw mcp_exception(error_code::internal_error, error_msg);
|
||||
}
|
||||
|
||||
// Check if it's a notification (no response expected)
|
||||
if (req.is_notification()) {
|
||||
return json::object();
|
||||
}
|
||||
|
||||
// Parse response
|
||||
try {
|
||||
json res_json = json::parse(result->body);
|
||||
std::promise<json> response_promise;
|
||||
std::future<json> response_future = response_promise.get_future();
|
||||
|
||||
// 打印响应信息(调试用)
|
||||
std::cerr << "收到JSON-RPC响应: " << res_json.dump() << std::endl;
|
||||
{
|
||||
std::lock_guard<std::mutex> response_lock(response_mutex_);
|
||||
pending_requests_[req.id] = std::move(response_promise);
|
||||
}
|
||||
|
||||
// Check for error
|
||||
if (res_json.contains("error")) {
|
||||
int code = res_json["error"]["code"];
|
||||
std::string message = res_json["error"]["message"];
|
||||
auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json");
|
||||
|
||||
throw mcp_exception(static_cast<error_code>(code), message);
|
||||
if (!result) {
|
||||
auto err = result.error();
|
||||
std::string error_msg = httplib::to_string(err);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> response_lock(response_mutex_);
|
||||
pending_requests_.erase(req.id);
|
||||
}
|
||||
|
||||
// Return result
|
||||
if (res_json.contains("result")) {
|
||||
return res_json["result"];
|
||||
LOG_ERROR("JSON-RPC request failed: ", error_msg);
|
||||
throw mcp_exception(error_code::internal_error, error_msg);
|
||||
}
|
||||
|
||||
if (result->status != 202) {
|
||||
try {
|
||||
json res_json = json::parse(result->body);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> response_lock(response_mutex_);
|
||||
pending_requests_.erase(req.id);
|
||||
}
|
||||
|
||||
if (res_json.contains("error")) {
|
||||
int code = res_json["error"]["code"];
|
||||
std::string message = res_json["error"]["message"];
|
||||
|
||||
throw mcp_exception(static_cast<error_code>(code), message);
|
||||
}
|
||||
|
||||
if (res_json.contains("result")) {
|
||||
return res_json["result"];
|
||||
} else {
|
||||
return json::object();
|
||||
}
|
||||
} catch (const json::exception& e) {
|
||||
{
|
||||
std::lock_guard<std::mutex> response_lock(response_mutex_);
|
||||
pending_requests_.erase(req.id);
|
||||
}
|
||||
|
||||
throw mcp_exception(error_code::parse_error,
|
||||
"Failed to parse JSON-RPC response: " + std::string(e.what()));
|
||||
}
|
||||
} else {
|
||||
const auto timeout = std::chrono::seconds(timeout_seconds_);
|
||||
|
||||
auto status = response_future.wait_for(timeout);
|
||||
|
||||
if (status == std::future_status::ready) {
|
||||
json response = response_future.get();
|
||||
|
||||
if (response.contains("isError") && response["isError"].get<bool>()) {
|
||||
int code = response["error"]["code"];
|
||||
std::string message = response["error"]["message"];
|
||||
|
||||
throw mcp_exception(static_cast<error_code>(code), message);
|
||||
}
|
||||
|
||||
return response;
|
||||
} else {
|
||||
return json::object();
|
||||
{
|
||||
std::lock_guard<std::mutex> response_lock(response_mutex_);
|
||||
pending_requests_.erase(req.id);
|
||||
}
|
||||
|
||||
throw mcp_exception(error_code::internal_error, "Timeout waiting for SSE response");
|
||||
}
|
||||
} catch (const json::exception& e) {
|
||||
throw mcp_exception(error_code::parse_error,
|
||||
"Failed to parse JSON-RPC response: " + std::string(e.what()));
|
||||
}
|
||||
}
|
||||
|
||||
bool client::check_server_accessible() {
|
||||
std::cerr << "检查服务器是否可访问..." << std::endl;
|
||||
LOG_INFO("Checking if server is accessible...");
|
||||
|
||||
try {
|
||||
// 尝试发送一个简单的GET请求到服务器
|
||||
auto res = http_client_->Get("/");
|
||||
|
||||
if (res) {
|
||||
std::cerr << "服务器可访问,状态码: " << res->status << std::endl;
|
||||
LOG_INFO("Server is accessible, status code: ", res->status);
|
||||
return true;
|
||||
} else {
|
||||
std::string error_msg = "服务器不可访问,错误代码: " + std::to_string(static_cast<int>(res.error()));
|
||||
|
||||
// 添加更详细的错误信息
|
||||
switch (res.error()) {
|
||||
case httplib::Error::Connection:
|
||||
error_msg += " (连接错误,服务器可能未运行或无法访问)";
|
||||
break;
|
||||
case httplib::Error::Read:
|
||||
error_msg += " (读取错误)";
|
||||
break;
|
||||
case httplib::Error::Write:
|
||||
error_msg += " (写入错误)";
|
||||
break;
|
||||
case httplib::Error::ConnectionTimeout:
|
||||
error_msg += " (连接超时)";
|
||||
break;
|
||||
default:
|
||||
error_msg += " (未知错误)";
|
||||
break;
|
||||
}
|
||||
|
||||
std::cerr << error_msg << std::endl;
|
||||
std::string error_msg = "Server not accessible: " + httplib::to_string(res.error());
|
||||
LOG_ERROR(error_msg);
|
||||
return false;
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "检查服务器可访问性时发生异常: " << e.what() << std::endl;
|
||||
LOG_ERROR("Exception while checking server accessibility: ", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,14 +14,6 @@ server::server(const std::string& host, int port, const std::string& sse_endpoin
|
|||
: host_(host), port_(port), sse_endpoint_(sse_endpoint), msg_endpoint_(msg_endpoint_prefix), name_("MCP Server"), version_(MCP_VERSION) {
|
||||
|
||||
http_server_ = std::make_unique<httplib::Server>();
|
||||
|
||||
// Set default capabilities
|
||||
capabilities_ = {
|
||||
{"resources", {
|
||||
{"subscribe", true},
|
||||
{"listChanged", true}
|
||||
}}
|
||||
};
|
||||
}
|
||||
|
||||
server::~server() {
|
||||
|
@ -33,7 +25,7 @@ bool server::start(bool blocking) {
|
|||
return true; // Already running
|
||||
}
|
||||
|
||||
std::cerr << "启动MCP服务器: " << host_ << ":" << port_ << std::endl;
|
||||
LOG_INFO("Starting MCP server on ", host_, ":", port_);
|
||||
|
||||
// 设置CORS处理
|
||||
http_server_->Options(".*", [](const httplib::Request& req, httplib::Response& res) {
|
||||
|
@ -46,36 +38,36 @@ bool server::start(bool blocking) {
|
|||
// 设置JSON-RPC端点
|
||||
http_server_->Post(msg_endpoint_.c_str(), [this](const httplib::Request& req, httplib::Response& res) {
|
||||
this->handle_jsonrpc(req, res);
|
||||
LOG_INFO(req.remote_addr, ":", req.remote_port, " - \"POST ", req.path, " HTTP/1.1\" ", res.status);
|
||||
});
|
||||
|
||||
// 设置SSE端点
|
||||
http_server_->Get(sse_endpoint_.c_str(), [this](const httplib::Request& req, httplib::Response& res) {
|
||||
this->handle_sse(req, res);
|
||||
LOG_INFO(req.remote_addr, ":", req.remote_port, " - \"GET ", req.path, " HTTP/1.1\" ", res.status);
|
||||
});
|
||||
|
||||
// 启动服务器
|
||||
if (blocking) {
|
||||
running_ = true;
|
||||
std::cerr << "以阻塞模式启动服务器" << std::endl;
|
||||
print_status();
|
||||
LOG_INFO("Starting server in blocking mode");
|
||||
if (!http_server_->listen(host_.c_str(), port_)) {
|
||||
running_ = false;
|
||||
std::cerr << "无法在 " << host_ << ":" << port_ << " 上启动服务器" << std::endl;
|
||||
LOG_ERROR("Failed to start server on ", host_, ":", port_);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
// 在单独的线程中启动
|
||||
server_thread_ = std::make_unique<std::thread>([this]() {
|
||||
std::cerr << "在单独的线程中启动服务器" << std::endl;
|
||||
LOG_INFO("Starting server in separate thread");
|
||||
if (!http_server_->listen(host_.c_str(), port_)) {
|
||||
std::cerr << "无法在 " << host_ << ":" << port_ << " 上启动服务器" << std::endl;
|
||||
LOG_ERROR("Failed to start server on ", host_, ":", port_);
|
||||
running_ = false;
|
||||
return;
|
||||
}
|
||||
});
|
||||
running_ = true;
|
||||
print_status();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -85,15 +77,13 @@ void server::stop() {
|
|||
return;
|
||||
}
|
||||
|
||||
std::cerr << "正在停止MCP服务器..." << std::endl;
|
||||
print_status();
|
||||
LOG_INFO("Stopping MCP server...");
|
||||
running_ = false;
|
||||
|
||||
// 关闭所有SSE连接
|
||||
std::vector<std::string> session_ids;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
// 先收集所有会话ID
|
||||
for (const auto& [session_id, _] : session_dispatchers_) {
|
||||
session_ids.push_back(session_id);
|
||||
}
|
||||
|
@ -102,31 +92,25 @@ void server::stop() {
|
|||
// 关闭每个会话的分发器
|
||||
for (const auto& session_id : session_ids) {
|
||||
try {
|
||||
std::cerr << "关闭会话: " << session_id << std::endl;
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto it = session_dispatchers_.find(session_id);
|
||||
if (it != session_dispatchers_.end()) {
|
||||
it->second->close();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "关闭会话时发生异常: " << session_id << ", " << e.what() << std::endl;
|
||||
LOG_ERROR("Exception while closing session ", session_id, ": ", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
// 等待一段时间,让会话线程有机会自行清理
|
||||
std::cerr << "等待会话线程自行清理..." << std::endl;
|
||||
std::this_thread::sleep_for(std::chrono::seconds(1));
|
||||
|
||||
// 清理剩余的线程
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
for (auto& [session_id, thread] : sse_threads_) {
|
||||
if (thread && thread->joinable()) {
|
||||
try {
|
||||
std::cerr << "分离会话线程: " << session_id << std::endl;
|
||||
thread->detach();
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "分离会话线程时发生异常: " << session_id << ", " << e.what() << std::endl;
|
||||
LOG_ERROR("Exception while detaching session thread ", session_id, ": ", e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -137,16 +121,14 @@ void server::stop() {
|
|||
}
|
||||
|
||||
if (http_server_) {
|
||||
std::cerr << "停止HTTP服务器..." << std::endl;
|
||||
http_server_->stop();
|
||||
}
|
||||
|
||||
if (server_thread_ && server_thread_->joinable()) {
|
||||
std::cerr << "等待服务器线程结束..." << std::endl;
|
||||
server_thread_->join();
|
||||
}
|
||||
|
||||
std::cerr << "MCP服务器已停止" << std::endl;
|
||||
LOG_INFO("MCP server stopped");
|
||||
}
|
||||
|
||||
bool server::is_running() const {
|
||||
|
@ -212,10 +194,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
|
|||
{"resources", resources}
|
||||
};
|
||||
|
||||
// Handle pagination if cursor is provided
|
||||
if (params.contains("cursor")) {
|
||||
// In this implementation, we don't actually paginate
|
||||
// but we include the nextCursor field for compatibility
|
||||
result["nextCursor"] = "";
|
||||
}
|
||||
|
||||
|
@ -235,15 +214,12 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
|
|||
throw mcp_exception(error_code::invalid_params, "Resource not found: " + uri);
|
||||
}
|
||||
|
||||
// In a real implementation, we would register a subscription here
|
||||
// For now, just return success
|
||||
return json::object();
|
||||
};
|
||||
}
|
||||
|
||||
if (method_handlers_.find("resources/templates/list") == method_handlers_.end()) {
|
||||
method_handlers_["resources/templates/list"] = [this](const json& params) -> json {
|
||||
// In this implementation, we don't support resource templates
|
||||
return json::array();
|
||||
};
|
||||
}
|
||||
|
@ -260,7 +236,7 @@ void server::register_tool(const tool& tool, tool_handler handler) {
|
|||
for (const auto& [name, tool_pair] : tools_) {
|
||||
tools_json.push_back(tool_pair.first.to_json());
|
||||
}
|
||||
return tools_json;
|
||||
return json{{"tools", tools_json}};
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -310,18 +286,9 @@ void server::set_auth_handler(std::function<bool(const std::string&)> handler) {
|
|||
}
|
||||
|
||||
void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
||||
// 生成会话ID
|
||||
std::string session_id = generate_session_id();
|
||||
std::string session_uri = msg_endpoint_ + "?session_id=" + session_id;
|
||||
|
||||
std::cerr << "新的SSE连接: 客户端=" << req.remote_addr << ", 会话ID=" << session_id << std::endl;
|
||||
std::cerr << "会话URI: " << session_uri << std::endl;
|
||||
std::cerr << "请求头: ";
|
||||
for (const auto& [key, value] : req.headers) {
|
||||
std::cerr << key << "=" << value << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
|
||||
// 设置SSE响应头
|
||||
res.set_header("Content-Type", "text/event-stream");
|
||||
res.set_header("Cache-Control", "no-cache");
|
||||
|
@ -337,62 +304,46 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
|||
session_dispatchers_[session_id] = session_dispatcher;
|
||||
}
|
||||
|
||||
// 创建会话线程,使用值捕获而不是引用捕获
|
||||
auto thread = std::make_unique<std::thread>([this, session_id, session_uri, session_dispatcher]() {
|
||||
// 创建会话线程
|
||||
auto thread = std::make_unique<std::thread>([this, res, session_id, session_uri, session_dispatcher]() {
|
||||
try {
|
||||
std::cerr << "SSE会话线程启动: " << session_id << std::endl;
|
||||
|
||||
// 发送初始会话URI - 使用endpoint事件类型,符合MCP规范
|
||||
// 发送初始会话URI
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||
std::stringstream ss;
|
||||
ss << "event: endpoint\ndata: " << session_uri << "\n\n";
|
||||
session_dispatcher->send_event(ss.str());
|
||||
std::cerr << "发送会话URI: " << session_uri << " 到会话: " << session_id << std::endl;
|
||||
|
||||
// 定期发送心跳,检测连接状态
|
||||
int heartbeat_count = 0;
|
||||
while (running_ && !session_dispatcher->is_closed()) {
|
||||
std::this_thread::sleep_for(std::chrono::seconds(10));
|
||||
|
||||
// 检查分发器是否已关闭
|
||||
if (session_dispatcher->is_closed()) {
|
||||
std::cerr << "会话已关闭,停止心跳: " << session_id << std::endl;
|
||||
if (session_dispatcher->is_closed() || !running_) {
|
||||
break;
|
||||
}
|
||||
|
||||
// 检查服务器是否仍在运行
|
||||
if (!running_) {
|
||||
std::cerr << "服务器已停止,停止心跳: " << session_id << std::endl;
|
||||
break;
|
||||
}
|
||||
|
||||
// 发送心跳事件 - 使用自定义heartbeat事件类型
|
||||
std::stringstream heartbeat;
|
||||
heartbeat << "event: heartbeat\ndata: " << heartbeat_count++ << "\n\n";
|
||||
|
||||
try {
|
||||
bool sent = session_dispatcher->send_event(heartbeat.str());
|
||||
if (!sent) {
|
||||
std::cerr << "发送心跳失败,客户端可能已关闭连接: " << session_id << std::endl;
|
||||
LOG_WARNING("Failed to send heartbeat, client may have closed connection: ", session_id);
|
||||
break;
|
||||
}
|
||||
std::cerr << "发送心跳到会话: " << session_id << ", 计数: " << heartbeat_count << std::endl;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "发送心跳失败,假定连接已关闭: " << e.what() << std::endl;
|
||||
LOG_ERROR("Failed to send heartbeat: ", e.what());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::cerr << "SSE会话线程退出: " << session_id << std::endl;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "SSE会话线程异常: " << session_id << ", " << e.what() << std::endl;
|
||||
LOG_ERROR("SSE session thread exception: ", session_id, ", ", e.what());
|
||||
}
|
||||
|
||||
// 安全地清理资源
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
// 先关闭分发器
|
||||
auto dispatcher_it = session_dispatchers_.find(session_id);
|
||||
if (dispatcher_it != session_dispatchers_.end()) {
|
||||
if (!dispatcher_it->second->is_closed()) {
|
||||
|
@ -401,18 +352,13 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
|||
session_dispatchers_.erase(dispatcher_it);
|
||||
}
|
||||
|
||||
// 再移除线程
|
||||
auto thread_it = sse_threads_.find(session_id);
|
||||
if (thread_it != sse_threads_.end()) {
|
||||
// 不要在线程内部join或detach自己
|
||||
// 只从映射中移除
|
||||
thread_it->second.release(); // 释放所有权但不删除线程对象
|
||||
thread_it->second.release();
|
||||
sse_threads_.erase(thread_it);
|
||||
}
|
||||
|
||||
std::cerr << "会话资源已清理: " << session_id << std::endl;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "清理会话资源时发生异常: " << session_id << ", " << e.what() << std::endl;
|
||||
LOG_ERROR("Exception while cleaning up session resources: ", session_id, ", ", e.what());
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -430,9 +376,6 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
|||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto it = session_dispatchers_.find(session_id);
|
||||
if (it == session_dispatchers_.end() || it->second->is_closed()) {
|
||||
std::cerr << "会话已关闭,停止内容提供: " << session_id << std::endl;
|
||||
|
||||
// 不在这里清理资源,让会话线程自己清理
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -440,7 +383,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
|||
// 等待事件
|
||||
bool result = session_dispatcher->wait_event(&sink);
|
||||
if (!result) {
|
||||
std::cerr << "等待事件失败,关闭连接: " << session_id << std::endl;
|
||||
LOG_WARNING("Failed to wait for event, closing connection: ", session_id);
|
||||
|
||||
// 关闭会话分发器,但不清理资源
|
||||
{
|
||||
|
@ -454,10 +397,9 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
|||
return false;
|
||||
}
|
||||
|
||||
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "SSE内容提供者异常: " << e.what() << std::endl;
|
||||
LOG_ERROR("SSE content provider exception: ", e.what());
|
||||
|
||||
// 关闭会话分发器,但不清理资源
|
||||
try {
|
||||
|
@ -467,7 +409,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
|
|||
it->second->close();
|
||||
}
|
||||
} catch (const std::exception& e2) {
|
||||
std::cerr << "关闭会话分发器时发生异常: " << e2.what() << std::endl;
|
||||
LOG_ERROR("Exception while closing session dispatcher: ", e2.what());
|
||||
}
|
||||
|
||||
return false;
|
||||
|
@ -492,143 +434,110 @@ void server::handle_jsonrpc(const httplib::Request& req, httplib::Response& res)
|
|||
auto it = req.params.find("session_id");
|
||||
std::string session_id = it != req.params.end() ? it->second : "";
|
||||
|
||||
std::cerr << "收到JSON-RPC请求: 会话ID=" << session_id << ", 路径=" << req.path << std::endl;
|
||||
std::cerr << "请求参数: ";
|
||||
for (const auto& [key, value] : req.params) {
|
||||
std::cerr << key << "=" << value << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
|
||||
// 检查会话是否存在
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (!session_id.empty()) {
|
||||
std::cerr << "检查会话是否存在: " << session_id << std::endl;
|
||||
std::cerr << "当前活跃会话: ";
|
||||
for (const auto& [id, _] : session_dispatchers_) {
|
||||
std::cerr << id << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
|
||||
if (session_dispatchers_.find(session_id) == session_dispatchers_.end()) {
|
||||
std::cerr << "会话不存在: " << session_id << std::endl;
|
||||
json error_response = {
|
||||
{"jsonrpc", "2.0"},
|
||||
{"error", {
|
||||
{"code", static_cast<int>(error_code::invalid_request)},
|
||||
{"message", "Session not found"}
|
||||
}},
|
||||
{"id", nullptr}
|
||||
};
|
||||
res.set_content(error_response.dump(), "application/json");
|
||||
return;
|
||||
} else {
|
||||
std::cerr << "会话存在: " << session_id << std::endl;
|
||||
}
|
||||
} else {
|
||||
std::cerr << "请求中没有会话ID" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// 解析请求
|
||||
json req_json;
|
||||
try {
|
||||
req_json = json::parse(req.body);
|
||||
std::cerr << "请求内容: " << req_json.dump() << std::endl;
|
||||
} catch (const json::exception& e) {
|
||||
// 无效的JSON
|
||||
std::cerr << "解析JSON失败: " << e.what() << std::endl;
|
||||
json error_response = {
|
||||
{"jsonrpc", "2.0"},
|
||||
{"error", {
|
||||
{"code", static_cast<int>(error_code::parse_error)},
|
||||
{"message", "Parse error: " + std::string(e.what())}
|
||||
}},
|
||||
{"id", nullptr}
|
||||
};
|
||||
res.set_content(error_response.dump(), "application/json");
|
||||
LOG_ERROR("Failed to parse JSON request: ", e.what());
|
||||
res.status = 400;
|
||||
res.set_content("{\"error\":\"Invalid JSON\"}", "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
// 检查是否是批量请求
|
||||
if (req_json.is_array()) {
|
||||
// 批量请求暂不支持
|
||||
std::cerr << "不支持批量请求" << std::endl;
|
||||
json error_response = {
|
||||
{"jsonrpc", "2.0"},
|
||||
{"error", {
|
||||
{"code", static_cast<int>(error_code::invalid_request)},
|
||||
{"message", "Batch requests are not supported"}
|
||||
}},
|
||||
{"id", nullptr}
|
||||
};
|
||||
res.set_content(error_response.dump(), "application/json");
|
||||
return;
|
||||
// 检查会话是否存在
|
||||
std::shared_ptr<event_dispatcher> dispatcher;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto disp_it = session_dispatchers_.find(session_id);
|
||||
if (disp_it == session_dispatchers_.end()) {
|
||||
// 处理ping请求
|
||||
if (req_json["method"] == "ping") {
|
||||
res.status = 202;
|
||||
res.set_content("Accepted", "text/plain");
|
||||
return;
|
||||
}
|
||||
LOG_ERROR("Session not found: ", session_id);
|
||||
res.status = 404;
|
||||
res.set_content("{\"error\":\"Session not found\"}", "application/json");
|
||||
return;
|
||||
}
|
||||
dispatcher = disp_it->second;
|
||||
}
|
||||
|
||||
// 转换为请求对象
|
||||
// 创建请求对象
|
||||
request mcp_req;
|
||||
try {
|
||||
mcp_req.jsonrpc = req_json["jsonrpc"];
|
||||
mcp_req.method = req_json["method"];
|
||||
|
||||
if (req_json.contains("id")) {
|
||||
mcp_req.jsonrpc = req_json["jsonrpc"].get<std::string>();
|
||||
if (req_json.contains("id") && !req_json["id"].is_null()) {
|
||||
mcp_req.id = req_json["id"];
|
||||
}
|
||||
|
||||
mcp_req.method = req_json["method"].get<std::string>();
|
||||
if (req_json.contains("params")) {
|
||||
mcp_req.params = req_json["params"];
|
||||
}
|
||||
} catch (const json::exception& e) {
|
||||
// 无效的请求
|
||||
std::cerr << "无效的请求: " << e.what() << std::endl;
|
||||
json error_response = {
|
||||
{"jsonrpc", "2.0"},
|
||||
{"error", {
|
||||
{"code", static_cast<int>(error_code::invalid_request)},
|
||||
{"message", "Invalid request: " + std::string(e.what())}
|
||||
}},
|
||||
{"id", nullptr}
|
||||
};
|
||||
res.set_content(error_response.dump(), "application/json");
|
||||
} catch (const std::exception& e) {
|
||||
LOG_ERROR("Failed to create request object: ", e.what());
|
||||
res.status = 400;
|
||||
res.set_content("{\"error\":\"Invalid request format\"}", "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
// 处理请求
|
||||
std::cerr << "处理方法: " << mcp_req.method << std::endl;
|
||||
json result = process_request(mcp_req, session_id);
|
||||
std::cerr << "响应: " << result.dump() << std::endl;
|
||||
res.set_content(result.dump(), "application/json");
|
||||
// 如果是通知(没有ID),直接处理并返回202状态码
|
||||
if (mcp_req.is_notification()) {
|
||||
// 在线程池中异步处理通知
|
||||
thread_pool_.enqueue([this, mcp_req, session_id]() {
|
||||
process_request(mcp_req, session_id);
|
||||
});
|
||||
|
||||
// 返回202 Accepted
|
||||
res.status = 202;
|
||||
res.set_content("Accepted", "text/plain");
|
||||
return;
|
||||
}
|
||||
|
||||
// 对于有ID的请求,在线程池中处理并通过SSE返回结果
|
||||
thread_pool_.enqueue([this, mcp_req, session_id, dispatcher]() {
|
||||
// 处理请求
|
||||
json response_json = process_request(mcp_req, session_id);
|
||||
|
||||
// 通过SSE发送响应
|
||||
std::stringstream ss;
|
||||
ss << "event: message\ndata: " << response_json.dump() << "\n\n";
|
||||
bool result = dispatcher->send_event(ss.str());
|
||||
|
||||
if (!result) {
|
||||
LOG_ERROR("Failed to send response via SSE: session_id=", session_id);
|
||||
}
|
||||
});
|
||||
|
||||
// 返回202 Accepted
|
||||
res.status = 202;
|
||||
res.set_content("Accepted", "text/plain");
|
||||
}
|
||||
|
||||
json server::process_request(const request& req, const std::string& session_id) {
|
||||
// 检查是否是通知
|
||||
if (req.is_notification()) {
|
||||
std::cerr << "处理通知: " << req.method << std::endl;
|
||||
// 通知没有响应
|
||||
if (req.method == "notifications/initialized") {
|
||||
std::cerr << "收到客户端initialized通知,会话: " << session_id << std::endl;
|
||||
set_session_initialized(session_id, true);
|
||||
std::cerr << "会话已设置为初始化状态: " << session_id << std::endl;
|
||||
}
|
||||
return json::object();
|
||||
}
|
||||
|
||||
// 处理方法调用
|
||||
try {
|
||||
LOG_INFO("处理方法调用: ", req.method);
|
||||
LOG_INFO("Processing method call: ", req.method);
|
||||
|
||||
// 特殊情况:初始化
|
||||
if (req.method == "initialize") {
|
||||
return handle_initialize(req, session_id);
|
||||
} else if (req.method == "ping") {
|
||||
// 接收者必须立即响应一个空响应
|
||||
LOG_INFO("处理ping请求");
|
||||
return response::create_success(req.id, {}).to_json();
|
||||
return response::create_success(req.id, json::object()).to_json();
|
||||
}
|
||||
|
||||
if (!is_session_initialized(session_id)) {
|
||||
LOG_WARNING("会话未初始化: ", session_id);
|
||||
LOG_WARNING("Session not initialized: ", session_id);
|
||||
return response::create_error(
|
||||
req.id,
|
||||
error_code::invalid_request,
|
||||
|
@ -648,19 +557,19 @@ json server::process_request(const request& req, const std::string& session_id)
|
|||
|
||||
if (handler) {
|
||||
// 调用处理器
|
||||
LOG_INFO("调用方法处理器: ", req.method);
|
||||
LOG_INFO("Calling method handler: ", req.method);
|
||||
auto future = thread_pool_.enqueue([handler, params = req.params]() -> json {
|
||||
return handler(params);
|
||||
});
|
||||
json result = future.get();
|
||||
|
||||
// 创建成功响应
|
||||
LOG_INFO("方法调用成功: ", req.method);
|
||||
LOG_INFO("Method call successful: ", req.method);
|
||||
return response::create_success(req.id, result).to_json();
|
||||
}
|
||||
|
||||
// 方法未找到
|
||||
LOG_WARNING("方法未找到: ", req.method);
|
||||
LOG_WARNING("Method not found: ", req.method);
|
||||
return response::create_error(
|
||||
req.id,
|
||||
error_code::method_not_found,
|
||||
|
@ -668,7 +577,7 @@ json server::process_request(const request& req, const std::string& session_id)
|
|||
).to_json();
|
||||
} catch (const mcp_exception& e) {
|
||||
// MCP异常
|
||||
LOG_ERROR("MCP异常: ", e.what(), ", 代码: ", static_cast<int>(e.code()));
|
||||
LOG_ERROR("MCP exception: ", e.what(), ", code: ", static_cast<int>(e.code()));
|
||||
return response::create_error(
|
||||
req.id,
|
||||
e.code(),
|
||||
|
@ -676,7 +585,7 @@ json server::process_request(const request& req, const std::string& session_id)
|
|||
).to_json();
|
||||
} catch (const std::exception& e) {
|
||||
// 其他异常
|
||||
LOG_ERROR("处理请求时发生异常: ", e.what());
|
||||
LOG_ERROR("Exception while processing request: ", e.what());
|
||||
return response::create_error(
|
||||
req.id,
|
||||
error_code::internal_error,
|
||||
|
@ -684,7 +593,7 @@ json server::process_request(const request& req, const std::string& session_id)
|
|||
).to_json();
|
||||
} catch (...) {
|
||||
// 未知异常
|
||||
LOG_ERROR("处理请求时发生未知异常");
|
||||
LOG_ERROR("Unknown exception while processing request");
|
||||
return response::create_error(
|
||||
req.id,
|
||||
error_code::internal_error,
|
||||
|
@ -696,11 +605,9 @@ json server::process_request(const request& req, const std::string& session_id)
|
|||
json server::handle_initialize(const request& req, const std::string& session_id) {
|
||||
const json& params = req.params;
|
||||
|
||||
std::cerr << "处理initialize请求,会话ID: " << session_id << std::endl;
|
||||
|
||||
// Version negotiation
|
||||
if (!params.contains("protocolVersion") || !params["protocolVersion"].is_string()) {
|
||||
std::cerr << "缺少protocolVersion参数或格式不正确" << std::endl;
|
||||
LOG_ERROR("Missing or invalid protocolVersion parameter");
|
||||
return response::create_error(
|
||||
req.id,
|
||||
error_code::invalid_params,
|
||||
|
@ -709,10 +616,10 @@ json server::handle_initialize(const request& req, const std::string& session_id
|
|||
}
|
||||
|
||||
std::string requested_version = params["protocolVersion"].get<std::string>();
|
||||
std::cerr << "客户端请求的协议版本: " << requested_version << std::endl;
|
||||
LOG_INFO("Client requested protocol version: ", requested_version);
|
||||
|
||||
if (requested_version != MCP_VERSION) {
|
||||
std::cerr << "不支持的协议版本: " << requested_version << ", 服务器支持: " << MCP_VERSION << std::endl;
|
||||
LOG_ERROR("Unsupported protocol version: ", requested_version, ", server supports: ", MCP_VERSION);
|
||||
return response::create_error(
|
||||
req.id,
|
||||
error_code::invalid_params,
|
||||
|
@ -738,7 +645,7 @@ json server::handle_initialize(const request& req, const std::string& session_id
|
|||
}
|
||||
|
||||
// Log connection
|
||||
std::cerr << "客户端连接: " << client_name << " " << client_version << std::endl;
|
||||
LOG_INFO("Client connected: ", client_name, " ", client_version);
|
||||
|
||||
// Return server info and capabilities
|
||||
json server_info = {
|
||||
|
@ -752,7 +659,7 @@ json server::handle_initialize(const request& req, const std::string& session_id
|
|||
{"serverInfo", server_info}
|
||||
};
|
||||
|
||||
std::cerr << "初始化成功,等待客户端发送notifications/initialized通知" << std::endl;
|
||||
LOG_INFO("Initialization successful, waiting for notifications/initialized notification");
|
||||
|
||||
return response::create_success(req.id, result).to_json();
|
||||
}
|
||||
|
@ -763,9 +670,7 @@ void server::send_request(const std::string& session_id, const std::string& meth
|
|||
|
||||
// Check if client is initialized or if this is an allowed method
|
||||
if (!is_allowed_before_init && !is_session_initialized(session_id)) {
|
||||
// Client not initialized and method is not allowed before initialization
|
||||
std::cerr << "Cannot send " << method << " request to session " << session_id
|
||||
<< " before it is initialized" << std::endl;
|
||||
LOG_WARNING("Cannot send ", method, " request to session ", session_id, " before it is initialized");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -778,38 +683,33 @@ void server::send_request(const std::string& session_id, const std::string& meth
|
|||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto it = session_dispatchers_.find(session_id);
|
||||
if (it == session_dispatchers_.end()) {
|
||||
std::cerr << "会话不存在: " << session_id << std::endl;
|
||||
LOG_ERROR("Session not found: ", session_id);
|
||||
return;
|
||||
}
|
||||
dispatcher = it->second;
|
||||
}
|
||||
|
||||
// 发送请求 - 使用message事件类型,符合MCP规范
|
||||
// 发送请求
|
||||
std::stringstream ss;
|
||||
ss << "event: message\ndata: " << req.to_json().dump() << "\n\n";
|
||||
bool result = dispatcher->send_event(ss.str());
|
||||
|
||||
if (!result) {
|
||||
std::cerr << "向会话发送请求失败: " << session_id << std::endl;
|
||||
} else {
|
||||
std::cerr << "成功向会话 " << session_id << " 发送请求: " << method << std::endl;
|
||||
LOG_ERROR("Failed to send request to session: ", session_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if a session is initialized
|
||||
bool server::is_session_initialized(const std::string& session_id) const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto it = session_initialized_.find(session_id);
|
||||
return (it != session_initialized_.end() && it->second);
|
||||
}
|
||||
|
||||
// Set session initialization status
|
||||
void server::set_session_initialized(const std::string& session_id, bool initialized) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
session_initialized_[session_id] = initialized;
|
||||
}
|
||||
|
||||
// Generate a random session ID in UUID format
|
||||
std::string server::generate_session_id() const {
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
|
@ -846,48 +746,4 @@ std::string server::generate_session_id() const {
|
|||
return ss.str();
|
||||
}
|
||||
|
||||
void server::print_status() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
std::cerr << "=== MCP服务器状态 ===" << std::endl;
|
||||
std::cerr << "服务器地址: " << host_ << ":" << port_ << std::endl;
|
||||
std::cerr << "运行状态: " << (running_ ? "运行中" : "已停止") << std::endl;
|
||||
std::cerr << "SSE端点: " << sse_endpoint_ << std::endl;
|
||||
std::cerr << "消息端点前缀: " << msg_endpoint_ << std::endl;
|
||||
|
||||
std::cerr << "活跃会话数: " << session_dispatchers_.size() << std::endl;
|
||||
for (const auto& [session_id, dispatcher] : session_dispatchers_) {
|
||||
std::cerr << " - 会话ID: " << session_id << ", 状态: " << (dispatcher->is_closed() ? "已关闭" : "活跃") << std::endl;
|
||||
}
|
||||
|
||||
std::cerr << "SSE线程数: " << sse_threads_.size() << std::endl;
|
||||
|
||||
std::cerr << "注册的方法数: " << method_handlers_.size() << std::endl;
|
||||
for (const auto& [method, _] : method_handlers_) {
|
||||
std::cerr << " - " << method << std::endl;
|
||||
}
|
||||
|
||||
std::cerr << "注册的通知数: " << notification_handlers_.size() << std::endl;
|
||||
for (const auto& [method, _] : notification_handlers_) {
|
||||
std::cerr << " - " << method << std::endl;
|
||||
}
|
||||
|
||||
std::cerr << "注册的资源数: " << resources_.size() << std::endl;
|
||||
for (const auto& [path, _] : resources_) {
|
||||
std::cerr << " - " << path << std::endl;
|
||||
}
|
||||
|
||||
std::cerr << "注册的工具数: " << tools_.size() << std::endl;
|
||||
for (const auto& [name, _] : tools_) {
|
||||
std::cerr << " - " << name << std::endl;
|
||||
}
|
||||
|
||||
std::cerr << "已初始化的会话数: " << session_initialized_.size() << std::endl;
|
||||
for (const auto& [session_id, initialized] : session_initialized_) {
|
||||
std::cerr << " - " << session_id << ": " << (initialized ? "已初始化" : "未初始化") << std::endl;
|
||||
}
|
||||
|
||||
std::cerr << "======================" << std::endl;
|
||||
}
|
||||
|
||||
} // namespace mcp
|
|
@ -150,7 +150,7 @@ tool tool_builder::build() const {
|
|||
|
||||
// Create the parameters schema
|
||||
json schema = parameters_;
|
||||
schema["type"] = "object";
|
||||
schema["type"] = "object";;
|
||||
|
||||
if (!required_params_.empty()) {
|
||||
schema["required"] = required_params_;
|
||||
|
|
|
@ -29,6 +29,9 @@ set(TEST_SOURCES
|
|||
test_mcp_client.cpp
|
||||
test_mcp_server.cpp
|
||||
test_mcp_direct_requests.cpp
|
||||
test_mcp_lifecycle_transport.cpp
|
||||
test_mcp_versioning.cpp
|
||||
test_mcp_tools_extended.cpp
|
||||
)
|
||||
|
||||
# Create test executable
|
||||
|
@ -57,3 +60,9 @@ add_test(
|
|||
NAME ${TEST_PROJECT_NAME}
|
||||
COMMAND ${TEST_PROJECT_NAME}
|
||||
)
|
||||
|
||||
# Add custom target to run tests
|
||||
add_custom_target(run_tests
|
||||
COMMAND ${CMAKE_CTEST_COMMAND} --verbose
|
||||
DEPENDS ${TEST_PROJECT_NAME}
|
||||
)
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* @file test_client.cpp
|
||||
* @brief 测试MCP客户端
|
||||
*/
|
||||
|
||||
#include "mcp_client.h"
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
|
||||
int main() {
|
||||
// 创建客户端
|
||||
mcp::client client("localhost", 8080);
|
||||
|
||||
// 设置超时
|
||||
client.set_timeout(30);
|
||||
|
||||
// 初始化客户端
|
||||
std::cout << "正在初始化客户端..." << std::endl;
|
||||
bool success = client.initialize("TestClient", "1.0.0");
|
||||
|
||||
if (!success) {
|
||||
std::cerr << "初始化失败" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::cout << "初始化成功" << std::endl;
|
||||
|
||||
// 获取服务器能力
|
||||
std::cout << "服务器能力: " << client.get_server_capabilities().dump(2) << std::endl;
|
||||
|
||||
// 获取可用工具
|
||||
std::cout << "正在获取可用工具..." << std::endl;
|
||||
auto tools = client.get_tools();
|
||||
std::cout << "可用工具数量: " << tools.size() << std::endl;
|
||||
|
||||
for (const auto& tool : tools) {
|
||||
std::cout << "工具: " << tool.name << " - " << tool.description << std::endl;
|
||||
}
|
||||
|
||||
// 发送ping请求
|
||||
std::cout << "正在发送ping请求..." << std::endl;
|
||||
bool ping_result = client.ping();
|
||||
std::cout << "Ping结果: " << (ping_result ? "成功" : "失败") << std::endl;
|
||||
|
||||
// 列出资源
|
||||
std::cout << "正在列出资源..." << std::endl;
|
||||
auto resources = client.list_resources();
|
||||
std::cout << "资源: " << resources.dump(2) << std::endl;
|
||||
|
||||
// 测试多个并发请求
|
||||
std::cout << "测试并发请求..." << std::endl;
|
||||
for (int i = 0; i < 5; i++) {
|
||||
std::cout << "请求 " << i << "..." << std::endl;
|
||||
auto response = client.send_request("ping");
|
||||
std::cout << "响应 " << i << ": " << response.result.dump() << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "测试完成" << std::endl;
|
||||
return 0;
|
||||
}
|
|
@ -96,6 +96,15 @@ protected:
|
|||
return mcp::json::object();
|
||||
}
|
||||
|
||||
// 检查状态码,202表示请求已接受,但响应将通过SSE发送
|
||||
if (res->status == 202) {
|
||||
// 在实际测试中,我们需要等待SSE响应
|
||||
// 但在这个测试中,我们只是返回一个空对象
|
||||
// 实际应用中应该使用客户端类来处理这种情况
|
||||
std::cout << "收到202 Accepted响应,实际响应将通过SSE发送" << std::endl;
|
||||
return mcp::json::object();
|
||||
}
|
||||
|
||||
EXPECT_EQ(res->status, 200);
|
||||
|
||||
try {
|
||||
|
@ -419,8 +428,8 @@ TEST_F(DirectRequestTest, SendNotification) {
|
|||
|
||||
// Verify response (notifications may have empty response or error response)
|
||||
EXPECT_TRUE(res != nullptr);
|
||||
EXPECT_EQ(res->status, 200);
|
||||
// Don't check if response body is empty, as server implementation may return an empty object
|
||||
// 状态码可能是200或202,取决于服务器实现
|
||||
EXPECT_TRUE(res->status == 200 || res->status == 202);
|
||||
}
|
||||
|
||||
// Test error handling - method not found
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
/**
|
||||
* @file test_mcp_lifecycle_transport.cpp
|
||||
* @brief 测试MCP消息生命周期和传输相关功能
|
||||
*
|
||||
* 本文件包含对MCP消息生命周期和SSE传输的单元测试,基于规范2024-11-05。
|
||||
*/
|
||||
|
||||
#include "mcp_message.h"
|
||||
#include "mcp_server.h"
|
||||
#include "mcp_client.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <future>
|
||||
#include <chrono>
|
||||
|
||||
// 测试类,用于设置服务器和客户端
|
||||
class McpLifecycleTransportTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
// 创建服务器
|
||||
server = std::make_unique<mcp::server>("localhost", 8096);
|
||||
server->set_server_info("TestServer", "2024-11-05");
|
||||
|
||||
// 设置服务器能力
|
||||
mcp::json capabilities = {
|
||||
{"tools", {{"listChanged", true}}},
|
||||
{"transport", {{"sse", true}}}
|
||||
};
|
||||
server->set_capabilities(capabilities);
|
||||
|
||||
// 注册一个简单的方法
|
||||
server->register_method("test_method", [](const mcp::json& params) -> mcp::json {
|
||||
return {{"result", "success"}, {"params_received", params}};
|
||||
});
|
||||
|
||||
// 注册一个通知处理器
|
||||
server->register_notification("test_notification", [this](const mcp::json& params) {
|
||||
notification_received = true;
|
||||
notification_params = params;
|
||||
});
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
// 停止服务器
|
||||
if (server && server_thread.joinable()) {
|
||||
server->stop();
|
||||
server_thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
// 启动服务器
|
||||
void start_server() {
|
||||
server_thread = std::thread([this]() {
|
||||
server->start(false);
|
||||
});
|
||||
// 等待服务器启动
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
|
||||
std::unique_ptr<mcp::server> server;
|
||||
std::thread server_thread;
|
||||
bool notification_received = false;
|
||||
mcp::json notification_params;
|
||||
};
|
||||
|
||||
// 测试消息生命周期 - 初始化
|
||||
TEST_F(McpLifecycleTransportTest, InitializationTest) {
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 创建客户端
|
||||
mcp::client client("localhost", 8096);
|
||||
client.set_timeout(5);
|
||||
|
||||
// 测试初始化
|
||||
bool init_result = client.initialize("TestClient", "1.0.0");
|
||||
EXPECT_TRUE(init_result);
|
||||
|
||||
// 获取服务器能力
|
||||
mcp::json server_capabilities = client.get_server_capabilities();
|
||||
EXPECT_TRUE(server_capabilities.contains("tools"));
|
||||
EXPECT_TRUE(server_capabilities.contains("transport"));
|
||||
EXPECT_TRUE(server_capabilities["transport"]["sse"].get<bool>());
|
||||
}
|
||||
|
||||
// 测试消息生命周期 - 请求和响应
|
||||
TEST_F(McpLifecycleTransportTest, RequestResponseTest) {
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 创建客户端
|
||||
mcp::client client("localhost", 8096);
|
||||
client.set_timeout(5);
|
||||
client.initialize("TestClient", "1.0.0");
|
||||
|
||||
// 发送请求并获取响应
|
||||
mcp::json params = {{"key", "value"}, {"number", 42}};
|
||||
mcp::response response = client.send_request("test_method", params);
|
||||
|
||||
// 验证响应
|
||||
EXPECT_FALSE(response.is_error());
|
||||
EXPECT_EQ(response.result["result"], "success");
|
||||
EXPECT_EQ(response.result["params_received"]["key"], "value");
|
||||
EXPECT_EQ(response.result["params_received"]["number"], 42);
|
||||
}
|
||||
|
||||
// 测试消息生命周期 - 通知
|
||||
TEST_F(McpLifecycleTransportTest, NotificationTest) {
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 创建客户端
|
||||
mcp::client client("localhost", 8096);
|
||||
client.set_timeout(5);
|
||||
client.initialize("TestClient", "1.0.0");
|
||||
|
||||
// 发送通知
|
||||
mcp::json params = {{"event", "update"}, {"status", "completed"}};
|
||||
client.send_notification("test_notification", params);
|
||||
|
||||
// 等待通知处理
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
// 验证通知已接收
|
||||
EXPECT_TRUE(notification_received);
|
||||
EXPECT_EQ(notification_params["event"], "update");
|
||||
EXPECT_EQ(notification_params["status"], "completed");
|
||||
}
|
||||
|
||||
// 测试SSE传输 - 使用ping方法测试SSE连接
|
||||
TEST_F(McpLifecycleTransportTest, SseTransportTest) {
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 注册一个特殊的方法,用于测试SSE连接
|
||||
server->register_method("sse_test", [](const mcp::json& params) -> mcp::json {
|
||||
return {{"sse_test_result", true}};
|
||||
});
|
||||
|
||||
// 创建客户端
|
||||
mcp::client client("localhost", 8096);
|
||||
client.set_timeout(5);
|
||||
client.initialize("TestClient", "1.0.0");
|
||||
|
||||
// 等待SSE连接建立
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
||||
|
||||
// 测试SSE连接是否正常工作 - 使用ping方法
|
||||
bool ping_result = client.ping();
|
||||
EXPECT_TRUE(ping_result);
|
||||
|
||||
// 发送请求并获取响应,验证SSE连接正常工作
|
||||
mcp::response response = client.send_request("sse_test");
|
||||
EXPECT_FALSE(response.is_error());
|
||||
EXPECT_TRUE(response.result["sse_test_result"].get<bool>());
|
||||
}
|
||||
|
||||
// 测试ping功能
|
||||
TEST_F(McpLifecycleTransportTest, PingTest) {
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 创建客户端
|
||||
mcp::client client("localhost", 8096);
|
||||
client.set_timeout(5);
|
||||
client.initialize("TestClient", "1.0.0");
|
||||
|
||||
// 测试ping
|
||||
bool ping_result = client.ping();
|
||||
EXPECT_TRUE(ping_result);
|
||||
|
||||
// 停止服务器
|
||||
server->stop();
|
||||
server_thread.join();
|
||||
|
||||
// 等待服务器完全停止
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
||||
|
||||
// 再次测试ping,应该失败
|
||||
ping_result = client.ping();
|
||||
EXPECT_FALSE(ping_result);
|
||||
}
|
|
@ -39,11 +39,11 @@ TEST(McpToolTest, ToolStructTest) {
|
|||
mcp::json json_tool = tool.to_json();
|
||||
EXPECT_EQ(json_tool["name"], "test_tool");
|
||||
EXPECT_EQ(json_tool["description"], "测试工具");
|
||||
EXPECT_EQ(json_tool["parameters"]["type"], "object");
|
||||
EXPECT_EQ(json_tool["parameters"]["properties"]["param1"]["type"], "string");
|
||||
EXPECT_EQ(json_tool["parameters"]["properties"]["param1"]["description"], "第一个参数");
|
||||
EXPECT_EQ(json_tool["parameters"]["properties"]["param2"]["type"], "number");
|
||||
EXPECT_EQ(json_tool["parameters"]["required"][0], "param1");
|
||||
EXPECT_EQ(json_tool["inputSchema"]["type"], "object");
|
||||
EXPECT_EQ(json_tool["inputSchema"]["properties"]["param1"]["type"], "string");
|
||||
EXPECT_EQ(json_tool["inputSchema"]["properties"]["param1"]["description"], "第一个参数");
|
||||
EXPECT_EQ(json_tool["inputSchema"]["properties"]["param2"]["type"], "number");
|
||||
EXPECT_EQ(json_tool["inputSchema"]["required"][0], "param1");
|
||||
}
|
||||
|
||||
// 测试工具构建器
|
||||
|
|
|
@ -0,0 +1,357 @@
|
|||
/**
|
||||
* @file test_mcp_tools_extended.cpp
|
||||
* @brief 测试MCP工具相关功能的扩展测试
|
||||
*
|
||||
* 本文件包含对MCP工具模块的扩展单元测试,基于规范2024-11-05。
|
||||
*/
|
||||
|
||||
#include "mcp_tool.h"
|
||||
#include "mcp_server.h"
|
||||
#include "mcp_client.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <future>
|
||||
#include <chrono>
|
||||
|
||||
// 测试类,用于设置服务器和客户端
|
||||
class McpToolsExtendedTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
// 创建服务器
|
||||
server = std::make_unique<mcp::server>("localhost", 8098);
|
||||
server->set_server_info("TestServer", "2024-11-05");
|
||||
|
||||
// 设置服务器能力
|
||||
mcp::json capabilities = {
|
||||
{"tools", {{"listChanged", true}}}
|
||||
};
|
||||
server->set_capabilities(capabilities);
|
||||
|
||||
// 注册计算器工具
|
||||
mcp::tool calculator = mcp::tool_builder("calculator")
|
||||
.with_description("计算器工具")
|
||||
.with_string_param("operation", "操作类型 (add, subtract, multiply, divide)")
|
||||
.with_number_param("a", "第一个操作数")
|
||||
.with_number_param("b", "第二个操作数")
|
||||
.build();
|
||||
|
||||
server->register_tool(calculator, [](const mcp::json& params) -> mcp::json {
|
||||
std::string operation = params["operation"];
|
||||
double a = params["a"];
|
||||
double b = params["b"];
|
||||
double result = 0;
|
||||
|
||||
if (operation == "add") {
|
||||
result = a + b;
|
||||
} else if (operation == "subtract") {
|
||||
result = a - b;
|
||||
} else if (operation == "multiply") {
|
||||
result = a * b;
|
||||
} else if (operation == "divide") {
|
||||
if (b == 0) {
|
||||
return {{"error", "除数不能为零"}};
|
||||
}
|
||||
result = a / b;
|
||||
} else {
|
||||
return {{"error", "未知操作: " + operation}};
|
||||
}
|
||||
|
||||
return {{"result", result}};
|
||||
});
|
||||
|
||||
// 注册文本处理工具
|
||||
mcp::tool text_processor = mcp::tool_builder("text_processor")
|
||||
.with_description("文本处理工具")
|
||||
.with_string_param("text", "要处理的文本")
|
||||
.with_string_param("operation", "操作类型 (uppercase, lowercase, reverse)")
|
||||
.build();
|
||||
|
||||
server->register_tool(text_processor, [](const mcp::json& params) -> mcp::json {
|
||||
std::string text = params["text"];
|
||||
std::string operation = params["operation"];
|
||||
std::string result;
|
||||
|
||||
if (operation == "uppercase") {
|
||||
result = text;
|
||||
std::transform(result.begin(), result.end(), result.begin(), ::toupper);
|
||||
} else if (operation == "lowercase") {
|
||||
result = text;
|
||||
std::transform(result.begin(), result.end(), result.begin(), ::tolower);
|
||||
} else if (operation == "reverse") {
|
||||
result = text;
|
||||
std::reverse(result.begin(), result.end());
|
||||
} else {
|
||||
return {{"error", "未知操作: " + operation}};
|
||||
}
|
||||
|
||||
return {{"result", result}};
|
||||
});
|
||||
|
||||
// 注册列表处理工具
|
||||
mcp::tool list_processor = mcp::tool_builder("list_processor")
|
||||
.with_description("列表处理工具")
|
||||
.with_array_param("items", "要处理的项目列表", "string")
|
||||
.with_string_param("operation", "操作类型 (sort, reverse, count)")
|
||||
.build();
|
||||
|
||||
server->register_tool(list_processor, [](const mcp::json& params) -> mcp::json {
|
||||
auto items = params["items"].get<std::vector<std::string>>();
|
||||
std::string operation = params["operation"];
|
||||
|
||||
if (operation == "sort") {
|
||||
std::sort(items.begin(), items.end());
|
||||
return {{"result", items}};
|
||||
} else if (operation == "reverse") {
|
||||
std::reverse(items.begin(), items.end());
|
||||
return {{"result", items}};
|
||||
} else if (operation == "count") {
|
||||
return {{"result", items.size()}};
|
||||
} else {
|
||||
return {{"error", "未知操作: " + operation}};
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
// 停止服务器
|
||||
if (server && server_thread.joinable()) {
|
||||
server->stop();
|
||||
server_thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
// 启动服务器
|
||||
void start_server() {
|
||||
server_thread = std::thread([this]() {
|
||||
server->start(false);
|
||||
});
|
||||
// 等待服务器启动
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
|
||||
std::unique_ptr<mcp::server> server;
|
||||
std::thread server_thread;
|
||||
};
|
||||
|
||||
// 测试获取工具列表
|
||||
TEST_F(McpToolsExtendedTest, GetToolsTest) {
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 创建客户端
|
||||
mcp::client client("localhost", 8098);
|
||||
client.set_timeout(5);
|
||||
client.initialize("TestClient", mcp::MCP_VERSION);
|
||||
|
||||
// 获取工具列表
|
||||
auto tools = client.get_tools();
|
||||
|
||||
// 验证工具列表
|
||||
EXPECT_EQ(tools.size(), 3);
|
||||
|
||||
// 验证工具名称
|
||||
std::vector<std::string> tool_names;
|
||||
for (const auto& tool : tools) {
|
||||
tool_names.push_back(tool.name);
|
||||
}
|
||||
|
||||
EXPECT_THAT(tool_names, ::testing::UnorderedElementsAre("calculator", "text_processor", "list_processor"));
|
||||
}
|
||||
|
||||
// 测试调用计算器工具
|
||||
TEST_F(McpToolsExtendedTest, CalculatorToolTest) {
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 创建客户端
|
||||
mcp::client client("localhost", 8098);
|
||||
client.set_timeout(5);
|
||||
client.initialize("TestClient", mcp::MCP_VERSION);
|
||||
|
||||
// 调用加法
|
||||
mcp::json add_result = client.call_tool("calculator", {
|
||||
{"operation", "add"},
|
||||
{"a", 5},
|
||||
{"b", 3}
|
||||
});
|
||||
EXPECT_EQ(add_result["result"], 8);
|
||||
|
||||
// 调用减法
|
||||
mcp::json subtract_result = client.call_tool("calculator", {
|
||||
{"operation", "subtract"},
|
||||
{"a", 10},
|
||||
{"b", 4}
|
||||
});
|
||||
EXPECT_EQ(subtract_result["result"], 6);
|
||||
|
||||
// 调用乘法
|
||||
mcp::json multiply_result = client.call_tool("calculator", {
|
||||
{"operation", "multiply"},
|
||||
{"a", 6},
|
||||
{"b", 7}
|
||||
});
|
||||
EXPECT_EQ(multiply_result["result"], 42);
|
||||
|
||||
// 调用除法
|
||||
mcp::json divide_result = client.call_tool("calculator", {
|
||||
{"operation", "divide"},
|
||||
{"a", 20},
|
||||
{"b", 5}
|
||||
});
|
||||
EXPECT_EQ(divide_result["result"], 4);
|
||||
|
||||
// 测试除以零
|
||||
mcp::json divide_by_zero = client.call_tool("calculator", {
|
||||
{"operation", "divide"},
|
||||
{"a", 10},
|
||||
{"b", 0}
|
||||
});
|
||||
EXPECT_TRUE(divide_by_zero.contains("error"));
|
||||
}
|
||||
|
||||
// 测试调用文本处理工具
|
||||
TEST_F(McpToolsExtendedTest, TextProcessorToolTest) {
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 创建客户端
|
||||
mcp::client client("localhost", 8098);
|
||||
client.set_timeout(5);
|
||||
client.initialize("TestClient", mcp::MCP_VERSION);
|
||||
|
||||
// 测试转大写
|
||||
mcp::json uppercase_result = client.call_tool("text_processor", {
|
||||
{"text", "Hello World"},
|
||||
{"operation", "uppercase"}
|
||||
});
|
||||
EXPECT_EQ(uppercase_result["result"], "HELLO WORLD");
|
||||
|
||||
// 测试转小写
|
||||
mcp::json lowercase_result = client.call_tool("text_processor", {
|
||||
{"text", "Hello World"},
|
||||
{"operation", "lowercase"}
|
||||
});
|
||||
EXPECT_EQ(lowercase_result["result"], "hello world");
|
||||
|
||||
// 测试反转
|
||||
mcp::json reverse_result = client.call_tool("text_processor", {
|
||||
{"text", "Hello World"},
|
||||
{"operation", "reverse"}
|
||||
});
|
||||
EXPECT_EQ(reverse_result["result"], "dlroW olleH");
|
||||
}
|
||||
|
||||
// 测试调用列表处理工具
|
||||
TEST_F(McpToolsExtendedTest, ListProcessorToolTest) {
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 创建客户端
|
||||
mcp::client client("localhost", 8098);
|
||||
client.set_timeout(5);
|
||||
client.initialize("TestClient", mcp::MCP_VERSION);
|
||||
|
||||
// 准备测试数据
|
||||
std::vector<std::string> items = {"banana", "apple", "orange", "grape"};
|
||||
|
||||
// 测试排序
|
||||
mcp::json sort_result = client.call_tool("list_processor", {
|
||||
{"items", items},
|
||||
{"operation", "sort"}
|
||||
});
|
||||
std::vector<std::string> sorted_items = sort_result["result"];
|
||||
EXPECT_THAT(sorted_items, ::testing::ElementsAre("apple", "banana", "grape", "orange"));
|
||||
|
||||
// 测试反转
|
||||
mcp::json reverse_result = client.call_tool("list_processor", {
|
||||
{"items", items},
|
||||
{"operation", "reverse"}
|
||||
});
|
||||
std::vector<std::string> reversed_items = reverse_result["result"];
|
||||
EXPECT_THAT(reversed_items, ::testing::ElementsAre("grape", "orange", "apple", "banana"));
|
||||
|
||||
// 测试计数
|
||||
mcp::json count_result = client.call_tool("list_processor", {
|
||||
{"items", items},
|
||||
{"operation", "count"}
|
||||
});
|
||||
EXPECT_EQ(count_result["result"], 4);
|
||||
}
|
||||
|
||||
// 测试工具参数验证
|
||||
TEST_F(McpToolsExtendedTest, ToolParameterValidationTest) {
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 创建客户端
|
||||
mcp::client client("localhost", 8098);
|
||||
client.set_timeout(5);
|
||||
client.initialize("TestClient", mcp::MCP_VERSION);
|
||||
|
||||
// 测试缺少必需参数
|
||||
try {
|
||||
client.call_tool("calculator", {
|
||||
{"a", 5}
|
||||
// 缺少 operation 和 b
|
||||
});
|
||||
FAIL() << "应该抛出异常,因为缺少必需参数";
|
||||
} catch (const mcp::mcp_exception& e) {
|
||||
EXPECT_EQ(e.code(), mcp::error_code::invalid_params);
|
||||
}
|
||||
|
||||
// 测试参数类型错误
|
||||
try {
|
||||
client.call_tool("calculator", {
|
||||
{"operation", "add"},
|
||||
{"a", "not_a_number"}, // 应该是数字
|
||||
{"b", 3}
|
||||
});
|
||||
FAIL() << "应该抛出异常,因为参数类型错误";
|
||||
} catch (const mcp::mcp_exception& e) {
|
||||
EXPECT_EQ(e.code(), mcp::error_code::invalid_params);
|
||||
}
|
||||
}
|
||||
|
||||
// 测试工具注册和注销
|
||||
TEST_F(McpToolsExtendedTest, ToolRegistrationAndUnregistrationTest) {
|
||||
// 创建工具注册表
|
||||
mcp::tool_registry& registry = mcp::tool_registry::instance();
|
||||
|
||||
// 创建一个测试工具
|
||||
mcp::tool test_tool = mcp::tool_builder("test_tool")
|
||||
.with_description("测试工具")
|
||||
.with_string_param("input", "输入参数")
|
||||
.build();
|
||||
|
||||
// 注册工具
|
||||
registry.register_tool(test_tool, [](const mcp::json& params) -> mcp::json {
|
||||
return {{"output", "处理结果: " + params["input"].get<std::string>()}};
|
||||
});
|
||||
|
||||
// 验证工具已注册
|
||||
auto registered_tool = registry.get_tool("test_tool");
|
||||
ASSERT_NE(registered_tool, nullptr);
|
||||
EXPECT_EQ(registered_tool->first.name, "test_tool");
|
||||
EXPECT_EQ(registered_tool->first.description, "测试工具");
|
||||
|
||||
// 调用工具
|
||||
mcp::json result = registry.call_tool("test_tool", {{"input", "测试输入"}});
|
||||
EXPECT_EQ(result["output"], "处理结果: 测试输入");
|
||||
|
||||
// 注销工具
|
||||
bool unregistered = registry.unregister_tool("test_tool");
|
||||
EXPECT_TRUE(unregistered);
|
||||
|
||||
// 验证工具已注销
|
||||
EXPECT_EQ(registry.get_tool("test_tool"), nullptr);
|
||||
|
||||
// 尝试调用已注销的工具
|
||||
try {
|
||||
registry.call_tool("test_tool", {{"input", "测试输入"}});
|
||||
FAIL() << "应该抛出异常,因为工具已注销";
|
||||
} catch (const mcp::mcp_exception& e) {
|
||||
EXPECT_EQ(e.code(), mcp::error_code::method_not_found);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,140 @@
|
|||
/**
|
||||
* @file test_mcp_versioning.cpp
|
||||
* @brief 测试MCP版本控制相关功能
|
||||
*
|
||||
* 本文件包含对MCP版本控制的单元测试,基于规范2024-11-05。
|
||||
*/
|
||||
|
||||
#include "mcp_message.h"
|
||||
#include "mcp_server.h"
|
||||
#include "mcp_client.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <future>
|
||||
#include <chrono>
|
||||
|
||||
// 测试类,用于设置服务器和客户端
|
||||
class McpVersioningTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
// 创建服务器
|
||||
server = std::make_unique<mcp::server>("localhost", 8097);
|
||||
server->set_server_info("TestServer", "2024-11-05");
|
||||
|
||||
// 设置服务器能力
|
||||
mcp::json capabilities = {
|
||||
{"tools", {{"listChanged", true}}},
|
||||
{"transport", {{"sse", true}}}
|
||||
};
|
||||
server->set_capabilities(capabilities);
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
// 停止服务器
|
||||
if (server && server_thread.joinable()) {
|
||||
server->stop();
|
||||
server_thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
// 启动服务器
|
||||
void start_server() {
|
||||
server_thread = std::thread([this]() {
|
||||
server->start(false);
|
||||
});
|
||||
// 等待服务器启动
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
|
||||
std::unique_ptr<mcp::server> server;
|
||||
std::thread server_thread;
|
||||
};
|
||||
|
||||
// 测试版本常量
|
||||
TEST(McpVersioningTest, VersionConstantTest) {
|
||||
// 验证MCP版本常量
|
||||
EXPECT_EQ(std::string(mcp::MCP_VERSION), "2024-11-05");
|
||||
}
|
||||
|
||||
// 测试版本匹配
|
||||
TEST_F(McpVersioningTest, VersionMatchTest) {
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 创建客户端,使用匹配的版本
|
||||
mcp::client client("localhost", 8097);
|
||||
client.set_timeout(5);
|
||||
|
||||
// 测试初始化,应该成功
|
||||
bool init_result = client.initialize("TestClient", mcp::MCP_VERSION);
|
||||
EXPECT_TRUE(init_result);
|
||||
}
|
||||
|
||||
// 测试版本不匹配
|
||||
TEST_F(McpVersioningTest, VersionMismatchTest) {
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 创建客户端,使用不匹配的版本
|
||||
mcp::client client("localhost", 8097);
|
||||
client.set_timeout(5);
|
||||
|
||||
// 测试初始化,应该失败或返回警告
|
||||
// 注意:根据实际实现,这可能会成功但有警告,或者完全失败
|
||||
bool init_result = client.initialize("TestClient", "2023-01-01");
|
||||
|
||||
// 如果实现允许版本不匹配,这个测试可能需要调整
|
||||
// 这里我们假设实现会拒绝不匹配的版本
|
||||
EXPECT_FALSE(init_result);
|
||||
}
|
||||
|
||||
// 测试服务器版本信息
|
||||
TEST_F(McpVersioningTest, ServerVersionInfoTest) {
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 创建客户端
|
||||
mcp::client client("localhost", 8097);
|
||||
client.set_timeout(5);
|
||||
client.initialize("TestClient", mcp::MCP_VERSION);
|
||||
|
||||
// 获取服务器信息
|
||||
mcp::response response = client.send_request("server/info");
|
||||
|
||||
// 验证服务器信息
|
||||
EXPECT_FALSE(response.is_error());
|
||||
EXPECT_EQ(response.result["name"], "TestServer");
|
||||
EXPECT_EQ(response.result["version"], "2024-11-05");
|
||||
EXPECT_TRUE(response.result.contains("capabilities"));
|
||||
}
|
||||
|
||||
// 测试客户端版本信息
|
||||
TEST_F(McpVersioningTest, ClientVersionInfoTest) {
|
||||
// 创建一个处理器来捕获初始化请求
|
||||
mcp::json captured_init_params;
|
||||
server->register_method("initialize", [&captured_init_params](const mcp::json& params) -> mcp::json {
|
||||
captured_init_params = params;
|
||||
return {
|
||||
{"name", "TestServer"},
|
||||
{"version", "2024-11-05"},
|
||||
{"capabilities", {
|
||||
{"tools", {{"listChanged", true}}},
|
||||
{"transport", {{"sse", true}}}
|
||||
}}
|
||||
};
|
||||
});
|
||||
|
||||
// 启动服务器
|
||||
start_server();
|
||||
|
||||
// 创建客户端
|
||||
mcp::client client("localhost", 8097);
|
||||
client.set_timeout(5);
|
||||
client.initialize("TestClient", "1.0.0");
|
||||
|
||||
// 验证客户端版本信息
|
||||
EXPECT_EQ(captured_init_params["name"], "TestClient");
|
||||
EXPECT_EQ(captured_init_params["version"], "1.0.0");
|
||||
}
|
Loading…
Reference in New Issue