diff --git a/include/mcp_logger.h b/include/mcp_logger.h new file mode 100644 index 0000000..62b9baa --- /dev/null +++ b/include/mcp_logger.h @@ -0,0 +1,121 @@ +/** + * @file mcp_logger.h + * @brief 简单的日志记录机制 + */ + +#ifndef MCP_LOGGER_H +#define MCP_LOGGER_H + +#include +#include +#include +#include +#include +#include + +namespace mcp { + +enum class log_level { + debug, + info, + warning, + error +}; + +class logger { +public: + static logger& instance() { + static logger instance; + return instance; + } + + void set_level(log_level level) { + std::lock_guard lock(mutex_); + level_ = level; + } + + template + void debug(Args&&... args) { + log(log_level::debug, std::forward(args)...); + } + + template + void info(Args&&... args) { + log(log_level::info, std::forward(args)...); + } + + template + void warning(Args&&... args) { + log(log_level::warning, std::forward(args)...); + } + + template + void error(Args&&... args) { + log(log_level::error, std::forward(args)...); + } + +private: + logger() : level_(log_level::info) {} + + template + void log_impl(std::stringstream& ss, T&& arg) { + ss << std::forward(arg); + } + + template + void log_impl(std::stringstream& ss, T&& arg, Args&&... args) { + ss << std::forward(arg); + log_impl(ss, std::forward(args)...); + } + + template + void log(log_level level, Args&&... args) { + if (level < level_) { + return; + } + + std::stringstream ss; + + // 添加时间戳 + auto now = std::chrono::system_clock::now(); + auto now_c = std::chrono::system_clock::to_time_t(now); + auto now_tm = std::localtime(&now_c); + + ss << std::put_time(now_tm, "%Y-%m-%d %H:%M:%S") << " "; + + // 添加日志级别 + switch (level) { + case log_level::debug: + ss << "[DEBUG] "; + break; + case log_level::info: + ss << "[INFO] "; + break; + case log_level::warning: + ss << "[WARNING] "; + break; + case log_level::error: + ss << "[ERROR] "; + break; + } + + // 添加日志内容 + log_impl(ss, std::forward(args)...); + + // 输出日志 + std::lock_guard lock(mutex_); + std::cerr << ss.str() << std::endl; + } + + log_level level_; + std::mutex mutex_; +}; + +#define LOG_DEBUG(...) mcp::logger::instance().debug(__VA_ARGS__) +#define LOG_INFO(...) mcp::logger::instance().info(__VA_ARGS__) +#define LOG_WARNING(...) mcp::logger::instance().warning(__VA_ARGS__) +#define LOG_ERROR(...) mcp::logger::instance().error(__VA_ARGS__) + +} // namespace mcp + +#endif // MCP_LOGGER_H \ No newline at end of file diff --git a/include/mcp_server.h b/include/mcp_server.h index 7e5e91e..d0e7ef3 100644 --- a/include/mcp_server.h +++ b/include/mcp_server.h @@ -12,6 +12,8 @@ #include "mcp_message.h" #include "mcp_resource.h" #include "mcp_tool.h" +#include "mcp_thread_pool.h" +#include "mcp_logger.h" // Include the HTTP library #include "httplib.h" @@ -25,6 +27,7 @@ #include #include #include +#include namespace mcp { @@ -64,7 +67,12 @@ public: // 写入数据 try { - sink->write(message_.data(), message_.size()); + bool write_result = sink->write(message_.data(), message_.size()); + if (!write_result) { + std::cerr << "写入事件数据失败: 客户端可能已关闭连接" << std::endl; + closed_ = true; + return false; + } return true; } catch (const std::exception& e) { std::cerr << "写入事件数据失败: " << e.what() << std::endl; @@ -73,17 +81,18 @@ public: } } - void send_event(const std::string& message) { + bool send_event(const std::string& message) { std::lock_guard lk(m_); - // 如果连接已关闭,抛出异常 + // 如果连接已关闭,返回失败 if (closed_) { - throw std::runtime_error("连接已关闭"); + return false; } cid_ = id_++; message_ = message; cv_.notify_all(); + return true; } void close() { @@ -265,6 +274,9 @@ private: // Running flag bool running_ = false; + // 线程池 + thread_pool thread_pool_; + // Map to track session initialization status (session_id -> initialized) std::map session_initialized_; @@ -288,8 +300,30 @@ private: // Generate a random session ID std::string generate_session_id() const; -}; + + // 辅助函数:创建异步方法处理器 + template + std::function(const json&)> make_async_handler(F&& handler) { + return [handler = std::forward(handler)](const json& params) -> std::future { + return std::async(std::launch::async, [handler, params]() -> json { + return handler(params); + }); + }; + } + // 辅助类,用于简化锁的管理 + class auto_lock { + public: + explicit auto_lock(std::mutex& mutex) : lock_(mutex) {} + private: + std::lock_guard lock_; + }; + + // 获取自动锁 + auto_lock get_lock() const { + return auto_lock(mutex_); + } +}; } // namespace mcp diff --git a/include/mcp_thread_pool.h b/include/mcp_thread_pool.h new file mode 100644 index 0000000..735597f --- /dev/null +++ b/include/mcp_thread_pool.h @@ -0,0 +1,117 @@ +/** + * @file mcp_thread_pool.h + * @brief 简单的线程池实现 + */ + +#ifndef MCP_THREAD_POOL_H +#define MCP_THREAD_POOL_H + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mcp { + +class thread_pool { +public: + /** + * @brief 构造函数 + * @param num_threads 线程池中的线程数量 + */ + explicit thread_pool(size_t num_threads = std::thread::hardware_concurrency()) : stop_(false) { + for (size_t i = 0; i < num_threads; ++i) { + workers_.emplace_back([this] { + while (true) { + std::function task; + + { + std::unique_lock lock(queue_mutex_); + condition_.wait(lock, [this] { + return stop_ || !tasks_.empty(); + }); + + if (stop_ && tasks_.empty()) { + return; + } + + task = std::move(tasks_.front()); + tasks_.pop(); + } + + task(); + } + }); + } + } + + /** + * @brief 析构函数 + */ + ~thread_pool() { + { + std::unique_lock lock(queue_mutex_); + stop_ = true; + } + + condition_.notify_all(); + + for (std::thread& worker : workers_) { + if (worker.joinable()) { + worker.join(); + } + } + } + + /** + * @brief 提交任务到线程池 + * @param f 任务函数 + * @param args 任务参数 + * @return 任务的future + */ + template + auto enqueue(F&& f, Args&&... args) -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...) + ); + + std::future result = task->get_future(); + + { + std::unique_lock lock(queue_mutex_); + + if (stop_) { + throw std::runtime_error("线程池已停止,无法添加任务"); + } + + tasks_.emplace([task]() { (*task)(); }); + } + + condition_.notify_one(); + return result; + } + +private: + // 工作线程 + std::vector workers_; + + // 任务队列 + std::queue> tasks_; + + // 互斥锁和条件变量 + std::mutex queue_mutex_; + std::condition_variable condition_; + + // 停止标志 + std::atomic stop_; +}; + +} // namespace mcp + +#endif // MCP_THREAD_POOL_H \ No newline at end of file diff --git a/src/mcp_client.cpp b/src/mcp_client.cpp index b5c6d48..fd7e608 100644 --- a/src/mcp_client.cpp +++ b/src/mcp_client.cpp @@ -301,7 +301,13 @@ void client::open_sse_connection() { std::cerr << "SSE线程: 解析数据失败" << std::endl; return false; // 解析失败,关闭连接 } - return sse_running_.load(); // 如果sse_running_为false,关闭连接 + + // 检查是否应该关闭连接 + bool should_continue = sse_running_.load(); + if (!should_continue) { + std::cerr << "SSE线程: sse_running_为false,关闭连接" << std::endl; + } + return should_continue; // 如果sse_running_为false,关闭连接 }); // 检查连接是否成功 @@ -402,10 +408,37 @@ bool client::parse_sse_data(const char* data, size_t length) { // 新增方法:关闭SSE连接 void client::close_sse_connection() { + // 设置标志,这将导致SSE回调函数返回false,从而关闭连接 sse_running_ = false; + // 给一些时间让回调函数返回false并关闭连接 + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + + // 等待SSE线程结束 if (sse_thread_ && sse_thread_->joinable()) { - sse_thread_->join(); + // 设置一个合理的超时时间,例如5秒 + auto timeout = std::chrono::seconds(5); + auto start = std::chrono::steady_clock::now(); + + // 尝试在超时前等待线程结束 + while (sse_thread_->joinable() && + std::chrono::steady_clock::now() - start < timeout) { + try { + // 尝试立即加入线程 + sse_thread_->join(); + break; // 如果成功加入,跳出循环 + } catch (const std::exception& e) { + std::cerr << "等待SSE线程时出错: " << e.what() << std::endl; + // 短暂休眠,避免CPU占用过高 + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + } + + // 如果线程仍然没有结束,记录警告并分离线程 + if (sse_thread_->joinable()) { + std::cerr << "警告: SSE线程未能在超时时间内结束,分离线程" << std::endl; + sse_thread_->detach(); + } } // 清空消息端点 @@ -416,6 +449,8 @@ void client::close_sse_connection() { // 通知等待的线程(虽然消息端点为空,但可以让等待的线程检查sse_running_状态) endpoint_cv_.notify_all(); } + + std::cerr << "SSE连接已关闭" << std::endl; } json client::send_jsonrpc(const request& req) { diff --git a/src/mcp_server.cpp b/src/mcp_server.cpp index ff5e0ab..c37c827 100644 --- a/src/mcp_server.cpp +++ b/src/mcp_server.cpp @@ -165,7 +165,7 @@ void server::register_resource(const std::string& path, std::shared_ptr json { if (!params.contains("uri")) { throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter"); } @@ -186,9 +186,9 @@ void server::register_resource(const std::string& path, std::shared_ptr json { json resources = json::array(); - + for (const auto& [uri, res] : resources_) { resources.push_back(res->get_metadata()); } @@ -209,7 +209,7 @@ void server::register_resource(const std::string& path, std::shared_ptr json { if (!params.contains("uri")) { throw mcp_exception(error_code::invalid_params, "Missing 'uri' parameter"); } @@ -227,11 +227,9 @@ void server::register_resource(const std::string& path, std::shared_ptr json { // In this implementation, we don't support resource templates - return json{ - {"resourceTemplates", json::array()} - }; + return json::array(); }; } } @@ -242,7 +240,7 @@ void server::register_tool(const tool& tool, tool_handler handler) { // Register methods for tool listing and calling if (method_handlers_.find("tools/list") == method_handlers_.end()) { - method_handlers_["tools/list"] = [this](const json& params) { + method_handlers_["tools/list"] = [this](const json& params) -> json { json tools_json = json::array(); for (const auto& [name, tool_pair] : tools_) { tools_json.push_back(tool_pair.first.to_json()); @@ -252,7 +250,7 @@ void server::register_tool(const tool& tool, tool_handler handler) { } if (method_handlers_.find("tools/call") == method_handlers_.end()) { - method_handlers_["tools/call"] = [this](const json& params) { + method_handlers_["tools/call"] = [this](const json& params) -> json { if (!params.contains("name")) { throw mcp_exception(error_code::invalid_params, "Missing 'name' parameter"); } @@ -335,12 +333,22 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) { break; } + // 检查服务器是否仍在运行 + if (!running_) { + std::cerr << "服务器已停止,停止心跳: " << session_id << std::endl; + break; + } + // 发送心跳事件 std::stringstream heartbeat; heartbeat << "event: heartbeat\ndata: " << heartbeat_count++ << "\n\n"; try { - session_dispatcher->send_event(heartbeat.str()); + bool sent = session_dispatcher->send_event(heartbeat.str()); + if (!sent) { + std::cerr << "发送心跳失败,客户端可能已关闭连接: " << session_id << std::endl; + break; + } std::cerr << "发送心跳到会话: " << session_id << ", 计数: " << heartbeat_count << std::endl; } catch (const std::exception& e) { std::cerr << "发送心跳失败,假定连接已关闭: " << e.what() << std::endl; @@ -413,6 +421,7 @@ void server::handle_sse(const httplib::Request& req, httplib::Response& res) { return false; } + return true; } catch (const std::exception& e) { @@ -543,18 +552,19 @@ json server::process_request(const request& req, const std::string& session_id) // 处理方法调用 try { - std::cerr << "处理方法调用: " << req.method << std::endl; + LOG_INFO("处理方法调用: ", req.method); // 特殊情况:初始化 if (req.method == "initialize") { return handle_initialize(req); } else if (req.method == "ping") { // 接收者必须立即响应一个空响应 - std::cerr << "处理ping请求" << std::endl; + LOG_INFO("处理ping请求"); return response::create_success(req.id, {}).to_json(); } if (!is_session_initialized(session_id)) { + LOG_WARNING("会话未初始化: ", session_id); return response::create_error( req.id, error_code::invalid_request, @@ -563,41 +573,59 @@ json server::process_request(const request& req, const std::string& session_id) } // 查找注册的方法处理器 - std::lock_guard lock(mutex_); - auto it = method_handlers_.find(req.method); - if (it != method_handlers_.end()) { + std::function handler; + { + std::lock_guard lock(mutex_); + auto it = method_handlers_.find(req.method); + if (it != method_handlers_.end()) { + handler = it->second; + } + } + + if (handler) { // 调用处理器 - std::cerr << "调用方法处理器: " << req.method << std::endl; - json result = it->second(req.params); + LOG_INFO("调用方法处理器: ", req.method); + auto future = thread_pool_.enqueue([handler, params = req.params]() -> json { + return handler(params); + }); + json result = future.get(); // 创建成功响应 - std::cerr << "方法调用成功: " << req.method << std::endl; + LOG_INFO("方法调用成功: ", req.method); return response::create_success(req.id, result).to_json(); } // 方法未找到 - std::cerr << "方法未找到: " << req.method << std::endl; + LOG_WARNING("方法未找到: ", req.method); return response::create_error( - req.id, - error_code::method_not_found, + req.id, + error_code::method_not_found, "Method not found: " + req.method ).to_json(); } catch (const mcp_exception& e) { // MCP异常 - std::cerr << "MCP异常: " << e.what() << ", 代码: " << static_cast(e.code()) << std::endl; + LOG_ERROR("MCP异常: ", e.what(), ", 代码: ", static_cast(e.code())); return response::create_error( - req.id, - e.code(), + req.id, + e.code(), e.what() ).to_json(); } catch (const std::exception& e) { - // 通用异常 - std::cerr << "处理请求时发生异常: " << e.what() << std::endl; + // 其他异常 + LOG_ERROR("处理请求时发生异常: ", e.what()); return response::create_error( - req.id, - error_code::internal_error, + req.id, + error_code::internal_error, "Internal error: " + std::string(e.what()) ).to_json(); + } catch (...) { + // 未知异常 + LOG_ERROR("处理请求时发生未知异常"); + return response::create_error( + req.id, + error_code::internal_error, + "Unknown internal error" + ).to_json(); } }