humanus.cpp/tool/base.h

205 lines
6.8 KiB
C
Raw Normal View History

2025-03-16 17:17:01 +08:00
#ifndef HUMANUS_TOOL_BASE_H
#define HUMANUS_TOOL_BASE_H
#include "../schema.h"
#include "../agent/base.h"
#include <string>
namespace humanus {
// Execute the tool with given parameters.
struct BaseTool {
std::string name;
std::string description;
json parameters;
std::unique_ptr<mcp::client> client_;
BaseTool(const std::string& name, const std::string& description, const json& parameters) :
name(name), description(description), parameters(parameters) {
// 从配置文件加载工具配置
_config = MCPToolConfig::load_from_toml(name);
if (_config.type == "stdio") {
std::string command = _config.command;
if (!_config.args.empty()) {
for (const auto& arg : _config.args) {
command += " " + arg;
}
}
_client = std::make_unique<mcp::stdio_client>(command, _config.env_vars);
} else if (_config.type == "sse") {
if (!_config.host.empty() && !_config.port.empty()) {
_client = std::make_unique<mcp::sse_client>(_config.host, _config.port);
} else if (!_config.url.empty()) {
_client = std::make_unique<mcp::sse_client>(_config.url);
} else {
throw std::runtime_error("MCP SSE 配置缺少 host 或 port 或 url");
}
}
client_->initialize(name + "_client", "0.0.1");
}
virtual ToolResult execute(const json& arguments) {
try {
if (!_client) {
throw std::runtime_error("MCP 客户端未初始化");
}
json result = _client->tool_call(name, arguments);
bool is_error = result.value("isError", false);
// 根据是否有错误返回不同的ToolResult
if (is_error) {
return ToolError(result.value("content", json::array()));
} else {
return ToolResult(result.value("content", json::array()));
}
} catch (const std::exception& e) {
return ToolError(e.what());
}
}
json to_param() const {
return {
{"type", "function"},
{"function", {
{"name", name},
{"description", description},
{"parameters", parameters}
}}
};
}
};
// Represents the result of a tool execution.
struct ToolResult {
json output;
json error;
json system;
ToolResult(const json& output, const json& error = {}, const json& system = {})
: output(output), error(error), system(system) {}
bool is_null() const {
return output.is_null() && error.is_null() && system.is_null();
}
ToolResult operator+(const ToolResult& other) const {
auto combined_field = [](const json& field, const json& other_field) {
if (field.is_null()) {
return other_field;
}
if (other_field.is_null()) {
return field;
}
json result = json::array();
if (field.is_array()) {
result.insert(result.end(), field.begin(), field.end());
} else {
result.push_back(field);
}
if (other_field.is_array()) {
result.insert(result.end(), other_field.begin(), other_field.end());
} else {
result.push_back(other_field);
}
return result;
};
return {
combined_field(output, other.output),
combined_field(error, other.error),
combined_field(system, other.system)
};
}
std::string to_string() const {
return !error.is_null() ? "Error: " + error.dump() : output.dump();
}
};
// A ToolResult that can be rendered as a CLI output.
struct CLIResult : ToolResult {
};
// A ToolResult that represents a failure.
struct ToolError : ToolResult {
ToolError(const std::string& error) : ToolResult({}, error) {}
};
struct AgentAware : ToolResult {
std::shared_ptr<BaseAgent> agent = nullptr;
};
// 从config_mcp.toml中读取工具配置
struct MCPToolConfig {
std::string command;
std::vector<std::string> args;
json env_vars = json::object();
static MCPToolConfig load_from_toml(const std::string& tool_name) {
MCPToolConfig config;
try {
// 获取配置文件路径
auto config_path = PROJECT_ROOT / "config" / "config_mcp.toml";
if (!std::filesystem::exists(config_path)) {
throw std::runtime_error("找不到MCP配置文件: " + config_path.string());
}
// 解析TOML文件
auto data = toml::parse_file(config_path.string());
// 检查工具配置是否存在
if (!data.contains(tool_name) || !data[tool_name].is_table()) {
throw std::runtime_error("MCP配置文件中找不到工具配置: " + tool_name);
}
auto& tool_table = data[tool_name].as_table();
// 读取命令
if (tool_table.contains("command") && tool_table["command"].is_string()) {
config.command = tool_table["command"].as_string();
} else {
throw std::runtime_error("工具配置缺少command字段: " + tool_name);
}
// 读取参数(如果有)
if (tool_table.contains("args") && tool_table["args"].is_array()) {
auto& args_array = tool_table["args"].as_array();
for (const auto& arg : args_array) {
if (arg.is_string()) {
config.args.push_back(arg.as_string());
}
}
}
// 读取环境变量
std::string env_section = tool_name + ".env";
if (data.contains(env_section) && data[env_section].is_table()) {
auto& env_table = data[env_section].as_table();
for (const auto& [key, value] : env_table) {
if (value.is_string()) {
config.env_vars[key] = value.as_string();
} else if (value.is_integer()) {
config.env_vars[key] = value.as_integer();
} else if (value.is_floating()) {
config.env_vars[key] = value.as_floating();
} else if (value.is_boolean()) {
config.env_vars[key] = value.as_boolean();
}
}
}
} catch (const std::exception& e) {
std::cerr << "加载MCP工具配置失败: " << e.what() << std::endl;
}
return config;
}
};
}
#endif // HUMANUS_TOOL_BASE_H