372 lines
12 KiB
C++
372 lines
12 KiB
C++
/**
|
|
* @file mcp_server.cpp
|
|
* @brief Implementation of the MCP server
|
|
*
|
|
* This file implements the server-side functionality for the Model Context Protocol.
|
|
* Follows the 2024-11-05 basic protocol specification.
|
|
*/
|
|
|
|
#include "mcp_server.h"
|
|
|
|
namespace mcp {
|
|
|
|
server::server(const std::string& host, int port)
|
|
: host_(host), port_(port), name_("MCP Server"), version_(MCP_VERSION) {
|
|
|
|
http_server_ = std::make_unique<httplib::Server>();
|
|
}
|
|
|
|
server::~server() {
|
|
stop();
|
|
}
|
|
|
|
bool server::start(bool blocking) {
|
|
if (running_) {
|
|
return true; // Already running
|
|
}
|
|
|
|
// Set up JSON-RPC endpoint
|
|
http_server_->Post("/jsonrpc", [this](const httplib::Request& req, httplib::Response& res) {
|
|
this->handle_jsonrpc(req, res);
|
|
});
|
|
|
|
// Start the server
|
|
if (blocking) {
|
|
running_ = true;
|
|
if (!http_server_->listen(host_.c_str(), port_)) {
|
|
running_ = false;
|
|
std::cerr << "Failed to start server on " << host_ << ":" << port_ << std::endl;
|
|
return false;
|
|
}
|
|
return true;
|
|
} else {
|
|
// Start in a separate thread
|
|
server_thread_ = std::make_unique<std::thread>([this]() {
|
|
if (!http_server_->listen(host_.c_str(), port_)) {
|
|
std::cerr << "Failed to start server on " << host_ << ":" << port_ << std::endl;
|
|
running_ = false;
|
|
return;
|
|
}
|
|
});
|
|
running_ = true;
|
|
return true;
|
|
}
|
|
}
|
|
|
|
void server::stop() {
|
|
if (!running_) {
|
|
return;
|
|
}
|
|
|
|
if (http_server_) {
|
|
http_server_->stop();
|
|
}
|
|
|
|
if (server_thread_ && server_thread_->joinable()) {
|
|
server_thread_->join();
|
|
}
|
|
|
|
running_ = false;
|
|
}
|
|
|
|
bool server::is_running() const {
|
|
return running_;
|
|
}
|
|
|
|
void server::set_server_info(const std::string& name, const std::string& version) {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
name_ = name;
|
|
version_ = version;
|
|
}
|
|
|
|
void server::set_capabilities(const json& capabilities) {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
capabilities_ = capabilities;
|
|
}
|
|
|
|
void server::register_method(const std::string& method, std::function<json(const json&)> handler) {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
method_handlers_[method] = handler;
|
|
}
|
|
|
|
void server::register_notification(const std::string& method, std::function<void(const json&)> handler) {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
notification_handlers_[method] = handler;
|
|
}
|
|
|
|
void server::register_resource(const std::string& path, std::shared_ptr<resource> resource) {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
resources_[path] = resource;
|
|
|
|
// Register methods for resource access
|
|
if (method_handlers_.find("resources/metadata") == method_handlers_.end()) {
|
|
method_handlers_["resources/metadata"] = [this](const json& params) {
|
|
if (!params.contains("path")) {
|
|
throw mcp_exception(error_code::invalid_params, "Missing 'path' parameter");
|
|
}
|
|
|
|
std::string path = params["path"];
|
|
auto it = resources_.find(path);
|
|
if (it == resources_.end()) {
|
|
throw mcp_exception(error_code::invalid_params, "Resource not found: " + path);
|
|
}
|
|
|
|
return it->second->get_metadata();
|
|
};
|
|
}
|
|
|
|
if (method_handlers_.find("resources/access") == method_handlers_.end()) {
|
|
method_handlers_["resources/access"] = [this](const json& params) {
|
|
if (!params.contains("path")) {
|
|
throw mcp_exception(error_code::invalid_params, "Missing 'path' parameter");
|
|
}
|
|
|
|
std::string path = params["path"];
|
|
auto it = resources_.find(path);
|
|
if (it == resources_.end()) {
|
|
throw mcp_exception(error_code::invalid_params, "Resource not found: " + path);
|
|
}
|
|
|
|
return it->second->access(params);
|
|
};
|
|
}
|
|
}
|
|
|
|
void server::register_tool(const tool& tool, tool_handler handler) {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
tools_[tool.name] = std::make_pair(tool, handler);
|
|
|
|
// Register methods for tool listing and calling
|
|
if (method_handlers_.find("tools/list") == method_handlers_.end()) {
|
|
method_handlers_["tools/list"] = [this](const json& params) {
|
|
json tools_json = json::array();
|
|
for (const auto& [name, tool_pair] : tools_) {
|
|
tools_json.push_back(tool_pair.first.to_json());
|
|
}
|
|
return tools_json;
|
|
};
|
|
}
|
|
|
|
if (method_handlers_.find("tools/call") == method_handlers_.end()) {
|
|
method_handlers_["tools/call"] = [this](const json& params) {
|
|
if (!params.contains("name")) {
|
|
throw mcp_exception(error_code::invalid_params, "Missing 'name' parameter");
|
|
}
|
|
|
|
std::string tool_name = params["name"];
|
|
auto it = tools_.find(tool_name);
|
|
if (it == tools_.end()) {
|
|
throw mcp_exception(error_code::invalid_params, "Tool not found: " + tool_name);
|
|
}
|
|
|
|
json tool_params = params.contains("parameters") ? params["parameters"] : json::object();
|
|
return it->second.second(tool_params);
|
|
};
|
|
}
|
|
}
|
|
|
|
std::vector<tool> server::get_tools() const {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
std::vector<tool> tools;
|
|
|
|
for (const auto& [name, tool_pair] : tools_) {
|
|
tools.push_back(tool_pair.first);
|
|
}
|
|
|
|
return tools;
|
|
}
|
|
|
|
void server::set_auth_handler(std::function<bool(const std::string&)> handler) {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
auth_handler_ = handler;
|
|
}
|
|
|
|
void server::handle_jsonrpc(const httplib::Request& req, httplib::Response& res) {
|
|
// Set response headers
|
|
res.set_header("Content-Type", "application/json");
|
|
|
|
// Parse the request
|
|
json req_json;
|
|
try {
|
|
req_json = json::parse(req.body);
|
|
} catch (const json::exception& e) {
|
|
// Invalid JSON
|
|
json error_response = {
|
|
{"jsonrpc", "2.0"},
|
|
{"error", {
|
|
{"code", static_cast<int>(error_code::parse_error)},
|
|
{"message", "Parse error: " + std::string(e.what())}
|
|
}},
|
|
{"id", nullptr}
|
|
};
|
|
res.set_content(error_response.dump(), "application/json");
|
|
return;
|
|
}
|
|
|
|
// Check if it's a batch request
|
|
if (req_json.is_array()) {
|
|
// Batch request not supported yet
|
|
json error_response = {
|
|
{"jsonrpc", "2.0"},
|
|
{"error", {
|
|
{"code", static_cast<int>(error_code::invalid_request)},
|
|
{"message", "Batch requests are not supported"}
|
|
}},
|
|
{"id", nullptr}
|
|
};
|
|
res.set_content(error_response.dump(), "application/json");
|
|
return;
|
|
}
|
|
|
|
// Convert to request object
|
|
request mcp_req;
|
|
try {
|
|
mcp_req.jsonrpc = req_json["jsonrpc"];
|
|
mcp_req.method = req_json["method"];
|
|
|
|
if (req_json.contains("id")) {
|
|
mcp_req.id = req_json["id"];
|
|
}
|
|
|
|
if (req_json.contains("params")) {
|
|
mcp_req.params = req_json["params"];
|
|
}
|
|
} catch (const json::exception& e) {
|
|
// Invalid request
|
|
json error_response = {
|
|
{"jsonrpc", "2.0"},
|
|
{"error", {
|
|
{"code", static_cast<int>(error_code::invalid_request)},
|
|
{"message", "Invalid request: " + std::string(e.what())}
|
|
}},
|
|
{"id", nullptr}
|
|
};
|
|
res.set_content(error_response.dump(), "application/json");
|
|
return;
|
|
}
|
|
|
|
// Process the request
|
|
json result = process_request(mcp_req);
|
|
res.set_content(result.dump(), "application/json");
|
|
}
|
|
|
|
json server::process_request(const request& req) {
|
|
// Check if it's a notification
|
|
if (req.is_notification()) {
|
|
// Process notification asynchronously
|
|
std::thread([this, req]() {
|
|
try {
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
auto it = notification_handlers_.find(req.method);
|
|
if (it != notification_handlers_.end()) {
|
|
it->second(req.params);
|
|
}
|
|
} catch (const std::exception& e) {
|
|
// Log error but don't send response
|
|
std::cerr << "Error processing notification: " << e.what() << std::endl;
|
|
}
|
|
}).detach();
|
|
|
|
// No response for notifications
|
|
return json::object();
|
|
}
|
|
|
|
// Handle method call
|
|
try {
|
|
// Special case for initialize
|
|
if (req.method == "initialize") {
|
|
return handle_initialize(req);
|
|
}
|
|
|
|
// Look for registered method handler
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
auto it = method_handlers_.find(req.method);
|
|
if (it != method_handlers_.end()) {
|
|
// Call the handler
|
|
json result = it->second(req.params);
|
|
|
|
// Create success response
|
|
return response::create_success(req.id, result).to_json();
|
|
}
|
|
|
|
// Method not found
|
|
return response::create_error(
|
|
req.id,
|
|
error_code::method_not_found,
|
|
"Method not found: " + req.method
|
|
).to_json();
|
|
} catch (const mcp_exception& e) {
|
|
// MCP exception
|
|
return response::create_error(
|
|
req.id,
|
|
e.code(),
|
|
e.what()
|
|
).to_json();
|
|
} catch (const std::exception& e) {
|
|
// Generic exception
|
|
return response::create_error(
|
|
req.id,
|
|
error_code::internal_error,
|
|
"Internal error: " + std::string(e.what())
|
|
).to_json();
|
|
}
|
|
}
|
|
|
|
json server::handle_initialize(const request& req) {
|
|
const json& params = req.params;
|
|
|
|
// Version negotiation
|
|
if (!params["params"].contains("protocolVersion") || !params["params"]["protocolVersion"].is_string()) {
|
|
return response::create_error(
|
|
req.id,
|
|
error_code::invalid_params,
|
|
"Expected string for 'protocolVersion' parameter"
|
|
).to_json();
|
|
}
|
|
|
|
std::string requested_version = params["params"]["protocolVersion"].get<std::string>();
|
|
|
|
if (requested_version != MCP_VERSION) {
|
|
return response::create_error(
|
|
req.id,
|
|
error_code::invalid_params,
|
|
"Unsupported protocol version",
|
|
{
|
|
{"supported", MCP_VERSION},
|
|
{"requested", req.params["protocolVersion"]}
|
|
}
|
|
).to_json();
|
|
}
|
|
|
|
// Extract client info
|
|
std::string client_name = "UnknownClient";
|
|
std::string client_version = "UnknownVersion";
|
|
|
|
if (params.contains("clientInfo")) {
|
|
if (params["clientInfo"].contains("name")) {
|
|
client_name = params["clientInfo"]["name"];
|
|
}
|
|
if (params["clientInfo"].contains("version")) {
|
|
client_version = params["clientInfo"]["version"];
|
|
}
|
|
}
|
|
|
|
// Log connection
|
|
// std::cout << "Client connected: " << client_name << " " << client_version << std::endl;
|
|
|
|
// Return server info and capabilities
|
|
json server_info = {
|
|
{"name", name_},
|
|
{"version", version_}
|
|
};
|
|
|
|
json result = {
|
|
{"protocolVersion", MCP_VERSION},
|
|
{"capabilities", capabilities_},
|
|
{"serverInfo", server_info}
|
|
};
|
|
|
|
return response::create_success(req.id, result).to_json();
|
|
}
|
|
|
|
} // namespace mcp
|