cpp-mcp/src/mcp_server.cpp

372 lines
12 KiB
C++

/**
* @file mcp_server.cpp
* @brief Implementation of the MCP server
*
* This file implements the server-side functionality for the Model Context Protocol.
* Follows the 2024-11-05 basic protocol specification.
*/
#include "mcp_server.h"
namespace mcp {
server::server(const std::string& host, int port)
: host_(host), port_(port), name_("MCP Server"), version_(MCP_VERSION) {
http_server_ = std::make_unique<httplib::Server>();
}
server::~server() {
stop();
}
bool server::start(bool blocking) {
if (running_) {
return true; // Already running
}
// Set up JSON-RPC endpoint
http_server_->Post("/jsonrpc", [this](const httplib::Request& req, httplib::Response& res) {
this->handle_jsonrpc(req, res);
});
// Start the server
if (blocking) {
running_ = true;
if (!http_server_->listen(host_.c_str(), port_)) {
running_ = false;
std::cerr << "Failed to start server on " << host_ << ":" << port_ << std::endl;
return false;
}
return true;
} else {
// Start in a separate thread
server_thread_ = std::make_unique<std::thread>([this]() {
if (!http_server_->listen(host_.c_str(), port_)) {
std::cerr << "Failed to start server on " << host_ << ":" << port_ << std::endl;
running_ = false;
return;
}
});
running_ = true;
return true;
}
}
void server::stop() {
if (!running_) {
return;
}
if (http_server_) {
http_server_->stop();
}
if (server_thread_ && server_thread_->joinable()) {
server_thread_->join();
}
running_ = false;
}
bool server::is_running() const {
return running_;
}
void server::set_server_info(const std::string& name, const std::string& version) {
std::lock_guard<std::mutex> lock(mutex_);
name_ = name;
version_ = version;
}
void server::set_capabilities(const json& capabilities) {
std::lock_guard<std::mutex> lock(mutex_);
capabilities_ = capabilities;
}
void server::register_method(const std::string& method, std::function<json(const json&)> handler) {
std::lock_guard<std::mutex> lock(mutex_);
method_handlers_[method] = handler;
}
void server::register_notification(const std::string& method, std::function<void(const json&)> handler) {
std::lock_guard<std::mutex> lock(mutex_);
notification_handlers_[method] = handler;
}
void server::register_resource(const std::string& path, std::shared_ptr<resource> resource) {
std::lock_guard<std::mutex> lock(mutex_);
resources_[path] = resource;
// Register methods for resource access
if (method_handlers_.find("resources/metadata") == method_handlers_.end()) {
method_handlers_["resources/metadata"] = [this](const json& params) {
if (!params.contains("path")) {
throw mcp_exception(error_code::invalid_params, "Missing 'path' parameter");
}
std::string path = params["path"];
auto it = resources_.find(path);
if (it == resources_.end()) {
throw mcp_exception(error_code::invalid_params, "Resource not found: " + path);
}
return it->second->get_metadata();
};
}
if (method_handlers_.find("resources/access") == method_handlers_.end()) {
method_handlers_["resources/access"] = [this](const json& params) {
if (!params.contains("path")) {
throw mcp_exception(error_code::invalid_params, "Missing 'path' parameter");
}
std::string path = params["path"];
auto it = resources_.find(path);
if (it == resources_.end()) {
throw mcp_exception(error_code::invalid_params, "Resource not found: " + path);
}
return it->second->access(params);
};
}
}
void server::register_tool(const tool& tool, tool_handler handler) {
std::lock_guard<std::mutex> lock(mutex_);
tools_[tool.name] = std::make_pair(tool, handler);
// Register methods for tool listing and calling
if (method_handlers_.find("tools/list") == method_handlers_.end()) {
method_handlers_["tools/list"] = [this](const json& params) {
json tools_json = json::array();
for (const auto& [name, tool_pair] : tools_) {
tools_json.push_back(tool_pair.first.to_json());
}
return tools_json;
};
}
if (method_handlers_.find("tools/call") == method_handlers_.end()) {
method_handlers_["tools/call"] = [this](const json& params) {
if (!params.contains("name")) {
throw mcp_exception(error_code::invalid_params, "Missing 'name' parameter");
}
std::string tool_name = params["name"];
auto it = tools_.find(tool_name);
if (it == tools_.end()) {
throw mcp_exception(error_code::invalid_params, "Tool not found: " + tool_name);
}
json tool_params = params.contains("parameters") ? params["parameters"] : json::object();
return it->second.second(tool_params);
};
}
}
std::vector<tool> server::get_tools() const {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<tool> tools;
for (const auto& [name, tool_pair] : tools_) {
tools.push_back(tool_pair.first);
}
return tools;
}
void server::set_auth_handler(std::function<bool(const std::string&)> handler) {
std::lock_guard<std::mutex> lock(mutex_);
auth_handler_ = handler;
}
void server::handle_jsonrpc(const httplib::Request& req, httplib::Response& res) {
// Set response headers
res.set_header("Content-Type", "application/json");
// Parse the request
json req_json;
try {
req_json = json::parse(req.body);
} catch (const json::exception& e) {
// Invalid JSON
json error_response = {
{"jsonrpc", "2.0"},
{"error", {
{"code", static_cast<int>(error_code::parse_error)},
{"message", "Parse error: " + std::string(e.what())}
}},
{"id", nullptr}
};
res.set_content(error_response.dump(), "application/json");
return;
}
// Check if it's a batch request
if (req_json.is_array()) {
// Batch request not supported yet
json error_response = {
{"jsonrpc", "2.0"},
{"error", {
{"code", static_cast<int>(error_code::invalid_request)},
{"message", "Batch requests are not supported"}
}},
{"id", nullptr}
};
res.set_content(error_response.dump(), "application/json");
return;
}
// Convert to request object
request mcp_req;
try {
mcp_req.jsonrpc = req_json["jsonrpc"];
mcp_req.method = req_json["method"];
if (req_json.contains("id")) {
mcp_req.id = req_json["id"];
}
if (req_json.contains("params")) {
mcp_req.params = req_json["params"];
}
} catch (const json::exception& e) {
// Invalid request
json error_response = {
{"jsonrpc", "2.0"},
{"error", {
{"code", static_cast<int>(error_code::invalid_request)},
{"message", "Invalid request: " + std::string(e.what())}
}},
{"id", nullptr}
};
res.set_content(error_response.dump(), "application/json");
return;
}
// Process the request
json result = process_request(mcp_req);
res.set_content(result.dump(), "application/json");
}
json server::process_request(const request& req) {
// Check if it's a notification
if (req.is_notification()) {
// Process notification asynchronously
std::thread([this, req]() {
try {
std::lock_guard<std::mutex> lock(mutex_);
auto it = notification_handlers_.find(req.method);
if (it != notification_handlers_.end()) {
it->second(req.params);
}
} catch (const std::exception& e) {
// Log error but don't send response
std::cerr << "Error processing notification: " << e.what() << std::endl;
}
}).detach();
// No response for notifications
return json::object();
}
// Handle method call
try {
// Special case for initialize
if (req.method == "initialize") {
return handle_initialize(req);
}
// Look for registered method handler
std::lock_guard<std::mutex> lock(mutex_);
auto it = method_handlers_.find(req.method);
if (it != method_handlers_.end()) {
// Call the handler
json result = it->second(req.params);
// Create success response
return response::create_success(req.id, result).to_json();
}
// Method not found
return response::create_error(
req.id,
error_code::method_not_found,
"Method not found: " + req.method
).to_json();
} catch (const mcp_exception& e) {
// MCP exception
return response::create_error(
req.id,
e.code(),
e.what()
).to_json();
} catch (const std::exception& e) {
// Generic exception
return response::create_error(
req.id,
error_code::internal_error,
"Internal error: " + std::string(e.what())
).to_json();
}
}
json server::handle_initialize(const request& req) {
const json& params = req.params;
// Version negotiation
if (!params["params"].contains("protocolVersion") || !params["params"]["protocolVersion"].is_string()) {
return response::create_error(
req.id,
error_code::invalid_params,
"Expected string for 'protocolVersion' parameter"
).to_json();
}
std::string requested_version = params["params"]["protocolVersion"].get<std::string>();
if (requested_version != MCP_VERSION) {
return response::create_error(
req.id,
error_code::invalid_params,
"Unsupported protocol version",
{
{"supported", MCP_VERSION},
{"requested", req.params["protocolVersion"]}
}
).to_json();
}
// Extract client info
std::string client_name = "UnknownClient";
std::string client_version = "UnknownVersion";
if (params.contains("clientInfo")) {
if (params["clientInfo"].contains("name")) {
client_name = params["clientInfo"]["name"];
}
if (params["clientInfo"].contains("version")) {
client_version = params["clientInfo"]["version"];
}
}
// Log connection
// std::cout << "Client connected: " << client_name << " " << client_version << std::endl;
// Return server info and capabilities
json server_info = {
{"name", name_},
{"version", version_}
};
json result = {
{"protocolVersion", MCP_VERSION},
{"capabilities", capabilities_},
{"serverInfo", server_info}
};
return response::create_success(req.id, result).to_json();
}
} // namespace mcp