/** * @file mcp_client.cpp * @brief Implementation of the MCP client * * This file implements the client-side functionality for the Model Context Protocol. * Follows the 2024-11-05 basic protocol specification. */ #include "mcp_client.h" #include "base64.hpp" namespace mcp { client::client(const std::string& host, int port, const json& capabilities, const std::string& sse_endpoint) : host_(host), port_(port), capabilities_(capabilities), sse_endpoint_(sse_endpoint) { init_client(host, port); } client::client(const std::string& base_url, const json& capabilities, const std::string& sse_endpoint) : base_url_(base_url), capabilities_(capabilities), sse_endpoint_(sse_endpoint) { init_client(base_url); } client::~client() { close_sse_connection(); } void client::init_client(const std::string& host, int port) { http_client_ = std::make_unique(host.c_str(), port); sse_client_ = std::make_unique(host.c_str(), port); http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_read_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0); sse_client_->set_connection_timeout(timeout_seconds_ * 2, 0); sse_client_->set_write_timeout(timeout_seconds_, 0); } void client::init_client(const std::string& base_url) { http_client_ = std::make_unique(base_url.c_str()); sse_client_ = std::make_unique(base_url.c_str()); http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_read_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0); sse_client_->set_connection_timeout(timeout_seconds_ * 2, 0); sse_client_->set_read_timeout(0, 0); sse_client_->set_write_timeout(timeout_seconds_, 0); } bool client::initialize(const std::string& client_name, const std::string& client_version) { LOG_INFO("Initializing MCP client..."); if (!check_server_accessible()) { return false; } request req = request::create("initialize", { {"protocolVersion", MCP_VERSION}, {"capabilities", capabilities_}, {"clientInfo", { {"name", client_name}, {"version", client_version} }} }); try { LOG_INFO("Opening SSE connection..."); open_sse_connection(); const auto timeout = std::chrono::milliseconds(5000); { std::unique_lock lock(mutex_); bool success = endpoint_cv_.wait_for(lock, timeout, [this]() { if (!sse_running_) { LOG_WARNING("SSE connection closed, stopping wait"); return true; } if (!msg_endpoint_.empty()) { LOG_INFO("Message endpoint set, stopping wait"); return true; } return false; }); if (!success) { LOG_WARNING("Condition variable wait timed out"); } if (!sse_running_) { throw std::runtime_error("SSE connection closed, failed to get message endpoint"); } if (msg_endpoint_.empty()) { throw std::runtime_error("Timeout waiting for SSE connection, failed to get message endpoint"); } LOG_INFO("Successfully got message endpoint: ", msg_endpoint_); } json result = send_jsonrpc(req); server_capabilities_ = result["capabilities"]; request notification = request::create_notification("initialized"); send_jsonrpc(notification); return true; } catch (const std::exception& e) { LOG_ERROR("Initialization failed: ", e.what()); close_sse_connection(); return false; } } bool client::ping() { request req = request::create("ping", {}); try { json result = send_jsonrpc(req); return result.empty(); } catch (const std::exception& e) { return false; } } void client::set_auth_token(const std::string& token) { std::lock_guard lock(mutex_); auth_token_ = token; set_header("Authorization", "Bearer " + auth_token_); } void client::set_header(const std::string& key, const std::string& value) { std::lock_guard lock(mutex_); default_headers_[key] = value; if (http_client_) { http_client_->set_default_headers({{key, value}}); } if (sse_client_) { sse_client_->set_default_headers({{key, value}}); } } void client::set_timeout(int timeout_seconds) { std::lock_guard lock(mutex_); timeout_seconds_ = timeout_seconds; if (http_client_) { http_client_->set_connection_timeout(timeout_seconds_, 0); http_client_->set_write_timeout(timeout_seconds_, 0); } if (sse_client_) { sse_client_->set_connection_timeout(timeout_seconds_ * 2, 0); sse_client_->set_write_timeout(timeout_seconds_, 0); } } void client::set_capabilities(const json& capabilities) { std::lock_guard lock(mutex_); capabilities_ = capabilities; } response client::send_request(const std::string& method, const json& params) { request req = request::create(method, params); json result = send_jsonrpc(req); response res; res.jsonrpc = "2.0"; res.id = req.id; res.result = result; return res; } void client::send_notification(const std::string& method, const json& params) { request req = request::create_notification(method, params); send_jsonrpc(req); } json client::get_server_capabilities() { return server_capabilities_; } json client::call_tool(const std::string& tool_name, const json& arguments) { return send_request("tools/call", { {"name", tool_name}, {"arguments", arguments} }).result; } std::vector client::get_tools() { json response_json = send_request("tools/list", {}).result; std::vector tools; json tools_json; if (response_json.contains("tools") && response_json["tools"].is_array()) { tools_json = response_json["tools"]; } else if (response_json.is_array()) { tools_json = response_json; } else { return tools; } for (const auto& tool_json : tools_json) { tool t; t.name = tool_json["name"]; t.description = tool_json["description"]; if (tool_json.contains("inputSchema")) { t.parameters_schema = tool_json["inputSchema"]; } tools.push_back(t); } return tools; } json client::get_capabilities() { return capabilities_; } json client::list_resources(const std::string& cursor) { json params = json::object(); if (!cursor.empty()) { params["cursor"] = cursor; } return send_request("resources/list", params).result; } json client::read_resource(const std::string& resource_uri) { return send_request("resources/read", { {"uri", resource_uri} }).result; } json client::subscribe_to_resource(const std::string& resource_uri) { return send_request("resources/subscribe", { {"uri", resource_uri} }).result; } json client::list_resource_templates() { return send_request("resources/templates/list").result; } void client::open_sse_connection() { sse_running_ = true; { std::lock_guard lock(mutex_); msg_endpoint_.clear(); endpoint_cv_.notify_all(); } std::string connection_info; if (!base_url_.empty()) { connection_info = "Base URL: " + base_url_ + ", SSE Endpoint: " + sse_endpoint_; } else { connection_info = "Host: " + host_ + ", Port: " + std::to_string(port_) + ", SSE Endpoint: " + sse_endpoint_; } LOG_INFO("Attempting to establish SSE connection: ", connection_info); sse_thread_ = std::make_unique([this]() { int retry_count = 0; const int max_retries = 5; const int retry_delay_base = 1000; while (sse_running_) { try { LOG_INFO("SSE thread: Attempting to connect to ", sse_endpoint_); auto res = sse_client_->Get(sse_endpoint_.c_str(), [this](const char *data, size_t data_length) { if (!parse_sse_data(data, data_length)) { LOG_ERROR("SSE thread: Failed to parse data"); return false; } bool should_continue = sse_running_.load(); if (!should_continue) { LOG_INFO("SSE thread: sse_running_ is false, closing connection"); } return should_continue; }); if (!res) { std::string error_msg = "SSE connection failed: "; error_msg += httplib::to_string(res.error()); throw std::runtime_error(error_msg); } retry_count = 0; LOG_INFO("SSE thread: Connection successful"); } catch (const std::exception& e) { LOG_ERROR("SSE connection error: ", e.what()); if (!sse_running_) { LOG_INFO("SSE connection actively closed, no retry needed"); break; } if (++retry_count > max_retries) { LOG_ERROR("Maximum retry count reached, stopping SSE connection attempts"); break; } int delay = retry_delay_base * (1 << (retry_count - 1)); LOG_INFO("Will retry in ", delay, " ms (attempt ", retry_count, "/", max_retries, ")"); const int check_interval = 100; for (int waited = 0; waited < delay && sse_running_; waited += check_interval) { std::this_thread::sleep_for(std::chrono::milliseconds(check_interval)); } if (!sse_running_) { LOG_INFO("SSE connection actively closed during retry wait, stopping retry"); break; } } } LOG_INFO("SSE thread: Exiting"); }); } bool client::parse_sse_data(const char* data, size_t length) { try { std::string sse_data(data, length); std::string event_type = "message"; auto event_pos = sse_data.find("event: "); if (event_pos != std::string::npos) { auto event_end = sse_data.find("\n", event_pos); if (event_end != std::string::npos) { event_type = sse_data.substr(event_pos + 7, event_end - (event_pos + 7)); if (!event_type.empty() && event_type.back() == '\r') { event_type.pop_back(); } } } auto data_pos = sse_data.find("data: "); if (data_pos == std::string::npos) { return true; } auto newline_pos = sse_data.find("\n", data_pos); if (newline_pos == std::string::npos) { newline_pos = sse_data.length(); } std::string data_content = sse_data.substr(data_pos + 6, newline_pos - (data_pos + 6)); if (event_type == "heartbeat") { return true; } else if (event_type == "endpoint") { std::lock_guard lock(mutex_); msg_endpoint_ = data_content; endpoint_cv_.notify_all(); return true; } else if (event_type == "message") { try { json response = json::parse(data_content); if (response.contains("jsonrpc") && response.contains("id") && !response["id"].is_null()) { json id = response["id"]; std::lock_guard lock(response_mutex_); auto it = pending_requests_.find(id); if (it != pending_requests_.end()) { if (response.contains("result")) { it->second.set_value(response["result"]); } else if (response.contains("error")) { json error_result = { {"isError", true}, {"error", response["error"]} }; it->second.set_value(error_result); } else { it->second.set_value(json::object()); } pending_requests_.erase(it); } else { LOG_WARNING("Received response for unknown request ID: ", id); } } else { LOG_WARNING("Received invalid JSON-RPC response: ", response.dump()); } } catch (const json::exception& e) { LOG_ERROR("Failed to parse JSON-RPC response: ", e.what()); } return true; } else { LOG_WARNING("Received unknown event type: ", event_type); return true; } } catch (const std::exception& e) { LOG_ERROR("Error parsing SSE data: ", e.what()); return false; } } void client::close_sse_connection() { if (!sse_running_) { LOG_INFO("SSE connection already closed"); return; } LOG_INFO("Actively closing SSE connection (normal exit flow)..."); sse_running_ = false; std::this_thread::sleep_for(std::chrono::milliseconds(500)); if (sse_thread_ && sse_thread_->joinable()) { auto timeout = std::chrono::seconds(5); auto start = std::chrono::steady_clock::now(); LOG_INFO("Waiting for SSE thread to end..."); while (sse_thread_->joinable() && std::chrono::steady_clock::now() - start < timeout) { try { sse_thread_->join(); LOG_INFO("SSE thread successfully ended"); break; } catch (const std::exception& e) { LOG_ERROR("Error waiting for SSE thread: ", e.what()); std::this_thread::sleep_for(std::chrono::milliseconds(100)); } } if (sse_thread_->joinable()) { LOG_WARNING("SSE thread did not end within timeout, detaching thread"); sse_thread_->detach(); } } { std::lock_guard lock(mutex_); msg_endpoint_.clear(); endpoint_cv_.notify_all(); } LOG_INFO("SSE connection successfully closed (normal exit flow)"); } json client::send_jsonrpc(const request& req) { std::lock_guard lock(mutex_); if (msg_endpoint_.empty()) { throw mcp_exception(error_code::internal_error, "Message endpoint not set, SSE connection may not be established"); } json req_json = req.to_json(); std::string req_body = req_json.dump(); httplib::Headers headers; headers.emplace("Content-Type", "application/json"); for (const auto& [key, value] : default_headers_) { headers.emplace(key, value); } if (req.is_notification()) { auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json"); if (!result) { auto err = result.error(); std::string error_msg = httplib::to_string(err); LOG_ERROR("JSON-RPC request failed: ", error_msg); throw mcp_exception(error_code::internal_error, error_msg); } return json::object(); } std::promise response_promise; std::future response_future = response_promise.get_future(); { std::lock_guard response_lock(response_mutex_); pending_requests_[req.id] = std::move(response_promise); } auto result = http_client_->Post(msg_endpoint_, headers, req_body, "application/json"); if (!result) { auto err = result.error(); std::string error_msg = httplib::to_string(err); { std::lock_guard response_lock(response_mutex_); pending_requests_.erase(req.id); } LOG_ERROR("JSON-RPC request failed: ", error_msg); throw mcp_exception(error_code::internal_error, error_msg); } if (result->status != 202) { try { json res_json = json::parse(result->body); { std::lock_guard response_lock(response_mutex_); pending_requests_.erase(req.id); } if (res_json.contains("error")) { int code = res_json["error"]["code"]; std::string message = res_json["error"]["message"]; throw mcp_exception(static_cast(code), message); } if (res_json.contains("result")) { return res_json["result"]; } else { return json::object(); } } catch (const json::exception& e) { { std::lock_guard response_lock(response_mutex_); pending_requests_.erase(req.id); } throw mcp_exception(error_code::parse_error, "Failed to parse JSON-RPC response: " + std::string(e.what())); } } else { const auto timeout = std::chrono::seconds(timeout_seconds_); auto status = response_future.wait_for(timeout); if (status == std::future_status::ready) { json response = response_future.get(); if (response.contains("isError") && response["isError"].get()) { int code = response["error"]["code"]; std::string message = response["error"]["message"]; throw mcp_exception(static_cast(code), message); } return response; } else { { std::lock_guard response_lock(response_mutex_); pending_requests_.erase(req.id); } throw mcp_exception(error_code::internal_error, "Timeout waiting for SSE response"); } } } bool client::check_server_accessible() { LOG_INFO("Checking if server is accessible..."); try { auto res = http_client_->Get("/"); if (res) { LOG_INFO("Server is accessible, status code: ", res->status); return true; } else { std::string error_msg = "Server not accessible: " + httplib::to_string(res.error()); LOG_ERROR(error_msg); return false; } } catch (const std::exception& e) { LOG_ERROR("Exception while checking server accessibility: ", e.what()); return false; } } } // namespace mcp