cleaned a bit; convert Chinese to English

main
hkr04 2025-04-12 21:29:18 +08:00
parent 6c746359aa
commit d7f0f63149
37 changed files with 969 additions and 660 deletions

25
.gitignore vendored
View File

@ -11,8 +11,6 @@
*.gcda
*.gcno
*.gcov
*.gguf
*.gguf.json
*.lastModified
*.log
*.metallib
@ -33,7 +31,6 @@
.vscode/
nppBackup
# Coverage
gcovr-report/
@ -49,16 +46,12 @@ build*
!build-info.sh
!build.zig
!docs/build.md
/libllama.so
/llama-*
/vulkan-shaders-gen
android-ndk-*
arm_neon.h
cmake-build-*
CMakeSettings.json
compile_commands.json
ggml-metal-embed.metal
llama-batched-swift
/rpc-server
out/
tmp/
@ -74,14 +67,9 @@ models/*
models-mnt
!models/.editorconfig
# Zig
# Logs
# Examples
*.log
# Server Web UI temporary files
node_modules
@ -93,14 +81,3 @@ examples/server/webui/dist
__pycache__/
*/poetry.lock
poetry.toml
# Nix
/result
# Test binaries
# Scripts
# Test models for lora adapters
# Local scripts

View File

@ -43,7 +43,7 @@ Start a MCP server with tool `python_execute` on port 8818:
.\build\bin\Release\mcp_server.exe # Windows
```
Run agent `Humanus` with tools `python_execute`, `filesystem` and `puppeteer` (for browser use):
Run agent `humanus` with tools `python_execute`, `filesystem` and `playwright` (for browser use):
```bash
./build/bin/humanus_cli # Unix/MacOS
@ -53,7 +53,7 @@ Run agent `Humanus` with tools `python_execute`, `filesystem` and `puppeteer` (f
.\build\bin\Release\humanus_cli.exe # Windows
```
Run experimental planning flow (only agent `Humanus` as executor):
Run experimental planning flow (only agent `humanus` as executor):
```bash
./build/bin/humanus_cli_plan # Unix/MacOS
```
@ -65,6 +65,16 @@ Run experimental planning flow (only agent `Humanus` as executor):
## Acknowledgement
<p align="center">
<img src="assets/whu.png" height="150"/>
<img src="assets/myth.png" height="150"/>
<img src="assets/whu.png" height="180"/>
<img src="assets/myth.png" height="180"/>
</p>
## Cite
```
@misc{humanuscpp,
author = {Zihong Zhang and Zuchao Li},
title = {humanus.cpp: A Lightweight C++ Framework for Local LLM Agents},
year = {2025}
}
```

View File

@ -64,7 +64,7 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
llm = LLM::get_instance("default");
}
if (!memory) {
memory = std::make_shared<Memory>(MemoryConfig());
memory = std::make_shared<Memory>(Config::get_memory_config("default"));
}
reset(true);
}

222
agent/humanus.cpp 100644
View File

@ -0,0 +1,222 @@
#include "humanus.h"
namespace humanus {
std::string Humanus::run(const std::string& request) {
memory->current_request = request;
auto tmp_next_step_prompt = next_step_prompt;
size_t pos = next_step_prompt.find("{current_date}");
if (pos != std::string::npos) {
// %Y-%d-%m
auto current_date = std::chrono::system_clock::now();
auto in_time_t = std::chrono::system_clock::to_time_t(current_date);
std::stringstream ss;
std::tm tm_info = *std::localtime(&in_time_t);
ss << std::put_time(&tm_info, "%Y-%m-%d");
std::string formatted_date = ss.str(); // YYYY-MM-DD
next_step_prompt.replace(pos, 14, formatted_date);
}
pos = next_step_prompt.find("{current_request}");
if (pos != std::string::npos) {
next_step_prompt.replace(pos, 17, request);
}
auto result = BaseAgent::run(request);
next_step_prompt = tmp_next_step_prompt; // restore the original prompt
return result;
}
Humanus Humanus::load_from_toml(const toml::table& config_table) {
try {
// Tools
std::vector<std::string> tools;
std::vector<std::string> mcp_servers;
if (config_table.contains("tools")) {
auto tools_table = *config_table["tools"].as_array();
for (const auto& tool : tools_table) {
tools.push_back(tool.as_string()->get());
}
} else {
tools = {"python_execute", "filesystem", "playwright", "image_loader", "content_provider", "terminate"};
}
if (config_table.contains("mcp_servers")) {
auto mcp_servers_table = *config_table["mcp_servers"].as_array();
for (const auto& server : mcp_servers_table) {
mcp_servers.push_back(server.as_string()->get());
}
}
ToolCollection available_tools;
for (const auto& tool : tools) {
auto tool_ptr = ToolFactory::create(tool);
if (tool_ptr) {
available_tools.add_tool(tool_ptr);
} else {
logger->warn("Tool `" + tool + "` not found in tool registry, skipping...");
}
}
for (const auto& mcp_server : mcp_servers) {
available_tools.add_mcp_tools(mcp_server);
}
// General settings
std::string name, description, system_prompt, next_step_prompt;
if (config_table.contains("name")) {
name = config_table["name"].as_string()->get();
} else {
name = "humanus";
}
if (config_table.contains("description")) {
description = config_table["description"].as_string()->get();
} else {
description = "A versatile agent that can solve various tasks using multiple tools";
}
if (config_table.contains("system_prompt")) {
system_prompt = config_table["system_prompt"].as_string()->get();
} else {
system_prompt = prompt::humanus::SYSTEM_PROMPT;
}
if (config_table.contains("next_step_prompt")) {
next_step_prompt = config_table["next_step_prompt"].as_string()->get();
} else {
next_step_prompt = prompt::humanus::NEXT_STEP_PROMPT;
}
// Workflow settings
std::shared_ptr<LLM> llm = nullptr;
if (config_table.contains("llm")) {
llm = LLM::get_instance(config_table["llm"].as_string()->get());
}
std::shared_ptr<Memory> memory = nullptr;
if (config_table.contains("memory")) {
memory = std::make_shared<Memory>(Config::get_memory_config(config_table["memory"].as_string()->get()));
}
int max_steps = 30;
if (config_table.contains("max_steps")) {
max_steps = config_table["max_steps"].as_integer()->get();
}
int duplicate_threshold = 2;
if (config_table.contains("duplicate_threshold")) {
duplicate_threshold = config_table["duplicate_threshold"].as_integer()->get();
}
return Humanus(
available_tools,
name,
description,
system_prompt,
next_step_prompt,
llm,
memory,
max_steps,
duplicate_threshold
);
} catch (const std::exception& e) {
logger->error("Error loading Humanus from TOML: " + std::string(e.what()));
throw;
}
}
Humanus Humanus::load_from_json(const json& config_json) {
try {
// Tools
std::vector<std::string> tools;
std::vector<std::string> mcp_servers;
if (config_json.contains("tools")) {
tools = config_json["tools"].get<std::vector<std::string>>();
}
if (config_json.contains("mcp_servers")) {
mcp_servers = config_json["mcp_servers"].get<std::vector<std::string>>();
}
ToolCollection available_tools;
for (const auto& tool : tools) {
auto tool_ptr = ToolFactory::create(tool);
if (tool_ptr) {
available_tools.add_tool(tool_ptr);
}
}
for (const auto& mcp_server : mcp_servers) {
available_tools.add_mcp_tools(mcp_server);
}
// General settings
std::string name, description, system_prompt, next_step_prompt;
if (config_json.contains("name")) {
name = config_json["name"].get<std::string>();
} else {
name = "humanus";
}
if (config_json.contains("description")) {
description = config_json["description"].get<std::string>();
} else {
description = "A versatile agent that can solve various tasks using multiple tools";
}
if (config_json.contains("system_prompt")) {
system_prompt = config_json["system_prompt"].get<std::string>();
} else {
system_prompt = prompt::humanus::SYSTEM_PROMPT;
}
if (config_json.contains("next_step_prompt")) {
next_step_prompt = config_json["next_step_prompt"].get<std::string>();
} else {
next_step_prompt = prompt::humanus::NEXT_STEP_PROMPT;
}
// Workflow settings
std::shared_ptr<LLM> llm = nullptr;
if (config_json.contains("llm")) {
llm = LLM::get_instance(config_json["llm"].get<std::string>());
}
std::shared_ptr<Memory> memory = nullptr;
if (config_json.contains("memory")) {
memory = std::make_shared<Memory>(Config::get_memory_config(config_json["memory"].get<std::string>()));
}
int max_steps = 30;
if (config_json.contains("max_steps")) {
max_steps = config_json["max_steps"].get<int>();
}
int duplicate_threshold = 2;
if (config_json.contains("duplicate_threshold")) {
duplicate_threshold = config_json["duplicate_threshold"].get<int>();
}
return Humanus(
available_tools,
name,
description,
system_prompt,
next_step_prompt,
llm,
memory,
max_steps,
duplicate_threshold
);
} catch (const std::exception& e) {
logger->error("Error loading Humanus from JSON: " + std::string(e.what()));
throw;
}
}
}

View File

@ -10,6 +10,7 @@
#include "tool/playwright.h"
#include "tool/filesystem.h"
#include "tool/image_loader.h"
namespace humanus {
/**
@ -31,8 +32,6 @@ struct Humanus : ToolCallAgent {
std::make_shared<Terminate>()
}
),
const std::string& tool_choice = "auto",
const std::set<std::string>& special_tool_names = {"terminate"},
const std::string& name = "humanus",
const std::string& description = "A versatile agent that can solve various tasks using multiple tools",
const std::string& system_prompt = prompt::humanus::SYSTEM_PROMPT,
@ -43,8 +42,8 @@ struct Humanus : ToolCallAgent {
int duplicate_threshold = 2
) : ToolCallAgent(
available_tools,
tool_choice,
special_tool_names,
"auto", // tool_choice
{"terminate"}, // special_tool_names
name,
description,
system_prompt,
@ -55,33 +54,11 @@ struct Humanus : ToolCallAgent {
duplicate_threshold
) {}
std::string run(const std::string& request = "") override {
memory->current_request = request;
std::string run(const std::string& request = "") override;
auto tmp_next_step_prompt = next_step_prompt;
static Humanus load_from_toml(const toml::table& config_table);
size_t pos = next_step_prompt.find("{current_date}");
if (pos != std::string::npos) {
// %Y-%d-%m
auto current_date = std::chrono::system_clock::now();
auto in_time_t = std::chrono::system_clock::to_time_t(current_date);
std::stringstream ss;
std::tm tm_info = *std::localtime(&in_time_t);
ss << std::put_time(&tm_info, "%Y-%m-%d");
std::string formatted_date = ss.str(); // YYYY-MM-DD
next_step_prompt.replace(pos, 14, formatted_date);
}
pos = next_step_prompt.find("{current_request}");
if (pos != std::string::npos) {
next_step_prompt.replace(pos, 17, request);
}
auto result = BaseAgent::run(request);
next_step_prompt = tmp_next_step_prompt; // restore the original prompt
return result;
}
static Humanus load_from_json(const json& config_json);
};
}

View File

@ -78,12 +78,9 @@ std::string ToolCallAgent::act() {
std::string result_str;
for (const auto& tool_call : tool_calls) {
if (state != AgentState::RUNNING) {
result_str += "Agent is not running, so no more tool calls will be executed.\n\n";
break;
}
auto result = execute_tool(tool_call);
auto result = state == AgentState::RUNNING ?
execute_tool(tool_call) :
ToolError("Agent is not running, so no more tool calls will be executed.");
logger->info(
"🎯 Tool `" + tool_call.function.name + "` completed its mission! Result: " + result.to_string(500)

View File

@ -0,0 +1,7 @@
[humanus_cli]
llm = "qwen-max-latest"
memory = "long-context"
tools = ["filesystem", "playwright", "image_loader"]
mcp_servers = ["python_execute"]
max_steps = 30
duplicate_threshold = 2

View File

@ -1,4 +1,4 @@
[nomic-embed-text-v1.5]
["nomic-embed-text-v1.5"]
provider = "oai"
base_url = "http://localhost:8080"
endpoint = "/v1/embeddings"
@ -7,7 +7,7 @@ api_key = ""
embeddings_dim = 768
max_retries = 3
[default]
[qwen-text-embedding-v3]
provider = "oai"
base_url = "https://dashscope.aliyuncs.com"
endpoint = "/compatible-mode/v1/embeddings"

View File

@ -1,29 +1,30 @@
[memory]
model = "qwen-max"
[qwen-max-latest]
model = "qwen-max-latest"
base_url = "https://dashscope.aliyuncs.com"
endpoint = "/compatible-mode/v1/chat/completions"
api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600"
[glm-4-plus]
model = "glm-4-plus"
base_url = "https://open.bigmodel.cn"
endpoint = "/api/paas/v4/chat/completions"
api_key = "7e12e1cb8fe5786d83c74d2ef48db511.xPVWzEZt8RvIciW9"
[default]
[qwen-vl-max-latest]
model = "qwen-vl-max-latest"
base_url = "https://dashscope.aliyuncs.com"
endpoint = "/compatible-mode/v1/chat/completions"
api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600"
enable_vision = true
[claude-3.5-sonnet]
["claude-3.5-sonnet"]
model = "anthropic/claude-3.5-sonnet"
base_url = "https://openrouter.ai"
endpoint = "/api/v1/chat/completions"
api_key = "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"
enable_vision = true
["claude-3.7-sonnet"]
model = "anthropic/claude-3.7-sonnet"
base_url = "https://openrouter.ai"
endpoint = "/api/v1/chat/completions"
api_key = "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"
enable_vision = true
[deepseek-chat]
model = "deepseek-chat"
base_url = "https://api.deepseek.com"

View File

@ -1,7 +1,7 @@
[python_execute]
type = "sse"
host = "localhost"
port = 8896
port = 8895
sse_endpoint = "/sse"
[puppeteer]
@ -20,8 +20,3 @@ command = "npx"
args = ["-y",
"@modelcontextprotocol/server-filesystem",
"/Users/hyde/Desktop"]
[shell]
type = "stdio"
command = "npx"
args = ["-y", "@kevinwatt/shell-mcp"]

View File

@ -1,6 +1,19 @@
[memory]
max_messages = 32
[default]
max_messages = 16
max_tokens_message = 32768
max_tokens_messages = 65536
max_tokens_context = 131072
retrieval_limit = 32
embedding_model = "qwen-text-embedding-v3"
vector_store = "hnswlib"
llm = "qwen-max-latest"
[long-context]
max_messages = 32
max_tokens_message = 64000
max_tokens_messages = 128000
max_tokens_context = 128000
retrieval_limit = 32
embedding_model = "qwen-text-embedding-v3"
vector_store = "hnswlib"
llm = "qwen-max-latest"

View File

@ -1,4 +1,4 @@
[default]
[hnswlib]
provider = "hnswlib"
dim = 768 # Dimension of the elements
max_elements = 100 # Maximum number of elements, should be known beforehand

View File

@ -50,14 +50,12 @@ int main() {
#endif
}
auto memory = std::make_shared<Memory>(MemoryConfig());
Chatbot chatbot{
"chatbot", // name
"A chatbot agent that uses memory to remember conversation history", // description
"You are a helpful assistant.", // system_prompt
nullptr, // llm
memory // memory
LLM::get_instance("chatbot"), // llm
std::make_shared<Memory>(Config::get_memory_config("chatbot")) // memory
};
while (true) {

View File

@ -1,7 +1,6 @@
#include "agent/humanus.h"
#include "logger.h"
#include "prompt.h"
#include "flow/flow_factory.h"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
@ -49,7 +48,15 @@ int main() {
#endif
}
Humanus agent = Humanus();
const auto& config_data = toml::parse_file((PROJECT_ROOT / "config" / "config.toml").string());
if (!config_data.contains("humanus_cli")) {
throw std::runtime_error("humanus_cli section not found in config.toml");
}
const auto& config_table = *config_data["humanus_cli"].as_table();
Humanus agent = Humanus::load_from_toml(config_table);
while (true) {
if (agent.current_step == agent.max_steps) {

View File

@ -1,5 +0,0 @@
set(target humanus_cli_mcp)
add_executable(${target} humanus_mcp.cpp)
target_link_libraries(${target} PRIVATE humanus)

View File

@ -1,85 +0,0 @@
#include "agent/mcp.h"
#include "logger.h"
#include "prompt.h"
#include "flow/flow_factory.h"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <signal.h>
#endif
using namespace humanus;
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void sigint_handler(int signo) {
if (signo == SIGINT) {
// make sure all logs are flushed
logger->info("Interrupted by user\n");
logger->flush();
_exit(130);
}
}
#endif
int main(int argc, char* argv[]) {
// ctrl+C handling
{
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = sigint_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
SetConsoleCP(CP_UTF8);
SetConsoleOutputCP(CP_UTF8);
_setmode(_fileno(stdin), _O_WTEXT); // wide character input mode
#endif
}
if (argc <= 1) {
std::cout << "Usage: " << argv[0] << " <mcp_server1> <mcp_server2>..." << std::endl;
return 0;
}
std::vector<std::string> mcp_servers;
for (int i = 1; i < argc; i++) {
mcp_servers.emplace_back(argv[i]);
}
MCPAgent agent = MCPAgent(
mcp_servers
);
while (true) {
if (agent.current_step == agent.max_steps) {
std::cout << "Automatically paused after " << agent.max_steps << " steps." << std::endl;
std::cout << "Enter your prompt (enter an empty line to resume or 'exit' to quit): ";
agent.reset(false);
} else {
std::cout << "Enter your prompt (or 'exit' to quit): ";
}
std::string prompt;
readline_utf8(prompt, false);
if (prompt == "exit") {
logger->info("Goodbye!");
break;
}
logger->info("Processing your request: " + prompt);
agent.run(prompt);
}
}

View File

@ -49,9 +49,18 @@ int main() {
#endif
}
std::shared_ptr<BaseAgent> agent_ptr = std::make_shared<Humanus>();
const auto& config_data = toml::parse_file((PROJECT_ROOT / "config" / "config.toml").string());
if (!config_data.contains("humanus_plan")) {
throw std::runtime_error("humanus_plan section not found in config.toml");
}
const auto& config_table = *config_data["humanus_plan"].as_table();
Humanus agent = Humanus::load_from_toml(config_table);
std::map<std::string, std::shared_ptr<BaseAgent>> agents;
agents["default"] = agent_ptr;
agents["default"] = std::make_shared<Humanus>(agent);
auto flow = FlowFactory::create_flow(
FlowType::PLANNING,
@ -65,13 +74,13 @@ int main() {
);
while (true) {
if (agent_ptr->current_step == agent_ptr->max_steps) {
std::cout << "Automatically paused after " << agent_ptr->current_step << " steps." << std::endl;
if (agent.current_step == agent.max_steps) {
std::cout << "Automatically paused after " << agent.current_step << " steps." << std::endl;
std::cout << "Enter your prompt (enter an empty line to resume or 'exit' to quit): ";
agent_ptr->reset(false);
} else if (agent_ptr->state != AgentState::IDLE) {
agent.reset(false);
} else if (agent.state != AgentState::IDLE) {
std::cout << "Enter your prompt (enter an empty line to retry or 'exit' to quit): ";
agent_ptr->reset(false);
agent.reset(false);
} else {
std::cout << "Enter your prompt (or 'exit' to quit): ";
}
@ -85,6 +94,6 @@ int main() {
logger->info("Processing your request: " + prompt);
auto result = flow->execute(prompt);
logger->info("🌟 " + agent_ptr->name + "'s summary: " + result);
logger->info("🌟 " + agent.name + "'s summary: " + result);
}
}

View File

@ -35,6 +35,11 @@ public:
return agent;
}
void set_agent(const std::string& session_id, const std::shared_ptr<Humanus>& agent) {
std::lock_guard<std::mutex> lock(mutex_);
agents_[session_id] = agent;
}
static std::vector<std::string> get_logs_buffer(const std::string& session_id) {
return session_sink->get_buffer(session_id);
}
@ -107,7 +112,7 @@ int main(int argc, char** argv) {
}
// Create and configure server
mcp::server server("localhost", port, "HumanusServer", "0.0.1");
mcp::server server("localhost", port, "humanus_server", "0.1.0");
// Set server capabilities
mcp::json capabilities = {
@ -117,6 +122,33 @@ int main(int argc, char** argv) {
auto session_manager = std::make_shared<SessionManager>();
auto initialize_tool = mcp::tool_builder("humanus_initialize")
.with_description("Initialize the agent")
.with_string_param("llm", "The LLM configuration to use. Default: default")
.with_string_param("memory", "The memory configuration to use. Default: default")
.with_array_param("tools", "The tools of the agent. Default: filesystem, playwright (for browser use), image_loader, content_provider, terminate", "string")
.with_array_param("mcp_servers", "The MCP servers offering tools for the agent. Default: python_execute", "string")
.with_number_param("max_steps", "The maximum steps of the agent. Default: 30")
.with_number_param("duplicate_threshold", "The duplicate threshold of the agent. Default: 2")
.build();
server.register_tool(initialize_tool, [session_manager](const json& args, const std::string& session_id) -> json {
if (session_manager->has_session(session_id)) {
throw mcp::mcp_exception(mcp::error_code::invalid_request, "Session already initialized");
}
try {
session_manager->set_agent(session_id, std::make_shared<Humanus>(Humanus::load_from_json(args)));
} catch (const std::exception& e) {
throw mcp::mcp_exception(mcp::error_code::invalid_params, "Invalid agent configuration: " + std::string(e.what()));
}
return {{
{"type", "text"},
{"text", "Agent initialized."}
}};
});
auto run_tool = mcp::tool_builder("humanus_run")
.with_description("Request to start a new task. Best to give clear and concise prompts.")
.with_string_param("prompt", "The prompt text to process", true)

View File

@ -11,47 +11,11 @@
#include <fstream>
#include <sstream>
#include <memory>
#include <mutex>
#include <shared_mutex>
namespace humanus {
struct LLMConfig {
std::string model;
std::string api_key;
std::string base_url;
std::string endpoint;
std::string vision_details;
int max_tokens;
int timeout;
double temperature;
bool enable_vision;
bool enable_tool;
LLMConfig(
const std::string& model = "deepseek-chat",
const std::string& api_key = "sk-",
const std::string& base_url = "https://api.deepseek.com",
const std::string& endpoint = "/v1/chat/completions",
const std::string& vision_details = "auto",
int max_tokens = -1, // -1 for default
int timeout = 120,
double temperature = -1, // -1 for default
bool enable_vision = false,
bool enable_tool = true
) : model(model), api_key(api_key), base_url(base_url), endpoint(endpoint), vision_details(vision_details),
max_tokens(max_tokens), timeout(timeout), temperature(temperature), enable_vision(enable_vision), enable_tool(enable_tool) {}
json to_json() const {
json j;
j["model"] = model;
j["api_key"] = api_key;
j["base_url"] = base_url;
j["endpoint"] = endpoint;
j["max_tokens"] = max_tokens;
j["temperature"] = temperature;
return j;
}
};
struct ToolParser {
std::string tool_start;
std::string tool_end;
@ -60,11 +24,6 @@ struct ToolParser {
ToolParser(const std::string& tool_start = "<tool_call>", const std::string& tool_end = "</tool_call>", const std::string& tool_hint_template = prompt::toolcall::TOOL_HINT_TEMPLATE)
: tool_start(tool_start), tool_end(tool_end), tool_hint_template(tool_hint_template) {}
static ToolParser get_instance() {
static ToolParser instance;
return instance;
}
static std::string str_replace(std::string& str, const std::string& from, const std::string& to) {
size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
@ -150,6 +109,38 @@ struct ToolParser {
}
};
struct LLMConfig {
std::string model;
std::string api_key;
std::string base_url;
std::string endpoint;
std::string vision_details;
int max_tokens;
int timeout;
double temperature;
bool enable_vision;
bool enable_tool;
ToolParser tool_parser;
LLMConfig(
const std::string& model = "deepseek-chat",
const std::string& api_key = "sk-",
const std::string& base_url = "https://api.deepseek.com",
const std::string& endpoint = "/v1/chat/completions",
const std::string& vision_details = "auto",
int max_tokens = -1, // -1 for default
int timeout = 120,
double temperature = -1, // -1 for default
bool enable_vision = false,
bool enable_tool = true,
const ToolParser& tool_parser = ToolParser()
) : model(model), api_key(api_key), base_url(base_url), endpoint(endpoint), vision_details(vision_details),
max_tokens(max_tokens), timeout(timeout), temperature(temperature), enable_vision(enable_vision), enable_tool(enable_tool), tool_parser(tool_parser) {}
static LLMConfig load_from_toml(const toml::table& config_table);
};
// Read tool configuration from config_mcp.toml
struct MCPServerConfig {
std::string type;
@ -160,7 +151,7 @@ struct MCPServerConfig {
std::vector<std::string> args;
json env_vars = json::object();
static MCPServerConfig load_from_toml(const toml::table& tool_table);
static MCPServerConfig load_from_toml(const toml::table& config_table);
};
enum class EmbeddingType {
@ -177,6 +168,8 @@ struct EmbeddingModelConfig {
std::string api_key = "";
int embedding_dims = 768;
int max_retries = 3;
static EmbeddingModelConfig load_from_toml(const toml::table& config_table);
};
struct VectorStoreConfig {
@ -191,50 +184,56 @@ struct VectorStoreConfig {
IP
};
Metric metric = Metric::L2;
static VectorStoreConfig load_from_toml(const toml::table& config_table);
};
struct MemoryConfig {
// Base config
int max_messages = 32; // Maximum number of messages in short-term memory
int max_messages = 16; // Maximum number of messages in short-term memory
int max_tokens_message = 1 << 15; // Maximum number of tokens in single message
int max_tokens_messages = 1 << 19; // Maximum number of tokens in short-term memory
int max_tokens_context = 1 << 20; // Maximum number of tokens in short-term memory
int retrieval_limit = 32; // Number of results to retrive from long-term memory
int max_tokens_messages = 1 << 16; // Maximum number of tokens in short-term memory
int max_tokens_context = 1 << 17; // Maximum number of tokens in context (used by `get_messages`)
int retrieval_limit = 32; // Maximum number of results to retrive from long-term memory
// Prompt config
std::string fact_extraction_prompt = prompt::FACT_EXTRACTION_PROMPT;
std::string update_memory_prompt = prompt::UPDATE_MEMORY_PROMPT;
// EmbeddingModel config
std::string embedding_model = "default";
std::shared_ptr<EmbeddingModelConfig> embedding_model_config = nullptr;
// Vector store config
std::string vector_store = "default";
std::shared_ptr<VectorStoreConfig> vector_store_config = nullptr;
FilterFunc filter = nullptr; // Filter to apply to search results
// LLM config
std::string llm = "default";
std::shared_ptr<LLMConfig> llm_config = nullptr;
};
std::string llm_vision = "vision_default";
std::shared_ptr<LLMConfig> llm_vision_config = nullptr;
struct AppConfig {
std::unordered_map<std::string, LLMConfig> llm;
std::unordered_map<std::string, MCPServerConfig> mcp_server;
std::unordered_map<std::string, ToolParser> tool_parser;
std::unordered_map<std::string, EmbeddingModelConfig> embedding_model;
std::unordered_map<std::string, VectorStoreConfig> vector_store;
static MemoryConfig load_from_toml(const toml::table& config_table);
};
class Config {
private:
static Config* _instance;
static std::shared_mutex _config_mutex;
bool _initialized = false;
AppConfig _config;
std::unordered_map<std::string, LLMConfig> llm;
std::unordered_map<std::string, MCPServerConfig> mcp_server;
std::unordered_map<std::string, MemoryConfig> memory;
std::unordered_map<std::string, EmbeddingModelConfig> embedding_model;
std::unordered_map<std::string, VectorStoreConfig> vector_store;
Config() {
_load_initial_llm_config();
_load_initial_mcp_server_config();
_load_initial_embedding_model_config();
_load_initial_vector_store_config();
_load_llm_config();
_load_mcp_server_config();
_load_memory_config();
_load_embedding_model_config();
_load_vector_store_config();
_initialized = true;
}
@ -259,6 +258,15 @@ private:
throw std::runtime_error("MCP Tool Config file not found");
}
static std::filesystem::path _get_memory_config_path() {
auto root = PROJECT_ROOT;
auto config_path = root / "config" / "config_mem.toml";
if (std::filesystem::exists(config_path)) {
return config_path;
}
throw std::runtime_error("Memory Config file not found");
}
static std::filesystem::path _get_embedding_model_config_path() {
auto root = PROJECT_ROOT;
auto config_path = root / "config" / "config_embd.toml";
@ -277,13 +285,15 @@ private:
throw std::runtime_error("Vector Store Config file not found");
}
void _load_initial_llm_config();
void _load_llm_config();
void _load_initial_mcp_server_config();
void _load_mcp_server_config();
void _load_initial_embedding_model_config();
void _load_memory_config();
void _load_initial_vector_store_config();
void _load_embedding_model_config();
void _load_vector_store_config();
public:
/**
@ -291,59 +301,19 @@ public:
* @return The config instance
*/
static Config& get_instance() {
if (_instance == nullptr) {
_instance = new Config();
}
return *_instance;
static Config instance;
return instance;
}
/**
* @brief Get the LLM settings
* @return The LLM settings map
*/
const std::unordered_map<std::string, LLMConfig>& llm() const {
return _config.llm;
}
static LLMConfig get_llm_config(const std::string& config_name);
/**
* @brief Get the MCP tool settings
* @return The MCP tool settings map
*/
const std::unordered_map<std::string, MCPServerConfig>& mcp_server() const {
return _config.mcp_server;
}
static MCPServerConfig get_mcp_server_config(const std::string& config_name);
/**
* @brief Get the tool helpers
* @return The tool helpers map
*/
const std::unordered_map<std::string, ToolParser>& tool_parser() const {
return _config.tool_parser;
}
static MemoryConfig get_memory_config(const std::string& config_name);
/**
* @brief Get the embedding model settings
* @return The embedding model settings map
*/
const std::unordered_map<std::string, EmbeddingModelConfig>& embedding_model() const {
return _config.embedding_model;
}
static EmbeddingModelConfig get_embedding_model_config(const std::string& config_name);
/**
* @brief Get the vector store settings
* @return The vector store settings map
*/
const std::unordered_map<std::string, VectorStoreConfig>& vector_store() const {
return _config.vector_store;
}
/**
* @brief Get the app config
* @return The app config
*/
const AppConfig& get_config() const {
return _config;
}
static VectorStoreConfig get_vector_store_config(const std::string& config_name);
};
} // namespace humanus

View File

@ -23,22 +23,12 @@ private:
std::shared_ptr<LLMConfig> llm_config_;
std::shared_ptr<ToolParser> tool_parser_;
size_t total_prompt_tokens_;
size_t total_completion_tokens_;
public:
// Constructor
LLM(const std::string& config_name, const std::shared_ptr<LLMConfig>& config = nullptr, const std::shared_ptr<ToolParser>& tool_parser = nullptr) : llm_config_(config), tool_parser_(tool_parser) {
if (!llm_config_->enable_tool && !tool_parser_) {
if (Config::get_instance().tool_parser().find(config_name) == Config::get_instance().tool_parser().end()) {
logger->warn("Tool helper config not found: " + config_name + ", falling back to default tool helper config.");
tool_parser_ = std::make_shared<ToolParser>(Config::get_instance().tool_parser().at("default"));
} else {
tool_parser_ = std::make_shared<ToolParser>(Config::get_instance().tool_parser().at(config_name));
}
}
LLM(const std::string& config_name, const std::shared_ptr<LLMConfig>& config = nullptr) : llm_config_(config) {
client_ = std::make_unique<httplib::Client>(llm_config_->base_url);
client_->set_default_headers({
{"Authorization", "Bearer " + llm_config_->api_key}
@ -53,12 +43,7 @@ public:
if (instances_.find(config_name) == instances_.end()) {
auto llm_config_ = llm_config;
if (!llm_config_) {
if (Config::get_instance().llm().find(config_name) == Config::get_instance().llm().end()) {
logger->warn("LLM config not found: " + config_name + ", falling back to default LLM config.");
llm_config_ = std::make_shared<LLMConfig>(Config::get_instance().llm().at("default"));
} else {
llm_config_ = std::make_shared<LLMConfig>(Config::get_instance().llm().at(config_name));
}
llm_config_ = std::make_shared<LLMConfig>(Config::get_llm_config(config_name));
}
instances_[config_name] = std::make_shared<LLM>(config_name, llm_config_);
}

View File

@ -74,6 +74,7 @@ struct Memory : BaseMemory {
std::shared_ptr<EmbeddingModel> embedding_model;
std::shared_ptr<VectorStore> vector_store;
std::shared_ptr<LLM> llm;
std::shared_ptr<LLM> llm_vision;
std::shared_ptr<FactExtract> fact_extract_tool;
@ -104,9 +105,10 @@ struct Memory : BaseMemory {
}
try {
embedding_model = EmbeddingModel::get_instance("default", config.embedding_model_config);
vector_store = VectorStore::get_instance("default", config.vector_store_config);
llm = LLM::get_instance("memory", config.llm_config);
embedding_model = EmbeddingModel::get_instance(config.embedding_model, config.embedding_model_config);
vector_store = VectorStore::get_instance(config.vector_store, config.vector_store_config);
llm = LLM::get_instance(config.llm, config.llm_config);
llm_vision = LLM::get_instance(config.llm_vision, config.llm_vision_config);
logger->info("🔥 Memory is warming up...");
auto test_response = llm->ask(
@ -123,9 +125,14 @@ struct Memory : BaseMemory {
embedding_model = nullptr;
vector_store = nullptr;
llm = nullptr;
llm_vision = nullptr;
retrieval_enabled = false;
}
if (llm_vision && llm_vision->enable_vision() == false) { // Make sure it can handle vision messages
llm_vision = nullptr;
}
fact_extract_tool = std::make_shared<FactExtract>();
}
@ -153,11 +160,11 @@ struct Memory : BaseMemory {
}
}
if (retrieval_enabled && !messages_to_memory.empty()) {
if (llm->enable_vision()) {
if (llm_vision) { // TODO: configure to use multimodal embedding model instead of converting to text
for (auto& m : messages_to_memory) {
m = parse_vision_message(m, llm, llm->vision_details());
m = parse_vision_message(m, llm_vision, llm_vision->vision_details());
}
} else {
} else { // Convert to a padding message indicating that the message is a vision message (but not description)
for (auto& m : messages_to_memory) {
m = parse_vision_message(m);
}

View File

@ -9,12 +9,7 @@ std::shared_ptr<EmbeddingModel> EmbeddingModel::get_instance(const std::string&
if (instances_.find(config_name) == instances_.end()) {
auto config_ = config;
if (!config_) {
if (Config::get_instance().embedding_model().find(config_name) == Config::get_instance().embedding_model().end()) {
logger->warn("Embedding model config not found: " + config_name + ", falling back to default config");
config_ = std::make_shared<EmbeddingModelConfig>(Config::get_instance().embedding_model().at("default"));
} else {
config_ = std::make_shared<EmbeddingModelConfig>(Config::get_instance().embedding_model().at(config_name));
}
config_ = std::make_shared<EmbeddingModelConfig>(Config::get_embedding_model_config(config_name));
}
if (config_->provider == "oai") {

View File

@ -9,12 +9,7 @@ std::shared_ptr<VectorStore> VectorStore::get_instance(const std::string& config
if (instances_.find(config_name) == instances_.end()) {
auto config_ = config;
if (!config_) {
if (Config::get_instance().vector_store().find(config_name) == Config::get_instance().vector_store().end()) {
logger->warn("Vector store config not found: " + config_name + ", falling back to default config");
config_ = std::make_shared<VectorStoreConfig>(Config::get_instance().vector_store().at("default"));
} else {
config_ = std::make_shared<VectorStoreConfig>(Config::get_instance().vector_store().at(config_name));
}
config_ = std::make_shared<VectorStoreConfig>(Config::get_vector_store_config(config_name));
}
if (config_->provider == "hnswlib") {

View File

@ -27,59 +27,59 @@ public:
virtual void reset() = 0;
/**
* @brief
* @param vector
* @param vector_id ID
* @param metadata
* @brief Insert a vector with metadata
* @param vector vector data
* @param vector_id vector ID
* @param metadata metadata
*/
virtual void insert(const std::vector<float>& vector,
const size_t vector_id,
const MemoryItem& metadata = MemoryItem()) = 0;
/**
* @brief
* @param query
* @param limit
* @param filter
* @return MemoryItem
* @brief Search similar vectors
* @param query query vector
* @param limit limit of returned results
* @param filter optional filter (returns true for allowed vectors)
* @return list of similar vectors
*/
virtual std::vector<MemoryItem> search(const std::vector<float>& query,
size_t limit = 5,
const FilterFunc& filter = nullptr) = 0;
/**
* @brief ID
* @param vector_id ID
* @brief Remove a vector by ID
* @param vector_id vector ID
*/
virtual void remove(size_t vector_id) = 0;
/**
* @brief
* @param vector_id ID
* @param vector
* @param metadata
* @brief Update a vector and its metadata
* @param vector_id vector ID
* @param vector new vector data
* @param metadata new metadata
*/
virtual void update(size_t vector_id, const std::vector<float>& vector = std::vector<float>(), const MemoryItem& metadata = MemoryItem()) = 0;
/**
* @brief ID
* @param vector_id ID
* @return
* @brief Get a vector by ID
* @param vector_id vector ID
* @return vector data
*/
virtual MemoryItem get(size_t vector_id) = 0;
/**
* @brief Set metadata for a vector
* @param vector_id Vector ID
* @param metadata New metadata
* @param vector_id vector ID
* @param metadata new metadata
*/
virtual void set(size_t vector_id, const MemoryItem& metadata) = 0;
/**
* @brief
* @param limit
* @param filter isIDAllowed
* @return ID
* @brief List all memories
* @param limit optional limit of returned results
* @param filter optional filter (returns true for allowed memories)
* @return list of memories
*/
virtual std::vector<MemoryItem> list(size_t limit = 0, const FilterFunc& filter = nullptr) = 0;
};

View File

@ -76,7 +76,6 @@ void HNSWLibVectorStore::remove(size_t vector_id) {
}
void HNSWLibVectorStore::update(size_t vector_id, const std::vector<float>& vector, const MemoryItem& metadata) {
// 检查向量是否需要更新
if (!vector.empty()) {
hnsw->markDelete(vector_id);
hnsw->addPoint(vector.data(), vector_id);
@ -134,7 +133,6 @@ std::vector<MemoryItem> HNSWLibVectorStore::list(size_t limit, const FilterFunc&
for (size_t i = 0; i < count; i++) {
if (!hnsw->isMarkedDeleted(i)) {
// 如果有过滤条件,检查元数据是否匹配
auto memory_item = get(i);
if (filter && !filter(memory_item)) {
continue;

View File

@ -10,9 +10,9 @@ namespace humanus {
class HNSWLibVectorStore : public VectorStore {
private:
std::shared_ptr<hnswlib::HierarchicalNSW<float>> hnsw;
std::shared_ptr<hnswlib::SpaceInterface<float>> space; // 保持space对象的引用以确保其生命周期
std::shared_ptr<hnswlib::SpaceInterface<float>> space;
std::unordered_map<size_t, std::list<MemoryItem>::iterator> cache_map; // LRU cache
std::list<MemoryItem> metadata_list; // 存储向量的元数据
std::list<MemoryItem> metadata_list; // Metadata for stored vectors
public:
HNSWLibVectorStore(const std::shared_ptr<VectorStoreConfig>& config) : VectorStore(config) {

View File

@ -2,27 +2,97 @@
namespace humanus {
MCPServerConfig MCPServerConfig::load_from_toml(const toml::table &tool_table) {
LLMConfig LLMConfig::load_from_toml(const toml::table& config_table) {
LLMConfig config;
try {
if (config_table.contains("model")) {
config.model = config_table["model"].as_string()->get();
}
if (config_table.contains("api_key")) {
config.api_key = config_table["api_key"].as_string()->get();
}
if (config_table.contains("base_url")) {
config.base_url = config_table["base_url"].as_string()->get();
}
if (config_table.contains("endpoint")) {
config.endpoint = config_table["endpoint"].as_string()->get();
}
if (config_table.contains("vision_details")) {
config.vision_details = config_table["vision_details"].as_string()->get();
}
if (config_table.contains("max_tokens")) {
config.max_tokens = config_table["max_tokens"].as_integer()->get();
}
if (config_table.contains("timeout")) {
config.timeout = config_table["timeout"].as_integer()->get();
}
if (config_table.contains("temperature")) {
config.temperature = config_table["temperature"].as_floating_point()->get();
}
if (config_table.contains("enable_vision")) {
config.enable_vision = config_table["enable_vision"].as_boolean()->get();
}
if (config_table.contains("enable_tool")) {
config.enable_tool = config_table["enable_tool"].as_boolean()->get();
}
if (!config.enable_tool) {
// Load tool parser configuration
ToolParser tool_parser;
if (config_table.contains("tool_start")) {
tool_parser.tool_start = config_table["tool_start"].as_string()->get();
}
if (config_table.contains("tool_end")) {
tool_parser.tool_end = config_table["tool_end"].as_string()->get();
}
if (config_table.contains("tool_hint_template")) {
tool_parser.tool_hint_template = config_table["tool_hint_template"].as_string()->get();
}
config.tool_parser = tool_parser;
}
return config;
} catch (const std::exception& e) {
logger->error("Failed to load LLM configuration: " + std::string(e.what()));
throw;
}
return config;
}
MCPServerConfig MCPServerConfig::load_from_toml(const toml::table& config_table) {
MCPServerConfig config;
try {
// Read type
if (!tool_table.contains("type") || !tool_table["type"].is_string()) {
if (!config_table.contains("type") || !config_table["type"].is_string()) {
throw std::runtime_error("Tool configuration missing type field, expected sse or stdio.");
}
config.type = tool_table["type"].as_string()->get();
config.type = config_table["type"].as_string()->get();
if (config.type == "stdio") {
// Read command
if (!tool_table.contains("command") || !tool_table["command"].is_string()) {
if (!config_table.contains("command") || !config_table["command"].is_string()) {
throw std::runtime_error("stdio type tool configuration missing command field.");
}
config.command = tool_table["command"].as_string()->get();
config.command = config_table["command"].as_string()->get();
// Read arguments (if any)
if (tool_table.contains("args") && tool_table["args"].is_array()) {
const auto& args_array = *tool_table["args"].as_array();
if (config_table.contains("args")) {
const auto& args_array = *config_table["args"].as_array();
for (const auto& arg : args_array) {
if (arg.is_string()) {
config.args.push_back(arg.as_string()->get());
@ -31,8 +101,8 @@ MCPServerConfig MCPServerConfig::load_from_toml(const toml::table &tool_table) {
}
// Read environment variables
if (tool_table.contains("env") && tool_table["env"].is_table()) {
const auto& env_table = *tool_table["env"].as_table();
if (config_table.contains("env")) {
const auto& env_table = *config_table["env"].as_table();
for (const auto& [key, value] : env_table) {
if (value.is_string()) {
config.env_vars[key] = value.as_string()->get();
@ -47,18 +117,18 @@ MCPServerConfig MCPServerConfig::load_from_toml(const toml::table &tool_table) {
}
} else if (config.type == "sse") {
// Read host and port or url
if (tool_table.contains("url") && tool_table["url"].is_string()) {
config.url = tool_table["url"].as_string()->get();
if (config_table.contains("url")) {
config.url = config_table["url"].as_string()->get();
} else {
if (!tool_table.contains("host") || !tool_table["host"].is_string()) {
if (!config_table.contains("host")) {
throw std::runtime_error("sse type tool configuration missing host field");
}
config.host = tool_table["host"].as_string()->get();
config.host = config_table["host"].as_string()->get();
if (!tool_table.contains("port") || !tool_table["port"].is_integer()) {
if (!config_table.contains("port")) {
throw std::runtime_error("sse type tool configuration missing port field");
}
config.port = tool_table["port"].as_integer()->get();
config.port = config_table["port"].as_integer()->get();
}
} else {
throw std::runtime_error("Unsupported tool type: " + config.type);
@ -71,10 +141,152 @@ MCPServerConfig MCPServerConfig::load_from_toml(const toml::table &tool_table) {
return config;
}
// Initialize static members
Config* Config::_instance = nullptr;
EmbeddingModelConfig EmbeddingModelConfig::load_from_toml(const toml::table& config_table) {
EmbeddingModelConfig config;
void Config::_load_initial_llm_config() {
try {
if (config_table.contains("provider")) {
config.provider = config_table["provider"].as_string()->get();
}
if (config_table.contains("base_url")) {
config.base_url = config_table["base_url"].as_string()->get();
}
if (config_table.contains("endpoint")) {
config.endpoint = config_table["endpoint"].as_string()->get();
}
if (config_table.contains("model")) {
config.model = config_table["model"].as_string()->get();
}
if (config_table.contains("api_key")) {
config.api_key = config_table["api_key"].as_string()->get();
}
if (config_table.contains("embedding_dims")) {
config.embedding_dims = config_table["embedding_dims"].as_integer()->get();
}
if (config_table.contains("max_retries")) {
config.max_retries = config_table["max_retries"].as_integer()->get();
}
} catch (const std::exception& e) {
logger->error("Failed to load embedding model configuration: " + std::string(e.what()));
throw;
}
return config;
}
VectorStoreConfig VectorStoreConfig::load_from_toml(const toml::table& config_table) {
VectorStoreConfig config;
try {
if (config_table.contains("provider")) {
config.provider = config_table["provider"].as_string()->get();
}
if (config_table.contains("dim")) {
config.dim = config_table["dim"].as_integer()->get();
}
if (config_table.contains("max_elements")) {
config.max_elements = config_table["max_elements"].as_integer()->get();
}
if (config_table.contains("M")) {
config.M = config_table["M"].as_integer()->get();
}
if (config_table.contains("ef_construction")) {
config.ef_construction = config_table["ef_construction"].as_integer()->get();
}
if (config_table.contains("metric")) {
const auto& metric_str = config_table["metric"].as_string()->get();
if (metric_str == "L2") {
config.metric = VectorStoreConfig::Metric::L2;
} else if (metric_str == "IP") {
config.metric = VectorStoreConfig::Metric::IP;
} else {
throw std::runtime_error("Invalid metric: " + metric_str);
}
}
} catch (const std::exception& e) {
logger->error("Failed to load vector store configuration: " + std::string(e.what()));
throw;
}
return config;
}
MemoryConfig MemoryConfig::load_from_toml(const toml::table& config_table) {
MemoryConfig config;
try {
// Base config
if (config_table.contains("max_messages")) {
config.max_messages = config_table["max_messages"].as_integer()->get();
}
if (config_table.contains("max_tokens_message")) {
config.max_tokens_message = config_table["max_tokens_message"].as_integer()->get();
}
if (config_table.contains("max_tokens_messages")) {
config.max_tokens_messages = config_table["max_tokens_messages"].as_integer()->get();
}
if (config_table.contains("max_tokens_context")) {
config.max_tokens_context = config_table["max_tokens_context"].as_integer()->get();
}
if (config_table.contains("retrieval_limit")) {
config.retrieval_limit = config_table["retrieval_limit"].as_integer()->get();
}
// Prompt config
if (config_table.contains("fact_extraction_prompt")) {
config.fact_extraction_prompt = config_table["fact_extraction_prompt"].as_string()->get();
}
if (config_table.contains("update_memory_prompt")) {
config.update_memory_prompt = config_table["update_memory_prompt"].as_string()->get();
}
// EmbeddingModel config
if (config_table.contains("embedding_model")) {
config.embedding_model = config_table["embedding_model"].as_string()->get();
}
// Vector store config
if (config_table.contains("vector_store")) {
config.vector_store = config_table["vector_store"].as_string()->get();
}
// LLM config
if (config_table.contains("llm")) {
config.llm = config_table["llm"].as_string()->get();
}
if (config_table.contains("llm_vision")) {
config.llm_vision = config_table["llm_vision"].as_string()->get();
}
} catch (const std::exception& e) {
logger->error("Failed to load memory configuration: " + std::string(e.what()));
throw;
}
return config;
}
// Initialize static members
std::shared_mutex Config::_config_mutex;
void Config::_load_llm_config() {
std::unique_lock<std::shared_mutex> lock(_config_mutex);
try {
auto config_path = _get_llm_config_path();
logger->info("Loading LLM config file from: " + config_path.string());
@ -83,109 +295,66 @@ void Config::_load_initial_llm_config() {
// Load LLM configuration
for (const auto& [key, value] : data) {
const auto& llm_table = *value.as_table();
LLMConfig llm_config;
if (llm_table.contains("model") && llm_table["model"].is_string()) {
llm_config.model = llm_table["model"].as_string()->get();
}
if (llm_table.contains("api_key") && llm_table["api_key"].is_string()) {
llm_config.api_key = llm_table["api_key"].as_string()->get();
}
if (llm_table.contains("base_url") && llm_table["base_url"].is_string()) {
llm_config.base_url = llm_table["base_url"].as_string()->get();
}
if (llm_table.contains("endpoint") && llm_table["endpoint"].is_string()) {
llm_config.endpoint = llm_table["endpoint"].as_string()->get();
}
if (llm_table.contains("vision_details") && llm_table["vision_details"].is_string()) {
llm_config.vision_details = llm_table["vision_details"].as_string()->get();
}
if (llm_table.contains("max_tokens") && llm_table["max_tokens"].is_integer()) {
llm_config.max_tokens = llm_table["max_tokens"].as_integer()->get();
}
if (llm_table.contains("timeout") && llm_table["timeout"].is_integer()) {
llm_config.timeout = llm_table["timeout"].as_integer()->get();
}
if (llm_table.contains("temperature") && llm_table["temperature"].is_floating_point()) {
llm_config.temperature = llm_table["temperature"].as_floating_point()->get();
}
if (llm_table.contains("enable_vision") && llm_table["enable_vision"].is_boolean()) {
llm_config.enable_vision = llm_table["enable_vision"].as_boolean()->get();
}
if (llm_table.contains("enable_tool") && llm_table["enable_tool"].is_boolean()) {
llm_config.enable_tool = llm_table["enable_tool"].as_boolean()->get();
}
_config.llm[std::string(key.str())] = llm_config;
if (!llm_config.enable_tool) {
// Load tool helper configuration
ToolParser tool_parser;
if (llm_table.contains("tool_parser") && llm_table["tool_parser"].is_table()) {
const auto& tool_parser_table = *llm_table["tool_parser"].as_table();
if (tool_parser_table.contains("tool_start")) {
tool_parser.tool_start = tool_parser_table["tool_start"].as_string()->get();
}
if (tool_parser_table.contains("tool_end")) {
tool_parser.tool_end = tool_parser_table["tool_end"].as_string()->get();
}
if (tool_parser_table.contains("tool_hint_template")) {
tool_parser.tool_hint_template = tool_parser_table["tool_hint_template"].as_string()->get();
}
}
_config.tool_parser[std::string(key.str())] = tool_parser;
const auto& config_table = *value.as_table();
logger->info("Loading LLM config: " + std::string(key.str()));
auto config = LLMConfig::load_from_toml(config_table);
llm[std::string(key.str())] = config;
if (config.enable_vision && llm.find("vision_default") == llm.end()) {
llm["vision_default"] = config;
}
}
if (_config.llm.empty()) {
if (llm.empty()) {
throw std::runtime_error("No LLM configuration found");
} else if (_config.llm.find("default") == _config.llm.end()) {
_config.llm["default"] = _config.llm.begin()->second;
}
if (_config.tool_parser.find("default") == _config.tool_parser.end()) {
_config.tool_parser["default"] = ToolParser();
} else if (llm.find("default") == llm.end()) {
llm["default"] = llm.begin()->second;
}
} catch (const std::exception& e) {
logger->warn("Failed to load LLM configuration: " + std::string(e.what()));
// Set default configuration
_config.llm["default"] = LLMConfig();
_config.tool_parser["default"] = ToolParser();
logger->error("Failed to load LLM configuration: " + std::string(e.what()));
throw;
}
}
void Config::_load_initial_mcp_server_config() {
void Config::_load_mcp_server_config() {
std::unique_lock<std::shared_mutex> lock(_config_mutex);
try {
auto config_path = _get_mcp_server_config_path();
logger->info("Loading MCP tool config file from: " + config_path.string());
logger->info("Loading MCP server config file from: " + config_path.string());
const auto& data = toml::parse_file(config_path.string());
// Load MCP tool configuration
// Load MCP server configuration
for (const auto& [key, value] : data) {
const auto& tool_table = *value.as_table();
_config.mcp_server[std::string(key.str())] = MCPServerConfig::load_from_toml(tool_table);
const auto& config_table = *value.as_table();
logger->info("Loading MCP server config: " + std::string(key.str()));
mcp_server[std::string(key.str())] = MCPServerConfig::load_from_toml(config_table);
}
} catch (const std::exception& e) {
logger->warn("Failed to load MCP tool configuration: " + std::string(e.what()));
logger->warn("Failed to load MCP server configuration: " + std::string(e.what()));
}
}
void Config::_load_initial_embedding_model_config() {
void Config::_load_memory_config() {
std::unique_lock<std::shared_mutex> lock(_config_mutex);
try {
auto config_path = _get_memory_config_path();
logger->info("Loading memory config file from: " + config_path.string());
const auto& data = toml::parse_file(config_path.string());
// Load memory configuration
for (const auto& [key, value] : data) {
const auto& config_table = *value.as_table();
logger->info("Loading memory config: " + std::string(key.str()));
memory[std::string(key.str())] = MemoryConfig::load_from_toml(config_table);
}
} catch (const std::exception& e) {
logger->warn("Failed to load memory configuration: " + std::string(e.what()));
}
}
void Config::_load_embedding_model_config() {
std::unique_lock<std::shared_mutex> lock(_config_mutex);
try {
auto config_path = _get_embedding_model_config_path();
logger->info("Loading embedding model config file from: " + config_path.string());
@ -194,54 +363,25 @@ void Config::_load_initial_embedding_model_config() {
// Load embedding model configuration
for (const auto& [key, value] : data) {
const auto& embd_table = *value.as_table();
EmbeddingModelConfig embd_config;
if (embd_table.contains("provider") && embd_table["provider"].is_string()) {
embd_config.provider = embd_table["provider"].as_string()->get();
const auto& config_table = *value.as_table();
logger->info("Loading embedding model config: " + std::string(key.str()));
embedding_model[std::string(key.str())] = EmbeddingModelConfig::load_from_toml(config_table);
}
if (embd_table.contains("base_url") && embd_table["base_url"].is_string()) {
embd_config.base_url = embd_table["base_url"].as_string()->get();
}
if (embd_table.contains("endpoint") && embd_table["endpoint"].is_string()) {
embd_config.endpoint = embd_table["endpoint"].as_string()->get();
}
if (embd_table.contains("model") && embd_table["model"].is_string()) {
embd_config.model = embd_table["model"].as_string()->get();
}
if (embd_table.contains("api_key") && embd_table["api_key"].is_string()) {
embd_config.api_key = embd_table["api_key"].as_string()->get();
}
if (embd_table.contains("embedding_dims") && embd_table["embedding_dims"].is_integer()) {
embd_config.embedding_dims = embd_table["embedding_dims"].as_integer()->get();
}
if (embd_table.contains("max_retries") && embd_table["max_retries"].is_integer()) {
embd_config.max_retries = embd_table["max_retries"].as_integer()->get();
}
_config.embedding_model[std::string(key.str())] = embd_config;
}
if (_config.embedding_model.empty()) {
if (embedding_model.empty()) {
throw std::runtime_error("No embedding model configuration found");
} else if (_config.embedding_model.find("default") == _config.embedding_model.end()) {
_config.embedding_model["default"] = _config.embedding_model.begin()->second;
} else if (embedding_model.find("default") == embedding_model.end()) {
embedding_model["default"] = embedding_model.begin()->second;
}
} catch (const std::exception& e) {
logger->warn("Failed to load embedding model configuration: " + std::string(e.what()));
// Set default configuration
_config.embedding_model["default"] = EmbeddingModelConfig();
embedding_model["default"] = EmbeddingModelConfig();
}
}
void Config::_load_initial_vector_store_config() {
void Config::_load_vector_store_config() {
std::unique_lock<std::shared_mutex> lock(_config_mutex);
try {
auto config_path = _get_vector_store_config_path();
logger->info("Loading vector store config file from: " + config_path.string());
@ -250,53 +390,90 @@ void Config::_load_initial_vector_store_config() {
// Load vector store configuration
for (const auto& [key, value] : data) {
const auto& vs_table = *value.as_table();
VectorStoreConfig vs_config;
if (vs_table.contains("provider") && vs_table["provider"].is_string()) {
vs_config.provider = vs_table["provider"].as_string()->get();
const auto& config_table = *value.as_table();
logger->info("Loading vector store config: " + std::string(key.str()));
vector_store[std::string(key.str())] = VectorStoreConfig::load_from_toml(config_table);
}
if (vs_table.contains("dim") && vs_table["dim"].is_integer()) {
vs_config.dim = vs_table["dim"].as_integer()->get();
}
if (vs_table.contains("max_elements") && vs_table["max_elements"].is_integer()) {
vs_config.max_elements = vs_table["max_elements"].as_integer()->get();
}
if (vs_table.contains("M") && vs_table["M"].is_integer()) {
vs_config.M = vs_table["M"].as_integer()->get();
}
if (vs_table.contains("ef_construction") && vs_table["ef_construction"].is_integer()) {
vs_config.ef_construction = vs_table["ef_construction"].as_integer()->get();
}
if (vs_table.contains("metric") && vs_table["metric"].is_string()) {
const auto& metric_str = vs_table["metric"].as_string()->get();
if (metric_str == "L2") {
vs_config.metric = VectorStoreConfig::Metric::L2;
} else if (metric_str == "IP") {
vs_config.metric = VectorStoreConfig::Metric::IP;
} else {
throw std::runtime_error("Invalid metric: " + metric_str);
}
}
_config.vector_store[std::string(key.str())] = vs_config;
}
if (_config.vector_store.empty()) {
if (vector_store.empty()) {
throw std::runtime_error("No vector store configuration found");
} else if (_config.vector_store.find("default") == _config.vector_store.end()) {
_config.vector_store["default"] = _config.vector_store.begin()->second;
} else if (vector_store.find("default") == vector_store.end()) {
vector_store["default"] = vector_store.begin()->second;
}
} catch (const std::exception& e) {
logger->warn("Failed to load vector store configuration: " + std::string(e.what()));
// Set default configuration
_config.vector_store["default"] = VectorStoreConfig();
vector_store["default"] = VectorStoreConfig();
}
}
LLMConfig Config::get_llm_config(const std::string& config_name) {
auto& instance = get_instance();
std::shared_lock<std::shared_mutex> lock(_config_mutex);
bool need_default = instance.llm.find(config_name) == instance.llm.end();
if (need_default) {
std::string message = "LLM config not found: " + config_name + ", falling back to default LLM config.";
lock.unlock();
logger->warn(message);
lock.lock();
return instance.llm.at("default");
} else {
return instance.llm.at(config_name);
}
}
MCPServerConfig Config::get_mcp_server_config(const std::string& config_name) {
auto& instance = get_instance();
std::shared_lock<std::shared_mutex> lock(_config_mutex);
return instance.mcp_server.at(config_name);
}
MemoryConfig Config::get_memory_config(const std::string& config_name) {
auto& instance = get_instance();
std::shared_lock<std::shared_mutex> lock(_config_mutex);
bool need_default = instance.memory.find(config_name) == instance.memory.end();
if (need_default) {
std::string message = "Memory config not found: " + config_name + ", falling back to default memory config.";
lock.unlock();
logger->warn(message);
lock.lock();
return instance.memory.at("default");
} else {
return instance.memory.at(config_name);
}
}
EmbeddingModelConfig Config::get_embedding_model_config(const std::string& config_name) {
auto& instance = get_instance();
std::shared_lock<std::shared_mutex> lock(_config_mutex);
bool need_default = instance.embedding_model.find(config_name) == instance.embedding_model.end();
if (need_default) {
std::string message = "Embedding model config not found: " + config_name + ", falling back to default embedding model config.";
lock.unlock();
logger->warn(message);
lock.lock();
return instance.embedding_model.at("default");
} else {
return instance.embedding_model.at(config_name);
}
}
VectorStoreConfig Config::get_vector_store_config(const std::string& config_name) {
auto& instance = get_instance();
std::shared_lock<std::shared_mutex> lock(_config_mutex);
bool need_default = instance.vector_store.find(config_name) == instance.vector_store.end();
if (need_default) {
std::string message = "Vector store config not found: " + config_name + ", falling back to default vector store config.";
lock.unlock();
logger->warn(message);
lock.lock();
return instance.vector_store.at("default");
} else {
return instance.vector_store.at(config_name);
}
}

View File

@ -51,7 +51,7 @@ json LLM::format_messages(const std::vector<Message>& messages) {
formatted_messages.back()["role"] = "user";
formatted_messages.back()["content"] = concat_content("Tool result for `" + message.name + "`:\n\n", formatted_messages.back()["content"]);
} else if (!formatted_messages.back()["tool_calls"].empty()) {
std::string tool_calls_str = tool_parser_->dump(formatted_messages.back()["tool_calls"]);
std::string tool_calls_str = llm_config_->tool_parser.dump(formatted_messages.back()["tool_calls"]);
formatted_messages.back().erase("tool_calls");
formatted_messages.back()["content"] = concat_content(formatted_messages.back()["content"], tool_calls_str);
}
@ -257,14 +257,14 @@ json LLM::ask_tool(
if (body["messages"].empty() || body["messages"].back()["role"] != "user") {
body["messages"].push_back({
{"role", "user"},
{"content", tool_parser_->hint(tools.dump(2))}
{"content", llm_config_->tool_parser.hint(tools.dump(2))}
});
} else if (body["messages"].back()["content"].is_string()) {
body["messages"].back()["content"] = body["messages"].back()["content"].get<std::string>() + "\n\n" + tool_parser_->hint(tools.dump(2));
body["messages"].back()["content"] = body["messages"].back()["content"].get<std::string>() + "\n\n" + llm_config_->tool_parser.hint(tools.dump(2));
} else if (body["messages"].back()["content"].is_array()) {
body["messages"].back()["content"].push_back({
{"type", "text"},
{"text", tool_parser_->hint(tools.dump(2))}
{"text", llm_config_->tool_parser.hint(tools.dump(2))}
});
}
}
@ -284,7 +284,7 @@ json LLM::ask_tool(
json json_data = json::parse(res->body);
json message = json_data["choices"][0]["message"];
if (!llm_config_->enable_tool && message["content"].is_string()) {
message = tool_parser_->parse(message["content"].get<std::string>());
message = llm_config_->tool_parser.parse(message["content"].get<std::string>());
}
total_prompt_tokens_ += json_data["usage"]["prompt_tokens"].get<size_t>();
total_completion_tokens_ += json_data["usage"]["completion_tokens"].get<size_t>();

View File

@ -14,6 +14,9 @@ const char* NEXT_STEP_PROMPT = R"(You can interact with the computer using pytho
- playwright: Interact with web pages, take screenshots, generate test code, web scraps the page and execute JavaScript in a real browser environment. Note: Most of the time you need to observer the page before executing other actions.
- image_loader: Get base64 image from file or url.
- content_provider: Save content and retrieve by chunks.
- terminate: Terminate the current task.
Besides, you may get access to other tools, refer to their descriptions and use them if necessary. Some tools are not available in the current context, you should tell by yourself and do not use them.
Remember the following:
- Today's date is {current_date}.

View File

@ -1,11 +1,8 @@
cmake_minimum_required(VERSION 3.10)
project(humanus_tests)
#
add_executable(test_bpe test_bpe.cpp)
#
target_link_libraries(test_bpe PRIVATE humanus)
#
target_include_directories(test_bpe PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)

View File

@ -17,15 +17,14 @@
namespace humanus {
/**
* @brief BPE (Byte Pair Encoding) Tokenizer
* @brief BPE (Byte Pair Encoding) Tokenizer Implementation
*
* BPE (Byte Pair Encoding)
* 使tiktoken
* 使BPE
* Uses tiktoken format vocabulary and merge rules file.
* Uses a priority queue for efficient BPE merging.
*/
class BPETokenizer : public BaseTokenizer {
private:
// 辅助结构用于哈希pair
// Helper structure for hashing pairs
struct PairHash {
template <class T1, class T2>
std::size_t operator() (const std::pair<T1, T2>& pair) const {
@ -33,30 +32,21 @@ private:
}
};
// 字节对到对应token ID的映射
// UTF-8 bytes to token mapping
std::unordered_map<std::string, size_t> encoder;
// token ID到字节对的映射
// Token ID to UTF-8 bytes mapping
std::unordered_map<size_t, std::string> decoder;
// 合并优先级映射,优先级越小越优先合并
// Merge priority mapping, lower rank with higher priority
std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash> merge_ranks;
/**
* @brief base64
* @param encoded base64
* @return
*/
// Used by tiktoken merge ranks
std::string base64_decode(const std::string& encoded) const {
if (encoded.empty()) return "";
return base64::decode(encoded);
}
/**
* @brief
*
* merge_rankstoken
* rank
*/
// Lower rank with higher priority
struct MergeComparator {
const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& ranks;
@ -65,27 +55,27 @@ private:
bool operator()(const std::pair<std::pair<std::string, std::string>, size_t>& a,
const std::pair<std::pair<std::string, std::string>, size_t>& b) const {
// 首先按照merge_ranks比较如果不存在则使用最大值
// First compare by merge_ranks, if not exist then use max value
size_t rank_a = ranks.count(a.first) ? ranks.at(a.first) : std::numeric_limits<size_t>::max();
size_t rank_b = ranks.count(b.first) ? ranks.at(b.first) : std::numeric_limits<size_t>::max();
// 优先队列是最大堆所以我们需要反向比较较小的rank优先级更高
// Max heap, so we need to reverse compare (smaller rank has higher priority)
return rank_a > rank_b;
}
};
public:
/**
* @brief tiktokenBPE tokenizer
* @param tokenizer_path tiktoken
* @brief Construct BPE tokenizer from tiktoken format file
* @param tokenizer_path path to tiktoken format vocabulary file
*
* base64tokentoken ID
* : "IQ== 0""IQ=="base64token0ID
* File format: Each line contains a base64 encoded token and its corresponding token ID
* Example: "IQ== 0", where "IQ==" is the base64 encoded token and 0 is the corresponding ID
*/
BPETokenizer(const std::string& tokenizer_path) {
std::ifstream file(tokenizer_path);
if (!file.is_open()) {
throw std::runtime_error("无法打开tokenizer文件: " + tokenizer_path);
throw std::runtime_error("Failed to open tokenizer file: " + tokenizer_path);
}
std::string line;
@ -97,36 +87,33 @@ public:
if (iss >> token_base64 >> rank) {
std::string token = base64_decode(token_base64);
// 存储token和其ID的映射关系
encoder[token] = rank;
decoder[rank] = token;
}
}
// 构建merge_ranks
build_merge_ranks();
}
/**
* @brief
* @brief Build merge priority mapping
*
* tokens
* 1tokens
*
* Use tokens in vocabulary to infer possible merge rules.
* For tokens longer than 1, try all possible splits, if the split parts are also in the vocabulary,
* then assume this is a valid merge rule.
*/
void build_merge_ranks() {
// 对于tiktoken格式我们可以利用编码中长度>1的token来构建合并规则
for (const auto& [token, id] : encoder) {
if (token.length() <= 1) continue;
// 尝试所有可能的分割点
// Try all possible split points
for (size_t i = 1; i < token.length(); ++i) {
std::string first = token.substr(0, i);
std::string second = token.substr(i);
// 如果两个部分都在词汇表中,假设这是一个有效的合并规则
// If both parts are in the vocabulary, assume this is a valid merge rule
if (encoder.count(first) && encoder.count(second)) {
// 使用ID作为优先级 - 较小的ID表示更高的优先级
// Use ID as priority - smaller ID means higher priority
merge_ranks[{first, second}] = id;
}
}
@ -134,42 +121,42 @@ public:
}
/**
* @brief
* @param ranks
* @brief Set merge priority
* @param ranks new merge priority mapping
*/
void set_merge_ranks(const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& ranks) {
merge_ranks = ranks;
}
/**
* @brief BPE
* @param text
* @return token IDs
* @brief Encode text using BPE
* @param text text to encode
* @return encoded token IDs
*
* 使BPE
* 1.
* 2. 使merge_rankstoken
* 3. tokensIDs
* This method uses BPE algorithm to encode the input text.
* 1. First decompose the text into single byte tokens
* 2. Use priority queue to merge adjacent tokens based on merge_ranks
* 3. Convert the final tokens to corresponding IDs
*/
std::vector<size_t> encode(const std::string& text) const override {
if (text.empty()) {
return {};
}
// 将文本分解为单个字符tokens
// Decompose the text into single character tokens
std::vector<std::string> tokens;
for (unsigned char c : text) {
tokens.push_back(std::string(1, c));
}
// 使用优先队列执行BPE合并
// Use priority queue to execute BPE merging
while (tokens.size() > 1) {
// 构建优先队列,用于选择最高优先级的合并对
// Build priority queue to select the highest priority merge pair
using MergePair = std::pair<std::pair<std::string, std::string>, size_t>;
MergeComparator comparator(merge_ranks);
std::priority_queue<MergePair, std::vector<MergePair>, MergeComparator> merge_candidates(comparator);
// 查找所有可能的合并对
// Find all possible merge pairs
for (size_t i = 0; i < tokens.size() - 1; ++i) {
std::pair<std::string, std::string> pair = {tokens[i], tokens[i+1]};
if (merge_ranks.count(pair)) {
@ -177,23 +164,23 @@ public:
}
}
// 如果没有可合并的对,退出循环
// If there are no merge pairs, exit loop
if (merge_candidates.empty()) {
break;
}
// 执行优先级最高的合并(优先队列中排在最前面的)
// Execute the highest priority merge (the first in the priority queue)
auto top_merge = merge_candidates.top();
auto pair = top_merge.first; // 要合并的token对
size_t pos = top_merge.second; // 要合并的位置
auto pair = top_merge.first; // The token pair to merge
size_t pos = top_merge.second; // The position to merge
// 合并token
// Merge tokens
std::string merged_token = pair.first + pair.second;
tokens[pos] = merged_token;
tokens.erase(tokens.begin() + pos + 1);
}
// 将tokens转换为IDs
// Convert tokens to IDs
std::vector<size_t> ids;
ids.reserve(tokens.size());
for (const auto& token : tokens) {
@ -201,39 +188,39 @@ public:
if (it != encoder.end()) {
ids.push_back(it->second);
}
// 未知token将被跳过
// Unknown tokens will be skipped
}
return ids;
}
/**
* @brief BPE tokens
* @param tokens token IDs
* @return
* @brief Decode BPE tokens
* @param tokens token IDs to decode
* @return decoded text
*
* token IDs
* IDtoken
* This method converts encoded token IDs back to the original text.
* Simply concatenate the token strings corresponding to each ID.
*/
std::string decode(const std::vector<size_t>& tokens) const override {
std::string result;
result.reserve(tokens.size() * 2); // 预估大小,避免频繁重新分配
result.reserve(tokens.size() * 2);
for (size_t id : tokens) {
auto it = decoder.find(id);
if (it != decoder.end()) {
result += it->second;
}
// 未知ID将被跳过
// Unknown IDs will be skipped
}
return result;
}
/**
* @brief tiktokenBPE
* @param file_path tiktoken
* @return BPE tokenizer
* @brief Load merge ranks from tiktoken format file
* @param file_path path to tiktoken format file
* @return shared pointer to created BPE tokenizer
*/
static std::shared_ptr<BPETokenizer> load_from_tiktoken(const std::string& file_path) {
return std::make_shared<BPETokenizer>(file_path);

View File

@ -130,7 +130,7 @@ struct BaseMCPTool : BaseTool {
std::shared_ptr<mcp::client> _client;
try {
// Load tool configuration from config file
auto _config = Config::get_instance().mcp_server().at(server_name);
auto _config = Config::get_mcp_server_config(server_name);
if (_config.type == "stdio") {
std::string command = _config.command;
@ -150,7 +150,7 @@ struct BaseMCPTool : BaseTool {
}
}
_client->initialize(server_name + "_client", "0.0.1");
_client->initialize(server_name + "_client", "0.1.0");
} catch (const std::exception& e) {
throw std::runtime_error("Failed to initialize MCP tool client for `" + server_name + "`: " + std::string(e.what()));
}

View File

@ -67,11 +67,9 @@ struct ContentProvider : BaseTool {
ContentProvider() : BaseTool(name_, description_, parameters_) {}
// 将文本分割成合适大小的块
std::vector<json> split_text_into_chunks(const std::string& text, int max_chunk_size) {
std::vector<json> chunks;
// 如果文本为空,返回空数组
if (text.empty()) {
return chunks;
}
@ -80,23 +78,22 @@ struct ContentProvider : BaseTool {
size_t offset = 0;
while (offset < text_length) {
// 首先确定最大可能的块大小
size_t raw_chunk_size = std::min(static_cast<size_t>(max_chunk_size), text_length - offset);
// 使用 validate_utf8 确保不会截断 UTF-8 字符
// Make sure the chunk is valid UTF-8
std::string potential_chunk = text.substr(offset, raw_chunk_size);
size_t valid_utf8_length = validate_utf8(potential_chunk);
// 调整为有效的 UTF-8 字符边界
// Adjust the chunk size to a valid UTF-8 character boundary
size_t chunk_size = valid_utf8_length;
// 如果不是在文本的结尾,并且我们没有因为 UTF-8 截断而减小块大小,
// 尝试在空格、换行或标点处分割,以获得更自然的分隔点
// If not at the end of the text and we didn't reduce the chunk size due to UTF-8 truncation,
// try to split at a natural break point (space, newline, punctuation)
if (offset + chunk_size < text_length && chunk_size == raw_chunk_size) {
size_t break_pos = offset + chunk_size;
// 向后寻找一个合适的分割点
size_t min_pos = offset + valid_utf8_length / 2; // 不要搜索太远,至少保留一半的有效内容
// Search backward for a natural break point (space, newline, punctuation)
size_t min_pos = offset + valid_utf8_length / 2; // Don't search too far, at least keep half of the valid content
while (break_pos > min_pos &&
text[break_pos] != ' ' &&
text[break_pos] != '\n' &&
@ -109,7 +106,7 @@ struct ContentProvider : BaseTool {
break_pos--;
}
// 如果找到了合适的分割点且不是原始位置
// If a suitable break point is found and it's not the original position
if (break_pos > min_pos) {
break_pos++; // Include the last character
std::string new_chunk = text.substr(offset, break_pos - offset);
@ -130,7 +127,7 @@ struct ContentProvider : BaseTool {
return chunks;
}
// 处理写入操作
// Handle write operation
ToolResult handle_write(const json& args) {
int max_chunk_size = args.value("max_chunk_size", 4000);
@ -142,7 +139,7 @@ struct ContentProvider : BaseTool {
std::string text_content;
// 处理内容,分割大型文本
// Process content, split large text or mixture of text and image
for (const auto& item : args["content"]) {
if (!item.contains("type")) {
return ToolError("Each content item must have a `type` field");
@ -168,7 +165,7 @@ struct ContentProvider : BaseTool {
return ToolError("Image items must have an `image_url` field with a `url` property");
}
// 图像保持为一个整体
// Image remains as a whole
processed_content.push_back(item);
} else {
return ToolError("Unsupported content type: " + type);
@ -181,7 +178,7 @@ struct ContentProvider : BaseTool {
text_content.clear();
}
// 生成一个唯一的存储ID
// Generate a unique store ID
std::string store_id = "content_" + std::to_string(current_id_);
if (content_store_.find(store_id) != content_store_.end()) {
@ -190,18 +187,18 @@ struct ContentProvider : BaseTool {
current_id_ = (current_id_ + 1) % MAX_STORE_ID;
// 存储处理后的内容
// Store processed content
content_store_[store_id] = processed_content;
// 返回存储ID和内容项数
// Return store ID and number of content items
json result;
result["store_id"] = store_id;
result["total_items"] = processed_content.size();
return ToolResult(result);
return ToolResult(result.dump(2));;
}
// 处理读取操作
// Handle read operation
ToolResult handle_read(const json& args) {
if (!args.contains("cursor") || !args["cursor"].is_string()) {
return ToolError("`cursor` is required for read operations");
@ -210,7 +207,7 @@ struct ContentProvider : BaseTool {
std::string cursor = args["cursor"];
if (cursor == "start") {
// 列出所有可用的存储ID
// List all available store IDs
json available_stores = json::array();
for (const auto& [id, content] : content_store_) {
json store_info;
@ -227,12 +224,12 @@ struct ContentProvider : BaseTool {
result["available_stores"] = available_stores;
result["next_cursor"] = "select_store";
return ToolResult(result);
return ToolResult(result.dump(2));
} else if (cursor == "select_store") {
// 用户需要选择一个存储ID
// User needs to select a store ID
return ToolError("Please provide a store_id as cursor in format `content_X:Y`");
} else if (cursor.find(":") != std::string::npos) { // content_X:Y
// 用户正在浏览特定存储内的内容
// User is browsing specific content in a store
size_t delimiter_pos = cursor.find(":");
std::string store_id = cursor.substr(0, delimiter_pos);
size_t index = std::stoul(cursor.substr(delimiter_pos + 1));
@ -245,10 +242,10 @@ struct ContentProvider : BaseTool {
return ToolError("Index out of range");
}
// 返回请求的内容项
// Return the requested content item
json result = content_store_[store_id][index];
// 添加导航信息
// Add navigation information
if (index + 1 < content_store_[store_id].size()) {
result["next_cursor"] = store_id + ":" + std::to_string(index + 1);
result["remaining_items"] = content_store_[store_id].size() - index - 1;
@ -257,7 +254,7 @@ struct ContentProvider : BaseTool {
result["remaining_items"] = 0;
}
return ToolResult(result);
return ToolResult(result.dump(2));
} else if (cursor == "end") {
return ToolResult("You have reached the end of the content.");
} else {

View File

@ -13,7 +13,7 @@ struct ImageLoader : BaseTool {
"properties": {
"url": {
"type": "string",
"description": "The URL of the image to load. Supports HTTP/HTTPS URLs and local file paths. If the URL is a local file path, it must start with file://"
"description": "The URL of the image to load. Supports HTTP/HTTPS URLs and absolute local file paths. If the URL is a local file path, it must start with file://"
}
},
"required": ["url"]

View File

@ -1,5 +1,5 @@
#include "planning.h"
#include <iomanip> // 添加iomanip头文件用于std::setprecision
#include <iomanip>
namespace humanus {

View File

@ -15,11 +15,6 @@ struct PythonExecute : BaseMCPTool {
{"code", {
{"type", "string"},
{"description", "The Python code to execute. Note: Use absolute file paths if code will read/write files."}
}},
{"timeout", {
{"type", "number"},
{"description", "The timeout for the Python code execution in seconds."},
{"default", 5}
}}
}},
{"required", {"code"}}

View File

@ -2,6 +2,13 @@
#define HUMANUS_TOOL_COLLECTION_H
#include "base.h"
#include "image_loader.h"
#include "python_execute.h"
#include "filesystem.h"
#include "playwright.h"
#include "puppeteer.h"
#include "content_provider.h"
#include "terminate.h"
namespace humanus {
@ -9,6 +16,8 @@ struct ToolCollection {
std::vector<std::shared_ptr<BaseTool>> tools;
std::map<std::string, std::shared_ptr<BaseTool>> tools_map;
ToolCollection() = default;
ToolCollection(std::vector<std::shared_ptr<BaseTool>> tools) : tools(tools) {
for (auto tool : tools) {
tools_map[tool->name] = tool;
@ -67,8 +76,47 @@ struct ToolCollection {
add_tool(tool);
}
}
};
std::shared_ptr<BaseTool> get_tool(const std::string& name) const {
auto tool_iter = tools_map.find(name);
if (tool_iter == tools_map.end()) {
return nullptr;
}
return tool_iter->second;
}
std::vector<std::string> get_tool_names() const {
std::vector<std::string> names;
for (const auto& tool : tools) {
names.push_back(tool->name);
}
return names;
}
};
class ToolFactory {
public:
static std::shared_ptr<BaseTool> create(const std::string& name) {
if (name == "python_execute") {
return std::make_shared<PythonExecute>();
} else if (name == "filesystem") {
return std::make_shared<Filesystem>();
} else if (name == "playwright") {
return std::make_shared<Playwright>();
} else if (name == "puppeteer") {
return std::make_shared<Puppeteer>();
} else if (name == "image_loader") {
return std::make_shared<ImageLoader>();
} else if (name == "content_provider") {
return std::make_shared<ContentProvider>();
} else if (name == "terminate") {
return std::make_shared<Terminate>();
} else {
return nullptr;
}
}
};
} // namespace humanus
#endif // HUMANUS_TOOL_COLLECTION_H