cleaned a bit; convert Chinese to English
parent
6c746359aa
commit
d7f0f63149
|
@ -11,8 +11,6 @@
|
|||
*.gcda
|
||||
*.gcno
|
||||
*.gcov
|
||||
*.gguf
|
||||
*.gguf.json
|
||||
*.lastModified
|
||||
*.log
|
||||
*.metallib
|
||||
|
@ -33,7 +31,6 @@
|
|||
.vscode/
|
||||
nppBackup
|
||||
|
||||
|
||||
# Coverage
|
||||
|
||||
gcovr-report/
|
||||
|
@ -49,16 +46,12 @@ build*
|
|||
!build-info.sh
|
||||
!build.zig
|
||||
!docs/build.md
|
||||
/libllama.so
|
||||
/llama-*
|
||||
/vulkan-shaders-gen
|
||||
android-ndk-*
|
||||
arm_neon.h
|
||||
cmake-build-*
|
||||
CMakeSettings.json
|
||||
compile_commands.json
|
||||
ggml-metal-embed.metal
|
||||
llama-batched-swift
|
||||
/rpc-server
|
||||
out/
|
||||
tmp/
|
||||
|
@ -74,14 +67,9 @@ models/*
|
|||
models-mnt
|
||||
!models/.editorconfig
|
||||
|
||||
# Zig
|
||||
|
||||
# Logs
|
||||
|
||||
|
||||
|
||||
# Examples
|
||||
|
||||
*.log
|
||||
|
||||
# Server Web UI temporary files
|
||||
node_modules
|
||||
|
@ -93,14 +81,3 @@ examples/server/webui/dist
|
|||
__pycache__/
|
||||
*/poetry.lock
|
||||
poetry.toml
|
||||
|
||||
# Nix
|
||||
/result
|
||||
|
||||
# Test binaries
|
||||
|
||||
# Scripts
|
||||
|
||||
# Test models for lora adapters
|
||||
|
||||
# Local scripts
|
||||
|
|
18
README.md
18
README.md
|
@ -43,7 +43,7 @@ Start a MCP server with tool `python_execute` on port 8818:
|
|||
.\build\bin\Release\mcp_server.exe # Windows
|
||||
```
|
||||
|
||||
Run agent `Humanus` with tools `python_execute`, `filesystem` and `puppeteer` (for browser use):
|
||||
Run agent `humanus` with tools `python_execute`, `filesystem` and `playwright` (for browser use):
|
||||
|
||||
```bash
|
||||
./build/bin/humanus_cli # Unix/MacOS
|
||||
|
@ -53,7 +53,7 @@ Run agent `Humanus` with tools `python_execute`, `filesystem` and `puppeteer` (f
|
|||
.\build\bin\Release\humanus_cli.exe # Windows
|
||||
```
|
||||
|
||||
Run experimental planning flow (only agent `Humanus` as executor):
|
||||
Run experimental planning flow (only agent `humanus` as executor):
|
||||
```bash
|
||||
./build/bin/humanus_cli_plan # Unix/MacOS
|
||||
```
|
||||
|
@ -65,6 +65,16 @@ Run experimental planning flow (only agent `Humanus` as executor):
|
|||
## Acknowledgement
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/whu.png" height="150"/>
|
||||
<img src="assets/myth.png" height="150"/>
|
||||
<img src="assets/whu.png" height="180"/>
|
||||
<img src="assets/myth.png" height="180"/>
|
||||
</p>
|
||||
|
||||
## Cite
|
||||
|
||||
```
|
||||
@misc{humanuscpp,
|
||||
author = {Zihong Zhang and Zuchao Li},
|
||||
title = {humanus.cpp: A Lightweight C++ Framework for Local LLM Agents},
|
||||
year = {2025}
|
||||
}
|
||||
```
|
|
@ -64,7 +64,7 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
|||
llm = LLM::get_instance("default");
|
||||
}
|
||||
if (!memory) {
|
||||
memory = std::make_shared<Memory>(MemoryConfig());
|
||||
memory = std::make_shared<Memory>(Config::get_memory_config("default"));
|
||||
}
|
||||
reset(true);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,222 @@
|
|||
#include "humanus.h"
|
||||
|
||||
namespace humanus {
|
||||
|
||||
std::string Humanus::run(const std::string& request) {
|
||||
memory->current_request = request;
|
||||
|
||||
auto tmp_next_step_prompt = next_step_prompt;
|
||||
|
||||
size_t pos = next_step_prompt.find("{current_date}");
|
||||
if (pos != std::string::npos) {
|
||||
// %Y-%d-%m
|
||||
auto current_date = std::chrono::system_clock::now();
|
||||
auto in_time_t = std::chrono::system_clock::to_time_t(current_date);
|
||||
std::stringstream ss;
|
||||
std::tm tm_info = *std::localtime(&in_time_t);
|
||||
ss << std::put_time(&tm_info, "%Y-%m-%d");
|
||||
std::string formatted_date = ss.str(); // YYYY-MM-DD
|
||||
next_step_prompt.replace(pos, 14, formatted_date);
|
||||
}
|
||||
|
||||
pos = next_step_prompt.find("{current_request}");
|
||||
if (pos != std::string::npos) {
|
||||
next_step_prompt.replace(pos, 17, request);
|
||||
}
|
||||
|
||||
auto result = BaseAgent::run(request);
|
||||
next_step_prompt = tmp_next_step_prompt; // restore the original prompt
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Humanus Humanus::load_from_toml(const toml::table& config_table) {
|
||||
try {
|
||||
// Tools
|
||||
std::vector<std::string> tools;
|
||||
std::vector<std::string> mcp_servers;
|
||||
|
||||
if (config_table.contains("tools")) {
|
||||
auto tools_table = *config_table["tools"].as_array();
|
||||
for (const auto& tool : tools_table) {
|
||||
tools.push_back(tool.as_string()->get());
|
||||
}
|
||||
} else {
|
||||
tools = {"python_execute", "filesystem", "playwright", "image_loader", "content_provider", "terminate"};
|
||||
}
|
||||
|
||||
if (config_table.contains("mcp_servers")) {
|
||||
auto mcp_servers_table = *config_table["mcp_servers"].as_array();
|
||||
for (const auto& server : mcp_servers_table) {
|
||||
mcp_servers.push_back(server.as_string()->get());
|
||||
}
|
||||
}
|
||||
|
||||
ToolCollection available_tools;
|
||||
|
||||
for (const auto& tool : tools) {
|
||||
auto tool_ptr = ToolFactory::create(tool);
|
||||
if (tool_ptr) {
|
||||
available_tools.add_tool(tool_ptr);
|
||||
} else {
|
||||
logger->warn("Tool `" + tool + "` not found in tool registry, skipping...");
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& mcp_server : mcp_servers) {
|
||||
available_tools.add_mcp_tools(mcp_server);
|
||||
}
|
||||
|
||||
// General settings
|
||||
std::string name, description, system_prompt, next_step_prompt;
|
||||
if (config_table.contains("name")) {
|
||||
name = config_table["name"].as_string()->get();
|
||||
} else {
|
||||
name = "humanus";
|
||||
}
|
||||
if (config_table.contains("description")) {
|
||||
description = config_table["description"].as_string()->get();
|
||||
} else {
|
||||
description = "A versatile agent that can solve various tasks using multiple tools";
|
||||
}
|
||||
if (config_table.contains("system_prompt")) {
|
||||
system_prompt = config_table["system_prompt"].as_string()->get();
|
||||
} else {
|
||||
system_prompt = prompt::humanus::SYSTEM_PROMPT;
|
||||
}
|
||||
if (config_table.contains("next_step_prompt")) {
|
||||
next_step_prompt = config_table["next_step_prompt"].as_string()->get();
|
||||
} else {
|
||||
next_step_prompt = prompt::humanus::NEXT_STEP_PROMPT;
|
||||
}
|
||||
|
||||
// Workflow settings
|
||||
std::shared_ptr<LLM> llm = nullptr;
|
||||
if (config_table.contains("llm")) {
|
||||
llm = LLM::get_instance(config_table["llm"].as_string()->get());
|
||||
}
|
||||
|
||||
std::shared_ptr<Memory> memory = nullptr;
|
||||
if (config_table.contains("memory")) {
|
||||
memory = std::make_shared<Memory>(Config::get_memory_config(config_table["memory"].as_string()->get()));
|
||||
}
|
||||
|
||||
int max_steps = 30;
|
||||
if (config_table.contains("max_steps")) {
|
||||
max_steps = config_table["max_steps"].as_integer()->get();
|
||||
}
|
||||
|
||||
int duplicate_threshold = 2;
|
||||
if (config_table.contains("duplicate_threshold")) {
|
||||
duplicate_threshold = config_table["duplicate_threshold"].as_integer()->get();
|
||||
}
|
||||
|
||||
return Humanus(
|
||||
available_tools,
|
||||
name,
|
||||
description,
|
||||
system_prompt,
|
||||
next_step_prompt,
|
||||
llm,
|
||||
memory,
|
||||
max_steps,
|
||||
duplicate_threshold
|
||||
);
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Error loading Humanus from TOML: " + std::string(e.what()));
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
Humanus Humanus::load_from_json(const json& config_json) {
|
||||
try {
|
||||
// Tools
|
||||
std::vector<std::string> tools;
|
||||
std::vector<std::string> mcp_servers;
|
||||
|
||||
if (config_json.contains("tools")) {
|
||||
tools = config_json["tools"].get<std::vector<std::string>>();
|
||||
}
|
||||
|
||||
if (config_json.contains("mcp_servers")) {
|
||||
mcp_servers = config_json["mcp_servers"].get<std::vector<std::string>>();
|
||||
}
|
||||
|
||||
ToolCollection available_tools;
|
||||
|
||||
for (const auto& tool : tools) {
|
||||
auto tool_ptr = ToolFactory::create(tool);
|
||||
if (tool_ptr) {
|
||||
available_tools.add_tool(tool_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& mcp_server : mcp_servers) {
|
||||
available_tools.add_mcp_tools(mcp_server);
|
||||
}
|
||||
|
||||
// General settings
|
||||
std::string name, description, system_prompt, next_step_prompt;
|
||||
if (config_json.contains("name")) {
|
||||
name = config_json["name"].get<std::string>();
|
||||
} else {
|
||||
name = "humanus";
|
||||
}
|
||||
|
||||
if (config_json.contains("description")) {
|
||||
description = config_json["description"].get<std::string>();
|
||||
} else {
|
||||
description = "A versatile agent that can solve various tasks using multiple tools";
|
||||
}
|
||||
|
||||
if (config_json.contains("system_prompt")) {
|
||||
system_prompt = config_json["system_prompt"].get<std::string>();
|
||||
} else {
|
||||
system_prompt = prompt::humanus::SYSTEM_PROMPT;
|
||||
}
|
||||
|
||||
if (config_json.contains("next_step_prompt")) {
|
||||
next_step_prompt = config_json["next_step_prompt"].get<std::string>();
|
||||
} else {
|
||||
next_step_prompt = prompt::humanus::NEXT_STEP_PROMPT;
|
||||
}
|
||||
|
||||
// Workflow settings
|
||||
std::shared_ptr<LLM> llm = nullptr;
|
||||
if (config_json.contains("llm")) {
|
||||
llm = LLM::get_instance(config_json["llm"].get<std::string>());
|
||||
}
|
||||
|
||||
std::shared_ptr<Memory> memory = nullptr;
|
||||
if (config_json.contains("memory")) {
|
||||
memory = std::make_shared<Memory>(Config::get_memory_config(config_json["memory"].get<std::string>()));
|
||||
}
|
||||
|
||||
int max_steps = 30;
|
||||
if (config_json.contains("max_steps")) {
|
||||
max_steps = config_json["max_steps"].get<int>();
|
||||
}
|
||||
|
||||
int duplicate_threshold = 2;
|
||||
if (config_json.contains("duplicate_threshold")) {
|
||||
duplicate_threshold = config_json["duplicate_threshold"].get<int>();
|
||||
}
|
||||
|
||||
return Humanus(
|
||||
available_tools,
|
||||
name,
|
||||
description,
|
||||
system_prompt,
|
||||
next_step_prompt,
|
||||
llm,
|
||||
memory,
|
||||
max_steps,
|
||||
duplicate_threshold
|
||||
);
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Error loading Humanus from JSON: " + std::string(e.what()));
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -10,6 +10,7 @@
|
|||
#include "tool/playwright.h"
|
||||
#include "tool/filesystem.h"
|
||||
#include "tool/image_loader.h"
|
||||
|
||||
namespace humanus {
|
||||
|
||||
/**
|
||||
|
@ -31,8 +32,6 @@ struct Humanus : ToolCallAgent {
|
|||
std::make_shared<Terminate>()
|
||||
}
|
||||
),
|
||||
const std::string& tool_choice = "auto",
|
||||
const std::set<std::string>& special_tool_names = {"terminate"},
|
||||
const std::string& name = "humanus",
|
||||
const std::string& description = "A versatile agent that can solve various tasks using multiple tools",
|
||||
const std::string& system_prompt = prompt::humanus::SYSTEM_PROMPT,
|
||||
|
@ -43,8 +42,8 @@ struct Humanus : ToolCallAgent {
|
|||
int duplicate_threshold = 2
|
||||
) : ToolCallAgent(
|
||||
available_tools,
|
||||
tool_choice,
|
||||
special_tool_names,
|
||||
"auto", // tool_choice
|
||||
{"terminate"}, // special_tool_names
|
||||
name,
|
||||
description,
|
||||
system_prompt,
|
||||
|
@ -55,33 +54,11 @@ struct Humanus : ToolCallAgent {
|
|||
duplicate_threshold
|
||||
) {}
|
||||
|
||||
std::string run(const std::string& request = "") override {
|
||||
memory->current_request = request;
|
||||
std::string run(const std::string& request = "") override;
|
||||
|
||||
auto tmp_next_step_prompt = next_step_prompt;
|
||||
static Humanus load_from_toml(const toml::table& config_table);
|
||||
|
||||
size_t pos = next_step_prompt.find("{current_date}");
|
||||
if (pos != std::string::npos) {
|
||||
// %Y-%d-%m
|
||||
auto current_date = std::chrono::system_clock::now();
|
||||
auto in_time_t = std::chrono::system_clock::to_time_t(current_date);
|
||||
std::stringstream ss;
|
||||
std::tm tm_info = *std::localtime(&in_time_t);
|
||||
ss << std::put_time(&tm_info, "%Y-%m-%d");
|
||||
std::string formatted_date = ss.str(); // YYYY-MM-DD
|
||||
next_step_prompt.replace(pos, 14, formatted_date);
|
||||
}
|
||||
|
||||
pos = next_step_prompt.find("{current_request}");
|
||||
if (pos != std::string::npos) {
|
||||
next_step_prompt.replace(pos, 17, request);
|
||||
}
|
||||
|
||||
auto result = BaseAgent::run(request);
|
||||
next_step_prompt = tmp_next_step_prompt; // restore the original prompt
|
||||
|
||||
return result;
|
||||
}
|
||||
static Humanus load_from_json(const json& config_json);
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -78,12 +78,9 @@ std::string ToolCallAgent::act() {
|
|||
std::string result_str;
|
||||
|
||||
for (const auto& tool_call : tool_calls) {
|
||||
if (state != AgentState::RUNNING) {
|
||||
result_str += "Agent is not running, so no more tool calls will be executed.\n\n";
|
||||
break;
|
||||
}
|
||||
|
||||
auto result = execute_tool(tool_call);
|
||||
auto result = state == AgentState::RUNNING ?
|
||||
execute_tool(tool_call) :
|
||||
ToolError("Agent is not running, so no more tool calls will be executed.");
|
||||
|
||||
logger->info(
|
||||
"🎯 Tool `" + tool_call.function.name + "` completed its mission! Result: " + result.to_string(500)
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
[humanus_cli]
|
||||
llm = "qwen-max-latest"
|
||||
memory = "long-context"
|
||||
tools = ["filesystem", "playwright", "image_loader"]
|
||||
mcp_servers = ["python_execute"]
|
||||
max_steps = 30
|
||||
duplicate_threshold = 2
|
|
@ -1,4 +1,4 @@
|
|||
[nomic-embed-text-v1.5]
|
||||
["nomic-embed-text-v1.5"]
|
||||
provider = "oai"
|
||||
base_url = "http://localhost:8080"
|
||||
endpoint = "/v1/embeddings"
|
||||
|
@ -7,7 +7,7 @@ api_key = ""
|
|||
embeddings_dim = 768
|
||||
max_retries = 3
|
||||
|
||||
[default]
|
||||
[qwen-text-embedding-v3]
|
||||
provider = "oai"
|
||||
base_url = "https://dashscope.aliyuncs.com"
|
||||
endpoint = "/compatible-mode/v1/embeddings"
|
||||
|
|
|
@ -1,29 +1,30 @@
|
|||
[memory]
|
||||
model = "qwen-max"
|
||||
[qwen-max-latest]
|
||||
model = "qwen-max-latest"
|
||||
base_url = "https://dashscope.aliyuncs.com"
|
||||
endpoint = "/compatible-mode/v1/chat/completions"
|
||||
api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600"
|
||||
|
||||
[glm-4-plus]
|
||||
model = "glm-4-plus"
|
||||
base_url = "https://open.bigmodel.cn"
|
||||
endpoint = "/api/paas/v4/chat/completions"
|
||||
api_key = "7e12e1cb8fe5786d83c74d2ef48db511.xPVWzEZt8RvIciW9"
|
||||
|
||||
[default]
|
||||
[qwen-vl-max-latest]
|
||||
model = "qwen-vl-max-latest"
|
||||
base_url = "https://dashscope.aliyuncs.com"
|
||||
endpoint = "/compatible-mode/v1/chat/completions"
|
||||
api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600"
|
||||
enable_vision = true
|
||||
|
||||
[claude-3.5-sonnet]
|
||||
["claude-3.5-sonnet"]
|
||||
model = "anthropic/claude-3.5-sonnet"
|
||||
base_url = "https://openrouter.ai"
|
||||
endpoint = "/api/v1/chat/completions"
|
||||
api_key = "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"
|
||||
enable_vision = true
|
||||
|
||||
["claude-3.7-sonnet"]
|
||||
model = "anthropic/claude-3.7-sonnet"
|
||||
base_url = "https://openrouter.ai"
|
||||
endpoint = "/api/v1/chat/completions"
|
||||
api_key = "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"
|
||||
enable_vision = true
|
||||
|
||||
[deepseek-chat]
|
||||
model = "deepseek-chat"
|
||||
base_url = "https://api.deepseek.com"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
[python_execute]
|
||||
type = "sse"
|
||||
host = "localhost"
|
||||
port = 8896
|
||||
port = 8895
|
||||
sse_endpoint = "/sse"
|
||||
|
||||
[puppeteer]
|
||||
|
@ -20,8 +20,3 @@ command = "npx"
|
|||
args = ["-y",
|
||||
"@modelcontextprotocol/server-filesystem",
|
||||
"/Users/hyde/Desktop"]
|
||||
|
||||
[shell]
|
||||
type = "stdio"
|
||||
command = "npx"
|
||||
args = ["-y", "@kevinwatt/shell-mcp"]
|
|
@ -1,6 +1,19 @@
|
|||
[memory]
|
||||
max_messages = 32
|
||||
[default]
|
||||
max_messages = 16
|
||||
max_tokens_message = 32768
|
||||
max_tokens_messages = 65536
|
||||
max_tokens_context = 131072
|
||||
retrieval_limit = 32
|
||||
embedding_model = "qwen-text-embedding-v3"
|
||||
vector_store = "hnswlib"
|
||||
llm = "qwen-max-latest"
|
||||
|
||||
[long-context]
|
||||
max_messages = 32
|
||||
max_tokens_message = 64000
|
||||
max_tokens_messages = 128000
|
||||
max_tokens_context = 128000
|
||||
retrieval_limit = 32
|
||||
embedding_model = "qwen-text-embedding-v3"
|
||||
vector_store = "hnswlib"
|
||||
llm = "qwen-max-latest"
|
|
@ -1,4 +1,4 @@
|
|||
[default]
|
||||
[hnswlib]
|
||||
provider = "hnswlib"
|
||||
dim = 768 # Dimension of the elements
|
||||
max_elements = 100 # Maximum number of elements, should be known beforehand
|
||||
|
|
|
@ -50,14 +50,12 @@ int main() {
|
|||
#endif
|
||||
}
|
||||
|
||||
auto memory = std::make_shared<Memory>(MemoryConfig());
|
||||
|
||||
Chatbot chatbot{
|
||||
"chatbot", // name
|
||||
"A chatbot agent that uses memory to remember conversation history", // description
|
||||
"You are a helpful assistant.", // system_prompt
|
||||
nullptr, // llm
|
||||
memory // memory
|
||||
LLM::get_instance("chatbot"), // llm
|
||||
std::make_shared<Memory>(Config::get_memory_config("chatbot")) // memory
|
||||
};
|
||||
|
||||
while (true) {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#include "agent/humanus.h"
|
||||
#include "logger.h"
|
||||
#include "prompt.h"
|
||||
#include "flow/flow_factory.h"
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#include <signal.h>
|
||||
|
@ -49,7 +48,15 @@ int main() {
|
|||
#endif
|
||||
}
|
||||
|
||||
Humanus agent = Humanus();
|
||||
const auto& config_data = toml::parse_file((PROJECT_ROOT / "config" / "config.toml").string());
|
||||
|
||||
if (!config_data.contains("humanus_cli")) {
|
||||
throw std::runtime_error("humanus_cli section not found in config.toml");
|
||||
}
|
||||
|
||||
const auto& config_table = *config_data["humanus_cli"].as_table();
|
||||
|
||||
Humanus agent = Humanus::load_from_toml(config_table);
|
||||
|
||||
while (true) {
|
||||
if (agent.current_step == agent.max_steps) {
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
set(target humanus_cli_mcp)
|
||||
|
||||
add_executable(${target} humanus_mcp.cpp)
|
||||
|
||||
target_link_libraries(${target} PRIVATE humanus)
|
|
@ -1,85 +0,0 @@
|
|||
#include "agent/mcp.h"
|
||||
#include "logger.h"
|
||||
#include "prompt.h"
|
||||
#include "flow/flow_factory.h"
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#include <signal.h>
|
||||
#include <unistd.h>
|
||||
#elif defined (_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <signal.h>
|
||||
#endif
|
||||
|
||||
using namespace humanus;
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
static void sigint_handler(int signo) {
|
||||
if (signo == SIGINT) {
|
||||
// make sure all logs are flushed
|
||||
logger->info("Interrupted by user\n");
|
||||
logger->flush();
|
||||
_exit(130);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
|
||||
// ctrl+C handling
|
||||
{
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = sigint_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
#elif defined (_WIN32)
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
SetConsoleCP(CP_UTF8);
|
||||
SetConsoleOutputCP(CP_UTF8);
|
||||
_setmode(_fileno(stdin), _O_WTEXT); // wide character input mode
|
||||
#endif
|
||||
}
|
||||
|
||||
if (argc <= 1) {
|
||||
std::cout << "Usage: " << argv[0] << " <mcp_server1> <mcp_server2>..." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<std::string> mcp_servers;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
mcp_servers.emplace_back(argv[i]);
|
||||
}
|
||||
|
||||
MCPAgent agent = MCPAgent(
|
||||
mcp_servers
|
||||
);
|
||||
|
||||
while (true) {
|
||||
if (agent.current_step == agent.max_steps) {
|
||||
std::cout << "Automatically paused after " << agent.max_steps << " steps." << std::endl;
|
||||
std::cout << "Enter your prompt (enter an empty line to resume or 'exit' to quit): ";
|
||||
agent.reset(false);
|
||||
} else {
|
||||
std::cout << "Enter your prompt (or 'exit' to quit): ";
|
||||
}
|
||||
|
||||
std::string prompt;
|
||||
readline_utf8(prompt, false);
|
||||
if (prompt == "exit") {
|
||||
logger->info("Goodbye!");
|
||||
break;
|
||||
}
|
||||
|
||||
logger->info("Processing your request: " + prompt);
|
||||
agent.run(prompt);
|
||||
}
|
||||
}
|
|
@ -49,9 +49,18 @@ int main() {
|
|||
#endif
|
||||
}
|
||||
|
||||
std::shared_ptr<BaseAgent> agent_ptr = std::make_shared<Humanus>();
|
||||
const auto& config_data = toml::parse_file((PROJECT_ROOT / "config" / "config.toml").string());
|
||||
|
||||
if (!config_data.contains("humanus_plan")) {
|
||||
throw std::runtime_error("humanus_plan section not found in config.toml");
|
||||
}
|
||||
|
||||
const auto& config_table = *config_data["humanus_plan"].as_table();
|
||||
|
||||
Humanus agent = Humanus::load_from_toml(config_table);
|
||||
|
||||
std::map<std::string, std::shared_ptr<BaseAgent>> agents;
|
||||
agents["default"] = agent_ptr;
|
||||
agents["default"] = std::make_shared<Humanus>(agent);
|
||||
|
||||
auto flow = FlowFactory::create_flow(
|
||||
FlowType::PLANNING,
|
||||
|
@ -65,13 +74,13 @@ int main() {
|
|||
);
|
||||
|
||||
while (true) {
|
||||
if (agent_ptr->current_step == agent_ptr->max_steps) {
|
||||
std::cout << "Automatically paused after " << agent_ptr->current_step << " steps." << std::endl;
|
||||
if (agent.current_step == agent.max_steps) {
|
||||
std::cout << "Automatically paused after " << agent.current_step << " steps." << std::endl;
|
||||
std::cout << "Enter your prompt (enter an empty line to resume or 'exit' to quit): ";
|
||||
agent_ptr->reset(false);
|
||||
} else if (agent_ptr->state != AgentState::IDLE) {
|
||||
agent.reset(false);
|
||||
} else if (agent.state != AgentState::IDLE) {
|
||||
std::cout << "Enter your prompt (enter an empty line to retry or 'exit' to quit): ";
|
||||
agent_ptr->reset(false);
|
||||
agent.reset(false);
|
||||
} else {
|
||||
std::cout << "Enter your prompt (or 'exit' to quit): ";
|
||||
}
|
||||
|
@ -85,6 +94,6 @@ int main() {
|
|||
|
||||
logger->info("Processing your request: " + prompt);
|
||||
auto result = flow->execute(prompt);
|
||||
logger->info("🌟 " + agent_ptr->name + "'s summary: " + result);
|
||||
logger->info("🌟 " + agent.name + "'s summary: " + result);
|
||||
}
|
||||
}
|
|
@ -35,6 +35,11 @@ public:
|
|||
return agent;
|
||||
}
|
||||
|
||||
void set_agent(const std::string& session_id, const std::shared_ptr<Humanus>& agent) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
agents_[session_id] = agent;
|
||||
}
|
||||
|
||||
static std::vector<std::string> get_logs_buffer(const std::string& session_id) {
|
||||
return session_sink->get_buffer(session_id);
|
||||
}
|
||||
|
@ -107,7 +112,7 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
// Create and configure server
|
||||
mcp::server server("localhost", port, "HumanusServer", "0.0.1");
|
||||
mcp::server server("localhost", port, "humanus_server", "0.1.0");
|
||||
|
||||
// Set server capabilities
|
||||
mcp::json capabilities = {
|
||||
|
@ -117,6 +122,33 @@ int main(int argc, char** argv) {
|
|||
|
||||
auto session_manager = std::make_shared<SessionManager>();
|
||||
|
||||
auto initialize_tool = mcp::tool_builder("humanus_initialize")
|
||||
.with_description("Initialize the agent")
|
||||
.with_string_param("llm", "The LLM configuration to use. Default: default")
|
||||
.with_string_param("memory", "The memory configuration to use. Default: default")
|
||||
.with_array_param("tools", "The tools of the agent. Default: filesystem, playwright (for browser use), image_loader, content_provider, terminate", "string")
|
||||
.with_array_param("mcp_servers", "The MCP servers offering tools for the agent. Default: python_execute", "string")
|
||||
.with_number_param("max_steps", "The maximum steps of the agent. Default: 30")
|
||||
.with_number_param("duplicate_threshold", "The duplicate threshold of the agent. Default: 2")
|
||||
.build();
|
||||
|
||||
server.register_tool(initialize_tool, [session_manager](const json& args, const std::string& session_id) -> json {
|
||||
if (session_manager->has_session(session_id)) {
|
||||
throw mcp::mcp_exception(mcp::error_code::invalid_request, "Session already initialized");
|
||||
}
|
||||
|
||||
try {
|
||||
session_manager->set_agent(session_id, std::make_shared<Humanus>(Humanus::load_from_json(args)));
|
||||
} catch (const std::exception& e) {
|
||||
throw mcp::mcp_exception(mcp::error_code::invalid_params, "Invalid agent configuration: " + std::string(e.what()));
|
||||
}
|
||||
|
||||
return {{
|
||||
{"type", "text"},
|
||||
{"text", "Agent initialized."}
|
||||
}};
|
||||
});
|
||||
|
||||
auto run_tool = mcp::tool_builder("humanus_run")
|
||||
.with_description("Request to start a new task. Best to give clear and concise prompts.")
|
||||
.with_string_param("prompt", "The prompt text to process", true)
|
||||
|
|
194
include/config.h
194
include/config.h
|
@ -11,47 +11,11 @@
|
|||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
|
||||
namespace humanus {
|
||||
|
||||
struct LLMConfig {
|
||||
std::string model;
|
||||
std::string api_key;
|
||||
std::string base_url;
|
||||
std::string endpoint;
|
||||
std::string vision_details;
|
||||
int max_tokens;
|
||||
int timeout;
|
||||
double temperature;
|
||||
bool enable_vision;
|
||||
bool enable_tool;
|
||||
|
||||
LLMConfig(
|
||||
const std::string& model = "deepseek-chat",
|
||||
const std::string& api_key = "sk-",
|
||||
const std::string& base_url = "https://api.deepseek.com",
|
||||
const std::string& endpoint = "/v1/chat/completions",
|
||||
const std::string& vision_details = "auto",
|
||||
int max_tokens = -1, // -1 for default
|
||||
int timeout = 120,
|
||||
double temperature = -1, // -1 for default
|
||||
bool enable_vision = false,
|
||||
bool enable_tool = true
|
||||
) : model(model), api_key(api_key), base_url(base_url), endpoint(endpoint), vision_details(vision_details),
|
||||
max_tokens(max_tokens), timeout(timeout), temperature(temperature), enable_vision(enable_vision), enable_tool(enable_tool) {}
|
||||
|
||||
json to_json() const {
|
||||
json j;
|
||||
j["model"] = model;
|
||||
j["api_key"] = api_key;
|
||||
j["base_url"] = base_url;
|
||||
j["endpoint"] = endpoint;
|
||||
j["max_tokens"] = max_tokens;
|
||||
j["temperature"] = temperature;
|
||||
return j;
|
||||
}
|
||||
};
|
||||
|
||||
struct ToolParser {
|
||||
std::string tool_start;
|
||||
std::string tool_end;
|
||||
|
@ -60,11 +24,6 @@ struct ToolParser {
|
|||
ToolParser(const std::string& tool_start = "<tool_call>", const std::string& tool_end = "</tool_call>", const std::string& tool_hint_template = prompt::toolcall::TOOL_HINT_TEMPLATE)
|
||||
: tool_start(tool_start), tool_end(tool_end), tool_hint_template(tool_hint_template) {}
|
||||
|
||||
static ToolParser get_instance() {
|
||||
static ToolParser instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
static std::string str_replace(std::string& str, const std::string& from, const std::string& to) {
|
||||
size_t start_pos = 0;
|
||||
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
|
||||
|
@ -150,6 +109,38 @@ struct ToolParser {
|
|||
}
|
||||
};
|
||||
|
||||
struct LLMConfig {
|
||||
std::string model;
|
||||
std::string api_key;
|
||||
std::string base_url;
|
||||
std::string endpoint;
|
||||
std::string vision_details;
|
||||
int max_tokens;
|
||||
int timeout;
|
||||
double temperature;
|
||||
bool enable_vision;
|
||||
bool enable_tool;
|
||||
|
||||
ToolParser tool_parser;
|
||||
|
||||
LLMConfig(
|
||||
const std::string& model = "deepseek-chat",
|
||||
const std::string& api_key = "sk-",
|
||||
const std::string& base_url = "https://api.deepseek.com",
|
||||
const std::string& endpoint = "/v1/chat/completions",
|
||||
const std::string& vision_details = "auto",
|
||||
int max_tokens = -1, // -1 for default
|
||||
int timeout = 120,
|
||||
double temperature = -1, // -1 for default
|
||||
bool enable_vision = false,
|
||||
bool enable_tool = true,
|
||||
const ToolParser& tool_parser = ToolParser()
|
||||
) : model(model), api_key(api_key), base_url(base_url), endpoint(endpoint), vision_details(vision_details),
|
||||
max_tokens(max_tokens), timeout(timeout), temperature(temperature), enable_vision(enable_vision), enable_tool(enable_tool), tool_parser(tool_parser) {}
|
||||
|
||||
static LLMConfig load_from_toml(const toml::table& config_table);
|
||||
};
|
||||
|
||||
// Read tool configuration from config_mcp.toml
|
||||
struct MCPServerConfig {
|
||||
std::string type;
|
||||
|
@ -160,7 +151,7 @@ struct MCPServerConfig {
|
|||
std::vector<std::string> args;
|
||||
json env_vars = json::object();
|
||||
|
||||
static MCPServerConfig load_from_toml(const toml::table& tool_table);
|
||||
static MCPServerConfig load_from_toml(const toml::table& config_table);
|
||||
};
|
||||
|
||||
enum class EmbeddingType {
|
||||
|
@ -177,6 +168,8 @@ struct EmbeddingModelConfig {
|
|||
std::string api_key = "";
|
||||
int embedding_dims = 768;
|
||||
int max_retries = 3;
|
||||
|
||||
static EmbeddingModelConfig load_from_toml(const toml::table& config_table);
|
||||
};
|
||||
|
||||
struct VectorStoreConfig {
|
||||
|
@ -191,50 +184,56 @@ struct VectorStoreConfig {
|
|||
IP
|
||||
};
|
||||
Metric metric = Metric::L2;
|
||||
|
||||
static VectorStoreConfig load_from_toml(const toml::table& config_table);
|
||||
};
|
||||
|
||||
struct MemoryConfig {
|
||||
// Base config
|
||||
int max_messages = 32; // Maximum number of messages in short-term memory
|
||||
int max_messages = 16; // Maximum number of messages in short-term memory
|
||||
int max_tokens_message = 1 << 15; // Maximum number of tokens in single message
|
||||
int max_tokens_messages = 1 << 19; // Maximum number of tokens in short-term memory
|
||||
int max_tokens_context = 1 << 20; // Maximum number of tokens in short-term memory
|
||||
int retrieval_limit = 32; // Number of results to retrive from long-term memory
|
||||
int max_tokens_messages = 1 << 16; // Maximum number of tokens in short-term memory
|
||||
int max_tokens_context = 1 << 17; // Maximum number of tokens in context (used by `get_messages`)
|
||||
int retrieval_limit = 32; // Maximum number of results to retrive from long-term memory
|
||||
|
||||
// Prompt config
|
||||
std::string fact_extraction_prompt = prompt::FACT_EXTRACTION_PROMPT;
|
||||
std::string update_memory_prompt = prompt::UPDATE_MEMORY_PROMPT;
|
||||
|
||||
// EmbeddingModel config
|
||||
std::string embedding_model = "default";
|
||||
std::shared_ptr<EmbeddingModelConfig> embedding_model_config = nullptr;
|
||||
|
||||
// Vector store config
|
||||
std::string vector_store = "default";
|
||||
std::shared_ptr<VectorStoreConfig> vector_store_config = nullptr;
|
||||
FilterFunc filter = nullptr; // Filter to apply to search results
|
||||
|
||||
// LLM config
|
||||
std::string llm = "default";
|
||||
std::shared_ptr<LLMConfig> llm_config = nullptr;
|
||||
};
|
||||
std::string llm_vision = "vision_default";
|
||||
std::shared_ptr<LLMConfig> llm_vision_config = nullptr;
|
||||
|
||||
struct AppConfig {
|
||||
std::unordered_map<std::string, LLMConfig> llm;
|
||||
std::unordered_map<std::string, MCPServerConfig> mcp_server;
|
||||
std::unordered_map<std::string, ToolParser> tool_parser;
|
||||
std::unordered_map<std::string, EmbeddingModelConfig> embedding_model;
|
||||
std::unordered_map<std::string, VectorStoreConfig> vector_store;
|
||||
static MemoryConfig load_from_toml(const toml::table& config_table);
|
||||
};
|
||||
|
||||
class Config {
|
||||
private:
|
||||
static Config* _instance;
|
||||
static std::shared_mutex _config_mutex;
|
||||
bool _initialized = false;
|
||||
AppConfig _config;
|
||||
std::unordered_map<std::string, LLMConfig> llm;
|
||||
std::unordered_map<std::string, MCPServerConfig> mcp_server;
|
||||
std::unordered_map<std::string, MemoryConfig> memory;
|
||||
std::unordered_map<std::string, EmbeddingModelConfig> embedding_model;
|
||||
std::unordered_map<std::string, VectorStoreConfig> vector_store;
|
||||
|
||||
Config() {
|
||||
_load_initial_llm_config();
|
||||
_load_initial_mcp_server_config();
|
||||
_load_initial_embedding_model_config();
|
||||
_load_initial_vector_store_config();
|
||||
_load_llm_config();
|
||||
_load_mcp_server_config();
|
||||
_load_memory_config();
|
||||
_load_embedding_model_config();
|
||||
_load_vector_store_config();
|
||||
_initialized = true;
|
||||
}
|
||||
|
||||
|
@ -259,6 +258,15 @@ private:
|
|||
throw std::runtime_error("MCP Tool Config file not found");
|
||||
}
|
||||
|
||||
static std::filesystem::path _get_memory_config_path() {
|
||||
auto root = PROJECT_ROOT;
|
||||
auto config_path = root / "config" / "config_mem.toml";
|
||||
if (std::filesystem::exists(config_path)) {
|
||||
return config_path;
|
||||
}
|
||||
throw std::runtime_error("Memory Config file not found");
|
||||
}
|
||||
|
||||
static std::filesystem::path _get_embedding_model_config_path() {
|
||||
auto root = PROJECT_ROOT;
|
||||
auto config_path = root / "config" / "config_embd.toml";
|
||||
|
@ -277,13 +285,15 @@ private:
|
|||
throw std::runtime_error("Vector Store Config file not found");
|
||||
}
|
||||
|
||||
void _load_initial_llm_config();
|
||||
void _load_llm_config();
|
||||
|
||||
void _load_initial_mcp_server_config();
|
||||
void _load_mcp_server_config();
|
||||
|
||||
void _load_initial_embedding_model_config();
|
||||
void _load_memory_config();
|
||||
|
||||
void _load_initial_vector_store_config();
|
||||
void _load_embedding_model_config();
|
||||
|
||||
void _load_vector_store_config();
|
||||
|
||||
public:
|
||||
/**
|
||||
|
@ -291,59 +301,19 @@ public:
|
|||
* @return The config instance
|
||||
*/
|
||||
static Config& get_instance() {
|
||||
if (_instance == nullptr) {
|
||||
_instance = new Config();
|
||||
}
|
||||
return *_instance;
|
||||
static Config instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the LLM settings
|
||||
* @return The LLM settings map
|
||||
*/
|
||||
const std::unordered_map<std::string, LLMConfig>& llm() const {
|
||||
return _config.llm;
|
||||
}
|
||||
static LLMConfig get_llm_config(const std::string& config_name);
|
||||
|
||||
/**
|
||||
* @brief Get the MCP tool settings
|
||||
* @return The MCP tool settings map
|
||||
*/
|
||||
const std::unordered_map<std::string, MCPServerConfig>& mcp_server() const {
|
||||
return _config.mcp_server;
|
||||
}
|
||||
static MCPServerConfig get_mcp_server_config(const std::string& config_name);
|
||||
|
||||
/**
|
||||
* @brief Get the tool helpers
|
||||
* @return The tool helpers map
|
||||
*/
|
||||
const std::unordered_map<std::string, ToolParser>& tool_parser() const {
|
||||
return _config.tool_parser;
|
||||
}
|
||||
static MemoryConfig get_memory_config(const std::string& config_name);
|
||||
|
||||
/**
|
||||
* @brief Get the embedding model settings
|
||||
* @return The embedding model settings map
|
||||
*/
|
||||
const std::unordered_map<std::string, EmbeddingModelConfig>& embedding_model() const {
|
||||
return _config.embedding_model;
|
||||
}
|
||||
static EmbeddingModelConfig get_embedding_model_config(const std::string& config_name);
|
||||
|
||||
/**
|
||||
* @brief Get the vector store settings
|
||||
* @return The vector store settings map
|
||||
*/
|
||||
const std::unordered_map<std::string, VectorStoreConfig>& vector_store() const {
|
||||
return _config.vector_store;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the app config
|
||||
* @return The app config
|
||||
*/
|
||||
const AppConfig& get_config() const {
|
||||
return _config;
|
||||
}
|
||||
static VectorStoreConfig get_vector_store_config(const std::string& config_name);
|
||||
};
|
||||
|
||||
} // namespace humanus
|
||||
|
|
|
@ -23,22 +23,12 @@ private:
|
|||
|
||||
std::shared_ptr<LLMConfig> llm_config_;
|
||||
|
||||
std::shared_ptr<ToolParser> tool_parser_;
|
||||
|
||||
size_t total_prompt_tokens_;
|
||||
size_t total_completion_tokens_;
|
||||
|
||||
public:
|
||||
// Constructor
|
||||
LLM(const std::string& config_name, const std::shared_ptr<LLMConfig>& config = nullptr, const std::shared_ptr<ToolParser>& tool_parser = nullptr) : llm_config_(config), tool_parser_(tool_parser) {
|
||||
if (!llm_config_->enable_tool && !tool_parser_) {
|
||||
if (Config::get_instance().tool_parser().find(config_name) == Config::get_instance().tool_parser().end()) {
|
||||
logger->warn("Tool helper config not found: " + config_name + ", falling back to default tool helper config.");
|
||||
tool_parser_ = std::make_shared<ToolParser>(Config::get_instance().tool_parser().at("default"));
|
||||
} else {
|
||||
tool_parser_ = std::make_shared<ToolParser>(Config::get_instance().tool_parser().at(config_name));
|
||||
}
|
||||
}
|
||||
LLM(const std::string& config_name, const std::shared_ptr<LLMConfig>& config = nullptr) : llm_config_(config) {
|
||||
client_ = std::make_unique<httplib::Client>(llm_config_->base_url);
|
||||
client_->set_default_headers({
|
||||
{"Authorization", "Bearer " + llm_config_->api_key}
|
||||
|
@ -53,12 +43,7 @@ public:
|
|||
if (instances_.find(config_name) == instances_.end()) {
|
||||
auto llm_config_ = llm_config;
|
||||
if (!llm_config_) {
|
||||
if (Config::get_instance().llm().find(config_name) == Config::get_instance().llm().end()) {
|
||||
logger->warn("LLM config not found: " + config_name + ", falling back to default LLM config.");
|
||||
llm_config_ = std::make_shared<LLMConfig>(Config::get_instance().llm().at("default"));
|
||||
} else {
|
||||
llm_config_ = std::make_shared<LLMConfig>(Config::get_instance().llm().at(config_name));
|
||||
}
|
||||
llm_config_ = std::make_shared<LLMConfig>(Config::get_llm_config(config_name));
|
||||
}
|
||||
instances_[config_name] = std::make_shared<LLM>(config_name, llm_config_);
|
||||
}
|
||||
|
|
|
@ -74,6 +74,7 @@ struct Memory : BaseMemory {
|
|||
std::shared_ptr<EmbeddingModel> embedding_model;
|
||||
std::shared_ptr<VectorStore> vector_store;
|
||||
std::shared_ptr<LLM> llm;
|
||||
std::shared_ptr<LLM> llm_vision;
|
||||
|
||||
std::shared_ptr<FactExtract> fact_extract_tool;
|
||||
|
||||
|
@ -104,9 +105,10 @@ struct Memory : BaseMemory {
|
|||
}
|
||||
|
||||
try {
|
||||
embedding_model = EmbeddingModel::get_instance("default", config.embedding_model_config);
|
||||
vector_store = VectorStore::get_instance("default", config.vector_store_config);
|
||||
llm = LLM::get_instance("memory", config.llm_config);
|
||||
embedding_model = EmbeddingModel::get_instance(config.embedding_model, config.embedding_model_config);
|
||||
vector_store = VectorStore::get_instance(config.vector_store, config.vector_store_config);
|
||||
llm = LLM::get_instance(config.llm, config.llm_config);
|
||||
llm_vision = LLM::get_instance(config.llm_vision, config.llm_vision_config);
|
||||
|
||||
logger->info("🔥 Memory is warming up...");
|
||||
auto test_response = llm->ask(
|
||||
|
@ -123,9 +125,14 @@ struct Memory : BaseMemory {
|
|||
embedding_model = nullptr;
|
||||
vector_store = nullptr;
|
||||
llm = nullptr;
|
||||
llm_vision = nullptr;
|
||||
retrieval_enabled = false;
|
||||
}
|
||||
|
||||
if (llm_vision && llm_vision->enable_vision() == false) { // Make sure it can handle vision messages
|
||||
llm_vision = nullptr;
|
||||
}
|
||||
|
||||
fact_extract_tool = std::make_shared<FactExtract>();
|
||||
}
|
||||
|
||||
|
@ -153,11 +160,11 @@ struct Memory : BaseMemory {
|
|||
}
|
||||
}
|
||||
if (retrieval_enabled && !messages_to_memory.empty()) {
|
||||
if (llm->enable_vision()) {
|
||||
if (llm_vision) { // TODO: configure to use multimodal embedding model instead of converting to text
|
||||
for (auto& m : messages_to_memory) {
|
||||
m = parse_vision_message(m, llm, llm->vision_details());
|
||||
m = parse_vision_message(m, llm_vision, llm_vision->vision_details());
|
||||
}
|
||||
} else {
|
||||
} else { // Convert to a padding message indicating that the message is a vision message (but not description)
|
||||
for (auto& m : messages_to_memory) {
|
||||
m = parse_vision_message(m);
|
||||
}
|
||||
|
|
|
@ -9,12 +9,7 @@ std::shared_ptr<EmbeddingModel> EmbeddingModel::get_instance(const std::string&
|
|||
if (instances_.find(config_name) == instances_.end()) {
|
||||
auto config_ = config;
|
||||
if (!config_) {
|
||||
if (Config::get_instance().embedding_model().find(config_name) == Config::get_instance().embedding_model().end()) {
|
||||
logger->warn("Embedding model config not found: " + config_name + ", falling back to default config");
|
||||
config_ = std::make_shared<EmbeddingModelConfig>(Config::get_instance().embedding_model().at("default"));
|
||||
} else {
|
||||
config_ = std::make_shared<EmbeddingModelConfig>(Config::get_instance().embedding_model().at(config_name));
|
||||
}
|
||||
config_ = std::make_shared<EmbeddingModelConfig>(Config::get_embedding_model_config(config_name));
|
||||
}
|
||||
|
||||
if (config_->provider == "oai") {
|
||||
|
|
|
@ -9,12 +9,7 @@ std::shared_ptr<VectorStore> VectorStore::get_instance(const std::string& config
|
|||
if (instances_.find(config_name) == instances_.end()) {
|
||||
auto config_ = config;
|
||||
if (!config_) {
|
||||
if (Config::get_instance().vector_store().find(config_name) == Config::get_instance().vector_store().end()) {
|
||||
logger->warn("Vector store config not found: " + config_name + ", falling back to default config");
|
||||
config_ = std::make_shared<VectorStoreConfig>(Config::get_instance().vector_store().at("default"));
|
||||
} else {
|
||||
config_ = std::make_shared<VectorStoreConfig>(Config::get_instance().vector_store().at(config_name));
|
||||
}
|
||||
config_ = std::make_shared<VectorStoreConfig>(Config::get_vector_store_config(config_name));
|
||||
}
|
||||
|
||||
if (config_->provider == "hnswlib") {
|
||||
|
|
|
@ -27,59 +27,59 @@ public:
|
|||
virtual void reset() = 0;
|
||||
|
||||
/**
|
||||
* @brief 插入向量到集合中
|
||||
* @param vector 向量数据
|
||||
* @param vector_id 向量ID
|
||||
* @param metadata 元数据
|
||||
* @brief Insert a vector with metadata
|
||||
* @param vector vector data
|
||||
* @param vector_id vector ID
|
||||
* @param metadata metadata
|
||||
*/
|
||||
virtual void insert(const std::vector<float>& vector,
|
||||
const size_t vector_id,
|
||||
const MemoryItem& metadata = MemoryItem()) = 0;
|
||||
|
||||
/**
|
||||
* @brief 搜索相似向量
|
||||
* @param query 查询向量
|
||||
* @param limit 返回结果数量限制
|
||||
* @param filter 可选的过滤条件
|
||||
* @return 相似向量的MemoryItem列表
|
||||
* @brief Search similar vectors
|
||||
* @param query query vector
|
||||
* @param limit limit of returned results
|
||||
* @param filter optional filter (returns true for allowed vectors)
|
||||
* @return list of similar vectors
|
||||
*/
|
||||
virtual std::vector<MemoryItem> search(const std::vector<float>& query,
|
||||
size_t limit = 5,
|
||||
const FilterFunc& filter = nullptr) = 0;
|
||||
|
||||
/**
|
||||
* @brief 通过ID删除向量
|
||||
* @param vector_id 向量ID
|
||||
* @brief Remove a vector by ID
|
||||
* @param vector_id vector ID
|
||||
*/
|
||||
virtual void remove(size_t vector_id) = 0;
|
||||
|
||||
/**
|
||||
* @brief 更新向量及其负载
|
||||
* @param vector_id 向量ID
|
||||
* @param vector 新向量数据
|
||||
* @param metadata 新负载数据
|
||||
* @brief Update a vector and its metadata
|
||||
* @param vector_id vector ID
|
||||
* @param vector new vector data
|
||||
* @param metadata new metadata
|
||||
*/
|
||||
virtual void update(size_t vector_id, const std::vector<float>& vector = std::vector<float>(), const MemoryItem& metadata = MemoryItem()) = 0;
|
||||
|
||||
/**
|
||||
* @brief 通过ID获取向量
|
||||
* @param vector_id 向量ID
|
||||
* @return 向量数据
|
||||
* @brief Get a vector by ID
|
||||
* @param vector_id vector ID
|
||||
* @return vector data
|
||||
*/
|
||||
virtual MemoryItem get(size_t vector_id) = 0;
|
||||
|
||||
/**
|
||||
* @brief Set metadata for a vector
|
||||
* @param vector_id Vector ID
|
||||
* @param metadata New metadata
|
||||
* @param vector_id vector ID
|
||||
* @param metadata new metadata
|
||||
*/
|
||||
virtual void set(size_t vector_id, const MemoryItem& metadata) = 0;
|
||||
|
||||
/**
|
||||
* @brief 列出所有记忆
|
||||
* @param limit 可选的结果数量限制
|
||||
* @param filter 可选的过滤条件(isIDAllowed)
|
||||
* @return 记忆ID列表
|
||||
* @brief List all memories
|
||||
* @param limit optional limit of returned results
|
||||
* @param filter optional filter (returns true for allowed memories)
|
||||
* @return list of memories
|
||||
*/
|
||||
virtual std::vector<MemoryItem> list(size_t limit = 0, const FilterFunc& filter = nullptr) = 0;
|
||||
};
|
||||
|
|
|
@ -76,7 +76,6 @@ void HNSWLibVectorStore::remove(size_t vector_id) {
|
|||
}
|
||||
|
||||
void HNSWLibVectorStore::update(size_t vector_id, const std::vector<float>& vector, const MemoryItem& metadata) {
|
||||
// 检查向量是否需要更新
|
||||
if (!vector.empty()) {
|
||||
hnsw->markDelete(vector_id);
|
||||
hnsw->addPoint(vector.data(), vector_id);
|
||||
|
@ -134,7 +133,6 @@ std::vector<MemoryItem> HNSWLibVectorStore::list(size_t limit, const FilterFunc&
|
|||
|
||||
for (size_t i = 0; i < count; i++) {
|
||||
if (!hnsw->isMarkedDeleted(i)) {
|
||||
// 如果有过滤条件,检查元数据是否匹配
|
||||
auto memory_item = get(i);
|
||||
if (filter && !filter(memory_item)) {
|
||||
continue;
|
||||
|
|
|
@ -10,9 +10,9 @@ namespace humanus {
|
|||
class HNSWLibVectorStore : public VectorStore {
|
||||
private:
|
||||
std::shared_ptr<hnswlib::HierarchicalNSW<float>> hnsw;
|
||||
std::shared_ptr<hnswlib::SpaceInterface<float>> space; // 保持space对象的引用以确保其生命周期
|
||||
std::shared_ptr<hnswlib::SpaceInterface<float>> space;
|
||||
std::unordered_map<size_t, std::list<MemoryItem>::iterator> cache_map; // LRU cache
|
||||
std::list<MemoryItem> metadata_list; // 存储向量的元数据
|
||||
std::list<MemoryItem> metadata_list; // Metadata for stored vectors
|
||||
|
||||
public:
|
||||
HNSWLibVectorStore(const std::shared_ptr<VectorStoreConfig>& config) : VectorStore(config) {
|
||||
|
|
535
src/config.cpp
535
src/config.cpp
|
@ -2,27 +2,97 @@
|
|||
|
||||
namespace humanus {
|
||||
|
||||
MCPServerConfig MCPServerConfig::load_from_toml(const toml::table &tool_table) {
|
||||
LLMConfig LLMConfig::load_from_toml(const toml::table& config_table) {
|
||||
LLMConfig config;
|
||||
|
||||
try {
|
||||
if (config_table.contains("model")) {
|
||||
config.model = config_table["model"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("api_key")) {
|
||||
config.api_key = config_table["api_key"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("base_url")) {
|
||||
config.base_url = config_table["base_url"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("endpoint")) {
|
||||
config.endpoint = config_table["endpoint"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("vision_details")) {
|
||||
config.vision_details = config_table["vision_details"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("max_tokens")) {
|
||||
config.max_tokens = config_table["max_tokens"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("timeout")) {
|
||||
config.timeout = config_table["timeout"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("temperature")) {
|
||||
config.temperature = config_table["temperature"].as_floating_point()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("enable_vision")) {
|
||||
config.enable_vision = config_table["enable_vision"].as_boolean()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("enable_tool")) {
|
||||
config.enable_tool = config_table["enable_tool"].as_boolean()->get();
|
||||
}
|
||||
|
||||
if (!config.enable_tool) {
|
||||
// Load tool parser configuration
|
||||
ToolParser tool_parser;
|
||||
if (config_table.contains("tool_start")) {
|
||||
tool_parser.tool_start = config_table["tool_start"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("tool_end")) {
|
||||
tool_parser.tool_end = config_table["tool_end"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("tool_hint_template")) {
|
||||
tool_parser.tool_hint_template = config_table["tool_hint_template"].as_string()->get();
|
||||
}
|
||||
config.tool_parser = tool_parser;
|
||||
}
|
||||
|
||||
return config;
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Failed to load LLM configuration: " + std::string(e.what()));
|
||||
throw;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
MCPServerConfig MCPServerConfig::load_from_toml(const toml::table& config_table) {
|
||||
MCPServerConfig config;
|
||||
|
||||
try {
|
||||
// Read type
|
||||
if (!tool_table.contains("type") || !tool_table["type"].is_string()) {
|
||||
if (!config_table.contains("type") || !config_table["type"].is_string()) {
|
||||
throw std::runtime_error("Tool configuration missing type field, expected sse or stdio.");
|
||||
}
|
||||
|
||||
config.type = tool_table["type"].as_string()->get();
|
||||
config.type = config_table["type"].as_string()->get();
|
||||
|
||||
if (config.type == "stdio") {
|
||||
// Read command
|
||||
if (!tool_table.contains("command") || !tool_table["command"].is_string()) {
|
||||
if (!config_table.contains("command") || !config_table["command"].is_string()) {
|
||||
throw std::runtime_error("stdio type tool configuration missing command field.");
|
||||
}
|
||||
config.command = tool_table["command"].as_string()->get();
|
||||
config.command = config_table["command"].as_string()->get();
|
||||
|
||||
// Read arguments (if any)
|
||||
if (tool_table.contains("args") && tool_table["args"].is_array()) {
|
||||
const auto& args_array = *tool_table["args"].as_array();
|
||||
if (config_table.contains("args")) {
|
||||
const auto& args_array = *config_table["args"].as_array();
|
||||
for (const auto& arg : args_array) {
|
||||
if (arg.is_string()) {
|
||||
config.args.push_back(arg.as_string()->get());
|
||||
|
@ -31,8 +101,8 @@ MCPServerConfig MCPServerConfig::load_from_toml(const toml::table &tool_table) {
|
|||
}
|
||||
|
||||
// Read environment variables
|
||||
if (tool_table.contains("env") && tool_table["env"].is_table()) {
|
||||
const auto& env_table = *tool_table["env"].as_table();
|
||||
if (config_table.contains("env")) {
|
||||
const auto& env_table = *config_table["env"].as_table();
|
||||
for (const auto& [key, value] : env_table) {
|
||||
if (value.is_string()) {
|
||||
config.env_vars[key] = value.as_string()->get();
|
||||
|
@ -47,18 +117,18 @@ MCPServerConfig MCPServerConfig::load_from_toml(const toml::table &tool_table) {
|
|||
}
|
||||
} else if (config.type == "sse") {
|
||||
// Read host and port or url
|
||||
if (tool_table.contains("url") && tool_table["url"].is_string()) {
|
||||
config.url = tool_table["url"].as_string()->get();
|
||||
if (config_table.contains("url")) {
|
||||
config.url = config_table["url"].as_string()->get();
|
||||
} else {
|
||||
if (!tool_table.contains("host") || !tool_table["host"].is_string()) {
|
||||
if (!config_table.contains("host")) {
|
||||
throw std::runtime_error("sse type tool configuration missing host field");
|
||||
}
|
||||
config.host = tool_table["host"].as_string()->get();
|
||||
config.host = config_table["host"].as_string()->get();
|
||||
|
||||
if (!tool_table.contains("port") || !tool_table["port"].is_integer()) {
|
||||
if (!config_table.contains("port")) {
|
||||
throw std::runtime_error("sse type tool configuration missing port field");
|
||||
}
|
||||
config.port = tool_table["port"].as_integer()->get();
|
||||
config.port = config_table["port"].as_integer()->get();
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported tool type: " + config.type);
|
||||
|
@ -71,10 +141,152 @@ MCPServerConfig MCPServerConfig::load_from_toml(const toml::table &tool_table) {
|
|||
return config;
|
||||
}
|
||||
|
||||
// Initialize static members
|
||||
Config* Config::_instance = nullptr;
|
||||
EmbeddingModelConfig EmbeddingModelConfig::load_from_toml(const toml::table& config_table) {
|
||||
EmbeddingModelConfig config;
|
||||
|
||||
void Config::_load_initial_llm_config() {
|
||||
try {
|
||||
if (config_table.contains("provider")) {
|
||||
config.provider = config_table["provider"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("base_url")) {
|
||||
config.base_url = config_table["base_url"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("endpoint")) {
|
||||
config.endpoint = config_table["endpoint"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("model")) {
|
||||
config.model = config_table["model"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("api_key")) {
|
||||
config.api_key = config_table["api_key"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("embedding_dims")) {
|
||||
config.embedding_dims = config_table["embedding_dims"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("max_retries")) {
|
||||
config.max_retries = config_table["max_retries"].as_integer()->get();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Failed to load embedding model configuration: " + std::string(e.what()));
|
||||
throw;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
VectorStoreConfig VectorStoreConfig::load_from_toml(const toml::table& config_table) {
|
||||
VectorStoreConfig config;
|
||||
|
||||
try {
|
||||
if (config_table.contains("provider")) {
|
||||
config.provider = config_table["provider"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("dim")) {
|
||||
config.dim = config_table["dim"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("max_elements")) {
|
||||
config.max_elements = config_table["max_elements"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("M")) {
|
||||
config.M = config_table["M"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("ef_construction")) {
|
||||
config.ef_construction = config_table["ef_construction"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("metric")) {
|
||||
const auto& metric_str = config_table["metric"].as_string()->get();
|
||||
if (metric_str == "L2") {
|
||||
config.metric = VectorStoreConfig::Metric::L2;
|
||||
} else if (metric_str == "IP") {
|
||||
config.metric = VectorStoreConfig::Metric::IP;
|
||||
} else {
|
||||
throw std::runtime_error("Invalid metric: " + metric_str);
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Failed to load vector store configuration: " + std::string(e.what()));
|
||||
throw;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
MemoryConfig MemoryConfig::load_from_toml(const toml::table& config_table) {
|
||||
MemoryConfig config;
|
||||
|
||||
try {
|
||||
// Base config
|
||||
if (config_table.contains("max_messages")) {
|
||||
config.max_messages = config_table["max_messages"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("max_tokens_message")) {
|
||||
config.max_tokens_message = config_table["max_tokens_message"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("max_tokens_messages")) {
|
||||
config.max_tokens_messages = config_table["max_tokens_messages"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("max_tokens_context")) {
|
||||
config.max_tokens_context = config_table["max_tokens_context"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("retrieval_limit")) {
|
||||
config.retrieval_limit = config_table["retrieval_limit"].as_integer()->get();
|
||||
}
|
||||
|
||||
// Prompt config
|
||||
if (config_table.contains("fact_extraction_prompt")) {
|
||||
config.fact_extraction_prompt = config_table["fact_extraction_prompt"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("update_memory_prompt")) {
|
||||
config.update_memory_prompt = config_table["update_memory_prompt"].as_string()->get();
|
||||
}
|
||||
|
||||
// EmbeddingModel config
|
||||
if (config_table.contains("embedding_model")) {
|
||||
config.embedding_model = config_table["embedding_model"].as_string()->get();
|
||||
}
|
||||
|
||||
// Vector store config
|
||||
if (config_table.contains("vector_store")) {
|
||||
config.vector_store = config_table["vector_store"].as_string()->get();
|
||||
}
|
||||
|
||||
// LLM config
|
||||
if (config_table.contains("llm")) {
|
||||
config.llm = config_table["llm"].as_string()->get();
|
||||
}
|
||||
|
||||
if (config_table.contains("llm_vision")) {
|
||||
config.llm_vision = config_table["llm_vision"].as_string()->get();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Failed to load memory configuration: " + std::string(e.what()));
|
||||
throw;
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
// Initialize static members
|
||||
std::shared_mutex Config::_config_mutex;
|
||||
|
||||
void Config::_load_llm_config() {
|
||||
std::unique_lock<std::shared_mutex> lock(_config_mutex);
|
||||
try {
|
||||
auto config_path = _get_llm_config_path();
|
||||
logger->info("Loading LLM config file from: " + config_path.string());
|
||||
|
@ -83,109 +295,66 @@ void Config::_load_initial_llm_config() {
|
|||
|
||||
// Load LLM configuration
|
||||
for (const auto& [key, value] : data) {
|
||||
const auto& llm_table = *value.as_table();
|
||||
|
||||
LLMConfig llm_config;
|
||||
|
||||
if (llm_table.contains("model") && llm_table["model"].is_string()) {
|
||||
llm_config.model = llm_table["model"].as_string()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("api_key") && llm_table["api_key"].is_string()) {
|
||||
llm_config.api_key = llm_table["api_key"].as_string()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("base_url") && llm_table["base_url"].is_string()) {
|
||||
llm_config.base_url = llm_table["base_url"].as_string()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("endpoint") && llm_table["endpoint"].is_string()) {
|
||||
llm_config.endpoint = llm_table["endpoint"].as_string()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("vision_details") && llm_table["vision_details"].is_string()) {
|
||||
llm_config.vision_details = llm_table["vision_details"].as_string()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("max_tokens") && llm_table["max_tokens"].is_integer()) {
|
||||
llm_config.max_tokens = llm_table["max_tokens"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("timeout") && llm_table["timeout"].is_integer()) {
|
||||
llm_config.timeout = llm_table["timeout"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("temperature") && llm_table["temperature"].is_floating_point()) {
|
||||
llm_config.temperature = llm_table["temperature"].as_floating_point()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("enable_vision") && llm_table["enable_vision"].is_boolean()) {
|
||||
llm_config.enable_vision = llm_table["enable_vision"].as_boolean()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("enable_tool") && llm_table["enable_tool"].is_boolean()) {
|
||||
llm_config.enable_tool = llm_table["enable_tool"].as_boolean()->get();
|
||||
}
|
||||
|
||||
_config.llm[std::string(key.str())] = llm_config;
|
||||
|
||||
if (!llm_config.enable_tool) {
|
||||
// Load tool helper configuration
|
||||
ToolParser tool_parser;
|
||||
if (llm_table.contains("tool_parser") && llm_table["tool_parser"].is_table()) {
|
||||
const auto& tool_parser_table = *llm_table["tool_parser"].as_table();
|
||||
if (tool_parser_table.contains("tool_start")) {
|
||||
tool_parser.tool_start = tool_parser_table["tool_start"].as_string()->get();
|
||||
}
|
||||
|
||||
if (tool_parser_table.contains("tool_end")) {
|
||||
tool_parser.tool_end = tool_parser_table["tool_end"].as_string()->get();
|
||||
}
|
||||
|
||||
if (tool_parser_table.contains("tool_hint_template")) {
|
||||
tool_parser.tool_hint_template = tool_parser_table["tool_hint_template"].as_string()->get();
|
||||
}
|
||||
}
|
||||
_config.tool_parser[std::string(key.str())] = tool_parser;
|
||||
const auto& config_table = *value.as_table();
|
||||
logger->info("Loading LLM config: " + std::string(key.str()));
|
||||
auto config = LLMConfig::load_from_toml(config_table);
|
||||
llm[std::string(key.str())] = config;
|
||||
if (config.enable_vision && llm.find("vision_default") == llm.end()) {
|
||||
llm["vision_default"] = config;
|
||||
}
|
||||
}
|
||||
|
||||
if (_config.llm.empty()) {
|
||||
if (llm.empty()) {
|
||||
throw std::runtime_error("No LLM configuration found");
|
||||
} else if (_config.llm.find("default") == _config.llm.end()) {
|
||||
_config.llm["default"] = _config.llm.begin()->second;
|
||||
}
|
||||
|
||||
if (_config.tool_parser.find("default") == _config.tool_parser.end()) {
|
||||
_config.tool_parser["default"] = ToolParser();
|
||||
} else if (llm.find("default") == llm.end()) {
|
||||
llm["default"] = llm.begin()->second;
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->warn("Failed to load LLM configuration: " + std::string(e.what()));
|
||||
// Set default configuration
|
||||
_config.llm["default"] = LLMConfig();
|
||||
_config.tool_parser["default"] = ToolParser();
|
||||
logger->error("Failed to load LLM configuration: " + std::string(e.what()));
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
void Config::_load_initial_mcp_server_config() {
|
||||
void Config::_load_mcp_server_config() {
|
||||
std::unique_lock<std::shared_mutex> lock(_config_mutex);
|
||||
try {
|
||||
auto config_path = _get_mcp_server_config_path();
|
||||
logger->info("Loading MCP tool config file from: " + config_path.string());
|
||||
logger->info("Loading MCP server config file from: " + config_path.string());
|
||||
|
||||
const auto& data = toml::parse_file(config_path.string());
|
||||
|
||||
// Load MCP tool configuration
|
||||
// Load MCP server configuration
|
||||
for (const auto& [key, value] : data) {
|
||||
const auto& tool_table = *value.as_table();
|
||||
|
||||
_config.mcp_server[std::string(key.str())] = MCPServerConfig::load_from_toml(tool_table);
|
||||
const auto& config_table = *value.as_table();
|
||||
logger->info("Loading MCP server config: " + std::string(key.str()));
|
||||
mcp_server[std::string(key.str())] = MCPServerConfig::load_from_toml(config_table);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->warn("Failed to load MCP tool configuration: " + std::string(e.what()));
|
||||
logger->warn("Failed to load MCP server configuration: " + std::string(e.what()));
|
||||
}
|
||||
}
|
||||
|
||||
void Config::_load_initial_embedding_model_config() {
|
||||
void Config::_load_memory_config() {
|
||||
std::unique_lock<std::shared_mutex> lock(_config_mutex);
|
||||
try {
|
||||
auto config_path = _get_memory_config_path();
|
||||
logger->info("Loading memory config file from: " + config_path.string());
|
||||
|
||||
const auto& data = toml::parse_file(config_path.string());
|
||||
|
||||
// Load memory configuration
|
||||
for (const auto& [key, value] : data) {
|
||||
const auto& config_table = *value.as_table();
|
||||
logger->info("Loading memory config: " + std::string(key.str()));
|
||||
memory[std::string(key.str())] = MemoryConfig::load_from_toml(config_table);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->warn("Failed to load memory configuration: " + std::string(e.what()));
|
||||
}
|
||||
}
|
||||
|
||||
void Config::_load_embedding_model_config() {
|
||||
std::unique_lock<std::shared_mutex> lock(_config_mutex);
|
||||
try {
|
||||
auto config_path = _get_embedding_model_config_path();
|
||||
logger->info("Loading embedding model config file from: " + config_path.string());
|
||||
|
@ -194,54 +363,25 @@ void Config::_load_initial_embedding_model_config() {
|
|||
|
||||
// Load embedding model configuration
|
||||
for (const auto& [key, value] : data) {
|
||||
const auto& embd_table = *value.as_table();
|
||||
|
||||
EmbeddingModelConfig embd_config;
|
||||
|
||||
if (embd_table.contains("provider") && embd_table["provider"].is_string()) {
|
||||
embd_config.provider = embd_table["provider"].as_string()->get();
|
||||
}
|
||||
|
||||
if (embd_table.contains("base_url") && embd_table["base_url"].is_string()) {
|
||||
embd_config.base_url = embd_table["base_url"].as_string()->get();
|
||||
}
|
||||
|
||||
if (embd_table.contains("endpoint") && embd_table["endpoint"].is_string()) {
|
||||
embd_config.endpoint = embd_table["endpoint"].as_string()->get();
|
||||
}
|
||||
|
||||
if (embd_table.contains("model") && embd_table["model"].is_string()) {
|
||||
embd_config.model = embd_table["model"].as_string()->get();
|
||||
}
|
||||
|
||||
if (embd_table.contains("api_key") && embd_table["api_key"].is_string()) {
|
||||
embd_config.api_key = embd_table["api_key"].as_string()->get();
|
||||
}
|
||||
|
||||
if (embd_table.contains("embedding_dims") && embd_table["embedding_dims"].is_integer()) {
|
||||
embd_config.embedding_dims = embd_table["embedding_dims"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (embd_table.contains("max_retries") && embd_table["max_retries"].is_integer()) {
|
||||
embd_config.max_retries = embd_table["max_retries"].as_integer()->get();
|
||||
}
|
||||
|
||||
_config.embedding_model[std::string(key.str())] = embd_config;
|
||||
const auto& config_table = *value.as_table();
|
||||
logger->info("Loading embedding model config: " + std::string(key.str()));
|
||||
embedding_model[std::string(key.str())] = EmbeddingModelConfig::load_from_toml(config_table);
|
||||
}
|
||||
|
||||
if (_config.embedding_model.empty()) {
|
||||
if (embedding_model.empty()) {
|
||||
throw std::runtime_error("No embedding model configuration found");
|
||||
} else if (_config.embedding_model.find("default") == _config.embedding_model.end()) {
|
||||
_config.embedding_model["default"] = _config.embedding_model.begin()->second;
|
||||
} else if (embedding_model.find("default") == embedding_model.end()) {
|
||||
embedding_model["default"] = embedding_model.begin()->second;
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->warn("Failed to load embedding model configuration: " + std::string(e.what()));
|
||||
// Set default configuration
|
||||
_config.embedding_model["default"] = EmbeddingModelConfig();
|
||||
embedding_model["default"] = EmbeddingModelConfig();
|
||||
}
|
||||
}
|
||||
|
||||
void Config::_load_initial_vector_store_config() {
|
||||
void Config::_load_vector_store_config() {
|
||||
std::unique_lock<std::shared_mutex> lock(_config_mutex);
|
||||
try {
|
||||
auto config_path = _get_vector_store_config_path();
|
||||
logger->info("Loading vector store config file from: " + config_path.string());
|
||||
|
@ -250,53 +390,90 @@ void Config::_load_initial_vector_store_config() {
|
|||
|
||||
// Load vector store configuration
|
||||
for (const auto& [key, value] : data) {
|
||||
const auto& vs_table = *value.as_table();
|
||||
|
||||
VectorStoreConfig vs_config;
|
||||
|
||||
if (vs_table.contains("provider") && vs_table["provider"].is_string()) {
|
||||
vs_config.provider = vs_table["provider"].as_string()->get();
|
||||
}
|
||||
|
||||
if (vs_table.contains("dim") && vs_table["dim"].is_integer()) {
|
||||
vs_config.dim = vs_table["dim"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (vs_table.contains("max_elements") && vs_table["max_elements"].is_integer()) {
|
||||
vs_config.max_elements = vs_table["max_elements"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (vs_table.contains("M") && vs_table["M"].is_integer()) {
|
||||
vs_config.M = vs_table["M"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (vs_table.contains("ef_construction") && vs_table["ef_construction"].is_integer()) {
|
||||
vs_config.ef_construction = vs_table["ef_construction"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (vs_table.contains("metric") && vs_table["metric"].is_string()) {
|
||||
const auto& metric_str = vs_table["metric"].as_string()->get();
|
||||
if (metric_str == "L2") {
|
||||
vs_config.metric = VectorStoreConfig::Metric::L2;
|
||||
} else if (metric_str == "IP") {
|
||||
vs_config.metric = VectorStoreConfig::Metric::IP;
|
||||
} else {
|
||||
throw std::runtime_error("Invalid metric: " + metric_str);
|
||||
}
|
||||
}
|
||||
|
||||
_config.vector_store[std::string(key.str())] = vs_config;
|
||||
const auto& config_table = *value.as_table();
|
||||
logger->info("Loading vector store config: " + std::string(key.str()));
|
||||
vector_store[std::string(key.str())] = VectorStoreConfig::load_from_toml(config_table);
|
||||
}
|
||||
|
||||
if (_config.vector_store.empty()) {
|
||||
if (vector_store.empty()) {
|
||||
throw std::runtime_error("No vector store configuration found");
|
||||
} else if (_config.vector_store.find("default") == _config.vector_store.end()) {
|
||||
_config.vector_store["default"] = _config.vector_store.begin()->second;
|
||||
} else if (vector_store.find("default") == vector_store.end()) {
|
||||
vector_store["default"] = vector_store.begin()->second;
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->warn("Failed to load vector store configuration: " + std::string(e.what()));
|
||||
// Set default configuration
|
||||
_config.vector_store["default"] = VectorStoreConfig();
|
||||
vector_store["default"] = VectorStoreConfig();
|
||||
}
|
||||
}
|
||||
|
||||
LLMConfig Config::get_llm_config(const std::string& config_name) {
|
||||
auto& instance = get_instance();
|
||||
std::shared_lock<std::shared_mutex> lock(_config_mutex);
|
||||
|
||||
bool need_default = instance.llm.find(config_name) == instance.llm.end();
|
||||
if (need_default) {
|
||||
std::string message = "LLM config not found: " + config_name + ", falling back to default LLM config.";
|
||||
lock.unlock();
|
||||
logger->warn(message);
|
||||
lock.lock();
|
||||
return instance.llm.at("default");
|
||||
} else {
|
||||
return instance.llm.at(config_name);
|
||||
}
|
||||
}
|
||||
|
||||
MCPServerConfig Config::get_mcp_server_config(const std::string& config_name) {
|
||||
auto& instance = get_instance();
|
||||
std::shared_lock<std::shared_mutex> lock(_config_mutex);
|
||||
return instance.mcp_server.at(config_name);
|
||||
}
|
||||
|
||||
MemoryConfig Config::get_memory_config(const std::string& config_name) {
|
||||
auto& instance = get_instance();
|
||||
std::shared_lock<std::shared_mutex> lock(_config_mutex);
|
||||
|
||||
bool need_default = instance.memory.find(config_name) == instance.memory.end();
|
||||
if (need_default) {
|
||||
std::string message = "Memory config not found: " + config_name + ", falling back to default memory config.";
|
||||
lock.unlock();
|
||||
logger->warn(message);
|
||||
lock.lock();
|
||||
return instance.memory.at("default");
|
||||
} else {
|
||||
return instance.memory.at(config_name);
|
||||
}
|
||||
}
|
||||
|
||||
EmbeddingModelConfig Config::get_embedding_model_config(const std::string& config_name) {
|
||||
auto& instance = get_instance();
|
||||
std::shared_lock<std::shared_mutex> lock(_config_mutex);
|
||||
|
||||
bool need_default = instance.embedding_model.find(config_name) == instance.embedding_model.end();
|
||||
if (need_default) {
|
||||
std::string message = "Embedding model config not found: " + config_name + ", falling back to default embedding model config.";
|
||||
lock.unlock();
|
||||
logger->warn(message);
|
||||
lock.lock();
|
||||
return instance.embedding_model.at("default");
|
||||
} else {
|
||||
return instance.embedding_model.at(config_name);
|
||||
}
|
||||
}
|
||||
|
||||
VectorStoreConfig Config::get_vector_store_config(const std::string& config_name) {
|
||||
auto& instance = get_instance();
|
||||
std::shared_lock<std::shared_mutex> lock(_config_mutex);
|
||||
|
||||
bool need_default = instance.vector_store.find(config_name) == instance.vector_store.end();
|
||||
if (need_default) {
|
||||
std::string message = "Vector store config not found: " + config_name + ", falling back to default vector store config.";
|
||||
lock.unlock();
|
||||
logger->warn(message);
|
||||
lock.lock();
|
||||
return instance.vector_store.at("default");
|
||||
} else {
|
||||
return instance.vector_store.at(config_name);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
10
src/llm.cpp
10
src/llm.cpp
|
@ -51,7 +51,7 @@ json LLM::format_messages(const std::vector<Message>& messages) {
|
|||
formatted_messages.back()["role"] = "user";
|
||||
formatted_messages.back()["content"] = concat_content("Tool result for `" + message.name + "`:\n\n", formatted_messages.back()["content"]);
|
||||
} else if (!formatted_messages.back()["tool_calls"].empty()) {
|
||||
std::string tool_calls_str = tool_parser_->dump(formatted_messages.back()["tool_calls"]);
|
||||
std::string tool_calls_str = llm_config_->tool_parser.dump(formatted_messages.back()["tool_calls"]);
|
||||
formatted_messages.back().erase("tool_calls");
|
||||
formatted_messages.back()["content"] = concat_content(formatted_messages.back()["content"], tool_calls_str);
|
||||
}
|
||||
|
@ -257,14 +257,14 @@ json LLM::ask_tool(
|
|||
if (body["messages"].empty() || body["messages"].back()["role"] != "user") {
|
||||
body["messages"].push_back({
|
||||
{"role", "user"},
|
||||
{"content", tool_parser_->hint(tools.dump(2))}
|
||||
{"content", llm_config_->tool_parser.hint(tools.dump(2))}
|
||||
});
|
||||
} else if (body["messages"].back()["content"].is_string()) {
|
||||
body["messages"].back()["content"] = body["messages"].back()["content"].get<std::string>() + "\n\n" + tool_parser_->hint(tools.dump(2));
|
||||
body["messages"].back()["content"] = body["messages"].back()["content"].get<std::string>() + "\n\n" + llm_config_->tool_parser.hint(tools.dump(2));
|
||||
} else if (body["messages"].back()["content"].is_array()) {
|
||||
body["messages"].back()["content"].push_back({
|
||||
{"type", "text"},
|
||||
{"text", tool_parser_->hint(tools.dump(2))}
|
||||
{"text", llm_config_->tool_parser.hint(tools.dump(2))}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -284,7 +284,7 @@ json LLM::ask_tool(
|
|||
json json_data = json::parse(res->body);
|
||||
json message = json_data["choices"][0]["message"];
|
||||
if (!llm_config_->enable_tool && message["content"].is_string()) {
|
||||
message = tool_parser_->parse(message["content"].get<std::string>());
|
||||
message = llm_config_->tool_parser.parse(message["content"].get<std::string>());
|
||||
}
|
||||
total_prompt_tokens_ += json_data["usage"]["prompt_tokens"].get<size_t>();
|
||||
total_completion_tokens_ += json_data["usage"]["completion_tokens"].get<size_t>();
|
||||
|
|
|
@ -14,6 +14,9 @@ const char* NEXT_STEP_PROMPT = R"(You can interact with the computer using pytho
|
|||
- playwright: Interact with web pages, take screenshots, generate test code, web scraps the page and execute JavaScript in a real browser environment. Note: Most of the time you need to observer the page before executing other actions.
|
||||
- image_loader: Get base64 image from file or url.
|
||||
- content_provider: Save content and retrieve by chunks.
|
||||
- terminate: Terminate the current task.
|
||||
|
||||
Besides, you may get access to other tools, refer to their descriptions and use them if necessary. Some tools are not available in the current context, you should tell by yourself and do not use them.
|
||||
|
||||
Remember the following:
|
||||
- Today's date is {current_date}.
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
cmake_minimum_required(VERSION 3.10)
|
||||
project(humanus_tests)
|
||||
|
||||
# 添加测试可执行文件
|
||||
add_executable(test_bpe test_bpe.cpp)
|
||||
|
||||
# 链接到主库
|
||||
target_link_libraries(test_bpe PRIVATE humanus)
|
||||
|
||||
# 添加包含路径,确保可以找到需要的头文件
|
||||
target_include_directories(test_bpe PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
119
tokenizer/bpe.h
119
tokenizer/bpe.h
|
@ -17,15 +17,14 @@
|
|||
namespace humanus {
|
||||
|
||||
/**
|
||||
* @brief BPE (Byte Pair Encoding) Tokenizer实现
|
||||
* @brief BPE (Byte Pair Encoding) Tokenizer Implementation
|
||||
*
|
||||
* 这个类实现了基于BPE (Byte Pair Encoding)算法的分词器,
|
||||
* 使用tiktoken格式的词汇表和合并规则文件。
|
||||
* 使用优先队列进行高效的BPE合并操作。
|
||||
* Uses tiktoken format vocabulary and merge rules file.
|
||||
* Uses a priority queue for efficient BPE merging.
|
||||
*/
|
||||
class BPETokenizer : public BaseTokenizer {
|
||||
private:
|
||||
// 辅助结构,用于哈希pair
|
||||
// Helper structure for hashing pairs
|
||||
struct PairHash {
|
||||
template <class T1, class T2>
|
||||
std::size_t operator() (const std::pair<T1, T2>& pair) const {
|
||||
|
@ -33,30 +32,21 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
// 字节对到对应token ID的映射
|
||||
// UTF-8 bytes to token mapping
|
||||
std::unordered_map<std::string, size_t> encoder;
|
||||
// token ID到字节对的映射
|
||||
// Token ID to UTF-8 bytes mapping
|
||||
std::unordered_map<size_t, std::string> decoder;
|
||||
|
||||
// 合并优先级映射,优先级越小越优先合并
|
||||
// Merge priority mapping, lower rank with higher priority
|
||||
std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash> merge_ranks;
|
||||
|
||||
/**
|
||||
* @brief 解码base64编码的字符串
|
||||
* @param encoded base64编码的字符串
|
||||
* @return 解码后的字符串
|
||||
*/
|
||||
// Used by tiktoken merge ranks
|
||||
std::string base64_decode(const std::string& encoded) const {
|
||||
if (encoded.empty()) return "";
|
||||
return base64::decode(encoded);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 用于优先队列的比较函数
|
||||
*
|
||||
* 根据merge_ranks中的优先级比较两个token对。
|
||||
* 优先级较低的值(即较小的rank值)具有更高的优先级。
|
||||
*/
|
||||
// Lower rank with higher priority
|
||||
struct MergeComparator {
|
||||
const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& ranks;
|
||||
|
||||
|
@ -65,27 +55,27 @@ private:
|
|||
|
||||
bool operator()(const std::pair<std::pair<std::string, std::string>, size_t>& a,
|
||||
const std::pair<std::pair<std::string, std::string>, size_t>& b) const {
|
||||
// 首先按照merge_ranks比较,如果不存在则使用最大值
|
||||
// First compare by merge_ranks, if not exist then use max value
|
||||
size_t rank_a = ranks.count(a.first) ? ranks.at(a.first) : std::numeric_limits<size_t>::max();
|
||||
size_t rank_b = ranks.count(b.first) ? ranks.at(b.first) : std::numeric_limits<size_t>::max();
|
||||
|
||||
// 优先队列是最大堆,所以我们需要反向比较(较小的rank优先级更高)
|
||||
// Max heap, so we need to reverse compare (smaller rank has higher priority)
|
||||
return rank_a > rank_b;
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief 从tiktoken格式文件构造BPE tokenizer
|
||||
* @param tokenizer_path tiktoken格式词汇表文件的路径
|
||||
* @brief Construct BPE tokenizer from tiktoken format file
|
||||
* @param tokenizer_path path to tiktoken format vocabulary file
|
||||
*
|
||||
* 文件格式:每行包含一个base64编码的token和对应的token ID
|
||||
* 例如: "IQ== 0",其中"IQ=="是base64编码的token,0是对应的ID
|
||||
* File format: Each line contains a base64 encoded token and its corresponding token ID
|
||||
* Example: "IQ== 0", where "IQ==" is the base64 encoded token and 0 is the corresponding ID
|
||||
*/
|
||||
BPETokenizer(const std::string& tokenizer_path) {
|
||||
std::ifstream file(tokenizer_path);
|
||||
if (!file.is_open()) {
|
||||
throw std::runtime_error("无法打开tokenizer文件: " + tokenizer_path);
|
||||
throw std::runtime_error("Failed to open tokenizer file: " + tokenizer_path);
|
||||
}
|
||||
|
||||
std::string line;
|
||||
|
@ -97,36 +87,33 @@ public:
|
|||
if (iss >> token_base64 >> rank) {
|
||||
std::string token = base64_decode(token_base64);
|
||||
|
||||
// 存储token和其ID的映射关系
|
||||
encoder[token] = rank;
|
||||
decoder[rank] = token;
|
||||
}
|
||||
}
|
||||
|
||||
// 构建merge_ranks
|
||||
build_merge_ranks();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 构建合并优先级映射
|
||||
* @brief Build merge priority mapping
|
||||
*
|
||||
* 利用词汇表中的tokens推断可能的合并规则。
|
||||
* 对于长度大于1的tokens,尝试所有可能的分割,如果分割后的两个部分也在词汇表中,
|
||||
* 则假设这是一个有效的合并规则。
|
||||
* Use tokens in vocabulary to infer possible merge rules.
|
||||
* For tokens longer than 1, try all possible splits, if the split parts are also in the vocabulary,
|
||||
* then assume this is a valid merge rule.
|
||||
*/
|
||||
void build_merge_ranks() {
|
||||
// 对于tiktoken格式,我们可以利用编码中长度>1的token来构建合并规则
|
||||
for (const auto& [token, id] : encoder) {
|
||||
if (token.length() <= 1) continue;
|
||||
|
||||
// 尝试所有可能的分割点
|
||||
// Try all possible split points
|
||||
for (size_t i = 1; i < token.length(); ++i) {
|
||||
std::string first = token.substr(0, i);
|
||||
std::string second = token.substr(i);
|
||||
|
||||
// 如果两个部分都在词汇表中,假设这是一个有效的合并规则
|
||||
// If both parts are in the vocabulary, assume this is a valid merge rule
|
||||
if (encoder.count(first) && encoder.count(second)) {
|
||||
// 使用ID作为优先级 - 较小的ID表示更高的优先级
|
||||
// Use ID as priority - smaller ID means higher priority
|
||||
merge_ranks[{first, second}] = id;
|
||||
}
|
||||
}
|
||||
|
@ -134,42 +121,42 @@ public:
|
|||
}
|
||||
|
||||
/**
|
||||
* @brief 设置合并优先级
|
||||
* @param ranks 新的合并优先级映射
|
||||
* @brief Set merge priority
|
||||
* @param ranks new merge priority mapping
|
||||
*/
|
||||
void set_merge_ranks(const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& ranks) {
|
||||
merge_ranks = ranks;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 执行BPE编码
|
||||
* @param text 要编码的文本
|
||||
* @return 编码后的token IDs
|
||||
* @brief Encode text using BPE
|
||||
* @param text text to encode
|
||||
* @return encoded token IDs
|
||||
*
|
||||
* 该方法使用BPE算法对输入文本进行编码。
|
||||
* 1. 首先将文本分解为单个字节
|
||||
* 2. 使用优先队列根据merge_ranks中的优先级对相邻token进行合并
|
||||
* 3. 将最终的tokens转换为对应的IDs
|
||||
* This method uses BPE algorithm to encode the input text.
|
||||
* 1. First decompose the text into single byte tokens
|
||||
* 2. Use priority queue to merge adjacent tokens based on merge_ranks
|
||||
* 3. Convert the final tokens to corresponding IDs
|
||||
*/
|
||||
std::vector<size_t> encode(const std::string& text) const override {
|
||||
if (text.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// 将文本分解为单个字符tokens
|
||||
// Decompose the text into single character tokens
|
||||
std::vector<std::string> tokens;
|
||||
for (unsigned char c : text) {
|
||||
tokens.push_back(std::string(1, c));
|
||||
}
|
||||
|
||||
// 使用优先队列执行BPE合并
|
||||
// Use priority queue to execute BPE merging
|
||||
while (tokens.size() > 1) {
|
||||
// 构建优先队列,用于选择最高优先级的合并对
|
||||
// Build priority queue to select the highest priority merge pair
|
||||
using MergePair = std::pair<std::pair<std::string, std::string>, size_t>;
|
||||
MergeComparator comparator(merge_ranks);
|
||||
std::priority_queue<MergePair, std::vector<MergePair>, MergeComparator> merge_candidates(comparator);
|
||||
|
||||
// 查找所有可能的合并对
|
||||
// Find all possible merge pairs
|
||||
for (size_t i = 0; i < tokens.size() - 1; ++i) {
|
||||
std::pair<std::string, std::string> pair = {tokens[i], tokens[i+1]};
|
||||
if (merge_ranks.count(pair)) {
|
||||
|
@ -177,23 +164,23 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
// 如果没有可合并的对,退出循环
|
||||
// If there are no merge pairs, exit loop
|
||||
if (merge_candidates.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// 执行优先级最高的合并(优先队列中排在最前面的)
|
||||
// Execute the highest priority merge (the first in the priority queue)
|
||||
auto top_merge = merge_candidates.top();
|
||||
auto pair = top_merge.first; // 要合并的token对
|
||||
size_t pos = top_merge.second; // 要合并的位置
|
||||
auto pair = top_merge.first; // The token pair to merge
|
||||
size_t pos = top_merge.second; // The position to merge
|
||||
|
||||
// 合并token
|
||||
// Merge tokens
|
||||
std::string merged_token = pair.first + pair.second;
|
||||
tokens[pos] = merged_token;
|
||||
tokens.erase(tokens.begin() + pos + 1);
|
||||
}
|
||||
|
||||
// 将tokens转换为IDs
|
||||
// Convert tokens to IDs
|
||||
std::vector<size_t> ids;
|
||||
ids.reserve(tokens.size());
|
||||
for (const auto& token : tokens) {
|
||||
|
@ -201,39 +188,39 @@ public:
|
|||
if (it != encoder.end()) {
|
||||
ids.push_back(it->second);
|
||||
}
|
||||
// 未知token将被跳过
|
||||
// Unknown tokens will be skipped
|
||||
}
|
||||
|
||||
return ids;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 解码BPE tokens
|
||||
* @param tokens 要解码的token IDs
|
||||
* @return 解码后的文本
|
||||
* @brief Decode BPE tokens
|
||||
* @param tokens token IDs to decode
|
||||
* @return decoded text
|
||||
*
|
||||
* 该方法将编码后的token IDs转换回原始文本。
|
||||
* 简单地将每个ID对应的token字符串连接起来。
|
||||
* This method converts encoded token IDs back to the original text.
|
||||
* Simply concatenate the token strings corresponding to each ID.
|
||||
*/
|
||||
std::string decode(const std::vector<size_t>& tokens) const override {
|
||||
std::string result;
|
||||
result.reserve(tokens.size() * 2); // 预估大小,避免频繁重新分配
|
||||
result.reserve(tokens.size() * 2);
|
||||
|
||||
for (size_t id : tokens) {
|
||||
auto it = decoder.find(id);
|
||||
if (it != decoder.end()) {
|
||||
result += it->second;
|
||||
}
|
||||
// 未知ID将被跳过
|
||||
// Unknown IDs will be skipped
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 从tiktoken文件加载BPE词汇和合并规则
|
||||
* @param file_path tiktoken格式文件的路径
|
||||
* @return 共享指针指向创建的BPE tokenizer
|
||||
* @brief Load merge ranks from tiktoken format file
|
||||
* @param file_path path to tiktoken format file
|
||||
* @return shared pointer to created BPE tokenizer
|
||||
*/
|
||||
static std::shared_ptr<BPETokenizer> load_from_tiktoken(const std::string& file_path) {
|
||||
return std::make_shared<BPETokenizer>(file_path);
|
||||
|
|
|
@ -130,7 +130,7 @@ struct BaseMCPTool : BaseTool {
|
|||
std::shared_ptr<mcp::client> _client;
|
||||
try {
|
||||
// Load tool configuration from config file
|
||||
auto _config = Config::get_instance().mcp_server().at(server_name);
|
||||
auto _config = Config::get_mcp_server_config(server_name);
|
||||
|
||||
if (_config.type == "stdio") {
|
||||
std::string command = _config.command;
|
||||
|
@ -150,7 +150,7 @@ struct BaseMCPTool : BaseTool {
|
|||
}
|
||||
}
|
||||
|
||||
_client->initialize(server_name + "_client", "0.0.1");
|
||||
_client->initialize(server_name + "_client", "0.1.0");
|
||||
} catch (const std::exception& e) {
|
||||
throw std::runtime_error("Failed to initialize MCP tool client for `" + server_name + "`: " + std::string(e.what()));
|
||||
}
|
||||
|
|
|
@ -67,11 +67,9 @@ struct ContentProvider : BaseTool {
|
|||
|
||||
ContentProvider() : BaseTool(name_, description_, parameters_) {}
|
||||
|
||||
// 将文本分割成合适大小的块
|
||||
std::vector<json> split_text_into_chunks(const std::string& text, int max_chunk_size) {
|
||||
std::vector<json> chunks;
|
||||
|
||||
// 如果文本为空,返回空数组
|
||||
if (text.empty()) {
|
||||
return chunks;
|
||||
}
|
||||
|
@ -80,23 +78,22 @@ struct ContentProvider : BaseTool {
|
|||
size_t offset = 0;
|
||||
|
||||
while (offset < text_length) {
|
||||
// 首先确定最大可能的块大小
|
||||
size_t raw_chunk_size = std::min(static_cast<size_t>(max_chunk_size), text_length - offset);
|
||||
|
||||
// 使用 validate_utf8 确保不会截断 UTF-8 字符
|
||||
// Make sure the chunk is valid UTF-8
|
||||
std::string potential_chunk = text.substr(offset, raw_chunk_size);
|
||||
size_t valid_utf8_length = validate_utf8(potential_chunk);
|
||||
|
||||
// 调整为有效的 UTF-8 字符边界
|
||||
// Adjust the chunk size to a valid UTF-8 character boundary
|
||||
size_t chunk_size = valid_utf8_length;
|
||||
|
||||
// 如果不是在文本的结尾,并且我们没有因为 UTF-8 截断而减小块大小,
|
||||
// 尝试在空格、换行或标点处分割,以获得更自然的分隔点
|
||||
// If not at the end of the text and we didn't reduce the chunk size due to UTF-8 truncation,
|
||||
// try to split at a natural break point (space, newline, punctuation)
|
||||
if (offset + chunk_size < text_length && chunk_size == raw_chunk_size) {
|
||||
size_t break_pos = offset + chunk_size;
|
||||
|
||||
// 向后寻找一个合适的分割点
|
||||
size_t min_pos = offset + valid_utf8_length / 2; // 不要搜索太远,至少保留一半的有效内容
|
||||
// Search backward for a natural break point (space, newline, punctuation)
|
||||
size_t min_pos = offset + valid_utf8_length / 2; // Don't search too far, at least keep half of the valid content
|
||||
while (break_pos > min_pos &&
|
||||
text[break_pos] != ' ' &&
|
||||
text[break_pos] != '\n' &&
|
||||
|
@ -109,7 +106,7 @@ struct ContentProvider : BaseTool {
|
|||
break_pos--;
|
||||
}
|
||||
|
||||
// 如果找到了合适的分割点且不是原始位置
|
||||
// If a suitable break point is found and it's not the original position
|
||||
if (break_pos > min_pos) {
|
||||
break_pos++; // Include the last character
|
||||
std::string new_chunk = text.substr(offset, break_pos - offset);
|
||||
|
@ -130,7 +127,7 @@ struct ContentProvider : BaseTool {
|
|||
return chunks;
|
||||
}
|
||||
|
||||
// 处理写入操作
|
||||
// Handle write operation
|
||||
ToolResult handle_write(const json& args) {
|
||||
int max_chunk_size = args.value("max_chunk_size", 4000);
|
||||
|
||||
|
@ -142,7 +139,7 @@ struct ContentProvider : BaseTool {
|
|||
|
||||
std::string text_content;
|
||||
|
||||
// 处理内容,分割大型文本
|
||||
// Process content, split large text or mixture of text and image
|
||||
for (const auto& item : args["content"]) {
|
||||
if (!item.contains("type")) {
|
||||
return ToolError("Each content item must have a `type` field");
|
||||
|
@ -168,7 +165,7 @@ struct ContentProvider : BaseTool {
|
|||
return ToolError("Image items must have an `image_url` field with a `url` property");
|
||||
}
|
||||
|
||||
// 图像保持为一个整体
|
||||
// Image remains as a whole
|
||||
processed_content.push_back(item);
|
||||
} else {
|
||||
return ToolError("Unsupported content type: " + type);
|
||||
|
@ -181,7 +178,7 @@ struct ContentProvider : BaseTool {
|
|||
text_content.clear();
|
||||
}
|
||||
|
||||
// 生成一个唯一的存储ID
|
||||
// Generate a unique store ID
|
||||
std::string store_id = "content_" + std::to_string(current_id_);
|
||||
|
||||
if (content_store_.find(store_id) != content_store_.end()) {
|
||||
|
@ -190,18 +187,18 @@ struct ContentProvider : BaseTool {
|
|||
|
||||
current_id_ = (current_id_ + 1) % MAX_STORE_ID;
|
||||
|
||||
// 存储处理后的内容
|
||||
// Store processed content
|
||||
content_store_[store_id] = processed_content;
|
||||
|
||||
// 返回存储ID和内容项数
|
||||
// Return store ID and number of content items
|
||||
json result;
|
||||
result["store_id"] = store_id;
|
||||
result["total_items"] = processed_content.size();
|
||||
|
||||
return ToolResult(result);
|
||||
return ToolResult(result.dump(2));;
|
||||
}
|
||||
|
||||
// 处理读取操作
|
||||
// Handle read operation
|
||||
ToolResult handle_read(const json& args) {
|
||||
if (!args.contains("cursor") || !args["cursor"].is_string()) {
|
||||
return ToolError("`cursor` is required for read operations");
|
||||
|
@ -210,7 +207,7 @@ struct ContentProvider : BaseTool {
|
|||
std::string cursor = args["cursor"];
|
||||
|
||||
if (cursor == "start") {
|
||||
// 列出所有可用的存储ID
|
||||
// List all available store IDs
|
||||
json available_stores = json::array();
|
||||
for (const auto& [id, content] : content_store_) {
|
||||
json store_info;
|
||||
|
@ -227,12 +224,12 @@ struct ContentProvider : BaseTool {
|
|||
result["available_stores"] = available_stores;
|
||||
result["next_cursor"] = "select_store";
|
||||
|
||||
return ToolResult(result);
|
||||
return ToolResult(result.dump(2));
|
||||
} else if (cursor == "select_store") {
|
||||
// 用户需要选择一个存储ID
|
||||
// User needs to select a store ID
|
||||
return ToolError("Please provide a store_id as cursor in format `content_X:Y`");
|
||||
} else if (cursor.find(":") != std::string::npos) { // content_X:Y
|
||||
// 用户正在浏览特定存储内的内容
|
||||
// User is browsing specific content in a store
|
||||
size_t delimiter_pos = cursor.find(":");
|
||||
std::string store_id = cursor.substr(0, delimiter_pos);
|
||||
size_t index = std::stoul(cursor.substr(delimiter_pos + 1));
|
||||
|
@ -245,10 +242,10 @@ struct ContentProvider : BaseTool {
|
|||
return ToolError("Index out of range");
|
||||
}
|
||||
|
||||
// 返回请求的内容项
|
||||
// Return the requested content item
|
||||
json result = content_store_[store_id][index];
|
||||
|
||||
// 添加导航信息
|
||||
// Add navigation information
|
||||
if (index + 1 < content_store_[store_id].size()) {
|
||||
result["next_cursor"] = store_id + ":" + std::to_string(index + 1);
|
||||
result["remaining_items"] = content_store_[store_id].size() - index - 1;
|
||||
|
@ -257,7 +254,7 @@ struct ContentProvider : BaseTool {
|
|||
result["remaining_items"] = 0;
|
||||
}
|
||||
|
||||
return ToolResult(result);
|
||||
return ToolResult(result.dump(2));
|
||||
} else if (cursor == "end") {
|
||||
return ToolResult("You have reached the end of the content.");
|
||||
} else {
|
||||
|
|
|
@ -13,7 +13,7 @@ struct ImageLoader : BaseTool {
|
|||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL of the image to load. Supports HTTP/HTTPS URLs and local file paths. If the URL is a local file path, it must start with file://"
|
||||
"description": "The URL of the image to load. Supports HTTP/HTTPS URLs and absolute local file paths. If the URL is a local file path, it must start with file://"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#include "planning.h"
|
||||
#include <iomanip> // 添加iomanip头文件,用于std::setprecision
|
||||
#include <iomanip>
|
||||
|
||||
namespace humanus {
|
||||
|
||||
|
|
|
@ -15,11 +15,6 @@ struct PythonExecute : BaseMCPTool {
|
|||
{"code", {
|
||||
{"type", "string"},
|
||||
{"description", "The Python code to execute. Note: Use absolute file paths if code will read/write files."}
|
||||
}},
|
||||
{"timeout", {
|
||||
{"type", "number"},
|
||||
{"description", "The timeout for the Python code execution in seconds."},
|
||||
{"default", 5}
|
||||
}}
|
||||
}},
|
||||
{"required", {"code"}}
|
||||
|
|
|
@ -2,6 +2,13 @@
|
|||
#define HUMANUS_TOOL_COLLECTION_H
|
||||
|
||||
#include "base.h"
|
||||
#include "image_loader.h"
|
||||
#include "python_execute.h"
|
||||
#include "filesystem.h"
|
||||
#include "playwright.h"
|
||||
#include "puppeteer.h"
|
||||
#include "content_provider.h"
|
||||
#include "terminate.h"
|
||||
|
||||
namespace humanus {
|
||||
|
||||
|
@ -9,6 +16,8 @@ struct ToolCollection {
|
|||
std::vector<std::shared_ptr<BaseTool>> tools;
|
||||
std::map<std::string, std::shared_ptr<BaseTool>> tools_map;
|
||||
|
||||
ToolCollection() = default;
|
||||
|
||||
ToolCollection(std::vector<std::shared_ptr<BaseTool>> tools) : tools(tools) {
|
||||
for (auto tool : tools) {
|
||||
tools_map[tool->name] = tool;
|
||||
|
@ -67,8 +76,47 @@ struct ToolCollection {
|
|||
add_tool(tool);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<BaseTool> get_tool(const std::string& name) const {
|
||||
auto tool_iter = tools_map.find(name);
|
||||
if (tool_iter == tools_map.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return tool_iter->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> get_tool_names() const {
|
||||
std::vector<std::string> names;
|
||||
for (const auto& tool : tools) {
|
||||
names.push_back(tool->name);
|
||||
}
|
||||
return names;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
class ToolFactory {
|
||||
public:
|
||||
static std::shared_ptr<BaseTool> create(const std::string& name) {
|
||||
if (name == "python_execute") {
|
||||
return std::make_shared<PythonExecute>();
|
||||
} else if (name == "filesystem") {
|
||||
return std::make_shared<Filesystem>();
|
||||
} else if (name == "playwright") {
|
||||
return std::make_shared<Playwright>();
|
||||
} else if (name == "puppeteer") {
|
||||
return std::make_shared<Puppeteer>();
|
||||
} else if (name == "image_loader") {
|
||||
return std::make_shared<ImageLoader>();
|
||||
} else if (name == "content_provider") {
|
||||
return std::make_shared<ContentProvider>();
|
||||
} else if (name == "terminate") {
|
||||
return std::make_shared<Terminate>();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace humanus
|
||||
|
||||
#endif // HUMANUS_TOOL_COLLECTION_H
|
Loading…
Reference in New Issue