From 4c7a7908116a4ab2b428715cce9029c11c0ca6f8 Mon Sep 17 00:00:00 2001 From: hkr04 Date: Fri, 28 Mar 2025 17:57:20 +0800 Subject: [PATCH] refactor BaseTool -> BaseTool + BaseMCPTool --- config.cpp | 84 ++++++++++++++++++++++ config.h | 18 ++++- tool/base.h | 161 ++++++++++-------------------------------- tool/filesystem.h | 4 +- tool/puppeteer.h | 4 +- tool/python_execute.h | 4 +- tool/shell.h | 4 +- tool/terminate.h | 2 +- 8 files changed, 144 insertions(+), 137 deletions(-) diff --git a/config.cpp b/config.cpp index 8fb092e..67ed213 100644 --- a/config.cpp +++ b/config.cpp @@ -6,6 +6,90 @@ namespace humanus { +MCPToolConfig MCPToolConfig::load_from_toml(const std::string& tool_name) { + MCPToolConfig config; + + try { + // Get config file path + auto config_path = PROJECT_ROOT / "config" / "config_mcp.toml"; + if (!std::filesystem::exists(config_path)) { + throw std::runtime_error("MCP config file not found: " + config_path.string()); + } + + // Parse TOML file + const auto& data = toml::parse_file(config_path.string()); + + // Check if tool config exists + if (!data.contains(tool_name) || !data[tool_name].is_table()) { + throw std::runtime_error("Tool configuration not found in MCP config file: " + tool_name); + } + + const auto& tool_table = *data[tool_name].as_table(); + + // Read type + if (!tool_table.contains("type") || !tool_table["type"].is_string()) { + throw std::runtime_error("Tool configuration missing type field: " + tool_name); + } + config.type = tool_table["type"].as_string()->get(); + + if (config.type == "stdio") { + // Read command + if (!tool_table.contains("command") || !tool_table["command"].is_string()) { + throw std::runtime_error("stdio type tool configuration missing command field: " + tool_name); + } + config.command = tool_table["command"].as_string()->get(); + + // Read arguments (if any) + if (tool_table.contains("args") && tool_table["args"].is_array()) { + const auto& args_array = *tool_table["args"].as_array(); + for (const auto& arg : args_array) { + if (arg.is_string()) { + config.args.push_back(arg.as_string()->get()); + } + } + } + + // Read environment variables + if (tool_table.contains("env") && tool_table["env"].is_table()) { + const auto& env_table = *tool_table["env"].as_table(); + for (const auto& [key, value] : env_table) { + if (value.is_string()) { + config.env_vars[key] = value.as_string()->get(); + } else if (value.is_integer()) { + config.env_vars[key] = value.as_integer()->get(); + } else if (value.is_floating_point()) { + config.env_vars[key] = value.as_floating_point()->get(); + } else if (value.is_boolean()) { + config.env_vars[key] = value.as_boolean()->get(); + } + } + } + } else if (config.type == "sse") { + // Read host and port or url + if (tool_table.contains("url") && tool_table["url"].is_string()) { + config.url = tool_table["url"].as_string()->get(); + } else { + if (!tool_table.contains("host") || !tool_table["host"].is_string()) { + throw std::runtime_error("sse type tool configuration missing host field: " + tool_name); + } + config.host = tool_table["host"].as_string()->get(); + + if (!tool_table.contains("port") || !tool_table["port"].is_integer()) { + throw std::runtime_error("sse type tool configuration missing port field: " + tool_name); + } + config.port = tool_table["port"].as_integer()->get(); + } + } else { + throw std::runtime_error("Unsupported tool type: " + config.type); + } + } catch (const std::exception& e) { + std::cerr << "Failed to load MCP tool configuration: " << e.what() << std::endl; + throw; + } + + return config; +} + // Initialize static members Config* Config::_instance = nullptr; std::mutex Config::_mutex; diff --git a/config.h b/config.h index 2b70dab..ea70141 100644 --- a/config.h +++ b/config.h @@ -1,6 +1,8 @@ #ifndef HUMANUS_CONFIG_H #define HUMANUS_CONFIG_H +#include "schema.h" +#include "prompt.h" #include #include #include @@ -9,9 +11,6 @@ #include #include -#include "schema.h" -#include "prompt.h" - namespace humanus { // Get project root directory @@ -157,6 +156,19 @@ struct ToolParser { } }; +// Read tool configuration from config_mcp.toml +struct MCPToolConfig { + std::string type; + std::string host; + int port; + std::string url; + std::string command; + std::vector args; + json env_vars = json::object(); + + static MCPToolConfig load_from_toml(const std::string& tool_name); +}; + enum class EmbeddingType { ADD = 0, SEARCH = 1, diff --git a/tool/base.h b/tool/base.h index f18db10..c4e29d6 100644 --- a/tool/base.h +++ b/tool/base.h @@ -1,111 +1,15 @@ #ifndef HUMANUS_TOOL_BASE_H #define HUMANUS_TOOL_BASE_H -#include "toml.hpp" #include "schema.h" #include "agent/base.h" -#include "mcp/include/mcp_client.h" -#include "mcp/include/mcp_stdio_client.h" -#include "mcp/include/mcp_sse_client.h" +#include "config.h" +#include "mcp_stdio_client.h" +#include "mcp_sse_client.h" #include namespace humanus { -// Read tool configuration from config_mcp.toml -struct MCPToolConfig { - std::string type; - std::string host; - int port; - std::string url; - std::string command; - std::vector args; - json env_vars = json::object(); - - static MCPToolConfig load_from_toml(const std::string& tool_name) { - MCPToolConfig config; - - try { - // Get config file path - auto config_path = PROJECT_ROOT / "config" / "config_mcp.toml"; - if (!std::filesystem::exists(config_path)) { - throw std::runtime_error("MCP config file not found: " + config_path.string()); - } - - // Parse TOML file - const auto& data = toml::parse_file(config_path.string()); - - // Check if tool config exists - if (!data.contains(tool_name) || !data[tool_name].is_table()) { - throw std::runtime_error("Tool configuration not found in MCP config file: " + tool_name); - } - - const auto& tool_table = *data[tool_name].as_table(); - - // Read type - if (!tool_table.contains("type") || !tool_table["type"].is_string()) { - throw std::runtime_error("Tool configuration missing type field: " + tool_name); - } - config.type = tool_table["type"].as_string()->get(); - - if (config.type == "stdio") { - // Read command - if (!tool_table.contains("command") || !tool_table["command"].is_string()) { - throw std::runtime_error("stdio type tool configuration missing command field: " + tool_name); - } - config.command = tool_table["command"].as_string()->get(); - - // Read arguments (if any) - if (tool_table.contains("args") && tool_table["args"].is_array()) { - const auto& args_array = *tool_table["args"].as_array(); - for (const auto& arg : args_array) { - if (arg.is_string()) { - config.args.push_back(arg.as_string()->get()); - } - } - } - - // Read environment variables - if (tool_table.contains("env") && tool_table["env"].is_table()) { - const auto& env_table = *tool_table["env"].as_table(); - for (const auto& [key, value] : env_table) { - if (value.is_string()) { - config.env_vars[key] = value.as_string()->get(); - } else if (value.is_integer()) { - config.env_vars[key] = value.as_integer()->get(); - } else if (value.is_floating_point()) { - config.env_vars[key] = value.as_floating_point()->get(); - } else if (value.is_boolean()) { - config.env_vars[key] = value.as_boolean()->get(); - } - } - } - } else if (config.type == "sse") { - // Read host and port or url - if (tool_table.contains("url") && tool_table["url"].is_string()) { - config.url = tool_table["url"].as_string()->get(); - } else { - if (!tool_table.contains("host") || !tool_table["host"].is_string()) { - throw std::runtime_error("sse type tool configuration missing host field: " + tool_name); - } - config.host = tool_table["host"].as_string()->get(); - - if (!tool_table.contains("port") || !tool_table["port"].is_integer()) { - throw std::runtime_error("sse type tool configuration missing port field: " + tool_name); - } - config.port = tool_table["port"].as_integer()->get(); - } - } else { - throw std::runtime_error("Unsupported tool type: " + config.type); - } - } catch (const std::exception& e) { - std::cerr << "Failed to load MCP tool configuration: " << e.what() << std::endl; - throw; - } - - return config; - } -}; - // Represents the result of a tool execution. struct ToolResult { json output; @@ -166,21 +70,45 @@ struct ToolError : ToolResult { ToolError(const json& error) : ToolResult({}, error) {} }; -// Execute the tool with given parameters. struct BaseTool { - inline static std::set special_tool_name = {"terminate", "planning", "fact_extract", "memory_update"}; + std::string name; + std::string description; + json parameters; + BaseTool(const std::string& name, const std::string& description, const json& parameters) : + name(name), description(description), parameters(parameters) {} + + // Execute the tool with given parameters. + ToolResult operator()(const json& arguments) { + return execute(arguments); + } + + // Execute the tool with given parameters. + virtual ToolResult execute(const json& arguments) = 0; + + json to_param() const { + return { + {"type", "function"}, + {"function", { + {"name", name}, + {"description", description}, + {"parameters", parameters} + }} + }; + } +}; + +// Execute the tool with given parameters. +struct BaseMCPTool : BaseTool { std::string name; std::string description; json parameters; std::unique_ptr _client; - BaseTool(const std::string& name, const std::string& description, const json& parameters) : - name(name), description(description), parameters(parameters) { - if (special_tool_name.find(name) != special_tool_name.end()) { - return; - } + BaseMCPTool(const std::string& name, const std::string& description, const json& parameters) + : BaseTool(name, description, parameters) { + // Load tool configuration from config file auto _config = MCPToolConfig::load_from_toml(name); @@ -204,14 +132,8 @@ struct BaseTool { _client->initialize(name + "_client", "0.0.1"); } - - // Execute the tool with given parameters. - ToolResult operator()(const json& arguments) { - return execute(arguments); - } - // Execute the tool with given parameters. - virtual ToolResult execute(const json& arguments) { + ToolResult execute(const json& arguments) override { try { if (!_client) { throw std::runtime_error("MCP client not initialized"); @@ -228,20 +150,9 @@ struct BaseTool { return ToolError(e.what()); } } - - json to_param() const { - return { - {"type", "function"}, - {"function", { - {"name", name}, - {"description", description}, - {"parameters", parameters} - }} - }; - } }; -struct AgentAware : ToolResult { +struct AgentAware : BaseTool { std::shared_ptr agent = nullptr; }; diff --git a/tool/filesystem.h b/tool/filesystem.h index 1b02222..84836a2 100644 --- a/tool/filesystem.h +++ b/tool/filesystem.h @@ -6,7 +6,7 @@ namespace humanus { // https://github.com/modelcontextprotocol/servers/tree/HEAD/src/filesystem -struct Filesystem : BaseTool { +struct Filesystem : BaseMCPTool { inline static const std::string name_ = "filesystem"; inline static const std::string description_ = "## Features\n\n- Read/write files\n- Create/list/delete directories\n- Move files/directories\n- Search files\n- Get file metadata"; inline static const json parameters_ = json::parse(R"json({ @@ -105,7 +105,7 @@ struct Filesystem : BaseTool { "list_allowed_directories" }; - Filesystem() : BaseTool(name_, description_, parameters_) {} + Filesystem() : BaseMCPTool(name_, description_, parameters_) {} ToolResult execute(const json& args) override { try { diff --git a/tool/puppeteer.h b/tool/puppeteer.h index 912dec6..9f8df3f 100644 --- a/tool/puppeteer.h +++ b/tool/puppeteer.h @@ -6,7 +6,7 @@ namespace humanus { // https://github.com/modelcontextprotocol/servers/tree/HEAD/src/puppeteer -struct Puppeteer : BaseTool { +struct Puppeteer : BaseMCPTool { inline static const std::string name_ = "puppeteer"; inline static const std::string description_ = "A Model Context Protocol server that provides browser automation capabilities using Puppeteer."; inline static const json parameters_ = json::parse(R"json({ @@ -69,7 +69,7 @@ struct Puppeteer : BaseTool { "evaluate" }; - Puppeteer() : BaseTool(name_, description_, parameters_) {} + Puppeteer() : BaseMCPTool(name_, description_, parameters_) {} ToolResult execute(const json& args) override { try { diff --git a/tool/python_execute.h b/tool/python_execute.h index 25d384c..743bebc 100644 --- a/tool/python_execute.h +++ b/tool/python_execute.h @@ -6,7 +6,7 @@ namespace humanus { -struct PythonExecute : BaseTool { +struct PythonExecute : BaseMCPTool { inline static const std::string name_ = "python_execute"; inline static const std::string description_ = "Executes Python code string. Note: Only print outputs are visible, function return values are not captured. Use print statements to see results."; inline static const json parameters_ = { @@ -25,7 +25,7 @@ struct PythonExecute : BaseTool { {"required", {"code"}} }; - PythonExecute() : BaseTool(name_, description_, parameters_) {} + PythonExecute() : BaseMCPTool(name_, description_, parameters_) {} }; } diff --git a/tool/shell.h b/tool/shell.h index 144510f..792d9b7 100644 --- a/tool/shell.h +++ b/tool/shell.h @@ -11,7 +11,7 @@ namespace humanus { // A tool for executing shell commands using MCP shell server // https://github.com/kevinwatt/shell-mcp.git -struct Shell : BaseTool { +struct Shell : BaseMCPTool { inline static const std::string name_ = "shell"; inline static const std::string description_ = "Execute a shell command in the terminal."; inline static const json parameters_ = json::parse(R"json({ @@ -78,7 +78,7 @@ struct Shell : BaseTool { "whereis" }; - Shell() : BaseTool(name_, description_, parameters_) {} + Shell() : BaseMCPTool(name_, description_, parameters_) {} ToolResult execute(const json& args) override { try { diff --git a/tool/terminate.h b/tool/terminate.h index f1e061d..2275ef6 100644 --- a/tool/terminate.h +++ b/tool/terminate.h @@ -6,7 +6,7 @@ namespace humanus { -struct Terminate : humanus::BaseTool { +struct Terminate : BaseTool { inline static const std::string name_ = "terminate"; inline static const std::string description_ = "Terminate the interaction when the request is met OR if the assistant cannot proceed further with the task."; inline static const humanus::json parameters_ = {