From c08c8fe6869f9b3d05910296f9c27222f5222e4c Mon Sep 17 00:00:00 2001 From: hkr04 Date: Fri, 14 Mar 2025 15:53:58 +0800 Subject: [PATCH] mcp_stdio_client: add Windows supprt --- CMakeLists.txt | 21 +++ include/mcp_stdio_client.h | 18 ++- include/mcp_tool.h | 2 + src/CMakeLists.txt | 4 +- src/mcp_client.cpp | 2 +- src/mcp_server.cpp | 2 +- src/mcp_stdio_client.cpp | 301 +++++++++++++++++++++++++++++++++++-- test/CMakeLists.txt | 9 +- test/mcp_test.cpp | 4 +- 9 files changed, 343 insertions(+), 20 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index da91741..15b5535 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,27 @@ cmake_minimum_required(VERSION 3.10) project(MCP VERSION 2024.11.05 LANGUAGES CXX) +set(CMAKE_WARN_UNUSED_CLI YES) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() + +option(BUILD_SHARED_LIBS "build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT}) + +if (WIN32) + add_compile_definitions(_CRT_SECURE_NO_WARNINGS) +endif() + +if (MSVC) + add_compile_options("$<$:/utf-8>") + add_compile_options("$<$:/utf-8>") + add_compile_options("$<$:/bigobj>") + add_compile_options("$<$:/bigobj>") +endif() # Set C++ standard set(CMAKE_CXX_STANDARD 17) diff --git a/include/mcp_stdio_client.h b/include/mcp_stdio_client.h index 5acf72f..dc80a13 100644 --- a/include/mcp_stdio_client.h +++ b/include/mcp_stdio_client.h @@ -25,6 +25,11 @@ #include #include +// 添加Windows平台特定的头文件 +#if defined(_WIN32) || defined(_WIN64) +#include +#endif + namespace mcp { /** @@ -176,11 +181,20 @@ private: // 进程ID int process_id_ = -1; - // 标准输入管道 +#if defined(_WIN32) || defined(_WIN64) + // Windows平台特定的进程句柄 + HANDLE process_handle_ = NULL; + + // 标准输入输出管道 (Windows) + HANDLE stdin_pipe_[2] = {NULL, NULL}; + HANDLE stdout_pipe_[2] = {NULL, NULL}; +#else + // 标准输入管道 (POSIX) int stdin_pipe_[2] = {-1, -1}; - // 标准输出管道 + // 标准输出管道 (POSIX) int stdout_pipe_[2] = {-1, -1}; +#endif // 读取线程 std::unique_ptr read_thread_; diff --git a/include/mcp_tool.h b/include/mcp_tool.h index 2824e16..e20acbe 100644 --- a/include/mcp_tool.h +++ b/include/mcp_tool.h @@ -14,6 +14,8 @@ #include #include #include +#include +#include namespace mcp { diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 627d15b..7cc9917 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -15,9 +15,7 @@ add_library(${TARGET} STATIC ../include/mcp_stdio_client.h ) -target_link_libraries(${TARGET} PUBLIC - Threads::Threads -) +target_link_libraries(${TARGET} PUBLIC ${CMAKE_THREAD_LIBS_INIT}) # 如果找到OpenSSL,链接OpenSSL库 if(OPENSSL_FOUND) diff --git a/src/mcp_client.cpp b/src/mcp_client.cpp index 94d8708..a82265d 100644 --- a/src/mcp_client.cpp +++ b/src/mcp_client.cpp @@ -124,7 +124,7 @@ bool client::ping() { try { json result = send_jsonrpc(req); return result.empty(); - } catch (const std::exception& e) { + } catch (...) { return false; } } diff --git a/src/mcp_server.cpp b/src/mcp_server.cpp index 4f0f28a..220c914 100644 --- a/src/mcp_server.cpp +++ b/src/mcp_server.cpp @@ -353,7 +353,7 @@ void server::register_tool(const tool& tool, tool_handler handler) { try { tool_result["content"] = it->second.second(tool_args); - } catch (const std::exception& e) { + } catch (...) { tool_result["isError"] = true; } diff --git a/src/mcp_stdio_client.cpp b/src/mcp_stdio_client.cpp index 11a0d88..7dde0d8 100644 --- a/src/mcp_stdio_client.cpp +++ b/src/mcp_stdio_client.cpp @@ -9,11 +9,17 @@ #include "mcp_stdio_client.h" +#if defined(_WIN32) || defined(_WIN64) +#include +#include +#else #include #include #include #include #include +#endif + #include #include #include @@ -77,7 +83,7 @@ bool stdio_client::ping() { try { json result = send_jsonrpc(req); return result.empty(); - } catch (const std::exception& e) { + } catch (...) { return false; } } @@ -199,6 +205,130 @@ bool stdio_client::start_server_process() { LOG_INFO("Starting server process: ", command_); +#if defined(_WIN32) || defined(_WIN64) + // Windows实现 + SECURITY_ATTRIBUTES sa; + sa.nLength = sizeof(SECURITY_ATTRIBUTES); + sa.bInheritHandle = TRUE; + sa.lpSecurityDescriptor = NULL; + + // 创建管道 + HANDLE child_stdin_read = NULL; + HANDLE child_stdin_write = NULL; + HANDLE child_stdout_read = NULL; + HANDLE child_stdout_write = NULL; + + if (!CreatePipe(&child_stdin_read, &child_stdin_write, &sa, 0)) { + LOG_ERROR("Failed to create stdin pipe: ", GetLastError()); + return false; + } + + if (!SetHandleInformation(child_stdin_write, HANDLE_FLAG_INHERIT, 0)) { + LOG_ERROR("Failed to set stdin pipe properties: ", GetLastError()); + CloseHandle(child_stdin_read); + CloseHandle(child_stdin_write); + return false; + } + + if (!CreatePipe(&child_stdout_read, &child_stdout_write, &sa, 0)) { + LOG_ERROR("Failed to create stdout pipe: ", GetLastError()); + CloseHandle(child_stdin_read); + CloseHandle(child_stdin_write); + return false; + } + + if (!SetHandleInformation(child_stdout_read, HANDLE_FLAG_INHERIT, 0)) { + LOG_ERROR("Failed to set stdout pipe properties: ", GetLastError()); + CloseHandle(child_stdin_read); + CloseHandle(child_stdin_write); + CloseHandle(child_stdout_read); + CloseHandle(child_stdout_write); + return false; + } + + // 准备进程启动信息 + STARTUPINFOA si; + PROCESS_INFORMATION pi; + + ZeroMemory(&si, sizeof(STARTUPINFOA)); + si.cb = sizeof(STARTUPINFOA); + si.hStdInput = child_stdin_read; + si.hStdOutput = child_stdout_write; + si.hStdError = child_stdout_write; + si.dwFlags |= STARTF_USESTDHANDLES; + + ZeroMemory(&pi, sizeof(PROCESS_INFORMATION)); + + // 准备环境变量 + std::string env_block; + if (!env_vars_.empty()) { + char* system_env = GetEnvironmentStringsA(); + if (system_env) { + // 复制系统环境变量 + const char* env_ptr = system_env; + while (*env_ptr) { + std::string env_var(env_ptr); + env_block += env_var + '\0'; + env_ptr += env_var.size() + 1; + } + FreeEnvironmentStringsA(system_env); + + // 添加自定义环境变量 + for (auto it = env_vars_.begin(); it != env_vars_.end(); ++it) { + std::string env_var = it.key() + "=" + it.value().get(); + env_block += env_var + '\0'; + } + + // 添加结束符 + env_block += '\0'; + } + } + + // 创建子进程 + std::string cmd_line = command_; + char* cmd_str = const_cast(cmd_line.c_str()); + + BOOL success = CreateProcessA( + NULL, // 应用程序名称 + cmd_str, // 命令行 + NULL, // 进程安全属性 + NULL, // 线程安全属性 + TRUE, // 继承句柄 + CREATE_NO_WINDOW, // 创建标志 + env_vars_.empty() ? NULL : (LPVOID)env_block.c_str(), // 环境变量 + NULL, // 当前目录 + &si, // 启动信息 + &pi // 进程信息 + ); + + if (!success) { + LOG_ERROR("Failed to create process: ", GetLastError()); + CloseHandle(child_stdin_read); + CloseHandle(child_stdin_write); + CloseHandle(child_stdout_read); + CloseHandle(child_stdout_write); + return false; + } + + // 关闭不需要的句柄 + CloseHandle(child_stdin_read); + CloseHandle(child_stdout_write); + CloseHandle(pi.hThread); + + // 保存进程信息 + process_id_ = pi.dwProcessId; + process_handle_ = pi.hProcess; + stdin_pipe_[0] = NULL; + stdin_pipe_[1] = child_stdin_write; + stdout_pipe_[0] = child_stdout_read; + stdout_pipe_[1] = NULL; + + // 设置非阻塞模式 + DWORD mode = PIPE_NOWAIT; + SetNamedPipeHandleState(stdout_pipe_[0], &mode, NULL, NULL); + +#else + // POSIX实现 // 创建管道 if (pipe(stdin_pipe_) == -1) { LOG_ERROR("Failed to create stdin pipe: ", strerror(errno)); @@ -292,14 +422,6 @@ bool stdio_client::start_server_process() { int flags = fcntl(stdout_pipe_[0], F_GETFL, 0); fcntl(stdout_pipe_[0], F_SETFL, flags | O_NONBLOCK); - running_ = true; - - // 启动读取线程 - read_thread_ = std::make_unique(&stdio_client::read_thread_func, this); - - // 等待一段时间,确保进程启动 - std::this_thread::sleep_for(std::chrono::milliseconds(500)); - // 检查进程是否仍在运行 int status; pid_t result = waitpid(process_id_, &status, WNOHANG); @@ -329,6 +451,34 @@ bool stdio_client::start_server_process() { return false; } +#endif + + running_ = true; + + // 启动读取线程 + read_thread_ = std::make_unique(&stdio_client::read_thread_func, this); + + // 等待一段时间,确保进程启动 + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + +#if defined(_WIN32) || defined(_WIN64) + // 检查进程是否仍在运行 + DWORD exit_code; + if (GetExitCodeProcess(process_handle_, &exit_code) && exit_code != STILL_ACTIVE) { + LOG_ERROR("Server process exited immediately with status: ", exit_code); + running_ = false; + + if (read_thread_ && read_thread_->joinable()) { + read_thread_->join(); + } + + CloseHandle(stdin_pipe_[1]); + CloseHandle(stdout_pipe_[0]); + CloseHandle(process_handle_); + + return false; + } +#endif LOG_INFO("Server process started successfully, PID: ", process_id_); return true; @@ -343,6 +493,46 @@ void stdio_client::stop_server_process() { running_ = false; +#if defined(_WIN32) || defined(_WIN64) + // Windows实现 + // 关闭管道 + if (stdin_pipe_[1] != NULL) { + CloseHandle(stdin_pipe_[1]); + stdin_pipe_[1] = NULL; + } + + if (stdout_pipe_[0] != NULL) { + CloseHandle(stdout_pipe_[0]); + stdout_pipe_[0] = NULL; + } + + // 等待读取线程结束 + if (read_thread_ && read_thread_->joinable()) { + read_thread_->join(); + } + + // 终止进程 + if (process_handle_ != NULL) { + LOG_INFO("Terminating process: ", process_id_); + TerminateProcess(process_handle_, 0); + + // 等待进程结束 + WaitForSingleObject(process_handle_, 2000); + + DWORD exit_code; + if (GetExitCodeProcess(process_handle_, &exit_code) && exit_code == STILL_ACTIVE) { + // 进程仍在运行,强制终止 + LOG_WARNING("Process did not terminate, forcing termination"); + TerminateProcess(process_handle_, 1); + WaitForSingleObject(process_handle_, 1000); + } + + CloseHandle(process_handle_); + process_handle_ = NULL; + process_id_ = -1; + } +#else + // POSIX实现 // 关闭管道 if (stdin_pipe_[1] != -1) { close(stdin_pipe_[1]); @@ -384,6 +574,7 @@ void stdio_client::stop_server_process() { process_id_ = -1; } +#endif LOG_INFO("Server process stopped"); } @@ -395,6 +586,84 @@ void stdio_client::read_thread_func() { char buffer[buffer_size]; std::string data_buffer; +#if defined(_WIN32) || defined(_WIN64) + // Windows实现 + DWORD bytes_read; + + while (running_) { + // 读取数据 + BOOL success = ReadFile(stdout_pipe_[0], buffer, buffer_size - 1, &bytes_read, NULL); + + if (success && bytes_read > 0) { + buffer[bytes_read] = '\0'; + data_buffer.append(buffer, bytes_read); + + // 处理完整的JSON-RPC消息 + size_t pos = 0; + while ((pos = data_buffer.find('\n')) != std::string::npos) { + std::string line = data_buffer.substr(0, pos); + data_buffer.erase(0, pos + 1); + + if (!line.empty()) { + try { + json message = json::parse(line); + + if (message.contains("jsonrpc") && message["jsonrpc"] == "2.0") { + if (message.contains("id") && !message["id"].is_null()) { + // 这是一个响应 + json id = message["id"]; + + std::lock_guard lock(response_mutex_); + auto it = pending_requests_.find(id); + + if (it != pending_requests_.end()) { + if (message.contains("result")) { + it->second.set_value(message["result"]); + } else if (message.contains("error")) { + json error_result = { + {"isError", true}, + {"error", message["error"]} + }; + it->second.set_value(error_result); + } else { + it->second.set_value(json::object()); + } + + pending_requests_.erase(it); + } else { + LOG_WARNING("Received response for unknown request ID: ", id); + } + } else if (message.contains("method")) { + // 这是一个请求或通知 + LOG_INFO("Received request/notification: ", message["method"]); + // 目前不处理服务器发来的请求 + } + } + } catch (const json::exception& e) { + LOG_ERROR("Failed to parse JSON-RPC message: ", e.what(), ", message: ", line); + } + } + } + } else if (!success) { + DWORD error = GetLastError(); + if (error == ERROR_BROKEN_PIPE || error == ERROR_NO_DATA) { + // 管道已关闭或没有数据 + LOG_WARNING("Pipe closed by server or no data available"); + break; + } else if (error != ERROR_IO_PENDING) { + LOG_ERROR("Error reading from pipe: ", error); + break; + } + + // 非阻塞模式下没有数据可读 + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } else { + // 非阻塞模式下没有数据可读 + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } +#else + // POSIX实现 while (running_) { // 读取数据 ssize_t bytes_read = read(stdout_pipe_[0], buffer, buffer_size - 1); @@ -463,6 +732,7 @@ void stdio_client::read_thread_func() { } } } +#endif LOG_INFO("Read thread stopped"); } @@ -475,13 +745,24 @@ json stdio_client::send_jsonrpc(const request& req) { json req_json = req.to_json(); std::string req_str = req_json.dump() + "\n"; - // 发送请求 +#if defined(_WIN32) || defined(_WIN64) + // Windows实现 + DWORD bytes_written; + BOOL success = WriteFile(stdin_pipe_[1], req_str.c_str(), static_cast(req_str.size()), &bytes_written, NULL); + + if (!success || bytes_written != static_cast(req_str.size())) { + LOG_ERROR("Failed to write complete request: ", GetLastError()); + throw mcp_exception(error_code::internal_error, "Failed to write to pipe"); + } +#else + // POSIX实现 ssize_t bytes_written = write(stdin_pipe_[1], req_str.c_str(), req_str.size()); if (bytes_written != static_cast(req_str.size())) { LOG_ERROR("Failed to write complete request: ", strerror(errno)); throw mcp_exception(error_code::internal_error, "Failed to write to pipe"); } +#endif // 如果是通知,不需要等待响应 if (req.is_notification()) { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ce6c7d0..8a1ef27 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -35,13 +35,20 @@ add_executable(${TEST_PROJECT_NAME} ${TEST_SOURCES}) # Link directories link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../build/src) +# 根据平台设置正确的库文件名 +if(WIN32) + set(MCP_LIB_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../build/src/mcp.lib") +else() + set(MCP_LIB_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../build/src/libmcp.a") +endif() + # Link Google Test and MCP library target_link_libraries(${TEST_PROJECT_NAME} PRIVATE gtest gtest_main gmock gmock_main - ${CMAKE_CURRENT_SOURCE_DIR}/../build/src/libmcp.a + mcp Threads::Threads ) diff --git a/test/mcp_test.cpp b/test/mcp_test.cpp index 9103a2f..a4c9252 100644 --- a/test/mcp_test.cpp +++ b/test/mcp_test.cpp @@ -304,7 +304,7 @@ TEST_F(VersioningTest, UnsupportedVersion) { // 添加延迟,确保资源完全释放 std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } catch (const mcp_exception& e) { + } catch (...) { EXPECT_TRUE(false); } } @@ -435,7 +435,7 @@ TEST_F(PingTest, DirectPing) { // 添加延迟,确保资源完全释放 std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } catch (const mcp_exception& e) { + } catch (...) { EXPECT_TRUE(false); } }