cpp-mcp/src/mcp_client.cpp

372 lines
11 KiB
C++
Raw Normal View History

2025-03-08 01:50:39 +08:00
/**
* @file mcp_client.cpp
* @brief Implementation of the MCP client
2025-03-08 22:49:19 +08:00
*
* This file implements the client-side functionality for the Model Context Protocol.
* Follows the 2024-11-05 basic protocol specification.
2025-03-08 01:50:39 +08:00
*/
#include "mcp_client.h"
#include "base64.hpp"
namespace mcp {
2025-03-11 23:29:38 +08:00
client::client(const std::string& host, int port, const json& capabilities, const std::string& sse_endpoint)
: host_(host), port_(port), capabilities_(capabilities), sse_endpoint_(sse_endpoint) {
2025-03-08 01:50:39 +08:00
2025-03-10 03:24:54 +08:00
init_client(host, port);
}
2025-03-11 23:29:38 +08:00
client::client(const std::string& base_url, const json& capabilities, const std::string& sse_endpoint)
: base_url_(base_url), capabilities_(capabilities), sse_endpoint_(sse_endpoint) {
2025-03-10 03:24:54 +08:00
init_client(base_url);
2025-03-08 01:50:39 +08:00
}
client::~client() {
2025-03-11 23:29:38 +08:00
// 关闭SSE连接
close_sse_connection();
// httplib::Client将自动销毁
2025-03-08 01:50:39 +08:00
}
2025-03-10 03:24:54 +08:00
void client::init_client(const std::string& host, int port) {
// Create the HTTP client
http_client_ = std::make_unique<httplib::Client>(host.c_str(), port);
// Set timeout
http_client_->set_connection_timeout(timeout_seconds_, 0);
http_client_->set_read_timeout(timeout_seconds_, 0);
http_client_->set_write_timeout(timeout_seconds_, 0);
}
void client::init_client(const std::string& base_url) {
2025-03-08 01:50:39 +08:00
// Create the HTTP client
2025-03-10 03:24:54 +08:00
http_client_ = std::make_unique<httplib::Client>(base_url.c_str());
2025-03-08 01:50:39 +08:00
// Set timeout
http_client_->set_connection_timeout(timeout_seconds_, 0);
http_client_->set_read_timeout(timeout_seconds_, 0);
http_client_->set_write_timeout(timeout_seconds_, 0);
}
2025-03-08 22:49:19 +08:00
bool client::initialize(const std::string& client_name, const std::string& client_version) {
// Create initialization request
request req = request::create("initialize", {
2025-03-09 15:45:09 +08:00
{"protocolVersion", MCP_VERSION},
{"capabilities", capabilities_},
2025-03-08 22:49:19 +08:00
{"clientInfo", {
{"name", client_name},
{"version", client_version}
}}
});
try {
2025-03-11 23:29:38 +08:00
open_sse_connection();
2025-03-08 22:49:19 +08:00
// Send the request
json result = send_jsonrpc(req);
// Store server capabilities
server_capabilities_ = result["capabilities"];
// Send initialized notification
request notification = request::create_notification("initialized");
send_jsonrpc(notification);
return true;
} catch (const std::exception& e) {
2025-03-11 23:29:38 +08:00
// 初始化失败关闭SSE连接
std::cerr << "初始化失败: " << e.what() << std::endl;
close_sse_connection();
2025-03-08 22:49:19 +08:00
return false;
2025-03-09 17:24:46 +08:00
}
}
bool client::ping() {
// Create ping request
request req = request::create("ping", {});
try {
// Send the request
json result = send_jsonrpc(req);
// The receiver MUST respond promptly with an empty response
if (result.empty()) {
return true;
} else {
return false;
}
} catch (const std::exception& e) {
// Ping failed
return false;
2025-03-08 22:49:19 +08:00
}
}
void client::set_auth_token(const std::string& token) {
std::lock_guard<std::mutex> lock(mutex_);
auth_token_ = token;
2025-03-08 01:50:39 +08:00
// Add to default headers
2025-03-08 22:49:19 +08:00
set_header("Authorization", "Bearer " + auth_token_);
2025-03-08 01:50:39 +08:00
}
void client::set_header(const std::string& key, const std::string& value) {
std::lock_guard<std::mutex> lock(mutex_);
default_headers_[key] = value;
}
void client::set_timeout(int timeout_seconds) {
std::lock_guard<std::mutex> lock(mutex_);
timeout_seconds_ = timeout_seconds;
// Update the client's timeout
http_client_->set_connection_timeout(timeout_seconds_, 0);
http_client_->set_read_timeout(timeout_seconds_, 0);
http_client_->set_write_timeout(timeout_seconds_, 0);
}
2025-03-08 22:49:19 +08:00
void client::set_capabilities(const json& capabilities) {
std::lock_guard<std::mutex> lock(mutex_);
capabilities_ = capabilities;
2025-03-08 01:50:39 +08:00
}
2025-03-08 22:49:19 +08:00
response client::send_request(const std::string& method, const json& params) {
request req = request::create(method, params);
json result = send_jsonrpc(req);
2025-03-08 01:50:39 +08:00
2025-03-08 22:49:19 +08:00
response res;
res.jsonrpc = "2.0";
res.id = req.id;
res.result = result;
return res;
2025-03-08 01:50:39 +08:00
}
2025-03-08 22:49:19 +08:00
void client::send_notification(const std::string& method, const json& params) {
request req = request::create_notification(method, params);
send_jsonrpc(req);
}
json client::get_server_capabilities() {
return server_capabilities_;
2025-03-08 01:50:39 +08:00
}
2025-03-09 17:10:01 +08:00
json client::call_tool(const std::string& tool_name, const json& arguments) {
2025-03-08 22:49:19 +08:00
return send_request("tools/call", {
{"name", tool_name},
2025-03-09 17:10:01 +08:00
{"arguments", arguments}
2025-03-08 22:49:19 +08:00
}).result;
2025-03-08 01:50:39 +08:00
}
2025-03-08 22:49:19 +08:00
std::vector<tool> client::get_tools() {
json tools_json = send_request("tools/list", {}).result;
std::vector<tool> tools;
if (tools_json.is_array()) {
for (const auto& tool_json : tools_json) {
tool t;
t.name = tool_json["name"];
t.description = tool_json["description"];
2025-03-09 17:10:01 +08:00
if (tool_json.contains("inputSchema")) {
t.parameters_schema = tool_json["inputSchema"];
2025-03-08 22:49:19 +08:00
}
tools.push_back(t);
}
2025-03-08 01:50:39 +08:00
}
2025-03-08 22:49:19 +08:00
return tools;
2025-03-08 01:50:39 +08:00
}
2025-03-08 22:49:19 +08:00
json client::get_capabilities() {
return capabilities_;
2025-03-08 01:50:39 +08:00
}
2025-03-10 03:24:54 +08:00
json client::list_resources(const std::string& cursor) {
json params = json::object();
if (!cursor.empty()) {
params["cursor"] = cursor;
2025-03-08 22:49:19 +08:00
}
2025-03-10 03:24:54 +08:00
return send_request("resources/list", params).result;
}
json client::read_resource(const std::string& resource_uri) {
return send_request("resources/read", {
{"uri", resource_uri}
}).result;
}
json client::subscribe_to_resource(const std::string& resource_uri) {
return send_request("resources/subscribe", {
{"uri", resource_uri}
}).result;
}
json client::list_resource_templates() {
return send_request("resources/templates/list").result;
2025-03-08 01:50:39 +08:00
}
2025-03-11 23:29:38 +08:00
void client::open_sse_connection() {
// 设置SSE连接状态为运行中
sse_running_ = true;
// 创建并启动SSE线程
sse_thread_ = std::make_unique<std::thread>([this]() {
int retry_count = 0;
const int max_retries = 5;
const int retry_delay_base = 1000; // 毫秒
while (sse_running_) {
try {
// 尝试建立SSE连接
auto res = http_client_->Get(sse_endpoint_.c_str(),
[this](const char *data, size_t data_length) {
// 解析SSE数据
if (!parse_sse_data(data, data_length)) {
return false; // 解析失败,关闭连接
}
return sse_running_.load(); // 如果sse_running_为false关闭连接
});
// 检查连接是否成功
if (!res) {
throw std::runtime_error("SSE连接失败: " + std::to_string(static_cast<int>(res.error())));
}
// 连接成功后重置重试计数
retry_count = 0;
} catch (const std::exception& e) {
// 记录错误
std::cerr << "SSE连接错误: " << e.what() << std::endl;
// 如果已达到最大重试次数,停止尝试
if (++retry_count > max_retries) {
std::cerr << "达到最大重试次数停止SSE连接尝试" << std::endl;
break;
}
// 指数退避重试
int delay = retry_delay_base * (1 << (retry_count - 1)); // 2^(retry_count-1) * base_delay
std::this_thread::sleep_for(std::chrono::milliseconds(delay));
}
}
});
}
// 新增方法解析SSE数据
bool client::parse_sse_data(const char* data, size_t length) {
try {
std::string sse_data(data, length);
// 查找"data:"标记
auto data_pos = sse_data.find("data: ");
if (data_pos == std::string::npos) {
return true; // 不是数据事件,可能是注释或心跳,继续保持连接
}
// 查找数据行结束位置
auto newline_pos = sse_data.find("\n", data_pos);
if (newline_pos == std::string::npos) {
newline_pos = sse_data.length(); // 如果没有换行符,使用整个字符串
}
// 提取数据内容
std::string data_content = sse_data.substr(data_pos + 6, newline_pos - (data_pos + 6));
// 检查是否是心跳事件
if (sse_data.find("event: heartbeat") != std::string::npos) {
// 心跳事件,不需要处理数据
return true;
}
// 更新消息端点
{
std::lock_guard<std::mutex> lock(mutex_);
msg_endpoint_ = data_content;
}
return true;
} catch (const std::exception& e) {
std::cerr << "解析SSE数据错误: " << e.what() << std::endl;
return false;
}
}
// 新增方法关闭SSE连接
void client::close_sse_connection() {
sse_running_ = false;
if (sse_thread_ && sse_thread_->joinable()) {
sse_thread_->join();
}
}
2025-03-08 22:49:19 +08:00
json client::send_jsonrpc(const request& req) {
2025-03-08 01:50:39 +08:00
std::lock_guard<std::mutex> lock(mutex_);
2025-03-08 22:49:19 +08:00
// Convert request to JSON
json req_json = req.to_json();
std::string req_body = req_json.dump();
// Prepare headers
httplib::Headers headers;
headers.emplace("Content-Type", "application/json");
2025-03-08 01:50:39 +08:00
// Add default headers
for (const auto& [key, value] : default_headers_) {
2025-03-08 22:49:19 +08:00
headers.emplace(key, value);
2025-03-08 01:50:39 +08:00
}
2025-03-08 22:49:19 +08:00
// Send the request
2025-03-11 23:29:38 +08:00
auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json");
2025-03-08 01:50:39 +08:00
if (!result) {
// Error occurred
auto err = result.error();
switch (err) {
case httplib::Error::Connection:
2025-03-08 22:49:19 +08:00
throw mcp_exception(error_code::server_error_start, "Connection error");
2025-03-08 01:50:39 +08:00
case httplib::Error::Read:
2025-03-08 22:49:19 +08:00
throw mcp_exception(error_code::internal_error, "Read error");
2025-03-08 01:50:39 +08:00
case httplib::Error::Write:
2025-03-08 22:49:19 +08:00
throw mcp_exception(error_code::internal_error, "Write error");
2025-03-08 01:50:39 +08:00
case httplib::Error::ConnectionTimeout:
2025-03-08 22:49:19 +08:00
throw mcp_exception(error_code::server_error_start, "Timeout error");
2025-03-08 01:50:39 +08:00
default:
2025-03-08 22:49:19 +08:00
throw mcp_exception(error_code::internal_error,
2025-03-08 01:50:39 +08:00
"HTTP client error: " + std::to_string(static_cast<int>(err)));
}
}
2025-03-08 22:49:19 +08:00
// Check if it's a notification (no response expected)
if (req.is_notification()) {
return json::object();
2025-03-08 01:50:39 +08:00
}
2025-03-08 22:49:19 +08:00
// Parse response
try {
json res_json = json::parse(result->body);
// Check for error
if (res_json.contains("error")) {
int code = res_json["error"]["code"];
std::string message = res_json["error"]["message"];
throw mcp_exception(static_cast<error_code>(code), message);
}
// Return result
if (res_json.contains("result")) {
return res_json["result"];
} else {
return json::object();
}
} catch (const json::exception& e) {
throw mcp_exception(error_code::parse_error,
"Failed to parse JSON-RPC response: " + std::string(e.what()));
}
2025-03-08 01:50:39 +08:00
}
} // namespace mcp