330 lines
9.1 KiB
C++
330 lines
9.1 KiB
C++
/**
|
||
* @file mcp_server.h
|
||
* @brief MCP Server implementation
|
||
*
|
||
* This file implements the server-side functionality for the Model Context Protocol.
|
||
* Follows the 2024-11-05 basic protocol specification.
|
||
*/
|
||
|
||
#ifndef MCP_SERVER_H
|
||
#define MCP_SERVER_H
|
||
|
||
#include "mcp_message.h"
|
||
#include "mcp_resource.h"
|
||
#include "mcp_tool.h"
|
||
#include "mcp_thread_pool.h"
|
||
#include "mcp_logger.h"
|
||
|
||
// Include the HTTP library
|
||
#include "httplib.h"
|
||
|
||
#include <string>
|
||
#include <map>
|
||
#include <vector>
|
||
#include <memory>
|
||
#include <mutex>
|
||
#include <thread>
|
||
#include <functional>
|
||
#include <chrono>
|
||
#include <condition_variable>
|
||
#include <future>
|
||
|
||
|
||
namespace mcp {
|
||
|
||
class event_dispatcher {
|
||
public:
|
||
event_dispatcher() = default;
|
||
|
||
bool wait_event(httplib::DataSink* sink, const std::chrono::milliseconds& timeout = std::chrono::milliseconds(30000)) {
|
||
if (!sink) {
|
||
return false;
|
||
}
|
||
|
||
std::unique_lock<std::mutex> lk(m_);
|
||
|
||
// 如果连接已关闭,返回false
|
||
if (closed_) {
|
||
return false;
|
||
}
|
||
|
||
int id = id_;
|
||
|
||
// 使用超时等待
|
||
bool result = cv_.wait_for(lk, timeout, [&] {
|
||
return cid_ == id || closed_;
|
||
});
|
||
|
||
// 如果连接已关闭或等待超时,返回false
|
||
if (closed_) {
|
||
return false;
|
||
}
|
||
|
||
if (!result) {
|
||
std::cerr << "等待事件超时" << std::endl;
|
||
return false;
|
||
}
|
||
|
||
// 写入数据
|
||
try {
|
||
bool write_result = sink->write(message_.data(), message_.size());
|
||
if (!write_result) {
|
||
std::cerr << "写入事件数据失败: 客户端可能已关闭连接" << std::endl;
|
||
closed_ = true;
|
||
return false;
|
||
}
|
||
return true;
|
||
} catch (const std::exception& e) {
|
||
std::cerr << "写入事件数据失败: " << e.what() << std::endl;
|
||
closed_ = true;
|
||
return false;
|
||
}
|
||
}
|
||
|
||
bool send_event(const std::string& message) {
|
||
std::lock_guard<std::mutex> lk(m_);
|
||
|
||
// 如果连接已关闭,返回失败
|
||
if (closed_) {
|
||
return false;
|
||
}
|
||
|
||
cid_ = id_++;
|
||
message_ = message;
|
||
cv_.notify_all();
|
||
return true;
|
||
}
|
||
|
||
void close() {
|
||
std::lock_guard<std::mutex> lk(m_);
|
||
closed_ = true;
|
||
cv_.notify_all();
|
||
}
|
||
|
||
bool is_closed() const {
|
||
std::lock_guard<std::mutex> lk(m_);
|
||
return closed_;
|
||
}
|
||
|
||
private:
|
||
mutable std::mutex m_;
|
||
std::condition_variable cv_;
|
||
std::atomic<int> id_{0};
|
||
std::atomic<int> cid_{-1};
|
||
std::string message_;
|
||
bool closed_ = false;
|
||
};
|
||
|
||
/**
|
||
* @class server
|
||
* @brief Main MCP server class
|
||
*
|
||
* The server class implements an HTTP server that handles JSON-RPC requests
|
||
* according to the Model Context Protocol specification.
|
||
*/
|
||
class server {
|
||
public:
|
||
/**
|
||
* @brief Constructor
|
||
* @param host The host to bind to (e.g., "localhost", "0.0.0.0")
|
||
* @param port The port to listen on
|
||
*/
|
||
server(const std::string& host = "localhost", int port = 8080,
|
||
const std::string& sse_endpoint = "/sse",
|
||
const std::string& msg_endpoint_prefix = "/message");
|
||
|
||
/**
|
||
* @brief Destructor
|
||
*/
|
||
~server();
|
||
|
||
/**
|
||
* @brief Start the server
|
||
* @param blocking If true, this call blocks until the server stops
|
||
* @return True if the server started successfully
|
||
*/
|
||
bool start(bool blocking = true);
|
||
|
||
/**
|
||
* @brief Stop the server
|
||
*/
|
||
void stop();
|
||
|
||
/**
|
||
* @brief Check if the server is running
|
||
* @return True if the server is running
|
||
*/
|
||
bool is_running() const;
|
||
|
||
/**
|
||
* @brief Set server information
|
||
* @param name The name of the server
|
||
* @param version The version of the server
|
||
*/
|
||
void set_server_info(const std::string& name, const std::string& version);
|
||
|
||
/**
|
||
* @brief Set server capabilities
|
||
* @param capabilities The capabilities of the server
|
||
*/
|
||
void set_capabilities(const json& capabilities);
|
||
|
||
/**
|
||
* @brief Register a method handler
|
||
* @param method The method name
|
||
* @param handler The function to call when the method is invoked
|
||
*/
|
||
void register_method(const std::string& method, std::function<json(const json&)> handler);
|
||
|
||
/**
|
||
* @brief Register a notification handler
|
||
* @param method The notification method name
|
||
* @param handler The function to call when the notification is received
|
||
*/
|
||
void register_notification(const std::string& method, std::function<void(const json&)> handler);
|
||
|
||
/**
|
||
* @brief Register a resource
|
||
* @param path The path to mount the resource at
|
||
* @param resource The resource to register
|
||
*/
|
||
void register_resource(const std::string& path, std::shared_ptr<resource> resource);
|
||
|
||
/**
|
||
* @brief Register a tool
|
||
* @param tool The tool to register
|
||
* @param handler The function to call when the tool is invoked
|
||
*/
|
||
void register_tool(const tool& tool, tool_handler handler);
|
||
|
||
/**
|
||
* @brief Get the list of available tools
|
||
* @return JSON array of available tools
|
||
*/
|
||
std::vector<tool> get_tools() const;
|
||
|
||
/**
|
||
* @brief Set authentication handler
|
||
* @param handler Function that takes a token and returns true if valid
|
||
*/
|
||
void set_auth_handler(std::function<bool(const std::string&)> handler);
|
||
|
||
/**
|
||
* @brief Send a request to a client
|
||
* @param session_id The session ID of the client
|
||
* @param method The method to call
|
||
* @param params The parameters to pass
|
||
*
|
||
* This method will only send requests other than ping and logging
|
||
* after the client has sent the initialized notification.
|
||
*/
|
||
void send_request(const std::string& session_id, const std::string& method, const json& params = json::object());
|
||
|
||
/**
|
||
* @brief 打印服务器状态
|
||
*
|
||
* 打印当前服务器的状态,包括活跃的会话、注册的方法等
|
||
*/
|
||
void print_status() const;
|
||
|
||
private:
|
||
std::string host_;
|
||
int port_;
|
||
std::string name_;
|
||
std::string version_;
|
||
json capabilities_;
|
||
|
||
// The HTTP server
|
||
std::unique_ptr<httplib::Server> http_server_;
|
||
|
||
// Server thread (for non-blocking mode)
|
||
std::unique_ptr<std::thread> server_thread_;
|
||
|
||
// SSE thread
|
||
std::map<std::string, std::unique_ptr<std::thread>> sse_threads_;
|
||
|
||
// Event dispatcher for server-sent events
|
||
event_dispatcher sse_dispatcher_;
|
||
|
||
// Session-specific event dispatchers
|
||
std::map<std::string, std::shared_ptr<event_dispatcher>> session_dispatchers_;
|
||
|
||
// Server-sent events endpoint
|
||
std::string sse_endpoint_;
|
||
std::string msg_endpoint_;
|
||
|
||
// Method handlers
|
||
std::map<std::string, std::function<json(const json&)>> method_handlers_;
|
||
|
||
// Notification handlers
|
||
std::map<std::string, std::function<void(const json&)>> notification_handlers_;
|
||
|
||
// Resources map (path -> resource)
|
||
std::map<std::string, std::shared_ptr<resource>> resources_;
|
||
|
||
// Tools map (name -> handler)
|
||
std::map<std::string, std::pair<tool, tool_handler>> tools_;
|
||
|
||
// Authentication handler
|
||
std::function<bool(const std::string&)> auth_handler_;
|
||
|
||
// Mutex for thread safety
|
||
mutable std::mutex mutex_;
|
||
|
||
// Running flag
|
||
bool running_ = false;
|
||
|
||
// 线程池
|
||
thread_pool thread_pool_;
|
||
|
||
// Map to track session initialization status (session_id -> initialized)
|
||
std::map<std::string, bool> session_initialized_;
|
||
|
||
// Handle SSE requests
|
||
void handle_sse(const httplib::Request& req, httplib::Response& res);
|
||
|
||
// Handle incoming JSON-RPC requests
|
||
void handle_jsonrpc(const httplib::Request& req, httplib::Response& res);
|
||
|
||
// Process a JSON-RPC request
|
||
json process_request(const request& req, const std::string& session_id);
|
||
|
||
// Handle initialization request
|
||
json handle_initialize(const request& req);
|
||
|
||
// Check if a session is initialized
|
||
bool is_session_initialized(const std::string& session_id) const;
|
||
|
||
// Set session initialization status
|
||
void set_session_initialized(const std::string& session_id, bool initialized);
|
||
|
||
// Generate a random session ID
|
||
std::string generate_session_id() const;
|
||
|
||
// 辅助函数:创建异步方法处理器
|
||
template<typename F>
|
||
std::function<std::future<json>(const json&)> make_async_handler(F&& handler) {
|
||
return [handler = std::forward<F>(handler)](const json& params) -> std::future<json> {
|
||
return std::async(std::launch::async, [handler, params]() -> json {
|
||
return handler(params);
|
||
});
|
||
};
|
||
}
|
||
|
||
// 辅助类,用于简化锁的管理
|
||
class auto_lock {
|
||
public:
|
||
explicit auto_lock(std::mutex& mutex) : lock_(mutex) {}
|
||
private:
|
||
std::lock_guard<std::mutex> lock_;
|
||
};
|
||
|
||
// 获取自动锁
|
||
auto_lock get_lock() const {
|
||
return auto_lock(mutex_);
|
||
}
|
||
};
|
||
|
||
} // namespace mcp
|
||
|
||
#endif // MCP_SERVER_H
|