Compare commits

...

3 Commits

17 changed files with 1196 additions and 568 deletions

View File

@ -33,9 +33,5 @@ if(MCP_BUILD_TESTS)
enable_testing() enable_testing()
add_subdirectory(test) add_subdirectory(test)
# Add custom test target # run_teststest/CMakeLists.txt
add_custom_target(run_tests
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure
COMMENT "Running MCP tests..."
)
endif() endif()

View File

@ -12,7 +12,7 @@
int main() { int main() {
// Create a client // Create a client
mcp::client client("localhost", 8089); mcp::client client("localhost", 8888);
// Set capabilites // Set capabilites
mcp::json capabilities = { mcp::json capabilities = {
@ -53,6 +53,8 @@ int main() {
std::cout << "- " << tool.name << ": " << tool.description << std::endl; std::cout << "- " << tool.name << ": " << tool.description << std::endl;
} }
// Get available resources
// Call the get_time tool // Call the get_time tool
std::cout << "\nCalling get_time tool..." << std::endl; std::cout << "\nCalling get_time tool..." << std::endl;
mcp::json time_result = client.call_tool("get_time"); mcp::json time_result = client.call_tool("get_time");
@ -76,17 +78,6 @@ int main() {
}; };
mcp::json calc_result = client.call_tool("calculator", calc_params); mcp::json calc_result = client.call_tool("calculator", calc_params);
std::cout << "10 + 5 = " << calc_result["content"][0]["text"].get<std::string>() << std::endl; std::cout << "10 + 5 = " << calc_result["content"][0]["text"].get<std::string>() << std::endl;
// Not implemented yet
// // Access a resource
// std::cout << "\nAccessing API resource..." << std::endl;
// mcp::json api_params = {
// {"endpoint", "hello"},
// {"name", "MCP Client"}
// };
// mcp::json api_result = client.access_resource("/api", api_params);
// std::cout << "API response: " << api_result["contents"][0]["text"].get<std::string>() << std::endl;
} catch (const mcp::mcp_exception& e) { } catch (const mcp::mcp_exception& e) {
std::cerr << "MCP error: " << e.what() << " (code: " << static_cast<int>(e.code()) << ")" << std::endl; std::cerr << "MCP error: " << e.what() << " (code: " << static_cast<int>(e.code()) << ")" << std::endl;
return 1; return 1;

View File

@ -122,12 +122,13 @@ int main() {
std::filesystem::create_directories("./files"); std::filesystem::create_directories("./files");
// Create and configure server // Create and configure server
mcp::server server("localhost", 8089); mcp::server server("localhost", 8888);
server.set_server_info("ExampleServer", "1.0.0"); server.set_server_info("ExampleServer", "1.0.0");
// Set server capabilities // Set server capabilities
mcp::json capabilities = { mcp::json capabilities = {
{"tools", {{"listChanged", true}}} {"tools", {{"listChanged", true}}},
{"resources", {{"subscribe", false}, {"listChanged", true}}}
}; };
server.set_capabilities(capabilities); server.set_capabilities(capabilities);
@ -154,18 +155,12 @@ int main() {
server.register_tool(echo_tool, echo_handler); server.register_tool(echo_tool, echo_handler);
server.register_tool(calc_tool, calculator_handler); server.register_tool(calc_tool, calculator_handler);
// Not implemented yet // Register resources
// // Register resources auto file_resource = std::make_shared<mcp::file_resource>("./Makefile");
// auto file_resource = std::make_shared<mcp::file_resource>("./files"); server.register_resource("file://./Makefile", file_resource);
// server.register_resource("/files", file_resource);
// auto api_resource = std::make_shared<mcp::api_resource>("API", "Custom API endpoints");
// api_resource->register_handler("hello", hello_handler, "Say hello");
// server.register_resource("/api", api_resource);
// Start server // Start server
std::cout << "Starting MCP server at localhost:8089..." << std::endl; std::cout << "Starting MCP server at localhost:8888..." << std::endl;
std::cout << "Press Ctrl+C to stop the server" << std::endl; std::cout << "Press Ctrl+C to stop the server" << std::endl;
server.start(true); // Blocking mode server.start(true); // Blocking mode

View File

@ -11,6 +11,7 @@
#include "mcp_message.h" #include "mcp_message.h"
#include "mcp_tool.h" #include "mcp_tool.h"
#include "mcp_logger.h"
// Include the HTTP library // Include the HTTP library
#include "httplib.h" #include "httplib.h"
@ -23,6 +24,7 @@
#include <functional> #include <functional>
#include <atomic> #include <atomic>
#include <condition_variable> #include <condition_variable>
#include <future>
namespace mcp { namespace mcp {
@ -173,43 +175,76 @@ public:
bool check_server_accessible(); bool check_server_accessible();
private: private:
std::string base_url_; // 初始化HTTP客户端
void init_client(const std::string& host, int port);
void init_client(const std::string& base_url);
// 打开SSE连接
void open_sse_connection();
// 解析SSE数据
bool parse_sse_data(const char* data, size_t length);
// 关闭SSE连接
void close_sse_connection();
// 发送JSON-RPC请求
json send_jsonrpc(const request& req);
// 服务器主机和端口
std::string host_; std::string host_;
int port_; int port_ = 8080;
std::string sse_endpoint_;
// 或者使用基础URL
std::string base_url_;
// SSE端点
std::string sse_endpoint_ = "/sse";
// 消息端点
std::string msg_endpoint_; std::string msg_endpoint_;
std::string auth_token_;
int timeout_seconds_ = 30;
json capabilities_;
std::map<std::string, std::string> default_headers_; // HTTP客户端
json server_capabilities_;
// HTTP client
std::unique_ptr<httplib::Client> http_client_; std::unique_ptr<httplib::Client> http_client_;
// Mutex for thread safety // SSE专用HTTP客户端
std::unique_ptr<httplib::Client> sse_client_;
// SSE线程
std::unique_ptr<std::thread> sse_thread_;
// SSE运行状态
std::atomic<bool> sse_running_{false};
// 认证令牌
std::string auth_token_;
// 默认请求头
std::map<std::string, std::string> default_headers_;
// 超时设置(秒)
int timeout_seconds_ = 30;
// 客户端能力
json capabilities_;
// 服务器能力
json server_capabilities_;
// 互斥锁
mutable std::mutex mutex_; mutable std::mutex mutex_;
// 条件变量,用于等待消息端点设置 // 条件变量,用于等待消息端点设置
std::condition_variable endpoint_cv_; std::condition_variable endpoint_cv_;
// SSE connection // 请求ID到Promise的映射用于异步等待响应
std::unique_ptr<std::thread> sse_thread_; std::map<json, std::promise<json>> pending_requests_;
// SSE连接状态 // 响应处理互斥锁
std::atomic<bool> sse_running_{false}; std::mutex response_mutex_;
// Initialize the client // 响应条件变量
void init_client(const std::string& host, int port); std::condition_variable response_cv_;
void init_client(const std::string& base_url);
void open_sse_connection();
void close_sse_connection();
bool parse_sse_data(const char* data, size_t length);
// Send a JSON-RPC request and get the response
json send_jsonrpc(const request& req);
}; };
} // namespace mcp } // namespace mcp

View File

@ -83,19 +83,19 @@ private:
ss << std::put_time(now_tm, "%Y-%m-%d %H:%M:%S") << " "; ss << std::put_time(now_tm, "%Y-%m-%d %H:%M:%S") << " ";
// 添加日志级别 // 添加日志级别和颜色
switch (level) { switch (level) {
case log_level::debug: case log_level::debug:
ss << "[DEBUG] "; ss << "\033[36m[DEBUG]\033[0m "; // 青色
break; break;
case log_level::info: case log_level::info:
ss << "[INFO] "; ss << "\033[32m[INFO]\033[0m "; // 绿色
break; break;
case log_level::warning: case log_level::warning:
ss << "[WARNING] "; ss << "\033[33m[WARNING]\033[0m "; // 黄色
break; break;
case log_level::error: case log_level::error:
ss << "[ERROR] "; ss << "\033[31m[ERROR]\033[0m "; // 红色
break; break;
} }

View File

@ -79,13 +79,11 @@ public:
try { try {
bool write_result = sink->write(message_copy.data(), message_copy.size()); bool write_result = sink->write(message_copy.data(), message_copy.size());
if (!write_result) { if (!write_result) {
std::cerr << "写入事件数据失败: 客户端可能已关闭连接" << std::endl;
close(); close();
return false; return false;
} }
return true; return true;
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << "写入事件数据失败: " << e.what() << std::endl;
close(); close();
return false; return false;
} }
@ -232,13 +230,6 @@ public:
*/ */
void send_request(const std::string& session_id, const std::string& method, const json& params = json::object()); void send_request(const std::string& session_id, const std::string& method, const json& params = json::object());
/**
* @brief
*
*
*/
void print_status() const;
private: private:
std::string host_; std::string host_;
int port_; int port_;

View File

@ -31,7 +31,7 @@ struct tool {
return { return {
{"name", name}, {"name", name},
{"description", description}, {"description", description},
{"parameters", parameters_schema} {"inputSchema", parameters_schema} // You may need 'parameters' instead of 'inputSchema' for OAI format
}; };
} }
}; };

View File

