/** * @file mcp_server.cpp * @brief Implementation of the MCP server * * This file implements the server-side functionality for the Model Context Protocol. * Follows the 2024-11-05 basic protocol specification. */ #include "mcp_server.h" namespace mcp { server::server(const std::string& host, int port) : host_(host), port_(port), name_("MCP Server"), version_(MCP_VERSION) { http_server_ = std::make_unique(); } server::~server() { stop(); } bool server::start(bool blocking) { if (running_) { return true; // Already running } // Set up JSON-RPC endpoint http_server_->Post("/jsonrpc", [this](const httplib::Request& req, httplib::Response& res) { this->handle_jsonrpc(req, res); }); // Start the server if (blocking) { running_ = true; if (!http_server_->listen(host_.c_str(), port_)) { running_ = false; std::cerr << "Failed to start server on " << host_ << ":" << port_ << std::endl; return false; } return true; } else { // Start in a separate thread server_thread_ = std::make_unique([this]() { if (!http_server_->listen(host_.c_str(), port_)) { std::cerr << "Failed to start server on " << host_ << ":" << port_ << std::endl; running_ = false; return; } }); running_ = true; return true; } } void server::stop() { if (!running_) { return; } if (http_server_) { http_server_->stop(); } if (server_thread_ && server_thread_->joinable()) { server_thread_->join(); } running_ = false; } bool server::is_running() const { return running_; } void server::set_server_info(const std::string& name, const std::string& version) { std::lock_guard lock(mutex_); name_ = name; version_ = version; } void server::set_capabilities(const json& capabilities) { std::lock_guard lock(mutex_); capabilities_ = capabilities; } void server::register_method(const std::string& method, std::function handler) { std::lock_guard lock(mutex_); method_handlers_[method] = handler; } void server::register_notification(const std::string& method, std::function handler) { std::lock_guard lock(mutex_); notification_handlers_[method] = handler; } void server::register_resource(const std::string& path, std::shared_ptr resource) { std::lock_guard lock(mutex_); resources_[path] = resource; // Register methods for resource access if (method_handlers_.find("resources/metadata") == method_handlers_.end()) { method_handlers_["resources/metadata"] = [this](const json& params) { if (!params.contains("path")) { throw mcp_exception(error_code::invalid_params, "Missing 'path' parameter"); } std::string path = params["path"]; auto it = resources_.find(path); if (it == resources_.end()) { throw mcp_exception(error_code::invalid_params, "Resource not found: " + path); } return it->second->get_metadata(); }; } if (method_handlers_.find("resources/access") == method_handlers_.end()) { method_handlers_["resources/access"] = [this](const json& params) { if (!params.contains("path")) { throw mcp_exception(error_code::invalid_params, "Missing 'path' parameter"); } std::string path = params["path"]; auto it = resources_.find(path); if (it == resources_.end()) { throw mcp_exception(error_code::invalid_params, "Resource not found: " + path); } return it->second->access(params); }; } } void server::register_tool(const tool& tool, tool_handler handler) { std::lock_guard lock(mutex_); tools_[tool.name] = std::make_pair(tool, handler); // Register methods for tool listing and calling if (method_handlers_.find("tools/list") == method_handlers_.end()) { method_handlers_["tools/list"] = [this](const json& params) { json tools_json = json::array(); for (const auto& [name, tool_pair] : tools_) { tools_json.push_back(tool_pair.first.to_json()); } return tools_json; }; } if (method_handlers_.find("tools/call") == method_handlers_.end()) { method_handlers_["tools/call"] = [this](const json& params) { if (!params.contains("name")) { throw mcp_exception(error_code::invalid_params, "Missing 'name' parameter"); } std::string tool_name = params["name"]; auto it = tools_.find(tool_name); if (it == tools_.end()) { throw mcp_exception(error_code::invalid_params, "Tool not found: " + tool_name); } json tool_params = params.contains("parameters") ? params["parameters"] : json::object(); return it->second.second(tool_params); }; } } std::vector server::get_tools() const { std::lock_guard lock(mutex_); std::vector tools; for (const auto& [name, tool_pair] : tools_) { tools.push_back(tool_pair.first); } return tools; } void server::set_auth_handler(std::function handler) { std::lock_guard lock(mutex_); auth_handler_ = handler; } void server::handle_jsonrpc(const httplib::Request& req, httplib::Response& res) { // Set response headers res.set_header("Content-Type", "application/json"); // Parse the request json req_json; try { req_json = json::parse(req.body); } catch (const json::exception& e) { // Invalid JSON json error_response = { {"jsonrpc", "2.0"}, {"error", { {"code", static_cast(error_code::parse_error)}, {"message", "Parse error: " + std::string(e.what())} }}, {"id", nullptr} }; res.set_content(error_response.dump(), "application/json"); return; } // Check if it's a batch request if (req_json.is_array()) { // Batch request not supported yet json error_response = { {"jsonrpc", "2.0"}, {"error", { {"code", static_cast(error_code::invalid_request)}, {"message", "Batch requests are not supported"} }}, {"id", nullptr} }; res.set_content(error_response.dump(), "application/json"); return; } // Convert to request object request mcp_req; try { mcp_req.jsonrpc = req_json["jsonrpc"]; mcp_req.method = req_json["method"]; if (req_json.contains("id")) { mcp_req.id = req_json["id"]; } if (req_json.contains("params")) { mcp_req.params = req_json["params"]; } } catch (const json::exception& e) { // Invalid request json error_response = { {"jsonrpc", "2.0"}, {"error", { {"code", static_cast(error_code::invalid_request)}, {"message", "Invalid request: " + std::string(e.what())} }}, {"id", nullptr} }; res.set_content(error_response.dump(), "application/json"); return; } // Process the request json result = process_request(mcp_req); res.set_content(result.dump(), "application/json"); } json server::process_request(const request& req) { // Check if it's a notification if (req.is_notification()) { // Process notification asynchronously std::thread([this, req]() { try { std::lock_guard lock(mutex_); auto it = notification_handlers_.find(req.method); if (it != notification_handlers_.end()) { it->second(req.params); } } catch (const std::exception& e) { // Log error but don't send response std::cerr << "Error processing notification: " << e.what() << std::endl; } }).detach(); // No response for notifications return json::object(); } // Handle method call try { // Special case for initialize if (req.method == "initialize") { return handle_initialize(req); } // Look for registered method handler std::lock_guard lock(mutex_); auto it = method_handlers_.find(req.method); if (it != method_handlers_.end()) { // Call the handler json result = it->second(req.params); // Create success response return response::create_success(req.id, result).to_json(); } // Method not found return response::create_error( req.id, error_code::method_not_found, "Method not found: " + req.method ).to_json(); } catch (const mcp_exception& e) { // MCP exception return response::create_error( req.id, e.code(), e.what() ).to_json(); } catch (const std::exception& e) { // Generic exception return response::create_error( req.id, error_code::internal_error, "Internal error: " + std::string(e.what()) ).to_json(); } } json server::handle_initialize(const request& req) { const json& params = req.params; // Version negotiation if (!params["params"].contains("protocolVersion") || !params["params"]["protocolVersion"].is_string()) { return response::create_error( req.id, error_code::invalid_params, "Expected string for 'protocolVersion' parameter" ).to_json(); } std::string requested_version = params["params"]["protocolVersion"].get(); if (requested_version != MCP_VERSION) { return response::create_error( req.id, error_code::invalid_params, "Unsupported protocol version", { {"supported", MCP_VERSION}, {"requested", req.params["protocolVersion"]} } ).to_json(); } // Extract client info std::string client_name = "UnknownClient"; std::string client_version = "UnknownVersion"; if (params.contains("clientInfo")) { if (params["clientInfo"].contains("name")) { client_name = params["clientInfo"]["name"]; } if (params["clientInfo"].contains("version")) { client_version = params["clientInfo"]["version"]; } } // Log connection // std::cout << "Client connected: " << client_name << " " << client_version << std::endl; // Return server info and capabilities json server_info = { {"name", name_}, {"version", version_} }; json result = { {"protocolVersion", MCP_VERSION}, {"capabilities", capabilities_}, {"serverInfo", server_info} }; return response::create_success(req.id, result).to_json(); } } // namespace mcp