/** * @file mcp_server.h * @brief MCP Server implementation * * This file implements the server-side functionality for the Model Context Protocol. * Follows the 2024-11-05 basic protocol specification. */ #ifndef MCP_SERVER_H #define MCP_SERVER_H #include "mcp_message.h" #include "mcp_resource.h" #include "mcp_tool.h" #include "mcp_thread_pool.h" #include "mcp_logger.h" // Include the HTTP library #include "httplib.h" #include #include #include #include #include #include #include #include #include #include namespace mcp { class event_dispatcher { public: event_dispatcher() = default; bool wait_event(httplib::DataSink* sink, const std::chrono::milliseconds& timeout = std::chrono::milliseconds(30000)) { if (!sink) { return false; } std::unique_lock lk(m_); // 如果连接已关闭,返回false if (closed_) { return false; } int id = id_; // 使用超时等待 bool result = cv_.wait_for(lk, timeout, [&] { return cid_ == id || closed_; }); // 如果连接已关闭或等待超时,返回false if (closed_) { return false; } if (!result) { std::cerr << "等待事件超时" << std::endl; return false; } // 写入数据 try { bool write_result = sink->write(message_.data(), message_.size()); if (!write_result) { std::cerr << "写入事件数据失败: 客户端可能已关闭连接" << std::endl; closed_ = true; return false; } return true; } catch (const std::exception& e) { std::cerr << "写入事件数据失败: " << e.what() << std::endl; closed_ = true; return false; } } bool send_event(const std::string& message) { std::lock_guard lk(m_); // 如果连接已关闭,返回失败 if (closed_) { return false; } cid_ = id_++; message_ = message; cv_.notify_all(); return true; } void close() { std::lock_guard lk(m_); closed_ = true; cv_.notify_all(); } bool is_closed() const { std::lock_guard lk(m_); return closed_; } private: mutable std::mutex m_; std::condition_variable cv_; std::atomic id_{0}; std::atomic cid_{-1}; std::string message_; bool closed_ = false; }; /** * @class server * @brief Main MCP server class * * The server class implements an HTTP server that handles JSON-RPC requests * according to the Model Context Protocol specification. */ class server { public: /** * @brief Constructor * @param host The host to bind to (e.g., "localhost", "0.0.0.0") * @param port The port to listen on */ server(const std::string& host = "localhost", int port = 8080, const std::string& sse_endpoint = "/sse", const std::string& msg_endpoint_prefix = "/message"); /** * @brief Destructor */ ~server(); /** * @brief Start the server * @param blocking If true, this call blocks until the server stops * @return True if the server started successfully */ bool start(bool blocking = true); /** * @brief Stop the server */ void stop(); /** * @brief Check if the server is running * @return True if the server is running */ bool is_running() const; /** * @brief Set server information * @param name The name of the server * @param version The version of the server */ void set_server_info(const std::string& name, const std::string& version); /** * @brief Set server capabilities * @param capabilities The capabilities of the server */ void set_capabilities(const json& capabilities); /** * @brief Register a method handler * @param method The method name * @param handler The function to call when the method is invoked */ void register_method(const std::string& method, std::function handler); /** * @brief Register a notification handler * @param method The notification method name * @param handler The function to call when the notification is received */ void register_notification(const std::string& method, std::function handler); /** * @brief Register a resource * @param path The path to mount the resource at * @param resource The resource to register */ void register_resource(const std::string& path, std::shared_ptr resource); /** * @brief Register a tool * @param tool The tool to register * @param handler The function to call when the tool is invoked */ void register_tool(const tool& tool, tool_handler handler); /** * @brief Get the list of available tools * @return JSON array of available tools */ std::vector get_tools() const; /** * @brief Set authentication handler * @param handler Function that takes a token and returns true if valid */ void set_auth_handler(std::function handler); /** * @brief Send a request to a client * @param session_id The session ID of the client * @param method The method to call * @param params The parameters to pass * * This method will only send requests other than ping and logging * after the client has sent the initialized notification. */ void send_request(const std::string& session_id, const std::string& method, const json& params = json::object()); /** * @brief 打印服务器状态 * * 打印当前服务器的状态,包括活跃的会话、注册的方法等 */ void print_status() const; private: std::string host_; int port_; std::string name_; std::string version_; json capabilities_; // The HTTP server std::unique_ptr http_server_; // Server thread (for non-blocking mode) std::unique_ptr server_thread_; // SSE thread std::map> sse_threads_; // Event dispatcher for server-sent events event_dispatcher sse_dispatcher_; // Session-specific event dispatchers std::map> session_dispatchers_; // Server-sent events endpoint std::string sse_endpoint_; std::string msg_endpoint_; // Method handlers std::map> method_handlers_; // Notification handlers std::map> notification_handlers_; // Resources map (path -> resource) std::map> resources_; // Tools map (name -> handler) std::map> tools_; // Authentication handler std::function auth_handler_; // Mutex for thread safety mutable std::mutex mutex_; // Running flag bool running_ = false; // 线程池 thread_pool thread_pool_; // Map to track session initialization status (session_id -> initialized) std::map session_initialized_; // Handle SSE requests void handle_sse(const httplib::Request& req, httplib::Response& res); // Handle incoming JSON-RPC requests void handle_jsonrpc(const httplib::Request& req, httplib::Response& res); // Process a JSON-RPC request json process_request(const request& req, const std::string& session_id); // Handle initialization request json handle_initialize(const request& req); // Check if a session is initialized bool is_session_initialized(const std::string& session_id) const; // Set session initialization status void set_session_initialized(const std::string& session_id, bool initialized); // Generate a random session ID std::string generate_session_id() const; // 辅助函数:创建异步方法处理器 template std::function(const json&)> make_async_handler(F&& handler) { return [handler = std::forward(handler)](const json& params) -> std::future { return std::async(std::launch::async, [handler, params]() -> json { return handler(params); }); }; } // 辅助类,用于简化锁的管理 class auto_lock { public: explicit auto_lock(std::mutex& mutex) : lock_(mutex) {} private: std::lock_guard lock_; }; // 获取自动锁 auto_lock get_lock() const { return auto_lock(mutex_); } }; } // namespace mcp #endif // MCP_SERVER_H