@ -23,42 +23,41 @@ client::client(const std::string& base_url, const json& capabilities, const std:
} }
client::~client() { client::~client() {
// 关闭SSE连接
close_sse_connection(); close_sse_connection();
// httplib::Client将自动销毁
} }
void client::init_client(const std::string& host, int port) { void client::init_client(const std::string& host, int port) {
// Create the HTTP client
http_client_ = std::make_unique<httplib::Client>(host.c_str(), port); http_client_ = std::make_unique<httplib::Client>(host.c_str(), port);
sse_client_ = std::make_unique<httplib::Client>(host.c_str(), port);
// Set timeout
http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_connection_timeout(timeout_seconds_, 0);
http_client_->set_read_timeout(timeout_seconds_, 0); http_client_->set_read_timeout(timeout_seconds_, 0);
http_client_->set_write_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0);
sse_client_->set_connection_timeout(timeout_seconds_ * 2, 0);
sse_client_->set_write_timeout(timeout_seconds_, 0);
} }
void client::init_client(const std::string& base_url) { void client::init_client(const std::string& base_url) {
// Create the HTTP client
http_client_ = std::make_unique<httplib::Client>(base_url.c_str()); http_client_ = std::make_unique<httplib::Client>(base_url.c_str());
sse_client_ = std::make_unique<httplib::Client>(base_url.c_str());
// Set timeout
http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_connection_timeout(timeout_seconds_, 0);
http_client_->set_read_timeout(timeout_seconds_, 0); http_client_->set_read_timeout(timeout_seconds_, 0);
http_client_->set_write_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0);
sse_client_->set_connection_timeout(timeout_seconds_ * 2, 0);
sse_client_->set_read_timeout(0, 0);
sse_client_->set_write_timeout(timeout_seconds_, 0);
} }
bool client::initialize(const std::string& client_name, const std::string& client_version) { bool client::initialize(const std::string& client_name, const std::string& client_version) {
std::cerr << "开始初始化MCP客户端..." << std::endl; LOG_INFO("Initializing MCP client...");
// 检查服务器是否可访问
if (!check_server_accessible()) { if (!check_server_accessible()) {
std::cerr << "服务器不可访问,初始化失败" << std::endl;
return false; return false;
} }
// Create initialization request
request req = request::create("initialize", { request req = request::create("initialize", {
{"protocolVersion", MCP_VERSION}, {"protocolVersion", MCP_VERSION},
{"capabilities", capabilities_}, {"capabilities", capabilities_},
@ -69,88 +68,63 @@ bool client::initialize(const std::string& client_name, const std::string& clien
}); });
try { try {
// 打开SSE连接 LOG_INFO("Opening SSE connection...");
std::cerr << "正在打开SSE连接..." << std::endl;
open_sse_connection(); open_sse_connection();
// 等待SSE连接建立并获取消息端点 const auto timeout = std::chrono::milliseconds(5000);
// 使用条件变量和超时机制
const auto timeout = std::chrono::milliseconds(5000); // 5秒超时
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
// 检查初始状态
if (!msg_endpoint_.empty()) {
std::cerr << "消息端点已经设置: " << msg_endpoint_ << std::endl;
} else {
std::cerr << "等待条件变量..." << std::endl;
}
bool success = endpoint_cv_.wait_for(lock, timeout, [this]() { bool success = endpoint_cv_.wait_for(lock, timeout, [this]() {
if (!sse_running_) { if (!sse_running_) {
std::cerr << "SSE连接已关闭停止等待" << std::endl; LOG_WARNING("SSE connection closed, stopping wait");
return true; return true;
} }
if (!msg_endpoint_.empty()) { if (!msg_endpoint_.empty()) {
std::cerr << "消息端点已设置,停止等待" << std::endl; LOG_INFO("Message endpoint set, stopping wait");
return true; return true;
} }
return false; return false;
}); });
// 检查等待结果
if (!success) { if (!success) {
std::cerr << "条件变量等待超时" << std::endl; LOG_WARNING("Condition variable wait timed out");
} }
// 如果SSE连接已关闭或等待超时抛出异常
if (!sse_running_) { if (!sse_running_) {
throw std::runtime_error("SSE连接已关闭,未能获取消息端点"); throw std::runtime_error("SSE connection closed, failed to get message endpoint");
} }
if (msg_endpoint_.empty()) { if (msg_endpoint_.empty()) {
throw std::runtime_error("等待SSE连接超时未能获取消息端点"); throw std::runtime_error("Timeout waiting for SSE connection, failed to get message endpoint");
} }
std::cerr << "成功获取消息端点: " << msg_endpoint_ << std::endl; LOG_INFO("Successfully got message endpoint: ", msg_endpoint_);
} }
// 发送初始化请求
json result = send_jsonrpc(req); json result = send_jsonrpc(req);
// 存储服务器能力
server_capabilities_ = result["capabilities"]; server_capabilities_ = result["capabilities"];
// 发送已初始化通知
request notification = request::create_notification("initialized"); request notification = request::create_notification("initialized");
send_jsonrpc(notification); send_jsonrpc(notification);
return true; return true;
} catch (const std::exception& e) { } catch (const std::exception& e) {
// 初始化失败关闭SSE连接 LOG_ERROR("Initialization failed: ", e.what());
std::cerr << "初始化失败: " << e.what() << std::endl;
close_sse_connection(); close_sse_connection();
return false; return false;
} }
} }
bool client::ping() { bool client::ping() {
// Create ping request
request req = request::create("ping", {}); request req = request::create("ping", {});
try { try {
// Send the request
json result = send_jsonrpc(req); json result = send_jsonrpc(req);
return result.empty();
// The receiver MUST respond promptly with an empty response
if (result.empty()) {
return true;
} else {
return false;
}
} catch (const std::exception& e) { } catch (const std::exception& e) {
// Ping failed
return false; return false;
} }
} }
@ -158,24 +132,34 @@ bool client::ping() {
void client::set_auth_token(const std::string& token) { void client::set_auth_token(const std::string& token) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auth_token_ = token; auth_token_ = token;
// Add to default headers
set_header("Authorization", "Bearer " + auth_token_); set_header("Authorization", "Bearer " + auth_token_);
} }
void client::set_header(const std::string& key, const std::string& value) { void client::set_header(const std::string& key, const std::string& value) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
default_headers_[key] = value; default_headers_[key] = value;
if (http_client_) {
http_client_->set_default_headers({{key, value}});
}
if (sse_client_) {
sse_client_->set_default_headers({{key, value}});
}
} }
void client::set_timeout(int timeout_seconds) { void client::set_timeout(int timeout_seconds) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
timeout_seconds_ = timeout_seconds; timeout_seconds_ = timeout_seconds;
// Update the client's timeout if (http_client_) {
http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_connection_timeout(timeout_seconds_, 0);
http_client_->set_read_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0);
http_client_->set_write_timeout(timeout_seconds_, 0); }
if (sse_client_) {
sse_client_->set_connection_timeout(timeout_seconds_ * 2, 0);
sse_client_->set_write_timeout(timeout_seconds_, 0);
}
} }
void client::set_capabilities(const json& capabilities) { void client::set_capabilities(const json& capabilities) {
@ -212,21 +196,28 @@ json client::call_tool(const std::string& tool_name, const json& arguments) {
} }
std::vector<tool> client::get_tools() { std::vector<tool> client::get_tools() {
json tools_json = send_request("tools/list", {}).result; json response_json = send_request("tools/list", {}).result;
std::vector<tool> tools; std::vector<tool> tools;
if (tools_json.is_array()) { json tools_json;
for (const auto& tool_json : tools_json) { if (response_json.contains("tools") && response_json["tools"].is_array()) {
tool t; tools_json = response_json["tools"];
t.name = tool_json["name"]; } else if (response_json.is_array()) {
t.description = tool_json["description"]; tools_json = response_json;
} else {
return tools;
}
if (tool_json.contains("inputSchema")) { for (const auto& tool_json : tools_json) {
t.parameters_schema = tool_json["inputSchema"]; tool t;
} t.name = tool_json["name"];
t.description = tool_json["description"];
tools.push_back(t); if (tool_json.contains("inputSchema")) {
t.parameters_schema = tool_json["inputSchema"];
} }
tools.push_back(t);
} }
return tools; return tools;
@ -261,321 +252,334 @@ json client::list_resource_templates() {
} }
void client::open_sse_connection() { void client::open_sse_connection() {
// 设置SSE连接状态为运行中
sse_running_ = true; sse_running_ = true;
// 清空消息端点
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
msg_endpoint_.clear(); msg_endpoint_.clear();
// 通知等待的线程虽然消息端点为空但可以让等待的线程检查sse_running_状态
endpoint_cv_.notify_all(); endpoint_cv_.notify_all();
} }
// 打印连接信息(调试用)
std::string connection_info; std::string connection_info;
if (!base_url_.empty()) { if (!base_url_.empty()) {
connection_info = "Base URL: " + base_url_ + ", SSE Endpoint: " + sse_endpoint_; connection_info = "Base URL: " + base_url_ + ", SSE Endpoint: " + sse_endpoint_;
} else { } else {
connection_info = "Host: " + host_ + ", Port: " + std::to_string(port_) + ", SSE Endpoint: " + sse_endpoint_; connection_info = "Host: " + host_ + ", Port: " + std::to_string(port_) + ", SSE Endpoint: " + sse_endpoint_;
} }
std::cerr << "尝试建立SSE连接: " << connection_info << std::endl; LOG_INFO("Attempting to establish SSE connection: ", connection_info);
// 创建并启动SSE线程
sse_thread_ = std::make_unique<std::thread>([this]() { sse_thread_ = std::make_unique<std::thread>([this]() {
int retry_count = 0; int retry_count = 0;
const int max_retries = 5; const int max_retries = 5;
const int retry_delay_base = 1000; // 毫秒 const int retry_delay_base = 1000;
while (sse_running_) { while (sse_running_) {
try { try {
// 尝试建立SSE连接 LOG_INFO("SSE thread: Attempting to connect to ", sse_endpoint_);
std::cerr << "SSE线程: 尝试连接到 " << sse_endpoint_ << std::endl;
auto res = http_client_->Get(sse_endpoint_.c_str(), auto res = sse_client_->Get(sse_endpoint_.c_str(),
[this](const char *data, size_t data_length) { [this](const char *data, size_t data_length) {
// 解析SSE数据
std::cerr << "SSE线程: 收到数据,长度: " << data_length << std::endl;
if (!parse_sse_data(data, data_length)) { if (!parse_sse_data(data, data_length)) {
std::cerr << "SSE线程: 解析数据失败" << std::endl; LOG_ERROR("SSE thread: Failed to parse data");
return false; // 解析失败,关闭连接 return false;
} }
// 检查是否应该关闭连接
bool should_continue = sse_running_.load(); bool should_continue = sse_running_.load();
if (!should_continue) { if (!should_continue) {
std::cerr << "SSE线程: sse_running_为false关闭连接" << std::endl; LOG_INFO("SSE thread: sse_running_ is false, closing connection");
} }
return should_continue; // 如果sse_running_为false关闭连接 return should_continue;
}); });
// 检查连接是否成功
if (!res) { if (!res) {
std::string error_msg = "SSE连接失败: "; std::string error_msg = "SSE connection failed: ";
error_msg += "错误代码: " + std::to_string(static_cast<int>(res.error())); error_msg += httplib::to_string(res.error());
// 添加更详细的错误信息
switch (res.error()) {
case httplib::Error::Connection:
error_msg += " (连接错误,服务器可能未运行或无法访问)";
break;
case httplib::Error::Read:
error_msg += " (读取错误,服务器可能关闭了连接或响应格式不正确)";
break;
case httplib::Error::Write:
error_msg += " (写入错误)";
break;
case httplib::Error::ConnectionTimeout:
error_msg += " (连接超时)";
break;
case httplib::Error::Canceled:
error_msg += " (请求被取消)";
break;
default:
error_msg += " (未知错误)";
break;
}
throw std::runtime_error(error_msg); throw std::runtime_error(error_msg);
} }
// 连接成功后重置重试计数
retry_count = 0; retry_count = 0;
std::cerr << "SSE线程: 连接成功" << std::endl; LOG_INFO("SSE thread: Connection successful");
} catch (const std::exception& e) { } catch (const std::exception& e) {
// 记录错误 LOG_ERROR("SSE connection error: ", e.what());
std::cerr << "SSE连接错误: " << e.what() << std::endl;
// 如果已达到最大重试次数,停止尝试 if (!sse_running_) {
if (++retry_count > max_retries) { LOG_INFO("SSE connection actively closed, no retry needed");
std::cerr << "达到最大重试次数停止SSE连接尝试" << std::endl;
break; break;
} }
// 指数退避重试 if (++retry_count > max_retries) {
int delay = retry_delay_base * (1 << (retry_count - 1)); // 2^(retry_count-1) * base_delay LOG_ERROR("Maximum retry count reached, stopping SSE connection attempts");
std::cerr << "将在 " << delay << " 毫秒后重试 (尝试 " << retry_count << "/" << max_retries << ")" << std::endl; break;
std::this_thread::sleep_for(std::chrono::milliseconds(delay)); }
int delay = retry_delay_base * (1 << (retry_count - 1));
LOG_INFO("Will retry in ", delay, " ms (attempt ", retry_count, "/", max_retries, ")");
const int check_interval = 100;
for (int waited = 0; waited < delay && sse_running_; waited += check_interval) {
std::this_thread::sleep_for(std::chrono::milliseconds(check_interval));
}
if (!sse_running_) {
LOG_INFO("SSE connection actively closed during retry wait, stopping retry");
break;
}
} }
} }
std::cerr << "SSE线程: 退出" << std::endl; LOG_INFO("SSE thread: Exiting");
}); });
} }
// 新增方法解析SSE数据
bool client::parse_sse_data(const char* data, size_t length) { bool client::parse_sse_data(const char* data, size_t length) {
try { try {
std::string sse_data(data, length); std::string sse_data(data, length);
// 查找"data:"标记 std::string event_type = "message";
auto event_pos = sse_data.find("event: ");
if (event_pos != std::string::npos) {
auto event_end = sse_data.find("\n", event_pos);
if (event_end != std::string::npos) {
event_type = sse_data.substr(event_pos + 7, event_end - (event_pos + 7));
if (!event_type.empty() && event_type.back() == '\r') {
event_type.pop_back();
}
}
}
auto data_pos = sse_data.find("data: "); auto data_pos = sse_data.find("data: ");
if (data_pos == std::string::npos) { if (data_pos == std::string::npos) {
return true; // 不是数据事件,可能是注释或心跳,继续保持连接
}
// 查找数据行结束位置
auto newline_pos = sse_data.find("\n", data_pos);
if (newline_pos == std::string::npos) {
newline_pos = sse_data.length(); // 如果没有换行符,使用整个字符串
}
// 提取数据内容
std::string data_content = sse_data.substr(data_pos + 6, newline_pos - (data_pos + 6));
// 检查是否是心跳事件
if (sse_data.find("event: heartbeat") != std::string::npos) {
// 心跳事件,不需要处理数据
return true; return true;
} }
// 更新消息端点 auto newline_pos = sse_data.find("\n", data_pos);
{ if (newline_pos == std::string::npos) {
std::lock_guard<std::mutex> lock(mutex_); newline_pos = sse_data.length();
msg_endpoint_ = data_content;
// 通知等待的线程
endpoint_cv_.notify_all();
} }
return true; std::string data_content = sse_data.substr(data_pos + 6, newline_pos - (data_pos + 6));
if (event_type == "heartbeat") {
return true;
} else if (event_type == "endpoint") {
std::lock_guard<std::mutex> lock(mutex_);
msg_endpoint_ = data_content;
endpoint_cv_.notify_all();
return true;
} else if (event_type == "message") {
try {
json response = json::parse(data_content);
if (response.contains("jsonrpc") && response.contains("id") && !response["id"].is_null()) {
json id = response["id"];
std::lock_guard<std::mutex> lock(response_mutex_);
auto it = pending_requests_.find(id);
if (it != pending_requests_.end()) {
if (response.contains("result")) {
it->second.set_value(response["result"]);
} else if (response.contains("error")) {
json error_result = {
{"isError", true},
{"error", response["error"]}
};
it->second.set_value(error_result);
} else {
it->second.set_value(json::object());
}
pending_requests_.erase(it);
} else {
LOG_WARNING("Received response for unknown request ID: ", id);
}
} else {
LOG_WARNING("Received invalid JSON-RPC response: ", response.dump());
}
} catch (const json::exception& e) {
LOG_ERROR("Failed to parse JSON-RPC response: ", e.what());
}
return true;
} else {
LOG_WARNING("Received unknown event type: ", event_type);
return true;
}
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << "解析SSE数据错误: " << e.what() << std::endl; LOG_ERROR("Error parsing SSE data: ", e.what());
return false; return false;
} }
} }
// 新增方法关闭SSE连接
void client::close_sse_connection() { void client::close_sse_connection() {
// 设置标志这将导致SSE回调函数返回false从而关闭连接 if (!sse_running_) {
LOG_INFO("SSE connection already closed");
return;
}
LOG_INFO("Actively closing SSE connection (normal exit flow)...");
sse_running_ = false; sse_running_ = false;
// 给一些时间让回调函数返回false并关闭连接
std::this_thread::sleep_for(std::chrono::milliseconds(500)); std::this_thread::sleep_for(std::chrono::milliseconds(500));
// 等待SSE线程结束
if (sse_thread_ && sse_thread_->joinable()) { if (sse_thread_ && sse_thread_->joinable()) {
// 设置一个合理的超时时间例如5秒
auto timeout = std::chrono::seconds(5); auto timeout = std::chrono::seconds(5);
auto start = std::chrono::steady_clock::now(); auto start = std::chrono::steady_clock::now();
// 尝试在超时前等待线程结束 LOG_INFO("Waiting for SSE thread to end...");
while (sse_thread_->joinable() && while (sse_thread_->joinable() &&
std::chrono::steady_clock::now() - start < timeout) { std::chrono::steady_clock::now() - start < timeout) {
try { try {
// 尝试立即加入线程
sse_thread_->join(); sse_thread_->join();
break; // 如果成功加入,跳出循环 LOG_INFO("SSE thread successfully ended");
break;
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << "等待SSE线程时出错: " << e.what() << std::endl; LOG_ERROR("Error waiting for SSE thread: ", e.what());
// 短暂休眠避免CPU占用过高
std::this_thread::sleep_for(std::chrono::milliseconds(100)); std::this_thread::sleep_for(std::chrono::milliseconds(100));
} }
} }
// 如果线程仍然没有结束,记录警告并分离线程
if (sse_thread_->joinable()) { if (sse_thread_->joinable()) {
std::cerr << "警告: SSE线程未能在超时时间内结束分离线程" << std::endl; LOG_WARNING("SSE thread did not end within timeout, detaching thread");
sse_thread_->detach(); sse_thread_->detach();
} }
} }
// 清空消息端点
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
msg_endpoint_.clear(); msg_endpoint_.clear();
// 通知等待的线程虽然消息端点为空但可以让等待的线程检查sse_running_状态
endpoint_cv_.notify_all(); endpoint_cv_.notify_all();
} }
std::cerr << "SSE连接已关闭" << std::endl; LOG_INFO("SSE connection successfully closed (normal exit flow)");
} }
json client::send_jsonrpc(const request& req) { json client::send_jsonrpc(const request& req) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
// 检查消息端点是否已设置
if (msg_endpoint_.empty()) { if (msg_endpoint_.empty()) {
throw mcp_exception(error_code::internal_error, "消息端点未设置SSE连接可能未建立"); throw mcp_exception(error_code::internal_error, "Message endpoint not set, SSE connection may not be established");
} }
// 打印请求信息(调试用)
std::cerr << "发送JSON-RPC请求: 方法=" << req.method << ", 端点=" << msg_endpoint_ << std::endl;
// Convert request to JSON
json req_json = req.to_json(); json req_json = req.to_json();
std::string req_body = req_json.dump(); std::string req_body = req_json.dump();
// Prepare headers
httplib::Headers headers; httplib::Headers headers;
headers.emplace("Content-Type", "application/json"); headers.emplace("Content-Type", "application/json");
// Add default headers
for (const auto& [key, value] : default_headers_) { for (const auto& [key, value] : default_headers_) {
headers.emplace(key, value); headers.emplace(key, value);
} }
// Send the request if (req.is_notification()) {
auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json"); auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json");
if (!result) { if (!result) {
// Error occurred auto err = result.error();
auto err = result.error(); std::string error_msg = httplib::to_string(err);
std::string error_msg; LOG_ERROR("JSON-RPC request failed: ", error_msg);
throw mcp_exception(error_code::internal_error, error_msg);
switch (err) {
case httplib::Error::Connection:
error_msg = "连接错误,服务器可能未运行或无法访问";
break;
case httplib::Error::Read:
error_msg = "读取错误,服务器可能关闭了连接或响应格式不正确";
break;
case httplib::Error::Write:
error_msg = "写入错误";
break;
case httplib::Error::ConnectionTimeout:
error_msg = "连接超时";
break;
default:
error_msg = "HTTP客户端错误: " + std::to_string(static_cast<int>(err));
break;
} }
std::cerr << "JSON-RPC请求失败: " << error_msg << std::endl;
throw mcp_exception(error_code::internal_error, error_msg);
}
// Check if it's a notification (no response expected)
if (req.is_notification()) {
return json::object(); return json::object();
} }
// Parse response std::promise<json> response_promise;
try { std::future<json> response_future = response_promise.get_future();
json res_json = json::parse(result->body);
// 打印响应信息(调试用) {
std::cerr << "收到JSON-RPC响应: " << res_json.dump() << std::endl; std::lock_guard<std::mutex> response_lock(response_mutex_);
pending_requests_[req.id] = std::move(response_promise);
}
// Check for error auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json");
if (res_json.contains("error")) {
int code = res_json["error"]["code"];
std::string message = res_json["error"]["message"];
throw mcp_exception(static_cast<error_code>(code), message); if (!result) {
auto err = result.error();
std::string error_msg = httplib::to_string(err);
{
std::lock_guard<std::mutex> response_lock(response_mutex_);
pending_requests_.erase(req.id);
} }
// Return result LOG_ERROR("JSON-RPC request failed: ", error_msg);
if (res_json.contains("result")) { throw mcp_exception(error_code::internal_error, error_msg);
return res_json["result"]; }
if (result->status != 202) {
try {
json res_json = json::parse(result->body);
{
std::lock_guard<std::mutex> response_lock(response_mutex_);
pending_requests_.erase(req.id);
}
if (res_json.contains("error")) {
int code = res_json["error"]["code"];
std::string message = res_json["error"]["message"];
throw mcp_exception(static_cast<error_code>(code), message);
}
if (res_json.contains("result")) {
return res_json["result"];
} else {
return json::object();
}
} catch (const json::exception& e) {
{
std::lock_guard<std::mutex> response_lock(response_mutex_);
pending_requests_.erase(req.id);
}
throw mcp_exception(error_code::parse_error,
"Failed to parse JSON-RPC response: " + std::string(e.what()));
}
} else {
const auto timeout = std::chrono::seconds(timeout_seconds_);
auto status = response_future.wait_for(timeout);
if (status == std::future_status::ready) {
json response = response_future.get();
if (response.contains("isError") && response["isError"].get<bool>()) {
int code = response["error"]["code"];
std::string message = response["error"]["message"];
throw mcp_exception(static_cast<error_code>(code), message);
}
return response;
} else { } else {
return json::object(); {
std::lock_guard<std::mutex> response_lock(response_mutex_);
pending_requests_.erase(req.id);
}
throw mcp_exception(error_code::internal_error, "Timeout waiting for SSE response");
} }
} catch (const json::exception& e) {
throw mcp_exception(error_code::parse_error,
"Failed to parse JSON-RPC response: " + std::string(e.what()));
} }
} }
bool client::check_server_accessible() { bool client::check_server_accessible() {
std::cerr << "检查服务器是否可访问..." << std::endl; LOG_INFO("Checking if server is accessible...");
try { try {
// 尝试发送一个简单的GET请求到服务器
auto res = http_client_->Get("/"); auto res = http_client_->Get("/");
if (res) { if (res) {
std::cerr << "服务器可访问,状态码: " << res->status << std::endl; LOG_INFO("Server is accessible, status code: ", res->status);
return true; return true;
} else { } else {
std::string error_msg = "服务器不可访问,错误代码: " + std::to_string(static_cast<int>(res.error())); std::string error_msg = "Server not accessible: " + httplib::to_string(res.error());
LOG_ERROR(error_msg);
// 添加更详细的错误信息
switch (res.error()) {
case httplib::Error::Connection:
error_msg += " (连接错误,服务器可能未运行或无法访问)";
break;
case httplib::Error::Read:
error_msg += " (读取错误)";
break;
case httplib::Error::Write:
error_msg += " (写入错误)";
break;
case httplib::Error::ConnectionTimeout:
error_msg += " (连接超时)";
break;
default:
error_msg += " (未知错误)";
break;
}
std::cerr << error_msg << std::endl;
return false; return false;
} }
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << "检查服务器可访问性时发生异常: " << e.what() << std::endl; LOG_ERROR("Exception while checking server accessibility: ", e.what());
return false; return false;
} }
} }

