cpp-mcp/include/mcp_server.h

330 lines
9.1 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/**
* @file mcp_server.h
* @brief MCP Server implementation
*
* This file implements the server-side functionality for the Model Context Protocol.
* Follows the 2024-11-05 basic protocol specification.
*/
#ifndef MCP_SERVER_H
#define MCP_SERVER_H
#include "mcp_message.h"
#include "mcp_resource.h"
#include "mcp_tool.h"
#include "mcp_thread_pool.h"
#include "mcp_logger.h"
// Include the HTTP library
#include "httplib.h"
#include <string>
#include <map>
#include <vector>
#include <memory>
#include <mutex>
#include <thread>
#include <functional>
#include <chrono>
#include <condition_variable>
#include <future>
namespace mcp {
class event_dispatcher {
public:
event_dispatcher() = default;
bool wait_event(httplib::DataSink* sink, const std::chrono::milliseconds& timeout = std::chrono::milliseconds(30000)) {
if (!sink) {
return false;
}
std::unique_lock<std::mutex> lk(m_);
// 如果连接已关闭返回false
if (closed_) {
return false;
}
int id = id_;
// 使用超时等待
bool result = cv_.wait_for(lk, timeout, [&] {
return cid_ == id || closed_;
});
// 如果连接已关闭或等待超时返回false
if (closed_) {
return false;
}
if (!result) {
std::cerr << "等待事件超时" << std::endl;
return false;
}
// 写入数据
try {
bool write_result = sink->write(message_.data(), message_.size());
if (!write_result) {
std::cerr << "写入事件数据失败: 客户端可能已关闭连接" << std::endl;
closed_ = true;
return false;
}
return true;
} catch (const std::exception& e) {
std::cerr << "写入事件数据失败: " << e.what() << std::endl;
closed_ = true;
return false;
}
}
bool send_event(const std::string& message) {
std::lock_guard<std::mutex> lk(m_);
// 如果连接已关闭,返回失败
if (closed_) {
return false;
}
cid_ = id_++;
message_ = message;
cv_.notify_all();
return true;
}
void close() {
std::lock_guard<std::mutex> lk(m_);
closed_ = true;
cv_.notify_all();
}
bool is_closed() const {
std::lock_guard<std::mutex> lk(m_);
return closed_;
}
private:
mutable std::mutex m_;
std::condition_variable cv_;
std::atomic<int> id_{0};
std::atomic<int> cid_{-1};
std::string message_;
bool closed_ = false;
};
/**
* @class server
* @brief Main MCP server class
*
* The server class implements an HTTP server that handles JSON-RPC requests
* according to the Model Context Protocol specification.
*/
class server {
public:
/**
* @brief Constructor
* @param host The host to bind to (e.g., "localhost", "0.0.0.0")
* @param port The port to listen on
*/
server(const std::string& host = "localhost", int port = 8080,
const std::string& sse_endpoint = "/sse",
const std::string& msg_endpoint_prefix = "/message");
/**
* @brief Destructor
*/
~server();
/**
* @brief Start the server
* @param blocking If true, this call blocks until the server stops
* @return True if the server started successfully
*/
bool start(bool blocking = true);
/**
* @brief Stop the server
*/
void stop();
/**
* @brief Check if the server is running
* @return True if the server is running
*/
bool is_running() const;
/**
* @brief Set server information
* @param name The name of the server
* @param version The version of the server
*/
void set_server_info(const std::string& name, const std::string& version);
/**
* @brief Set server capabilities
* @param capabilities The capabilities of the server
*/
void set_capabilities(const json& capabilities);
/**
* @brief Register a method handler
* @param method The method name
* @param handler The function to call when the method is invoked
*/
void register_method(const std::string& method, std::function<json(const json&)> handler);
/**
* @brief Register a notification handler
* @param method The notification method name
* @param handler The function to call when the notification is received
*/
void register_notification(const std::string& method, std::function<void(const json&)> handler);
/**
* @brief Register a resource
* @param path The path to mount the resource at
* @param resource The resource to register
*/
void register_resource(const std::string& path, std::shared_ptr<resource> resource);
/**
* @brief Register a tool
* @param tool The tool to register
* @param handler The function to call when the tool is invoked
*/
void register_tool(const tool& tool, tool_handler handler);
/**
* @brief Get the list of available tools
* @return JSON array of available tools
*/
std::vector<tool> get_tools() const;
/**
* @brief Set authentication handler
* @param handler Function that takes a token and returns true if valid
*/
void set_auth_handler(std::function<bool(const std::string&)> handler);
/**
* @brief Send a request to a client
* @param session_id The session ID of the client
* @param method The method to call
* @param params The parameters to pass
*
* This method will only send requests other than ping and logging
* after the client has sent the initialized notification.
*/
void send_request(const std::string& session_id, const std::string& method, const json& params = json::object());
/**
* @brief 打印服务器状态
*
* 打印当前服务器的状态,包括活跃的会话、注册的方法等
*/
void print_status() const;
private:
std::string host_;
int port_;
std::string name_;
std::string version_;
json capabilities_;
// The HTTP server
std::unique_ptr<httplib::Server> http_server_;
// Server thread (for non-blocking mode)
std::unique_ptr<std::thread> server_thread_;
// SSE thread
std::map<std::string, std::unique_ptr<std::thread>> sse_threads_;
// Event dispatcher for server-sent events
event_dispatcher sse_dispatcher_;
// Session-specific event dispatchers
std::map<std::string, std::shared_ptr<event_dispatcher>> session_dispatchers_;
// Server-sent events endpoint
std::string sse_endpoint_;
std::string msg_endpoint_;
// Method handlers
std::map<std::string, std::function<json(const json&)>> method_handlers_;
// Notification handlers
std::map<std::string, std::function<void(const json&)>> notification_handlers_;
// Resources map (path -> resource)
std::map<std::string, std::shared_ptr<resource>> resources_;
// Tools map (name -> handler)
std::map<std::string, std::pair<tool, tool_handler>> tools_;
// Authentication handler
std::function<bool(const std::string&)> auth_handler_;
// Mutex for thread safety
mutable std::mutex mutex_;
// Running flag
bool running_ = false;
// 线程池
thread_pool thread_pool_;
// Map to track session initialization status (session_id -> initialized)
std::map<std::string, bool> session_initialized_;
// Handle SSE requests
void handle_sse(const httplib::Request& req, httplib::Response& res);
// Handle incoming JSON-RPC requests
void handle_jsonrpc(const httplib::Request& req, httplib::Response& res);
// Process a JSON-RPC request
json process_request(const request& req, const std::string& session_id);
// Handle initialization request
json handle_initialize(const request& req);
// Check if a session is initialized
bool is_session_initialized(const std::string& session_id) const;
// Set session initialization status
void set_session_initialized(const std::string& session_id, bool initialized);
// Generate a random session ID
std::string generate_session_id() const;
// 辅助函数:创建异步方法处理器
template<typename F>
std::function<std::future<json>(const json&)> make_async_handler(F&& handler) {
return [handler = std::forward<F>(handler)](const json& params) -> std::future<json> {
return std::async(std::launch::async, [handler, params]() -> json {
return handler(params);
});
};
}
// 辅助类,用于简化锁的管理
class auto_lock {
public:
explicit auto_lock(std::mutex& mutex) : lock_(mutex) {}
private:
std::lock_guard<std::mutex> lock_;
};
// 获取自动锁
auto_lock get_lock() const {
return auto_lock(mutex_);
}
};
} // namespace mcp
#endif // MCP_SERVER_H