diff --git a/examples/server_example.cpp b/examples/server_example.cpp index f995ab0..9846cc6 100644 --- a/examples/server_example.cpp +++ b/examples/server_example.cpp @@ -123,7 +123,7 @@ int main() { // Create and configure server mcp::server server("localhost", 8089); - server.set_server_info("ExampleServer", "2024-11-05"); + server.set_server_info("ExampleServer", "1.0.0"); // Set server capabilities mcp::json capabilities = { diff --git a/include/mcp_server.h b/include/mcp_server.h index d0e7ef3..49e19e0 100644 --- a/include/mcp_server.h +++ b/include/mcp_server.h @@ -35,48 +35,58 @@ namespace mcp { class event_dispatcher { public: event_dispatcher() = default; + + ~event_dispatcher() { + close(); + } bool wait_event(httplib::DataSink* sink, const std::chrono::milliseconds& timeout = std::chrono::milliseconds(30000)) { if (!sink) { return false; } - std::unique_lock lk(m_); - - // 如果连接已关闭,返回false - if (closed_) { - return false; + std::string message_copy; + { + std::unique_lock lk(m_); + + // 如果连接已关闭,返回false + if (closed_) { + return false; + } + + int id = id_; + + // 使用超时等待 + bool result = cv_.wait_for(lk, timeout, [&] { + return cid_ == id || closed_; + }); + + // 如果连接已关闭或等待超时,返回false + if (closed_) { + return false; + } + + if (!result) { + std::cerr << "等待事件超时" << std::endl; + return false; + } + + // 复制消息,避免在锁外访问共享数据 + message_copy = message_; } - int id = id_; - - // 使用超时等待 - bool result = cv_.wait_for(lk, timeout, [&] { - return cid_ == id || closed_; - }); - - // 如果连接已关闭或等待超时,返回false - if (closed_) { - return false; - } - - if (!result) { - std::cerr << "等待事件超时" << std::endl; - return false; - } - - // 写入数据 + // 写入数据 - 在锁外进行,避免长时间持有锁 try { - bool write_result = sink->write(message_.data(), message_.size()); + bool write_result = sink->write(message_copy.data(), message_copy.size()); if (!write_result) { std::cerr << "写入事件数据失败: 客户端可能已关闭连接" << std::endl; - closed_ = true; + close(); return false; } return true; } catch (const std::exception& e) { std::cerr << "写入事件数据失败: " << e.what() << std::endl; - closed_ = true; + close(); return false; } } @@ -97,8 +107,10 @@ public: void close() { std::lock_guard lk(m_); - closed_ = true; - cv_.notify_all(); + if (!closed_) { + closed_ = true; + cv_.notify_all(); + } } bool is_closed() const { @@ -290,7 +302,7 @@ private: json process_request(const request& req, const std::string& session_id); // Handle initialization request - json handle_initialize(const request& req); + json handle_initialize(const request& req, const std::string& session_id); // Check if a session is initialized bool is_session_initialized(const std::string& session_id) const; diff --git a/src/mcp_server.cpp b/src/mcp_server.cpp index c37c827..790d2a9 100644 --- a/src/mcp_server.cpp +++ b/src/mcp_server.cpp @@ -90,35 +90,50 @@ void server::stop() { running_ = false; // 关闭所有SSE连接 + std::vector session_ids; { std::lock_guard lock(mutex_); - for (auto& [session_id, dispatcher] : session_dispatchers_) { - try { - std::cerr << "关闭会话: " << session_id << std::endl; - dispatcher->close(); - } catch (const std::exception& e) { - std::cerr << "关闭会话时发生异常: " << session_id << ", " << e.what() << std::endl; - } + // 先收集所有会话ID + for (const auto& [session_id, _] : session_dispatchers_) { + session_ids.push_back(session_id); } } - // 等待所有SSE线程结束 + // 关闭每个会话的分发器 + for (const auto& session_id : session_ids) { + try { + std::cerr << "关闭会话: " << session_id << std::endl; + std::lock_guard lock(mutex_); + auto it = session_dispatchers_.find(session_id); + if (it != session_dispatchers_.end()) { + it->second->close(); + } + } catch (const std::exception& e) { + std::cerr << "关闭会话时发生异常: " << session_id << ", " << e.what() << std::endl; + } + } + + // 等待一段时间,让会话线程有机会自行清理 + std::cerr << "等待会话线程自行清理..." << std::endl; + std::this_thread::sleep_for(std::chrono::seconds(1)); + + // 清理剩余的线程 { std::lock_guard lock(mutex_); - for (auto it = sse_threads_.begin(); it != sse_threads_.end();) { - auto& [session_id, thread] = *it; + for (auto& [session_id, thread] : sse_threads_) { if (thread && thread->joinable()) { try { - std::cerr << "等待会话线程结束: " << session_id << std::endl; - // 分离线程而不是等待它结束,避免可能的死锁 + std::cerr << "分离会话线程: " << session_id << std::endl; thread->detach(); } catch (const std::exception& e) { - std::cerr << "等待会话线程结束时发生异常: " << session_id << ", " << e.what() << std::endl; + std::cerr << "分离会话线程时发生异常: " << session_id << ", " << e.what() << std::endl; } } - it = sse_threads_.erase(it); } + + // 清空映射 session_dispatchers_.clear(); + sse_threads_.clear(); } if (http_server_) { @@ -300,6 +315,18 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) { std::string session_uri = msg_endpoint_ + "?session_id=" + session_id; std::cerr << "新的SSE连接: 客户端=" << req.remote_addr << ", 会话ID=" << session_id << std::endl; + std::cerr << "会话URI: " << session_uri << std::endl; + std::cerr << "请求头: "; + for (const auto& [key, value] : req.headers) { + std::cerr << key << "=" << value << " "; + } + std::cerr << std::endl; + + // 设置SSE响应头 + res.set_header("Content-Type", "text/event-stream"); + res.set_header("Cache-Control", "no-cache"); + res.set_header("Connection", "keep-alive"); + res.set_header("Access-Control-Allow-Origin", "*"); // 创建会话特定的事件分发器 auto session_dispatcher = std::make_shared(); @@ -315,8 +342,8 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) { try { std::cerr << "SSE会话线程启动: " << session_id << std::endl; - // 发送初始会话URI - std::this_thread::sleep_for(std::chrono::seconds(1)); + // 发送初始会话URI - 使用endpoint事件类型,符合MCP规范 + std::this_thread::sleep_for(std::chrono::milliseconds(500)); std::stringstream ss; ss << "event: endpoint\ndata: " << session_uri << "\n\n"; session_dispatcher->send_event(ss.str()); @@ -339,7 +366,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) { break; } - // 发送心跳事件 + // 发送心跳事件 - 使用自定义heartbeat事件类型 std::stringstream heartbeat; heartbeat << "event: heartbeat\ndata: " << heartbeat_count++ << "\n\n"; @@ -361,11 +388,31 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) { std::cerr << "SSE会话线程异常: " << session_id << ", " << e.what() << std::endl; } - // 清理资源 - { + // 安全地清理资源 + try { std::lock_guard lock(mutex_); - session_dispatchers_.erase(session_id); - sse_threads_.erase(session_id); + + // 先关闭分发器 + auto dispatcher_it = session_dispatchers_.find(session_id); + if (dispatcher_it != session_dispatchers_.end()) { + if (!dispatcher_it->second->is_closed()) { + dispatcher_it->second->close(); + } + session_dispatchers_.erase(dispatcher_it); + } + + // 再移除线程 + auto thread_it = sse_threads_.find(session_id); + if (thread_it != sse_threads_.end()) { + // 不要在线程内部join或detach自己 + // 只从映射中移除 + thread_it->second.release(); // 释放所有权但不删除线程对象 + sse_threads_.erase(thread_it); + } + + std::cerr << "会话资源已清理: " << session_id << std::endl; + } catch (const std::exception& e) { + std::cerr << "清理会话资源时发生异常: " << session_id << ", " << e.what() << std::endl; } }); @@ -385,14 +432,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) { if (it == session_dispatchers_.end() || it->second->is_closed()) { std::cerr << "会话已关闭,停止内容提供: " << session_id << std::endl; - // 清理资源 - auto thread_it = sse_threads_.find(session_id); - if (thread_it != sse_threads_.end() && thread_it->second && thread_it->second->joinable()) { - thread_it->second->detach(); // 分离线程,让它自行清理 - } - session_dispatchers_.erase(session_id); - sse_threads_.erase(session_id); - + // 不在这里清理资源,让会话线程自己清理 return false; } } @@ -402,21 +442,13 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) { if (!result) { std::cerr << "等待事件失败,关闭连接: " << session_id << std::endl; - // 关闭会话 + // 关闭会话分发器,但不清理资源 { std::lock_guard lock(mutex_); auto it = session_dispatchers_.find(session_id); if (it != session_dispatchers_.end()) { it->second->close(); } - - // 清理资源 - auto thread_it = sse_threads_.find(session_id); - if (thread_it != sse_threads_.end() && thread_it->second && thread_it->second->joinable()) { - thread_it->second->detach(); // 分离线程,让它自行清理 - } - session_dispatchers_.erase(session_id); - sse_threads_.erase(session_id); } return false; @@ -426,6 +458,18 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) { return true; } catch (const std::exception& e) { std::cerr << "SSE内容提供者异常: " << e.what() << std::endl; + + // 关闭会话分发器,但不清理资源 + try { + std::lock_guard lock(mutex_); + auto it = session_dispatchers_.find(session_id); + if (it != session_dispatchers_.end()) { + it->second->close(); + } + } catch (const std::exception& e2) { + std::cerr << "关闭会话分发器时发生异常: " << e2.what() << std::endl; + } + return false; } }); @@ -449,22 +493,40 @@ void server::handle_jsonrpc(const httplib::Request& req, httplib::Response& res) std::string session_id = it != req.params.end() ? it->second : ""; std::cerr << "收到JSON-RPC请求: 会话ID=" << session_id << ", 路径=" << req.path << std::endl; + std::cerr << "请求参数: "; + for (const auto& [key, value] : req.params) { + std::cerr << key << "=" << value << " "; + } + std::cerr << std::endl; // 检查会话是否存在 { std::lock_guard lock(mutex_); - if (!session_id.empty() && session_dispatchers_.find(session_id) == session_dispatchers_.end()) { - std::cerr << "会话不存在: " << session_id << std::endl; - json error_response = { - {"jsonrpc", "2.0"}, - {"error", { - {"code", static_cast(error_code::invalid_request)}, - {"message", "Session not found"} - }}, - {"id", nullptr} - }; - res.set_content(error_response.dump(), "application/json"); - return; + if (!session_id.empty()) { + std::cerr << "检查会话是否存在: " << session_id << std::endl; + std::cerr << "当前活跃会话: "; + for (const auto& [id, _] : session_dispatchers_) { + std::cerr << id << " "; + } + std::cerr << std::endl; + + if (session_dispatchers_.find(session_id) == session_dispatchers_.end()) { + std::cerr << "会话不存在: " << session_id << std::endl; + json error_response = { + {"jsonrpc", "2.0"}, + {"error", { + {"code", static_cast(error_code::invalid_request)}, + {"message", "Session not found"} + }}, + {"id", nullptr} + }; + res.set_content(error_response.dump(), "application/json"); + return; + } else { + std::cerr << "会话存在: " << session_id << std::endl; + } + } else { + std::cerr << "请求中没有会话ID" << std::endl; } } @@ -545,7 +607,9 @@ json server::process_request(const request& req, const std::string& session_id) std::cerr << "处理通知: " << req.method << std::endl; // 通知没有响应 if (req.method == "notifications/initialized") { + std::cerr << "收到客户端initialized通知,会话: " << session_id << std::endl; set_session_initialized(session_id, true); + std::cerr << "会话已设置为初始化状态: " << session_id << std::endl; } return json::object(); } @@ -556,7 +620,7 @@ json server::process_request(const request& req, const std::string& session_id) // 特殊情况:初始化 if (req.method == "initialize") { - return handle_initialize(req); + return handle_initialize(req, session_id); } else if (req.method == "ping") { // 接收者必须立即响应一个空响应 LOG_INFO("处理ping请求"); @@ -629,11 +693,14 @@ json server::process_request(const request& req, const std::string& session_id) } } -json server::handle_initialize(const request& req) { +json server::handle_initialize(const request& req, const std::string& session_id) { const json& params = req.params; + std::cerr << "处理initialize请求,会话ID: " << session_id << std::endl; + // Version negotiation if (!params.contains("protocolVersion") || !params["protocolVersion"].is_string()) { + std::cerr << "缺少protocolVersion参数或格式不正确" << std::endl; return response::create_error( req.id, error_code::invalid_params, @@ -642,8 +709,10 @@ json server::handle_initialize(const request& req) { } std::string requested_version = params["protocolVersion"].get(); + std::cerr << "客户端请求的协议版本: " << requested_version << std::endl; if (requested_version != MCP_VERSION) { + std::cerr << "不支持的协议版本: " << requested_version << ", 服务器支持: " << MCP_VERSION << std::endl; return response::create_error( req.id, error_code::invalid_params, @@ -669,7 +738,7 @@ json server::handle_initialize(const request& req) { } // Log connection - // std::cout << "Client connected: " << client_name << " " << client_version << std::endl; + std::cerr << "客户端连接: " << client_name << " " << client_version << std::endl; // Return server info and capabilities json server_info = { @@ -683,8 +752,8 @@ json server::handle_initialize(const request& req) { {"serverInfo", server_info} }; - // set_session_initialized(session_id, false); - + std::cerr << "初始化成功,等待客户端发送notifications/initialized通知" << std::endl; + return response::create_success(req.id, result).to_json(); } @@ -703,8 +772,28 @@ void server::send_request(const std::string& session_id, const std::string& meth // Create request request req = request::create(method, params); - // TODO: Implement actual request sending logic - // This would typically involve sending an HTTP request to the client + // 获取会话分发器 + std::shared_ptr dispatcher; + { + std::lock_guard lock(mutex_); + auto it = session_dispatchers_.find(session_id); + if (it == session_dispatchers_.end()) { + std::cerr << "会话不存在: " << session_id << std::endl; + return; + } + dispatcher = it->second; + } + + // 发送请求 - 使用message事件类型,符合MCP规范 + std::stringstream ss; + ss << "event: message\ndata: " << req.to_json().dump() << "\n\n"; + bool result = dispatcher->send_event(ss.str()); + + if (!result) { + std::cerr << "向会话发送请求失败: " << session_id << std::endl; + } else { + std::cerr << "成功向会话 " << session_id << " 发送请求: " << method << std::endl; + } } // Check if a session is initialized @@ -720,7 +809,7 @@ void server::set_session_initialized(const std::string& session_id, bool initial session_initialized_[session_id] = initialized; } -// Generate a random session ID +// Generate a random session ID in UUID format std::string server::generate_session_id() const { std::random_device rd; std::mt19937 gen(rd()); @@ -729,7 +818,28 @@ std::string server::generate_session_id() const { std::stringstream ss; ss << std::hex; - for (int i = 0; i < 32; ++i) { + // UUID format: 8-4-4-4-12 hexadecimal digits + for (int i = 0; i < 8; ++i) { + ss << dis(gen); + } + ss << "-"; + + for (int i = 0; i < 4; ++i) { + ss << dis(gen); + } + ss << "-"; + + for (int i = 0; i < 4; ++i) { + ss << dis(gen); + } + ss << "-"; + + for (int i = 0; i < 4; ++i) { + ss << dis(gen); + } + ss << "-"; + + for (int i = 0; i < 12; ++i) { ss << dis(gen); }