View File

@ -14,14 +14,6 @@ server::server(const std::string& host, int port, const std::string& sse_endpoin
: host_(host), port_(port), sse_endpoint_(sse_endpoint), msg_endpoint_(msg_endpoint_prefix), name_("MCP Server"), version_(MCP_VERSION) { : host_(host), port_(port), sse_endpoint_(sse_endpoint), msg_endpoint_(msg_endpoint_prefix), name_("MCP Server"), version_(MCP_VERSION) {
http_server_ = std::make_unique<httplib::Server>(); http_server_ = std::make_unique<httplib::Server>();
// Set default capabilities
capabilities_ = {
{"resources", {
{"subscribe", true},
{"listChanged", true}
}}
};
} }
server::~server() { server::~server() {
@ -33,7 +25,7 @@ bool server::start(bool blocking) {
return true; // Already running return true; // Already running
} }
std::cerr << "启动MCP服务器: " << host_ << ":" << port_ << std::endl; LOG_INFO("Starting MCP server on ", host_, ":", port_);
// 设置CORS处理 // 设置CORS处理
http_server_->Options(".*", [](const httplib::Request& req, httplib::Response& res) { http_server_->Options(".*", [](const httplib::Request& req, httplib::Response& res) {
@ -46,36 +38,36 @@ bool server::start(bool blocking) {
// 设置JSON-RPC端点 // 设置JSON-RPC端点
http_server_->Post(msg_endpoint_.c_str(), [this](const httplib::Request& req, httplib::Response& res) { http_server_->Post(msg_endpoint_.c_str(), [this](const httplib::Request& req, httplib::Response& res) {
this->handle_jsonrpc(req, res); this->handle_jsonrpc(req, res);
LOG_INFO(req.remote_addr, ":", req.remote_port, " - \"POST ", req.path, " HTTP/1.1\" ", res.status);
}); });
// 设置SSE端点 // 设置SSE端点
http_server_->Get(sse_endpoint_.c_str(), [this](const httplib::Request& req, httplib::Response& res) { http_server_->Get(sse_endpoint_.c_str(), [this](const httplib::Request& req, httplib::Response& res) {
this->handle_sse(req, res); this->handle_sse(req, res);
LOG_INFO(req.remote_addr, ":", req.remote_port, " - \"GET ", req.path, " HTTP/1.1\" ", res.status);
}); });
// 启动服务器 // 启动服务器
if (blocking) { if (blocking) {
running_ = true; running_ = true;
std::cerr << "以阻塞模式启动服务器" << std::endl; LOG_INFO("Starting server in blocking mode");
print_status();
if (!http_server_->listen(host_.c_str(), port_)) { if (!http_server_->listen(host_.c_str(), port_)) {
running_ = false; running_ = false;
std::cerr << "无法在 " << host_ << ":" << port_ << " 上启动服务器" << std::endl; LOG_ERROR("Failed to start server on ", host_, ":", port_);
return false; return false;
} }
return true; return true;
} else { } else {
// 在单独的线程中启动 // 在单独的线程中启动
server_thread_ = std::make_unique<std::thread>([this]() { server_thread_ = std::make_unique<std::thread>([this]() {
std::cerr << "在单独的线程中启动服务器" << std::endl; LOG_INFO("Starting server in separate thread");
if (!http_server_->listen(host_.c_str(), port_)) { if (!http_server_->listen(host_.c_str(), port_)) {
std::cerr << "无法在 " << host_ << ":" << port_ << " 上启动服务器" << std::endl; LOG_ERROR("Failed to start server on ", host_, ":", port_);
running_ = false; running_ = false;
return; return;
} }
}); });
running_ = true; running_ = true;
print_status();
return true; return true;
} }
} }
@ -85,15 +77,13 @@ void server::stop() {
return; return;
} }
std::cerr << "正在停止MCP服务器..." << std::endl; LOG_INFO("Stopping MCP server...");
print_status();
running_ = false; running_ = false;
// 关闭所有SSE连接 // 关闭所有SSE连接
std::vector<std::string> session_ids; std::vector<std::string> session_ids;
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
// 先收集所有会话ID
for (const auto& [session_id, _] : session_dispatchers_) { for (const auto& [session_id, _] : session_dispatchers_) {
session_ids.push_back(session_id); session_ids.push_back(session_id);
} }
@ -102,31 +92,25 @@ void server::stop() {
// 关闭每个会话的分发器 // 关闭每个会话的分发器
for (const auto& session_id : session_ids) { for (const auto& session_id : session_ids) {
try { try {
std::cerr << "关闭会话: " << session_id << std::endl;
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto it = session_dispatchers_.find(session_id); auto it = session_dispatchers_.find(session_id);
if (it != session_dispatchers_.end()) { if (it != session_dispatchers_.end()) {
it->second->close(); it->second->close();
} }
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << "关闭会话时发生异常: " << session_id << ", " << e.what() << std::endl; LOG_ERROR("Exception while closing session ", session_id, ": ", e.what());
} }
} }
// 等待一段时间,让会话线程有机会自行清理
std::cerr << "等待会话线程自行清理..." << std::endl;
std::this_thread::sleep_for(std::chrono::seconds(1));
// 清理剩余的线程 // 清理剩余的线程
{ {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
for (auto& [session_id, thread] : sse_threads_) { for (auto& [session_id, thread] : sse_threads_) {
if (thread && thread->joinable()) { if (thread && thread->joinable()) {
try { try {
std::cerr << "分离会话线程: " << session_id << std::endl;
thread->detach(); thread->detach();
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << "分离会话线程时发生异常: " << session_id << ", " << e.what() << std::endl; LOG_ERROR("Exception while detaching session thread ", session_id, ": ", e.what());
} }
} }
} }
@ -137,16 +121,14 @@ void server::stop() {
} }
if (http_server_) { if (http_server_) {
std::cerr << "停止HTTP服务器..." << std::endl;
http_server_->stop(); http_server_->stop();
} }
if (server_thread_ && server_thread_->joinable()) { if (server_thread_ && server_thread_->joinable()) {
std::cerr << "等待服务器线程结束..." << std::endl;
server_thread_->join(); server_thread_->join();
} }
std::cerr << "MCP服务器已停止" << std::endl; LOG_INFO("MCP server stopped");
} }
bool server::is_running() const { bool server::is_running() const {
@ -212,10 +194,7 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
{"resources", resources} {"resources", resources}
}; };
// Handle pagination if cursor is provided
if (params.contains("cursor")) { if (params.contains("cursor")) {
// In this implementation, we don't actually paginate
// but we include the nextCursor field for compatibility
result["nextCursor"] = ""; result["nextCursor"] = "";
} }
@ -235,15 +214,12 @@ void server::register_resource(const std::string& path, std::shared_ptr<resource
throw mcp_exception(error_code::invalid_params, "Resource not found: " + uri); throw mcp_exception(error_code::invalid_params, "Resource not found: " + uri);
} }
// In a real implementation, we would register a subscription here
// For now, just return success
return json::object(); return json::object();
}; };
} }
if (method_handlers_.find("resources/templates/list") == method_handlers_.end()) { if (method_handlers_.find("resources/templates/list") == method_handlers_.end()) {
method_handlers_["resources/templates/list"] = [this](const json& params) -> json { method_handlers_["resources/templates/list"] = [this](const json& params) -> json {
// In this implementation, we don't support resource templates
return json::array(); return json::array();
}; };
} }
@ -260,7 +236,7 @@ void server::register_tool(const tool& tool, tool_handler handler) {
for (const auto& [name, tool_pair] : tools_) { for (const auto& [name, tool_pair] : tools_) {
tools_json.push_back(tool_pair.first.to_json()); tools_json.push_back(tool_pair.first.to_json());
} }
return tools_json; return json{{"tools", tools_json}};
}; };
} }
@ -310,18 +286,9 @@ void server::set_auth_handler(std::function<bool(const std::string&)> handler) {
} }
void server::handle_sse(const httplib::Request& req, httplib::Response& res) { void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
// 生成会话ID
std::string session_id = generate_session_id(); std::string session_id = generate_session_id();
std::string session_uri = msg_endpoint_ + "?session_id=" + session_id; std::string session_uri = msg_endpoint_ + "?session_id=" + session_id;
std::cerr << "新的SSE连接: 客户端=" << req.remote_addr << ", 会话ID=" << session_id << std::endl;
std::cerr << "会话URI: " << session_uri << std::endl;
std::cerr << "请求头: ";
for (const auto& [key, value] : req.headers) {
std::cerr << key << "=" << value << " ";
}
std::cerr << std::endl;
// 设置SSE响应头 // 设置SSE响应头
res.set_header("Content-Type", "text/event-stream"); res.set_header("Content-Type", "text/event-stream");
res.set_header("Cache-Control", "no-cache"); res.set_header("Cache-Control", "no-cache");
@ -337,62 +304,46 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
session_dispatchers_[session_id] = session_dispatcher; session_dispatchers_[session_id] = session_dispatcher;
} }
// 创建会话线程,使用值捕获而不是引用捕获 // 创建会话线程
auto thread = std::make_unique<std::thread>([this, session_id, session_uri, session_dispatcher]() { auto thread = std::make_unique<std::thread>([this, res, session_id, session_uri, session_dispatcher]() {
try { try {
std::cerr << "SSE会话线程启动: " << session_id << std::endl; // 发送初始会话URI
// 发送初始会话URI - 使用endpoint事件类型符合MCP规范
std::this_thread::sleep_for(std::chrono::milliseconds(500)); std::this_thread::sleep_for(std::chrono::milliseconds(500));
std::stringstream ss; std::stringstream ss;
ss << "event: endpoint\ndata: " << session_uri << "\n\n"; ss << "event: endpoint\ndata: " << session_uri << "\n\n";
session_dispatcher->send_event(ss.str()); session_dispatcher->send_event(ss.str());
std::cerr << "发送会话URI: " << session_uri << " 到会话: " << session_id << std::endl;
// 定期发送心跳,检测连接状态 // 定期发送心跳,检测连接状态
int heartbeat_count = 0; int heartbeat_count = 0;
while (running_ && !session_dispatcher->is_closed()) { while (running_ && !session_dispatcher->is_closed()) {
std::this_thread::sleep_for(std::chrono::seconds(10)); std::this_thread::sleep_for(std::chrono::seconds(10));
// 检查分发器是否已关闭 if (session_dispatcher->is_closed() || !running_) {
if (session_dispatcher->is_closed()) {
std::cerr << "会话已关闭,停止心跳: " << session_id << std::endl;
break; break;
} }
// 检查服务器是否仍在运行
if (!running_) {
std::cerr << "服务器已停止,停止心跳: " << session_id << std::endl;
break;
}
// 发送心跳事件 - 使用自定义heartbeat事件类型
std::stringstream heartbeat; std::stringstream heartbeat;
heartbeat << "event: heartbeat\ndata: " << heartbeat_count++ << "\n\n"; heartbeat << "event: heartbeat\ndata: " << heartbeat_count++ << "\n\n";
try { try {
bool sent = session_dispatcher->send_event(heartbeat.str()); bool sent = session_dispatcher->send_event(heartbeat.str());
if (!sent) { if (!sent) {
std::cerr << "发送心跳失败,客户端可能已关闭连接: " << session_id << std::endl; LOG_WARNING("Failed to send heartbeat, client may have closed connection: ", session_id);
break; break;
} }
std::cerr << "发送心跳到会话: " << session_id << ", 计数: " << heartbeat_count << std::endl;
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << "发送心跳失败,假定连接已关闭: " << e.what() << std::endl; LOG_ERROR("Failed to send heartbeat: ", e.what());
break; break;
} }
} }
std::cerr << "SSE会话线程退出: " << session_id << std::endl;
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << "SSE会话线程异常: " << session_id << ", " << e.what() << std::endl; LOG_ERROR("SSE session thread exception: ", session_id, ", ", e.what());
} }
// 安全地清理资源 // 安全地清理资源
try { try {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
// 先关闭分发器
auto dispatcher_it = session_dispatchers_.find(session_id); auto dispatcher_it = session_dispatchers_.find(session_id);
if (dispatcher_it != session_dispatchers_.end()) { if (dispatcher_it != session_dispatchers_.end()) {
if (!dispatcher_it->second->is_closed()) { if (!dispatcher_it->second->is_closed()) {
@ -401,18 +352,13 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
session_dispatchers_.erase(dispatcher_it); session_dispatchers_.erase(dispatcher_it);
} }
// 再移除线程
auto thread_it = sse_threads_.find(session_id); auto thread_it = sse_threads_.find(session_id);
if (thread_it != sse_threads_.end()) { if (thread_it != sse_threads_.end()) {
// 不要在线程内部join或detach自己 thread_it->second.release();
// 只从映射中移除
thread_it->second.release(); // 释放所有权但不删除线程对象
sse_threads_.erase(thread_it); sse_threads_.erase(thread_it);
} }
std::cerr << "会话资源已清理: " << session_id << std::endl;
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << "清理会话资源时发生异常: " << session_id << ", " << e.what() << std::endl; LOG_ERROR("Exception while cleaning up session resources: ", session_id, ", ", e.what());
} }
}); });
@ -430,9 +376,6 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto it = session_dispatchers_.find(session_id); auto it = session_dispatchers_.find(session_id);
if (it == session_dispatchers_.end() || it->second->is_closed()) { if (it == session_dispatchers_.end() || it->second->is_closed()) {
std::cerr << "会话已关闭,停止内容提供: " << session_id << std::endl;
// 不在这里清理资源,让会话线程自己清理
return false; return false;
} }
} }
@ -440,7 +383,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
// 等待事件 // 等待事件
bool result = session_dispatcher->wait_event(&sink); bool result = session_dispatcher->wait_event(&sink);
if (!result) { if (!result) {
std::cerr << "等待事件失败,关闭连接: " << session_id << std::endl; LOG_WARNING("Failed to wait for event, closing connection: ", session_id);
// 关闭会话分发器,但不清理资源 // 关闭会话分发器,但不清理资源
{ {
@ -454,10 +397,9 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
return false; return false;
} }
return true; return true;
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << "SSE内容提供者异常: " << e.what() << std::endl; LOG_ERROR("SSE content provider exception: ", e.what());
// 关闭会话分发器,但不清理资源 // 关闭会话分发器,但不清理资源
try { try {
@ -467,7 +409,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) {
it->second->close(); it->second->close();
} }
} catch (const std::exception& e2) { } catch (const std::exception& e2) {
std::cerr << "关闭会话分发器时发生异常: " << e2.what() << std::endl; LOG_ERROR("Exception while closing session dispatcher: ", e2.what());
} }
return false; return false;
@ -492,143 +434,110 @@ void server::handle_jsonrpc(const httplib::Request& req, httplib::Response& res)
auto it = req.params.find("session_id"); auto it = req.params.find("session_id");
std::string session_id = it != req.params.end() ? it->second : ""; std::string session_id = it != req.params.end() ? it->second : "";
std::cerr << "收到JSON-RPC请求: 会话ID=" << session_id << ", 路径=" << req.path << std::endl;
std::cerr << "请求参数: ";
for (const auto& [key, value] : req.params) {
std::cerr << key << "=" << value << " ";
}
std::cerr << std::endl;
// 检查会话是否存在
{
std::lock_guard<std::mutex> lock(mutex_);
if (!session_id.empty()) {
std::cerr << "检查会话是否存在: " << session_id << std::endl;
std::cerr << "当前活跃会话: ";
for (const auto& [id, _] : session_dispatchers_) {
std::cerr << id << " ";
}
std::cerr << std::endl;
if (session_dispatchers_.find(session_id) == session_dispatchers_.end()) {
std::cerr << "会话不存在: " << session_id << std::endl;
json error_response = {
{"jsonrpc", "2.0"},
{"error", {
{"code", static_cast<int>(error_code::invalid_request)},
{"message", "Session not found"}
}},
{"id", nullptr}
};
res.set_content(error_response.dump(), "application/json");
return;
} else {
std::cerr << "会话存在: " << session_id << std::endl;
}
} else {
std::cerr << "请求中没有会话ID" << std::endl;
}
}
// 解析请求 // 解析请求
json req_json; json req_json;
try { try {
req_json = json::parse(req.body); req_json = json::parse(req.body);
std::cerr << "请求内容: " << req_json.dump() << std::endl;
} catch (const json::exception& e) { } catch (const json::exception& e) {
// 无效的JSON LOG_ERROR("Failed to parse JSON request: ", e.what());
std::cerr << "解析JSON失败: " << e.what() << std::endl; res.status = 400;
json error_response = { res.set_content("{\"error\":\"Invalid JSON\"}", "application/json");
{"jsonrpc", "2.0"},
{"error", {
{"code", static_cast<int>(error_code::parse_error)},
{"message", "Parse error: " + std::string(e.what())}
}},
{"id", nullptr}
};
res.set_content(error_response.dump(), "application/json");
return; return;
} }
// 检查是否是批量请求 // 检查会话是否存在
if (req_json.is_array()) { std::shared_ptr<event_dispatcher> dispatcher;
// 批量请求暂不支持 {
std::cerr << "不支持批量请求" << std::endl; std::lock_guard<std::mutex> lock(mutex_);
json error_response = { auto disp_it = session_dispatchers_.find(session_id);
{"jsonrpc", "2.0"}, if (disp_it == session_dispatchers_.end()) {
{"error", { // 处理ping请求
{"code", static_cast<int>(error_code::invalid_request)}, if (req_json["method"] == "ping") {
{"message", "Batch requests are not supported"} res.status = 202;
}}, res.set_content("Accepted", "text/plain");
{"id", nullptr} return;
}; }
res.set_content(error_response.dump(), "application/json"); LOG_ERROR("Session not found: ", session_id);
return; res.status = 404;
res.set_content("{\"error\":\"Session not found\"}", "application/json");
return;
}
dispatcher = disp_it->second;
} }
// 转换为请求对象 // 创建请求对象
request mcp_req; request mcp_req;
try { try {
mcp_req.jsonrpc = req_json["jsonrpc"]; mcp_req.jsonrpc = req_json["jsonrpc"].get<std::string>();
mcp_req.method = req_json["method"]; if (req_json.contains("id") && !req_json["id"].is_null()) {
if (req_json.contains("id")) {
mcp_req.id = req_json["id"]; mcp_req.id = req_json["id"];
} }
mcp_req.method = req_json["method"].get<std::string>();
if (req_json.contains("params")) { if (req_json.contains("params")) {
mcp_req.params = req_json["params"]; mcp_req.params = req_json["params"];
} }
} catch (const json::exception& e) { } catch (const std::exception& e) {
// 无效的请求 LOG_ERROR("Failed to create request object: ", e.what());
std::cerr << "无效的请求: " << e.what() << std::endl; res.status = 400;
json error_response = { res.set_content("{\"error\":\"Invalid request format\"}", "application/json");
{"jsonrpc", "2.0"},
{"error", {
{"code", static_cast<int>(error_code::invalid_request)},
{"message", "Invalid request: " + std::string(e.what())}
}},
{"id", nullptr}
};
res.set_content(error_response.dump(), "application/json");
return; return;
} }
// 处理请求 // 如果是通知没有ID直接处理并返回202状态码
std::cerr << "处理方法: " << mcp_req.method << std::endl; if (mcp_req.is_notification()) {
json result = process_request(mcp_req, session_id); // 在线程池中异步处理通知
std::cerr << "响应: " << result.dump() << std::endl; thread_pool_.enqueue([this, mcp_req, session_id]() {
res.set_content(result.dump(), "application/json"); process_request(mcp_req, session_id);
});
// 返回202 Accepted
res.status = 202;
res.set_content("Accepted", "text/plain");
return;
}
// 对于有ID的请求在线程池中处理并通过SSE返回结果
thread_pool_.enqueue([this, mcp_req, session_id, dispatcher]() {
// 处理请求
json response_json = process_request(mcp_req, session_id);
// 通过SSE发送响应
std::stringstream ss;
ss << "event: message\ndata: " << response_json.dump() << "\n\n";
bool result = dispatcher->send_event(ss.str());
if (!result) {
LOG_ERROR("Failed to send response via SSE: session_id=", session_id);
}
});
// 返回202 Accepted
res.status = 202;
res.set_content("Accepted", "text/plain");
} }
json server::process_request(const request& req, const std::string& session_id) { json server::process_request(const request& req, const std::string& session_id) {
// 检查是否是通知 // 检查是否是通知
if (req.is_notification()) { if (req.is_notification()) {
std::cerr << "处理通知: " << req.method << std::endl;
// 通知没有响应
if (req.method == "notifications/initialized") { if (req.method == "notifications/initialized") {
std::cerr << "收到客户端initialized通知会话: " << session_id << std::endl;
set_session_initialized(session_id, true); set_session_initialized(session_id, true);
std::cerr << "会话已设置为初始化状态: " << session_id << std::endl;
} }
return json::object(); return json::object();
} }
// 处理方法调用 // 处理方法调用
try { try {
LOG_INFO("处理方法调用: ", req.method); LOG_INFO("Processing method call: ", req.method);
// 特殊情况:初始化 // 特殊情况:初始化
if (req.method == "initialize") { if (req.method == "initialize") {
return handle_initialize(req, session_id); return handle_initialize(req, session_id);
} else if (req.method == "ping") { } else if (req.method == "ping") {
// 接收者必须立即响应一个空响应 return response::create_success(req.id, json::object()).to_json();
LOG_INFO("处理ping请求");
return response::create_success(req.id, {}).to_json();
} }
if (!is_session_initialized(session_id)) { if (!is_session_initialized(session_id)) {
LOG_WARNING("会话未初始化: ", session_id); LOG_WARNING("Session not initialized: ", session_id);
return response::create_error( return response::create_error(
req.id, req.id,
error_code::invalid_request, error_code::invalid_request,
@ -648,19 +557,19 @@ json server::process_request(const request& req, const std::string& session_id)
if (handler) { if (handler) {
// 调用处理器 // 调用处理器
LOG_INFO("调用方法处理器: ", req.method); LOG_INFO("Calling method handler: ", req.method);
auto future = thread_pool_.enqueue([handler, params = req.params]() -> json { auto future = thread_pool_.enqueue([handler, params = req.params]() -> json {
return handler(params); return handler(params);
}); });
json result = future.get(); json result = future.get();
// 创建成功响应 // 创建成功响应
LOG_INFO("方法调用成功: ", req.method); LOG_INFO("Method call successful: ", req.method);
return response::create_success(req.id, result).to_json(); return response::create_success(req.id, result).to_json();
} }
// 方法未找到 // 方法未找到
LOG_WARNING("方法未找到: ", req.method); LOG_WARNING("Method not found: ", req.method);
return response::create_error( return response::create_error(
req.id, req.id,
error_code::method_not_found, error_code::method_not_found,
@ -668,7 +577,7 @@ json server::process_request(const request& req, const std::string& session_id)
).to_json(); ).to_json();
} catch (const mcp_exception& e) { } catch (const mcp_exception& e) {
// MCP异常 // MCP异常
LOG_ERROR("MCP异常: ", e.what(), ", 代码: ", static_cast<int>(e.code())); LOG_ERROR("MCP exception: ", e.what(), ", code: ", static_cast<int>(e.code()));
return response::create_error( return response::create_error(
req.id, req.id,
e.code(), e.code(),
@ -676,7 +585,7 @@ json server::process_request(const request& req, const std::string& session_id)
).to_json(); ).to_json();
} catch (const std::exception& e) { } catch (const std::exception& e) {
// 其他异常 // 其他异常
LOG_ERROR("处理请求时发生异常: ", e.what()); LOG_ERROR("Exception while processing request: ", e.what());
return response::create_error( return response::create_error(
req.id, req.id,
error_code::internal_error, error_code::internal_error,
@ -684,7 +593,7 @@ json server::process_request(const request& req, const std::string& session_id)
).to_json(); ).to_json();
} catch (...) { } catch (...) {
// 未知异常 // 未知异常
LOG_ERROR("处理请求时发生未知异常"); LOG_ERROR("Unknown exception while processing request");
return response::create_error( return response::create_error(
req.id, req.id,
error_code::internal_error, error_code::internal_error,
@ -696,11 +605,9 @@ json server::process_request(const request& req, const std::string& session_id)
json server::handle_initialize(const request& req, const std::string& session_id) { json server::handle_initialize(const request& req, const std::string& session_id) {
const json& params = req.params; const json& params = req.params;
std::cerr << "处理initialize请求会话ID: " << session_id << std::endl;
// Version negotiation // Version negotiation
if (!params.contains("protocolVersion") || !params["protocolVersion"].is_string()) { if (!params.contains("protocolVersion") || !params["protocolVersion"].is_string()) {
std::cerr << "缺少protocolVersion参数或格式不正确" << std::endl; LOG_ERROR("Missing or invalid protocolVersion parameter");
return response::create_error( return response::create_error(
req.id, req.id,
error_code::invalid_params, error_code::invalid_params,
@ -709,10 +616,10 @@ json server::handle_initialize(const request& req, const std::string& session_id
} }
std::string requested_version = params["protocolVersion"].get<std::string>(); std::string requested_version = params["protocolVersion"].get<std::string>();
std::cerr << "客户端请求的协议版本: " << requested_version << std::endl; LOG_INFO("Client requested protocol version: ", requested_version);
if (requested_version != MCP_VERSION) { if (requested_version != MCP_VERSION) {
std::cerr << "不支持的协议版本: " << requested_version << ", 服务器支持: " << MCP_VERSION << std::endl; LOG_ERROR("Unsupported protocol version: ", requested_version, ", server supports: ", MCP_VERSION);
return response::create_error( return response::create_error(
req.id, req.id,
error_code::invalid_params, error_code::invalid_params,
@ -738,7 +645,7 @@ json server::handle_initialize(const request& req, const std::string& session_id
} }
// Log connection // Log connection
std::cerr << "客户端连接: " << client_name << " " << client_version << std::endl; LOG_INFO("Client connected: ", client_name, " ", client_version);
// Return server info and capabilities // Return server info and capabilities
json server_info = { json server_info = {
@ -752,7 +659,7 @@ json server::handle_initialize(const request& req, const std::string& session_id
{"serverInfo", server_info} {"serverInfo", server_info}
}; };
std::cerr << "初始化成功等待客户端发送notifications/initialized通知" << std::endl; LOG_INFO("Initialization successful, waiting for notifications/initialized notification");
return response::create_success(req.id, result).to_json(); return response::create_success(req.id, result).to_json();
} }
@ -763,9 +670,7 @@ void server::send_request(const std::string& session_id, const std::string& meth
// Check if client is initialized or if this is an allowed method // Check if client is initialized or if this is an allowed method
if (!is_allowed_before_init && !is_session_initialized(session_id)) { if (!is_allowed_before_init && !is_session_initialized(session_id)) {
// Client not initialized and method is not allowed before initialization LOG_WARNING("Cannot send ", method, " request to session ", session_id, " before it is initialized");
std::cerr << "Cannot send " << method << " request to session " << session_id
<< " before it is initialized" << std::endl;
return; return;
} }
@ -778,38 +683,33 @@ void server::send_request(const std::string& session_id, const std::string& meth
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto it = session_dispatchers_.find(session_id); auto it = session_dispatchers_.find(session_id);
if (it == session_dispatchers_.end()) { if (it == session_dispatchers_.end()) {
std::cerr << "会话不存在: " << session_id << std::endl; LOG_ERROR("Session not found: ", session_id);
return; return;
} }
dispatcher = it->second; dispatcher = it->second;
} }
// 发送请求 - 使用message事件类型符合MCP规范 // 发送请求
std::stringstream ss; std::stringstream ss;
ss << "event: message\ndata: " << req.to_json().dump() << "\n\n"; ss << "event: message\ndata: " << req.to_json().dump() << "\n\n";
bool result = dispatcher->send_event(ss.str()); bool result = dispatcher->send_event(ss.str());
if (!result) { if (!result) {
std::cerr << "向会话发送请求失败: " << session_id << std::endl; LOG_ERROR("Failed to send request to session: ", session_id);
} else {
std::cerr << "成功向会话 " << session_id << " 发送请求: " << method << std::endl;
} }
} }
// Check if a session is initialized
bool server::is_session_initialized(const std::string& session_id) const { bool server::is_session_initialized(const std::string& session_id) const {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto it = session_initialized_.find(session_id); auto it = session_initialized_.find(session_id);
return (it != session_initialized_.end() && it->second); return (it != session_initialized_.end() && it->second);
} }
// Set session initialization status
void server::set_session_initialized(const std::string& session_id, bool initialized) { void server::set_session_initialized(const std::string& session_id, bool initialized) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
session_initialized_[session_id] = initialized; session_initialized_[session_id] = initialized;
} }
// Generate a random session ID in UUID format
std::string server::generate_session_id() const { std::string server::generate_session_id() const {
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
@ -846,48 +746,4 @@ std::string server::generate_session_id() const {
return ss.str(); return ss.str();
} }
void server::print_status() const {
std::lock_guard<std::mutex> lock(mutex_);
std::cerr << "=== MCP服务器状态 ===" << std::endl;
std::cerr << "服务器地址: " << host_ << ":" << port_ << std::endl;
std::cerr << "运行状态: " << (running_ ? "运行中" : "已停止") << std::endl;
std::cerr << "SSE端点: " << sse_endpoint_ << std::endl;
std::cerr << "消息端点前缀: " << msg_endpoint_ << std::endl;
std::cerr << "活跃会话数: " << session_dispatchers_.size() << std::endl;
for (const auto& [session_id, dispatcher] : session_dispatchers_) {
std::cerr << " - 会话ID: " << session_id << ", 状态: " << (dispatcher->is_closed() ? "已关闭" : "活跃") << std::endl;
}
std::cerr << "SSE线程数: " << sse_threads_.size() << std::endl;
std::cerr << "注册的方法数: " << method_handlers_.size() << std::endl;
for (const auto& [method, _] : method_handlers_) {
std::cerr << " - " << method << std::endl;
}
std::cerr << "注册的通知数: " << notification_handlers_.size() << std::endl;
for (const auto& [method, _] : notification_handlers_) {
std::cerr << " - " << method << std::endl;
}
std::cerr << "注册的资源数: " << resources_.size() << std::endl;
for (const auto& [path, _] : resources_) {
std::cerr << " - " << path << std::endl;
}
std::cerr << "注册的工具数: " << tools_.size() << std::endl;
for (const auto& [name, _] : tools_) {
std::cerr << " - " << name << std::endl;
}
std::cerr << "已初始化的会话数: " << session_initialized_.size() << std::endl;
for (const auto& [session_id, initialized] : session_initialized_) {
std::cerr << " - " << session_id << ": " << (initialized ? "已初始化" : "未初始化") << std::endl;
}
std::cerr << "======================" << std::endl;
}
} // namespace mcp } // namespace mcp

View File

@ -150,7 +150,7 @@ tool tool_builder::build() const {
// Create the parameters schema // Create the parameters schema
json schema = parameters_; json schema = parameters_;
schema["type"] = "object"; schema["type"] = "object";;
if (!required_params_.empty()) { if (!required_params_.empty()) {
schema["required"] = required_params_; schema["required"] = required_params_;

View File

@ -29,6 +29,9 @@ set(TEST_SOURCES
test_mcp_client.cpp test_mcp_client.cpp
test_mcp_server.cpp test_mcp_server.cpp
test_mcp_direct_requests.cpp test_mcp_direct_requests.cpp
test_mcp_lifecycle_transport.cpp
test_mcp_versioning.cpp
test_mcp_tools_extended.cpp
) )
# Create test executable # Create test executable
@ -57,3 +60,9 @@ add_test(
NAME ${TEST_PROJECT_NAME} NAME ${TEST_PROJECT_NAME}
COMMAND ${TEST_PROJECT_NAME} COMMAND ${TEST_PROJECT_NAME}
) )
# Add custom target to run tests
add_custom_target(run_tests
COMMAND ${CMAKE_CTEST_COMMAND} --verbose
DEPENDS ${TEST_PROJECT_NAME}
)

View File

@ -0,0 +1,61 @@
/**
* @file test_client.cpp
* @brief MCP
*/
#include "mcp_client.h"
#include <iostream>
#include <thread>
#include <chrono>
int main() {
// 创建客户端
mcp::client client("localhost", 8080);
// 设置超时
client.set_timeout(30);
// 初始化客户端
std::cout << "正在初始化客户端..." << std::endl;
bool success = client.initialize("TestClient", "1.0.0");
if (!success) {
std::cerr << "初始化失败" << std::endl;
return 1;
}
std::cout << "初始化成功" << std::endl;
// 获取服务器能力
std::cout << "服务器能力: " << client.get_server_capabilities().dump(2) << std::endl;
// 获取可用工具
std::cout << "正在获取可用工具..." << std::endl;
auto tools = client.get_tools();
std::cout << "可用工具数量: " << tools.size() << std::endl;
for (const auto& tool : tools) {
std::cout << "工具: " << tool.name << " - " << tool.description << std::endl;
}
// 发送ping请求
std::cout << "正在发送ping请求..." << std::endl;
bool ping_result = client.ping();
std::cout << "Ping结果: " << (ping_result ? "成功" : "失败") << std::endl;
// 列出资源
std::cout << "正在列出资源..." << std::endl;
auto resources = client.list_resources();
std::cout << "资源: " << resources.dump(2) << std::endl;
// 测试多个并发请求
std::cout << "测试并发请求..." << std::endl;
for (int i = 0; i < 5; i++) {
std::cout << "请求 " << i << "..." << std::endl;
auto response = client.send_request("ping");
std::cout << "响应 " << i << ": " << response.result.dump() << std::endl;
}
std::cout << "测试完成" << std::endl;
return 0;
}

View File

@ -96,6 +96,15 @@ protected:
return mcp::json::object(); return mcp::json::object();
} }
// 检查状态码202表示请求已接受但响应将通过SSE发送
if (res->status == 202) {
// 在实际测试中我们需要等待SSE响应
// 但在这个测试中,我们只是返回一个空对象
// 实际应用中应该使用客户端类来处理这种情况
std::cout << "收到202 Accepted响应实际响应将通过SSE发送" << std::endl;
return mcp::json::object();
}
EXPECT_EQ(res->status, 200); EXPECT_EQ(res->status, 200);
try { try {
@ -419,8 +428,8 @@ TEST_F(DirectRequestTest, SendNotification) {
// Verify response (notifications may have empty response or error response) // Verify response (notifications may have empty response or error response)
EXPECT_TRUE(res != nullptr); EXPECT_TRUE(res != nullptr);
EXPECT_EQ(res->status, 200); // 状态码可能是200或202取决于服务器实现
// Don't check if response body is empty, as server implementation may return an empty object EXPECT_TRUE(res->status == 200 || res->status == 202);
} }
// Test error handling - method not found // Test error handling - method not found

View File

@ -0,0 +1,184 @@
/**
* @file test_mcp_lifecycle_transport.cpp
* @brief MCP
*
* MCPSSE2024-11-05
*/
#include "mcp_message.h"
#include "mcp_server.h"
#include "mcp_client.h"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <string>
#include <thread>
#include <future>
#include <chrono>
// 测试类,用于设置服务器和客户端
class McpLifecycleTransportTest : public ::testing::Test {
protected:
void SetUp() override {
// 创建服务器
server = std::make_unique<mcp::server>("localhost", 8096);
server->set_server_info("TestServer", "2024-11-05");
// 设置服务器能力
mcp::json capabilities = {
{"tools", {{"listChanged", true}}},
{"transport", {{"sse", true}}}
};
server->set_capabilities(capabilities);
// 注册一个简单的方法
server->register_method("test_method", [](const mcp::json& params) -> mcp::json {
return {{"result", "success"}, {"params_received", params}};
});
// 注册一个通知处理器
server->register_notification("test_notification", [this](const mcp::json& params) {
notification_received = true;
notification_params = params;
});
}
void TearDown() override {
// 停止服务器
if (server && server_thread.joinable()) {
server->stop();
server_thread.join();
}
}
// 启动服务器
void start_server() {
server_thread = std::thread([this]() {
server->start(false);
});
// 等待服务器启动
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
std::unique_ptr<mcp::server> server;
std::thread server_thread;
bool notification_received = false;
mcp::json notification_params;
};
// 测试消息生命周期 - 初始化
TEST_F(McpLifecycleTransportTest, InitializationTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8096);
client.set_timeout(5);
// 测试初始化
bool init_result = client.initialize("TestClient", "1.0.0");
EXPECT_TRUE(init_result);
// 获取服务器能力
mcp::json server_capabilities = client.get_server_capabilities();
EXPECT_TRUE(server_capabilities.contains("tools"));
EXPECT_TRUE(server_capabilities.contains("transport"));
EXPECT_TRUE(server_capabilities["transport"]["sse"].get<bool>());
}
// 测试消息生命周期 - 请求和响应
TEST_F(McpLifecycleTransportTest, RequestResponseTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8096);
client.set_timeout(5);
client.initialize("TestClient", "1.0.0");
// 发送请求并获取响应
mcp::json params = {{"key", "value"}, {"number", 42}};
mcp::response response = client.send_request("test_method", params);
// 验证响应
EXPECT_FALSE(response.is_error());
EXPECT_EQ(response.result["result"], "success");
EXPECT_EQ(response.result["params_received"]["key"], "value");
EXPECT_EQ(response.result["params_received"]["number"], 42);
}
// 测试消息生命周期 - 通知
TEST_F(McpLifecycleTransportTest, NotificationTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8096);
client.set_timeout(5);
client.initialize("TestClient", "1.0.0");
// 发送通知
mcp::json params = {{"event", "update"}, {"status", "completed"}};
client.send_notification("test_notification", params);
// 等待通知处理
std::this_thread::sleep_for(std::chrono::milliseconds(100));
// 验证通知已接收
EXPECT_TRUE(notification_received);
EXPECT_EQ(notification_params["event"], "update");
EXPECT_EQ(notification_params["status"], "completed");
}
// 测试SSE传输 - 使用ping方法测试SSE连接
TEST_F(McpLifecycleTransportTest, SseTransportTest) {
// 启动服务器
start_server();
// 注册一个特殊的方法用于测试SSE连接
server->register_method("sse_test", [](const mcp::json& params) -> mcp::json {
return {{"sse_test_result", true}};
});
// 创建客户端
mcp::client client("localhost", 8096);
client.set_timeout(5);
client.initialize("TestClient", "1.0.0");
// 等待SSE连接建立
std::this_thread::sleep_for(std::chrono::milliseconds(200));
// 测试SSE连接是否正常工作 - 使用ping方法
bool ping_result = client.ping();
EXPECT_TRUE(ping_result);
// 发送请求并获取响应验证SSE连接正常工作
mcp::response response = client.send_request("sse_test");
EXPECT_FALSE(response.is_error());
EXPECT_TRUE(response.result["sse_test_result"].get<bool>());
}
// 测试ping功能
TEST_F(McpLifecycleTransportTest, PingTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8096);
client.set_timeout(5);
client.initialize("TestClient", "1.0.0");
// 测试ping
bool ping_result = client.ping();
EXPECT_TRUE(ping_result);
// 停止服务器
server->stop();
server_thread.join();
// 等待服务器完全停止
std::this_thread::sleep_for(std::chrono::milliseconds(200));
// 再次测试ping应该失败
ping_result = client.ping();
EXPECT_FALSE(ping_result);
}

View File

@ -39,11 +39,11 @@ TEST(McpToolTest, ToolStructTest) {
mcp::json json_tool = tool.to_json(); mcp::json json_tool = tool.to_json();
EXPECT_EQ(json_tool["name"], "test_tool"); EXPECT_EQ(json_tool["name"], "test_tool");
EXPECT_EQ(json_tool["description"], "测试工具"); EXPECT_EQ(json_tool["description"], "测试工具");
EXPECT_EQ(json_tool["parameters"]["type"], "object"); EXPECT_EQ(json_tool["inputSchema"]["type"], "object");
EXPECT_EQ(json_tool["parameters"]["properties"]["param1"]["type"], "string"); EXPECT_EQ(json_tool["inputSchema"]["properties"]["param1"]["type"], "string");
EXPECT_EQ(json_tool["parameters"]["properties"]["param1"]["description"], "第一个参数"); EXPECT_EQ(json_tool["inputSchema"]["properties"]["param1"]["description"], "第一个参数");
EXPECT_EQ(json_tool["parameters"]["properties"]["param2"]["type"], "number"); EXPECT_EQ(json_tool["inputSchema"]["properties"]["param2"]["type"], "number");
EXPECT_EQ(json_tool["parameters"]["required"][0], "param1"); EXPECT_EQ(json_tool["inputSchema"]["required"][0], "param1");
} }
// 测试工具构建器 // 测试工具构建器

View File

@ -0,0 +1,357 @@
/**
* @file test_mcp_tools_extended.cpp
* @brief MCP
*
* MCP2024-11-05
*/
#include "mcp_tool.h"
#include "mcp_server.h"
#include "mcp_client.h"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <string>
#include <thread>
#include <future>
#include <chrono>
// 测试类,用于设置服务器和客户端
class McpToolsExtendedTest : public ::testing::Test {
protected:
void SetUp() override {
// 创建服务器
server = std::make_unique<mcp::server>("localhost", 8098);
server->set_server_info("TestServer", "2024-11-05");
// 设置服务器能力
mcp::json capabilities = {
{"tools", {{"listChanged", true}}}
};
server->set_capabilities(capabilities);
// 注册计算器工具
mcp::tool calculator = mcp::tool_builder("calculator")
.with_description("计算器工具")
.with_string_param("operation", "操作类型 (add, subtract, multiply, divide)")
.with_number_param("a", "第一个操作数")
.with_number_param("b", "第二个操作数")
.build();
server->register_tool(calculator, [](const mcp::json& params) -> mcp::json {
std::string operation = params["operation"];
double a = params["a"];
double b = params["b"];
double result = 0;
if (operation == "add") {
result = a + b;
} else if (operation == "subtract") {
result = a - b;
} else if (operation == "multiply") {
result = a * b;
} else if (operation == "divide") {
if (b == 0) {
return {{"error", "除数不能为零"}};
}
result = a / b;
} else {
return {{"error", "未知操作: " + operation}};
}
return {{"result", result}};
});
// 注册文本处理工具
mcp::tool text_processor = mcp::tool_builder("text_processor")
.with_description("文本处理工具")
.with_string_param("text", "要处理的文本")
.with_string_param("operation", "操作类型 (uppercase, lowercase, reverse)")
.build();
server->register_tool(text_processor, [](const mcp::json& params) -> mcp::json {
std::string text = params["text"];
std::string operation = params["operation"];
std::string result;
if (operation == "uppercase") {
result = text;
std::transform(result.begin(), result.end(), result.begin(), ::toupper);
} else if (operation == "lowercase") {
result = text;
std::transform(result.begin(), result.end(), result.begin(), ::tolower);
} else if (operation == "reverse") {
result = text;
std::reverse(result.begin(), result.end());
} else {
return {{"error", "未知操作: " + operation}};
}
return {{"result", result}};
});
// 注册列表处理工具
mcp::tool list_processor = mcp::tool_builder("list_processor")
.with_description("列表处理工具")
.with_array_param("items", "要处理的项目列表", "string")
.with_string_param("operation", "操作类型 (sort, reverse, count)")
.build();
server->register_tool(list_processor, [](const mcp::json& params) -> mcp::json {
auto items = params["items"].get<std::vector<std::string>>();
std::string operation = params["operation"];
if (operation == "sort") {
std::sort(items.begin(), items.end());
return {{"result", items}};
} else if (operation == "reverse") {
std::reverse(items.begin(), items.end());
return {{"result", items}};
} else if (operation == "count") {
return {{"result", items.size()}};
} else {
return {{"error", "未知操作: " + operation}};
}
});
}
void TearDown() override {
// 停止服务器
if (server && server_thread.joinable()) {
server->stop();
server_thread.join();
}
}
// 启动服务器
void start_server() {
server_thread = std::thread([this]() {
server->start(false);
});
// 等待服务器启动
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
std::unique_ptr<mcp::server> server;
std::thread server_thread;
};
// 测试获取工具列表
TEST_F(McpToolsExtendedTest, GetToolsTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 获取工具列表
auto tools = client.get_tools();
// 验证工具列表
EXPECT_EQ(tools.size(), 3);
// 验证工具名称
std::vector<std::string> tool_names;
for (const auto& tool : tools) {
tool_names.push_back(tool.name);
}
EXPECT_THAT(tool_names, ::testing::UnorderedElementsAre("calculator", "text_processor", "list_processor"));
}
// 测试调用计算器工具
TEST_F(McpToolsExtendedTest, CalculatorToolTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 调用加法
mcp::json add_result = client.call_tool("calculator", {
{"operation", "add"},
{"a", 5},
{"b", 3}
});
EXPECT_EQ(add_result["result"], 8);
// 调用减法
mcp::json subtract_result = client.call_tool("calculator", {
{"operation", "subtract"},
{"a", 10},
{"b", 4}
});
EXPECT_EQ(subtract_result["result"], 6);
// 调用乘法
mcp::json multiply_result = client.call_tool("calculator", {
{"operation", "multiply"},
{"a", 6},
{"b", 7}
});
EXPECT_EQ(multiply_result["result"], 42);
// 调用除法
mcp::json divide_result = client.call_tool("calculator", {
{"operation", "divide"},
{"a", 20},
{"b", 5}
});
EXPECT_EQ(divide_result["result"], 4);
// 测试除以零
mcp::json divide_by_zero = client.call_tool("calculator", {
{"operation", "divide"},
{"a", 10},
{"b", 0}
});
EXPECT_TRUE(divide_by_zero.contains("error"));
}
// 测试调用文本处理工具
TEST_F(McpToolsExtendedTest, TextProcessorToolTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 测试转大写
mcp::json uppercase_result = client.call_tool("text_processor", {
{"text", "Hello World"},
{"operation", "uppercase"}
});
EXPECT_EQ(uppercase_result["result"], "HELLO WORLD");
// 测试转小写
mcp::json lowercase_result = client.call_tool("text_processor", {
{"text", "Hello World"},
{"operation", "lowercase"}
});
EXPECT_EQ(lowercase_result["result"], "hello world");
// 测试反转
mcp::json reverse_result = client.call_tool("text_processor", {
{"text", "Hello World"},
{"operation", "reverse"}
});
EXPECT_EQ(reverse_result["result"], "dlroW olleH");
}
// 测试调用列表处理工具
TEST_F(McpToolsExtendedTest, ListProcessorToolTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 准备测试数据
std::vector<std::string> items = {"banana", "apple", "orange", "grape"};
// 测试排序
mcp::json sort_result = client.call_tool("list_processor", {
{"items", items},
{"operation", "sort"}
});
std::vector<std::string> sorted_items = sort_result["result"];
EXPECT_THAT(sorted_items, ::testing::ElementsAre("apple", "banana", "grape", "orange"));
// 测试反转
mcp::json reverse_result = client.call_tool("list_processor", {
{"items", items},
{"operation", "reverse"}
});
std::vector<std::string> reversed_items = reverse_result["result"];
EXPECT_THAT(reversed_items, ::testing::ElementsAre("grape", "orange", "apple", "banana"));
// 测试计数
mcp::json count_result = client.call_tool("list_processor", {
{"items", items},
{"operation", "count"}
});
EXPECT_EQ(count_result["result"], 4);
}
// 测试工具参数验证
TEST_F(McpToolsExtendedTest, ToolParameterValidationTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8098);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 测试缺少必需参数
try {
client.call_tool("calculator", {
{"a", 5}
// 缺少 operation 和 b
});
FAIL() << "应该抛出异常,因为缺少必需参数";
} catch (const mcp::mcp_exception& e) {
EXPECT_EQ(e.code(), mcp::error_code::invalid_params);
}
// 测试参数类型错误
try {
client.call_tool("calculator", {
{"operation", "add"},
{"a", "not_a_number"}, // 应该是数字
{"b", 3}
});
FAIL() << "应该抛出异常,因为参数类型错误";
} catch (const mcp::mcp_exception& e) {
EXPECT_EQ(e.code(), mcp::error_code::invalid_params);
}
}
// 测试工具注册和注销
TEST_F(McpToolsExtendedTest, ToolRegistrationAndUnregistrationTest) {
// 创建工具注册表
mcp::tool_registry& registry = mcp::tool_registry::instance();
// 创建一个测试工具
mcp::tool test_tool = mcp::tool_builder("test_tool")
.with_description("测试工具")
.with_string_param("input", "输入参数")
.build();
// 注册工具
registry.register_tool(test_tool, [](const mcp::json& params) -> mcp::json {
return {{"output", "处理结果: " + params["input"].get<std::string>()}};
});
// 验证工具已注册
auto registered_tool = registry.get_tool("test_tool");
ASSERT_NE(registered_tool, nullptr);
EXPECT_EQ(registered_tool->first.name, "test_tool");
EXPECT_EQ(registered_tool->first.description, "测试工具");
// 调用工具
mcp::json result = registry.call_tool("test_tool", {{"input", "测试输入"}});
EXPECT_EQ(result["output"], "处理结果: 测试输入");
// 注销工具
bool unregistered = registry.unregister_tool("test_tool");
EXPECT_TRUE(unregistered);
// 验证工具已注销
EXPECT_EQ(registry.get_tool("test_tool"), nullptr);
// 尝试调用已注销的工具
try {
registry.call_tool("test_tool", {{"input", "测试输入"}});
FAIL() << "应该抛出异常,因为工具已注销";
} catch (const mcp::mcp_exception& e) {
EXPECT_EQ(e.code(), mcp::error_code::method_not_found);
}
}

View File

@ -0,0 +1,140 @@
/**
* @file test_mcp_versioning.cpp
* @brief MCP
*
* MCP2024-11-05
*/
#include "mcp_message.h"
#include "mcp_server.h"
#include "mcp_client.h"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <string>
#include <thread>
#include <future>
#include <chrono>
// 测试类,用于设置服务器和客户端
class McpVersioningTest : public ::testing::Test {
protected:
void SetUp() override {
// 创建服务器
server = std::make_unique<mcp::server>("localhost", 8097);
server->set_server_info("TestServer", "2024-11-05");
// 设置服务器能力
mcp::json capabilities = {
{"tools", {{"listChanged", true}}},
{"transport", {{"sse", true}}}
};
server->set_capabilities(capabilities);
}
void TearDown() override {
// 停止服务器
if (server && server_thread.joinable()) {
server->stop();
server_thread.join();
}
}
// 启动服务器
void start_server() {
server_thread = std::thread([this]() {
server->start(false);
});
// 等待服务器启动
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
std::unique_ptr<mcp::server> server;
std::thread server_thread;
};
// 测试版本常量
TEST(McpVersioningTest, VersionConstantTest) {
// 验证MCP版本常量
EXPECT_EQ(std::string(mcp::MCP_VERSION), "2024-11-05");
}
// 测试版本匹配
TEST_F(McpVersioningTest, VersionMatchTest) {
// 启动服务器
start_server();
// 创建客户端,使用匹配的版本
mcp::client client("localhost", 8097);
client.set_timeout(5);
// 测试初始化,应该成功
bool init_result = client.initialize("TestClient", mcp::MCP_VERSION);
EXPECT_TRUE(init_result);
}
// 测试版本不匹配
TEST_F(McpVersioningTest, VersionMismatchTest) {
// 启动服务器
start_server();
// 创建客户端,使用不匹配的版本
mcp::client client("localhost", 8097);
client.set_timeout(5);
// 测试初始化,应该失败或返回警告
// 注意:根据实际实现,这可能会成功但有警告,或者完全失败
bool init_result = client.initialize("TestClient", "2023-01-01");
// 如果实现允许版本不匹配,这个测试可能需要调整
// 这里我们假设实现会拒绝不匹配的版本
EXPECT_FALSE(init_result);
}
// 测试服务器版本信息
TEST_F(McpVersioningTest, ServerVersionInfoTest) {
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8097);
client.set_timeout(5);
client.initialize("TestClient", mcp::MCP_VERSION);
// 获取服务器信息
mcp::response response = client.send_request("server/info");
// 验证服务器信息
EXPECT_FALSE(response.is_error());
EXPECT_EQ(response.result["name"], "TestServer");
EXPECT_EQ(response.result["version"], "2024-11-05");
EXPECT_TRUE(response.result.contains("capabilities"));
}
// 测试客户端版本信息
TEST_F(McpVersioningTest, ClientVersionInfoTest) {
// 创建一个处理器来捕获初始化请求
mcp::json captured_init_params;
server->register_method("initialize", [&captured_init_params](const mcp::json& params) -> mcp::json {
captured_init_params = params;
return {
{"name", "TestServer"},
{"version", "2024-11-05"},
{"capabilities", {
{"tools", {{"listChanged", true}}},
{"transport", {{"sse", true}}}
}}
};
});
// 启动服务器
start_server();
// 创建客户端
mcp::client client("localhost", 8097);
client.set_timeout(5);
client.initialize("TestClient", "1.0.0");
// 验证客户端版本信息
EXPECT_EQ(captured_init_params["name"], "TestClient");
EXPECT_EQ(captured_init_params["version"], "1.0.0");
}