cpp-mcp/src/mcp_server.cpp

340 lines
10 KiB
C++
Raw Normal View History

2025-03-08 01:50:39 +08:00
/**
* @file mcp_server.cpp
* @brief Implementation of the MCP server
*/
#include "mcp_server.h"
namespace mcp {
server::server(const std::string& host, int port)
: host_(host), port_(port), name_("MCP Server"),
cors_enabled_(false), allowed_origins_("*") {
http_server_ = std::make_unique<httplib::Server>();
}
server::~server() {
stop();
}
bool server::start(bool blocking) {
if (running_) {
return true; // Already running
}
// Mount handlers for common paths
// Root handler - returns server info
http_server_->Get("/", [this](const httplib::Request& req, httplib::Response& res) {
set_cors_headers(res);
json info = get_server_info();
res.set_content(info.dump(), "application/json");
});
// Tools listing
http_server_->Get("/tools", [this](const httplib::Request& req, httplib::Response& res) {
set_cors_headers(res);
json tools_json = get_tools();
res.set_content(tools_json.dump(), "application/json");
});
// General request handler for all other paths
http_server_->set_mount_point("/", "");
http_server_->Get(".*", [this](const httplib::Request& req, httplib::Response& res) {
this->handle_request(req, res);
});
http_server_->Post(".*", [this](const httplib::Request& req, httplib::Response& res) {
this->handle_request(req, res);
});
http_server_->Put(".*", [this](const httplib::Request& req, httplib::Response& res) {
this->handle_request(req, res);
});
http_server_->Delete(".*", [this](const httplib::Request& req, httplib::Response& res) {
this->handle_request(req, res);
});
http_server_->Patch(".*", [this](const httplib::Request& req, httplib::Response& res) {
this->handle_request(req, res);
});
http_server_->Options(".*", [this](const httplib::Request& req, httplib::Response& res) {
set_cors_headers(res);
res.status = 204; // No content
});
// Start the server
if (blocking) {
running_ = true;
if (!http_server_->listen(host_.c_str(), port_)) {
running_ = false;
std::cerr << "Failed to start server on " << host_ << ":" << port_ << std::endl;
return false;
}
return true;
} else {
// Start in a separate thread
server_thread_ = std::make_unique<std::thread>([this]() {
if (!http_server_->listen(host_.c_str(), port_)) {
std::cerr << "Failed to start server on " << host_ << ":" << port_ << std::endl;
running_ = false;
return;
}
});
running_ = true;
return true;
}
}
void server::stop() {
if (!running_) {
return;
}
if (http_server_) {
http_server_->stop();
}
if (server_thread_ && server_thread_->joinable()) {
server_thread_->join();
}
running_ = false;
}
bool server::is_running() const {
return running_;
}
void server::register_resource(const std::string& path, std::shared_ptr<resource> resource) {
std::lock_guard<std::mutex> lock(mutex_);
resources_[path] = resource;
}
void server::register_tool(const tool& tool, tool_handler handler) {
std::lock_guard<std::mutex> lock(mutex_);
tools_[tool.name] = std::make_pair(tool, handler);
// Register a POST endpoint for the tool
http_server_->Post("/tools/" + tool.name, [this, tool_name = tool.name](
const httplib::Request& req, httplib::Response& res) {
set_cors_headers(res);
// Check if the tool exists
std::lock_guard<std::mutex> lock(mutex_);
auto it = tools_.find(tool_name);
if (it == tools_.end()) {
res.status = 404;
json error = {
{"error", {
{"code", 404},
{"message", "Tool not found: " + tool_name}
}}
};
res.set_content(error.dump(), "application/json");
return;
}
// Parse the request body
json params;
try {
if (!req.body.empty()) {
params = json::parse(req.body);
}
} catch (const json::exception& e) {
res.status = 400;
json error = {
{"error", {
{"code", 400},
{"message", "Invalid JSON: " + std::string(e.what())}
}}
};
res.set_content(error.dump(), "application/json");
return;
}
// Execute the tool
try {
json result = it->second.second(params);
res.set_content(result.dump(), "application/json");
} catch (const std::exception& e) {
res.status = 500;
json error = {
{"error", {
{"code", 500},
{"message", "Tool execution error: " + std::string(e.what())}
}}
};
res.set_content(error.dump(), "application/json");
}
});
}
json server::get_tools() const {
std::lock_guard<std::mutex> lock(mutex_);
json tools_json = json::array();
for (const auto& [name, tool_pair] : tools_) {
tools_json.push_back(tool_pair.first.to_json());
}
return tools_json;
}
json server::get_server_info() const {
std::lock_guard<std::mutex> lock(mutex_);
json info = {
{"name", name_},
{"version", MCP_VERSION},
{"resources", json::array()},
{"tools_count", tools_.size()}
};
// Add resources info
for (const auto& [path, resource] : resources_) {
json res_info = {
{"path", path},
{"type", static_cast<int>(resource->type())},
{"metadata", resource->metadata()}
};
info["resources"].push_back(res_info);
}
return info;
}
void server::set_cors(bool enable, const std::string& allowed_origins) {
cors_enabled_ = enable;
allowed_origins_ = allowed_origins;
}
void server::set_name(const std::string& name) {
std::lock_guard<std::mutex> lock(mutex_);
name_ = name;
}
void server::handle_request(const httplib::Request& req, httplib::Response& res) {
set_cors_headers(res);
// Convert the httplib request to our internal request format
request mcp_req = convert_request(req);
// Find a resource that matches the path
std::shared_ptr<resource> resource_ptr = nullptr;
std::string matched_path;
{
std::lock_guard<std::mutex> lock(mutex_);
// Look for an exact match first
auto it = resources_.find(mcp_req.path);
if (it != resources_.end()) {
resource_ptr = it->second;
matched_path = mcp_req.path;
} else {
// Look for a prefix match
for (const auto& [path, res] : resources_) {
if (mcp_req.path.find(path) == 0) {
// This path is a prefix of the requested path
if (matched_path.empty() || path.length() > matched_path.length()) {
// Use the longest matching prefix
resource_ptr = res;
matched_path = path;
}
}
}
}
}
// Handle the request
response mcp_res;
if (resource_ptr) {
try {
// Adjust the path to be relative to the resource path
if (!matched_path.empty() && matched_path != "/") {
mcp_req.path = mcp_req.path.substr(matched_path.length());
// Ensure the path starts with a slash
if (mcp_req.path.empty() || mcp_req.path[0] != '/') {
mcp_req.path = "/" + mcp_req.path;
}
}
// Handle the request
mcp_res = resource_ptr->handle_request(mcp_req);
} catch (const mcp_exception& e) {
mcp_res.set_error(e.code(), e.what());
} catch (const std::exception& e) {
mcp_res.set_error(error_code::internal_server_error, e.what());
}
} else {
// No resource found
mcp_res.set_error(error_code::not_found, "Resource not found: " + mcp_req.path);
}
// Apply the response
apply_response(mcp_res, res);
}
request server::convert_request(const httplib::Request& req) {
request mcp_req;
// Set method
mcp_req.method = string_to_http_method(req.method);
// Set path
mcp_req.path = req.path;
// Set headers
for (const auto& [key, value] : req.headers) {
mcp_req.headers[key] = value;
}
// Set query parameters
for (const auto& [key, value] : req.params) {
mcp_req.query_params[key] = value;
}
// Set body
mcp_req.body = req.body;
return mcp_req;
}
void server::apply_response(const response& mcp_res, httplib::Response& http_res) {
// Set status code
http_res.status = mcp_res.status_code;
// Set headers
for (const auto& [key, value] : mcp_res.headers) {
http_res.set_header(key.c_str(), value.c_str());
}
// Set body
http_res.body = mcp_res.body;
// Set content type if not already set
if (http_res.get_header_value("Content-Type").empty() && !mcp_res.body.empty()) {
// Guess content type based on the response body
if (mcp_res.body[0] == '{' || mcp_res.body[0] == '[') {
http_res.set_header("Content-Type", "application/json");
} else {
http_res.set_header("Content-Type", "text/plain");
}
}
}
void server::set_cors_headers(httplib::Response& res) {
if (cors_enabled_) {
res.set_header("Access-Control-Allow-Origin", allowed_origins_.c_str());
res.set_header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS");
res.set_header("Access-Control-Allow-Headers", "Content-Type, Authorization");
res.set_header("Access-Control-Allow-Credentials", "true");
}
}
} // namespace mcp