add tokenizer (lack of test in use) and content_provider (just implementation, no use)
parent
fb4c3180ae
commit
7c797864dd
|
@ -42,23 +42,18 @@ else()
|
|||
message(FATAL_ERROR "OpenSSL not found. Please install OpenSSL development libraries.")
|
||||
endif()
|
||||
|
||||
find_package(Python3 COMPONENTS Development)
|
||||
if(Python3_FOUND)
|
||||
message(STATUS "Python3 found: ${Python3_VERSION}")
|
||||
message(STATUS "Python3 include directory: ${Python3_INCLUDE_DIRS}")
|
||||
message(STATUS "Python3 libraries: ${Python3_LIBRARIES}")
|
||||
include_directories(${Python3_INCLUDE_DIRS})
|
||||
add_compile_definitions(PYTHON_FOUND)
|
||||
else()
|
||||
message(WARNING "Python3 development libraries not found. Python interpreter will not be available.")
|
||||
endif()
|
||||
|
||||
# mcp
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mcp)
|
||||
|
||||
# server
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/server)
|
||||
|
||||
# tokenizer
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tokenizer)
|
||||
|
||||
# tests
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tests)
|
||||
|
||||
# include
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
|
@ -68,6 +63,11 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/mcp/common)
|
|||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
file(GLOB BASIC_SOURCES
|
||||
"src/*.cpp"
|
||||
"src/*.cc"
|
||||
)
|
||||
|
||||
file(GLOB AGENT_SOURCES
|
||||
"agent/*.cpp"
|
||||
"agent/*.cc"
|
||||
|
@ -88,28 +88,24 @@ file(GLOB MEMORY_SOURCES
|
|||
"memory/*.cc"
|
||||
"memory/*/*.cpp"
|
||||
"memory/*/*.cc"
|
||||
"memory/*/*/*.cpp"
|
||||
"memory/*/*/*.cc"
|
||||
)
|
||||
|
||||
file(GLOB TOKENIZER_SOURCES
|
||||
"tokenizer/*.cpp"
|
||||
"tokenizer/*.cc"
|
||||
)
|
||||
|
||||
# humanus core
|
||||
add_library(humanus
|
||||
src/config.cpp
|
||||
src/llm.cpp
|
||||
src/prompt.cpp
|
||||
src/logger.cpp
|
||||
src/schema.cpp
|
||||
${BASIC_SOURCES}
|
||||
${AGENT_SOURCES}
|
||||
${TOOL_SOURCES}
|
||||
${FLOW_SOURCES}
|
||||
${MEMORY_SOURCES}
|
||||
${TOKENIZER_SOURCES}
|
||||
)
|
||||
|
||||
target_link_libraries(humanus PUBLIC Threads::Threads mcp ${OPENSSL_LIBRARIES})
|
||||
|
||||
if(Python3_FOUND)
|
||||
target_link_libraries(humanus PUBLIC ${Python3_LIBRARIES})
|
||||
endif()
|
||||
|
||||
# examples
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/examples)
|
48
README.md
48
README.md
|
@ -1,48 +1,14 @@
|
|||
## Introduction
|
||||
<p align="center">
|
||||
<img src="assets/logo.png" width="200"/>
|
||||
</p>
|
||||
|
||||
Humanus (meaning "human" in Latin) is a lightweight framework inspired by OpenManus, integrated with the Model Context Protocol (MCP). `humanus.cpp` enables more flexible tool choices, and provides a foundation for building powerful local LLM agents.
|
||||
# humanus.cpp
|
||||
|
||||
Let's embrace local LLM agents w/ humanus.cpp!
|
||||
Humanus (meaning "human" in Latin) is a lightweight framework inspired by [OpenManus](https://github.com/mannaandpoem/OpenManus) and [mem0](https://github.com/mem0ai/mem0), integrated with the Model Context Protocol (MCP). `humanus.cpp` enables more flexible tool choices, and provides a foundation for building powerful local LLM agents.
|
||||
|
||||
## Overview
|
||||
```bash
|
||||
humanus.cpp/
|
||||
├── 📄 config.cpp/.h # 配置系统头文件
|
||||
├── 📄 llm.cpp/.h # LLM集成主实现文件
|
||||
├── 📄 logger.cpp/.h # 日志系统实现文件
|
||||
├── 📄 main.cpp # 程序入口文件
|
||||
├── 📄 prompt.cpp/.h # 预定义提示词
|
||||
├── 📄 schema.cpp/.h # 数据结构定义实现文件
|
||||
├── 📄 toml.hpp # TOML配置文件解析库
|
||||
├── 📂 agent/ # 代理模块目录
|
||||
│ ├── 📄 base.h # 基础代理接口定义
|
||||
│ ├── 📄 humanus.h # Humanus核心代理实现
|
||||
│ ├── 📄 react.h # ReAct代理实现
|
||||
│ └── 📄 toolcall.cpp/.h # 工具调用实现文件
|
||||
├── 📂 flow/ # 工作流模块目录
|
||||
│ ├── 📄 base.h # 基础工作流接口定义
|
||||
│ ├── 📄 flow_factory.h # 工作流工厂类
|
||||
│ └── 📄 planning.cpp/.h # 规划型工作流实现文件
|
||||
├── 📂 mcp/ # 模型上下文协议(MCP)实现目录
|
||||
├── 📂 memory/ # 内存管理模块
|
||||
│ ├── 📄 base.h # 基础内存接口定义
|
||||
│ └── 📂 mem0/ # TODO: mem0记忆实现
|
||||
├── 📂 server/ # 服务器模块
|
||||
│ ├── 📄 mcp_server_main.cpp # MCP服务器入口文件
|
||||
│ └── 📄 python_execute.cpp # Python执行环境集成实现
|
||||
├── 📂 spdlog/ # 第三方日志库
|
||||
└── 📂 tool/ # 工具模块目录
|
||||
├── 📄 base.h # 基础工具接口定义
|
||||
├── 📄 filesystem.h # 文件系统操作工具
|
||||
├── 📄 planning.cpp/.h # 规划工具实现
|
||||
├── 📄 puppeteer.h # Puppeteer浏览器自动化工具
|
||||
├── 📄 python_execute.h # Python执行工具
|
||||
├── 📄 terminate.h # 终止工具
|
||||
└── 📄 tool_collection.h # 工具集合定义
|
||||
```
|
||||
Let's embrace local LLM agents **w/** humanus.cpp!
|
||||
|
||||
|
||||
## Features
|
||||
## Project Demo
|
||||
|
||||
## How to Build
|
||||
|
||||
|
|
52
agent/base.h
52
agent/base.h
|
@ -38,8 +38,6 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
|||
|
||||
int duplicate_threshold; // Threshold for duplicate messages
|
||||
|
||||
std::string current_request; // Current request from user
|
||||
|
||||
BaseAgent(
|
||||
const std::string& name,
|
||||
const std::string& description,
|
||||
|
@ -47,9 +45,7 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
|||
const std::string& next_step_prompt,
|
||||
const std::shared_ptr<LLM>& llm = nullptr,
|
||||
const std::shared_ptr<BaseMemory>& memory = nullptr,
|
||||
AgentState state = AgentState::IDLE,
|
||||
int max_steps = 10,
|
||||
int current_step = 0,
|
||||
int duplicate_threshold = 2
|
||||
) : name(name),
|
||||
description(description),
|
||||
|
@ -57,9 +53,7 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
|||
next_step_prompt(next_step_prompt),
|
||||
llm(llm),
|
||||
memory(memory),
|
||||
state(state),
|
||||
max_steps(max_steps),
|
||||
current_step(current_step),
|
||||
duplicate_threshold(duplicate_threshold) {
|
||||
initialize_agent();
|
||||
}
|
||||
|
@ -70,8 +64,9 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
|||
llm = LLM::get_instance("default");
|
||||
}
|
||||
if (!memory) {
|
||||
memory = std::make_shared<Memory>(max_steps);
|
||||
memory = std::make_shared<Memory>(MemoryConfig());
|
||||
}
|
||||
reset(true);
|
||||
}
|
||||
|
||||
// Add a message to the agent's memory
|
||||
|
@ -92,7 +87,7 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
|||
|
||||
// Execute the agent's main loop asynchronously
|
||||
virtual std::string run(const std::string& request = "") {
|
||||
current_request = request;
|
||||
memory->current_request = request;
|
||||
|
||||
if (state != AgentState::IDLE) {
|
||||
throw std::runtime_error("Cannot run agent from state " + agent_state_map[state]);
|
||||
|
@ -118,7 +113,7 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
|||
}
|
||||
|
||||
if (is_stuck()) {
|
||||
this->handle_stuck_state();
|
||||
handle_stuck_state();
|
||||
}
|
||||
|
||||
results.push_back("Step " + std::to_string(current_step) + ": " + step_result);
|
||||
|
@ -157,13 +152,24 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
|||
|
||||
// Handle stuck state by adding a prompt to change strategy
|
||||
void handle_stuck_state() {
|
||||
std::string stuck_prompt = "\
|
||||
Observed duplicate responses. Consider new strategies and avoid repeating ineffective paths already attempted.";
|
||||
next_step_prompt = stuck_prompt + "\n" + next_step_prompt;
|
||||
if (!current_request.empty()) {
|
||||
next_step_prompt += "\nAnd don't for get your current task: " + current_request;
|
||||
}
|
||||
std::string stuck_prompt = "Observed duplicate responses. Consider new strategies and avoid repeating ineffective paths already attempted.";
|
||||
logger->warn("Agent detected stuck state. Added prompt: " + stuck_prompt);
|
||||
memory->add_message(Message::user_message(stuck_prompt));
|
||||
}
|
||||
|
||||
// O(nm) LCS algorithm, could basically handle current LLM context
|
||||
size_t get_lcs_length(const std::string& s1, const std::string& s2) {
|
||||
std::vector<std::vector<size_t>> dp(s1.size() + 1, std::vector<size_t>(s2.size() + 1));
|
||||
for (size_t i = 1; i <= s1.size(); i++) {
|
||||
for (size_t j = 1; j <= s2.size(); j++) {
|
||||
if (s1[i - 1] == s2[j - 1]) {
|
||||
dp[i][j] = dp[i - 1][j - 1] + 1;
|
||||
} else {
|
||||
dp[i][j] = std::max(dp[i - 1][j], dp[i][j - 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return dp[s1.size()][s2.size()];
|
||||
}
|
||||
|
||||
// Check if the agent is stuck in a loop by detecting duplicate content
|
||||
|
@ -175,18 +181,24 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
|||
}
|
||||
|
||||
const Message& last_message = messages.back();
|
||||
if (last_message.content.empty() || last_message.content.is_null()) {
|
||||
if (last_message.content.empty() || last_message.role != "assistant") {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Count identical content occurrences
|
||||
int duplicate_count = 0;
|
||||
int duplicate_lcs_length = 0.6 * last_message.content.get<std::string>().size(); // TODO: make this threshold configurable
|
||||
for (auto r_it = messages.rbegin(); r_it != messages.rend(); ++r_it) {
|
||||
if (r_it == messages.rbegin()) {
|
||||
continue;
|
||||
}
|
||||
const Message& message = *r_it;
|
||||
if (message.role == "assistant" && message.content == last_message.content) {
|
||||
duplicate_count++;
|
||||
if (duplicate_count >= duplicate_threshold) {
|
||||
break;
|
||||
if (message.role == "assistant" && !message.content.empty()) {
|
||||
if (get_lcs_length(message.content, last_message.content) > duplicate_lcs_length) {
|
||||
duplicate_count++;
|
||||
if (duplicate_count >= duplicate_threshold) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include "tool/python_execute.h"
|
||||
#include "tool/terminate.h"
|
||||
#include "tool/puppeteer.h"
|
||||
#include "tool/playwright.h"
|
||||
#include "tool/filesystem.h"
|
||||
|
||||
namespace humanus {
|
||||
|
@ -25,6 +26,7 @@ struct Humanus : ToolCallAgent {
|
|||
{
|
||||
std::make_shared<PythonExecute>(),
|
||||
std::make_shared<Filesystem>(),
|
||||
std::make_shared<Playwright>(),
|
||||
std::make_shared<Terminate>()
|
||||
}
|
||||
),
|
||||
|
@ -36,9 +38,7 @@ struct Humanus : ToolCallAgent {
|
|||
const std::string& next_step_prompt = prompt::humanus::NEXT_STEP_PROMPT,
|
||||
const std::shared_ptr<LLM>& llm = nullptr,
|
||||
const std::shared_ptr<BaseMemory>& memory = nullptr,
|
||||
AgentState state = AgentState::IDLE,
|
||||
int max_steps = 30,
|
||||
int current_step = 0,
|
||||
int duplicate_threshold = 2
|
||||
) : ToolCallAgent(
|
||||
available_tools,
|
||||
|
@ -50,11 +50,37 @@ struct Humanus : ToolCallAgent {
|
|||
next_step_prompt,
|
||||
llm,
|
||||
memory,
|
||||
state,
|
||||
max_steps,
|
||||
current_step,
|
||||
duplicate_threshold
|
||||
) {}
|
||||
|
||||
std::string run(const std::string& request = "") override {
|
||||
memory->current_request = request;
|
||||
|
||||
auto tmp_next_step_prompt = next_step_prompt;
|
||||
|
||||
size_t pos = next_step_prompt.find("{current_date}");
|
||||
if (pos != std::string::npos) {
|
||||
// %Y-%d-%m
|
||||
auto current_date = std::chrono::system_clock::now();
|
||||
auto in_time_t = std::chrono::system_clock::to_time_t(current_date);
|
||||
std::stringstream ss;
|
||||
std::tm tm_info = *std::localtime(&in_time_t);
|
||||
ss << std::put_time(&tm_info, "%Y-%m-%d");
|
||||
std::string formatted_date = ss.str(); // YYYY-MM-DD
|
||||
next_step_prompt.replace(pos, 14, formatted_date);
|
||||
}
|
||||
|
||||
pos = next_step_prompt.find("{current_request}");
|
||||
if (pos != std::string::npos) {
|
||||
next_step_prompt.replace(pos, 17, request);
|
||||
}
|
||||
|
||||
auto result = BaseAgent::run(request);
|
||||
next_step_prompt = tmp_next_step_prompt; // restore the original prompt
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
#ifndef HUMANUS_AGENT_MCP_H
|
||||
#define HUMANUS_AGENT_MCP_H
|
||||
|
||||
#include "base.h"
|
||||
#include "toolcall.h"
|
||||
#include "prompt.h"
|
||||
|
||||
namespace humanus {
|
||||
|
||||
struct MCPAgent : ToolCallAgent {
|
||||
MCPAgent(
|
||||
const std::vector<std::string>& mcp_servers,
|
||||
const ToolCollection& available_tools = ToolCollection(
|
||||
{
|
||||
std::make_shared<Terminate>()
|
||||
}
|
||||
),
|
||||
const std::string& tool_choice = "auto",
|
||||
const std::set<std::string>& special_tool_names = {"terminate"},
|
||||
const std::string& name = "mcp_agent",
|
||||
const std::string& description = "an agent that can execute tool calls.",
|
||||
const std::string& system_prompt = prompt::toolcall::SYSTEM_PROMPT,
|
||||
const std::string& next_step_prompt = prompt::toolcall::NEXT_STEP_PROMPT,
|
||||
const std::shared_ptr<LLM>& llm = nullptr,
|
||||
const std::shared_ptr<BaseMemory>& memory = nullptr,
|
||||
int max_steps = 30,
|
||||
int duplicate_threshold = 2
|
||||
) : ToolCallAgent(
|
||||
available_tools,
|
||||
tool_choice,
|
||||
special_tool_names,
|
||||
name,
|
||||
description,
|
||||
system_prompt,
|
||||
next_step_prompt,
|
||||
llm,
|
||||
memory,
|
||||
max_steps,
|
||||
duplicate_threshold
|
||||
) {
|
||||
for (const auto& server_name : mcp_servers) {
|
||||
this->available_tools.add_mcp_tools(server_name);
|
||||
}
|
||||
}
|
||||
|
||||
std::string run(const std::string& request = "") override {
|
||||
memory->current_request = request;
|
||||
|
||||
auto tmp_next_step_prompt = next_step_prompt;
|
||||
|
||||
size_t pos = next_step_prompt.find("{current_date}");
|
||||
if (pos != std::string::npos) {
|
||||
// %Y-%d-%m
|
||||
auto current_date = std::chrono::system_clock::now();
|
||||
auto in_time_t = std::chrono::system_clock::to_time_t(current_date);
|
||||
std::stringstream ss;
|
||||
std::tm tm_info = *std::localtime(&in_time_t);
|
||||
ss << std::put_time(&tm_info, "%Y-%m-%d");
|
||||
std::string formatted_date = ss.str(); // YYYY-MM-DD
|
||||
next_step_prompt.replace(pos, 14, formatted_date);
|
||||
}
|
||||
|
||||
pos = next_step_prompt.find("{current_request}");
|
||||
if (pos != std::string::npos) {
|
||||
next_step_prompt.replace(pos, 17, request);
|
||||
}
|
||||
|
||||
auto result = BaseAgent::run(request);
|
||||
next_step_prompt = tmp_next_step_prompt; // restore the original prompt
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif // HUMANUS_AGENT_MCP_H
|
|
@ -13,9 +13,7 @@ struct ReActAgent : BaseAgent {
|
|||
const std::string& next_step_prompt,
|
||||
const std::shared_ptr<LLM>& llm = nullptr,
|
||||
const std::shared_ptr<BaseMemory>& memory = nullptr,
|
||||
AgentState state = AgentState::IDLE,
|
||||
int max_steps = 10,
|
||||
int current_step = 0,
|
||||
int duplicate_threshold = 2
|
||||
) : BaseAgent(
|
||||
name,
|
||||
|
@ -24,9 +22,7 @@ struct ReActAgent : BaseAgent {
|
|||
next_step_prompt,
|
||||
llm,
|
||||
memory,
|
||||
state,
|
||||
max_steps,
|
||||
current_step,
|
||||
duplicate_threshold
|
||||
) {}
|
||||
|
||||
|
@ -39,6 +35,7 @@ struct ReActAgent : BaseAgent {
|
|||
// Execute a single step: think and act.
|
||||
virtual std::string step() {
|
||||
bool should_act = think();
|
||||
logger->info("Prompt tokens: " + std::to_string(llm->total_prompt_tokens()) + ", Completion tokens: " + std::to_string(llm->total_completion_tokens()));
|
||||
if (!should_act) {
|
||||
return "Thinking complete - no action needed";
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ namespace humanus {
|
|||
bool ToolCallAgent::think() {
|
||||
// Get response with tool options
|
||||
auto response = llm->ask_tool(
|
||||
memory->get_messages(current_request),
|
||||
memory->get_messages(memory->current_request),
|
||||
system_prompt,
|
||||
next_step_prompt,
|
||||
available_tools.to_params(),
|
||||
|
@ -59,9 +59,6 @@ bool ToolCallAgent::think() {
|
|||
return !tool_calls.empty();
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("🚨 Oops! The " + name + "'s thinking process hit a snag: " + std::string(e.what()));
|
||||
memory->add_message(Message::assistant_message(
|
||||
"Error encountered while processing: " + std::string(e.what())
|
||||
));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,9 +29,7 @@ struct ToolCallAgent : ReActAgent {
|
|||
const std::string& next_step_prompt = prompt::toolcall::NEXT_STEP_PROMPT,
|
||||
const std::shared_ptr<LLM>& llm = nullptr,
|
||||
const std::shared_ptr<BaseMemory>& memory = nullptr,
|
||||
AgentState state = AgentState::IDLE,
|
||||
int max_steps = 30,
|
||||
int current_step = 0,
|
||||
int duplicate_threshold = 2
|
||||
) : ReActAgent(
|
||||
name,
|
||||
|
@ -40,14 +38,16 @@ struct ToolCallAgent : ReActAgent {
|
|||
next_step_prompt,
|
||||
llm,
|
||||
memory,
|
||||
state,
|
||||
max_steps,
|
||||
current_step,
|
||||
duplicate_threshold
|
||||
),
|
||||
available_tools(available_tools),
|
||||
tool_choice(tool_choice),
|
||||
special_tool_names(special_tool_names) {}
|
||||
special_tool_names(special_tool_names) {
|
||||
if (available_tools.tools_map.find("terminate") == available_tools.tools_map.end()) {
|
||||
throw std::runtime_error("terminate tool must be present in available_tools");
|
||||
}
|
||||
}
|
||||
|
||||
// Process current state and decide next actions using tools
|
||||
bool think() override;
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 331 KiB |
|
@ -1,4 +1,4 @@
|
|||
[default]
|
||||
[nomic-embed-text-v1.5]
|
||||
provider = "oai"
|
||||
base_url = "http://localhost:8080"
|
||||
endpoint = "/v1/embeddings"
|
||||
|
@ -6,3 +6,12 @@ model = "nomic-embed-text-v1.5.f16.gguf"
|
|||
api_key = ""
|
||||
embeddings_dim = 768
|
||||
max_retries = 3
|
||||
|
||||
[default]
|
||||
provider = "oai"
|
||||
base_url = "https://dashscope.aliyuncs.com"
|
||||
endpoint = "/compatible-mode/v1/embeddings"
|
||||
model = "text-embedding-v3"
|
||||
api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600"
|
||||
embeddings_dim = 1024
|
||||
max_retries = 3
|
|
@ -3,3 +3,35 @@ model = "qwen-max"
|
|||
base_url = "https://dashscope.aliyuncs.com"
|
||||
endpoint = "/compatible-mode/v1/chat/completions"
|
||||
api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600"
|
||||
|
||||
[glm-4-plus]
|
||||
model = "glm-4-plus"
|
||||
base_url = "https://open.bigmodel.cn"
|
||||
endpoint = "/api/paas/v4/chat/completions"
|
||||
api_key = "7e12e1cb8fe5786d83c74d2ef48db511.xPVWzEZt8RvIciW9"
|
||||
|
||||
[qwen-vl-max]
|
||||
model = "qwen-vl-max"
|
||||
base_url = "https://dashscope.aliyuncs.com"
|
||||
endpoint = "/compatible-mode/v1/chat/completions"
|
||||
api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600"
|
||||
|
||||
[claude-3.5-sonnet]
|
||||
model = "anthropic/claude-3.5-sonnet"
|
||||
base_url = "https://openrouter.ai"
|
||||
endpoint = "/api/v1/chat/completions"
|
||||
api_key = "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"
|
||||
max_tokens = 8192
|
||||
|
||||
[deepseek-chat]
|
||||
model = "deepseek-chat"
|
||||
base_url = "https://api.deepseek.com"
|
||||
endpoint = "/v1/chat/completions"
|
||||
api_key = "sk-93c5bfcb920c4a8aa345791d429b8536"
|
||||
|
||||
[deepseek-r1]
|
||||
model = "deepseek-reasoner"
|
||||
base_url = "https://api.deepseek.com"
|
||||
endpoint = "/v1/chat/completions"
|
||||
api_key = "sk-93c5bfcb920c4a8aa345791d429b8536"
|
||||
oai_tool_support = false
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
[python_execute]
|
||||
type = "sse"
|
||||
host = "localhost"
|
||||
port = 8818
|
||||
port = 8896
|
||||
sse_endpoint = "/sse"
|
||||
|
||||
[puppeteer]
|
||||
|
@ -9,6 +9,11 @@ type = "stdio"
|
|||
command = "npx"
|
||||
args = ["-y", "@modelcontextprotocol/server-puppeteer"]
|
||||
|
||||
[playwright]
|
||||
type = "stdio"
|
||||
command = "npx"
|
||||
args = ["-y", "@executeautomation/playwright-mcp-server"]
|
||||
|
||||
[filesystem]
|
||||
type = "stdio"
|
||||
command = "npx"
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
set(target humanus_chat)
|
||||
|
||||
add_executable(${target} humanus_chat.cpp)
|
||||
|
||||
# 链接到核心库
|
||||
target_link_libraries(${target} PRIVATE humanus)
|
|
@ -2,7 +2,7 @@
|
|||
#include "logger.h"
|
||||
#include "prompt.h"
|
||||
#include "flow/flow_factory.h"
|
||||
#include "memory/mem0/base.h"
|
||||
#include "memory/base.h"
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#include <signal.h>
|
||||
|
@ -21,49 +21,14 @@ using namespace humanus;
|
|||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
static void sigint_handler(int signo) {
|
||||
if (signo == SIGINT) {
|
||||
// make sure all logs are flushed
|
||||
logger->info("Interrupted by user\n");
|
||||
exit(0);
|
||||
logger->flush();
|
||||
_exit(130);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
static bool readline_utf8(std::string & line, bool multiline_input) {
|
||||
#if defined(_WIN32)
|
||||
std::wstring wline;
|
||||
if (!std::getline(std::wcin, wline)) {
|
||||
// Input stream is bad or EOF received
|
||||
line.clear();
|
||||
GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
|
||||
return false;
|
||||
}
|
||||
|
||||
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
|
||||
line.resize(size_needed);
|
||||
WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
|
||||
#else
|
||||
if (!std::getline(std::cin, line)) {
|
||||
// Input stream is bad or EOF received
|
||||
line.clear();
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
if (!line.empty()) {
|
||||
char last = line.back();
|
||||
if (last == '/') { // Always return control on '/' symbol
|
||||
line.pop_back();
|
||||
return false;
|
||||
}
|
||||
if (last == '\\') { // '\\' changes the default action
|
||||
line.pop_back();
|
||||
multiline_input = !multiline_input;
|
||||
}
|
||||
}
|
||||
line += '\n';
|
||||
|
||||
// By default, continue input if multiline_input is set
|
||||
return multiline_input;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
// ctrl+C handling
|
||||
|
@ -85,16 +50,10 @@ int main() {
|
|||
#endif
|
||||
}
|
||||
|
||||
auto memory_config = mem0::MemoryConfig();
|
||||
|
||||
memory_config.max_messages = 1;
|
||||
memory_config.retrieval_limit = 10;
|
||||
|
||||
auto memory = std::make_shared<mem0::Memory>(memory_config);
|
||||
memory->current_request = "Chat with the user";
|
||||
auto memory = std::make_shared<Memory>(MemoryConfig());
|
||||
|
||||
Chatbot chatbot{
|
||||
"chat_mem0", // name
|
||||
"chatbot", // name
|
||||
"A chatbot agent that uses memory to remember conversation history", // description
|
||||
"You are a helpful assistant.", // system_prompt
|
||||
nullptr, // llm
|
||||
|
@ -103,13 +62,19 @@ int main() {
|
|||
|
||||
while (true) {
|
||||
std::cout << "> ";
|
||||
|
||||
std::string prompt;
|
||||
readline_utf8(prompt, false);
|
||||
if (prompt == "exit" || prompt == "exit\n") {
|
||||
|
||||
if (prompt == "exit") {
|
||||
logger->info("Goodbye!");
|
||||
break;
|
||||
}
|
||||
|
||||
logger->info("Processing your request: " + prompt);
|
||||
auto response = chatbot.run(prompt);
|
||||
std::cout << response << std::endl;
|
||||
logger->info("✨ " + chatbot.name + "'s response: " + response);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -20,49 +20,14 @@ using namespace humanus;
|
|||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
static void sigint_handler(int signo) {
|
||||
if (signo == SIGINT) {
|
||||
// make sure all logs are flushed
|
||||
logger->info("Interrupted by user\n");
|
||||
exit(0);
|
||||
logger->flush();
|
||||
_exit(130);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
static bool readline_utf8(std::string & line, bool multiline_input) {
|
||||
#if defined(_WIN32)
|
||||
std::wstring wline;
|
||||
if (!std::getline(std::wcin, wline)) {
|
||||
// Input stream is bad or EOF received
|
||||
line.clear();
|
||||
GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
|
||||
return false;
|
||||
}
|
||||
|
||||
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
|
||||
line.resize(size_needed);
|
||||
WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
|
||||
#else
|
||||
if (!std::getline(std::cin, line)) {
|
||||
// Input stream is bad or EOF received
|
||||
line.clear();
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
if (!line.empty()) {
|
||||
char last = line.back();
|
||||
if (last == '/') { // Always return control on '/' symbol
|
||||
line.pop_back();
|
||||
return false;
|
||||
}
|
||||
if (last == '\\') { // '\\' changes the default action
|
||||
line.pop_back();
|
||||
multiline_input = !multiline_input;
|
||||
}
|
||||
}
|
||||
line += '\n';
|
||||
|
||||
// By default, continue input if multiline_input is set
|
||||
return multiline_input;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
// ctrl+C handling
|
||||
|
@ -85,6 +50,7 @@ int main() {
|
|||
}
|
||||
|
||||
Humanus agent = Humanus();
|
||||
|
||||
while (true) {
|
||||
if (agent.current_step == agent.max_steps) {
|
||||
std::cout << "Automatically paused after " << agent.max_steps << " steps." << std::endl;
|
||||
|
@ -93,13 +59,17 @@ int main() {
|
|||
} else {
|
||||
std::cout << "Enter your prompt (or 'exit' to quit): ";
|
||||
}
|
||||
|
||||
std::string prompt;
|
||||
readline_utf8(prompt, false);
|
||||
if (prompt == "exit" || prompt == "exit\n") {
|
||||
if (prompt == "exit") {
|
||||
logger->info("Goodbye!");
|
||||
break;
|
||||
}
|
||||
logger->info("Processing your request...");
|
||||
|
||||
logger->info("Processing your request: " + prompt);
|
||||
agent.run(prompt);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
set(target humanus_chat_mem0)
|
||||
set(target humanus_cli_mcp)
|
||||
|
||||
add_executable(${target} chat_mem0.cpp)
|
||||
add_executable(${target} humanus_mcp.cpp)
|
||||
|
||||
# 链接到核心库
|
||||
target_link_libraries(${target} PRIVATE humanus)
|
|
@ -0,0 +1,85 @@
|
|||
#include "agent/mcp.h"
|
||||
#include "logger.h"
|
||||
#include "prompt.h"
|
||||
#include "flow/flow_factory.h"
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#include <signal.h>
|
||||
#include <unistd.h>
|
||||
#elif defined (_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <signal.h>
|
||||
#endif
|
||||
|
||||
using namespace humanus;
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
static void sigint_handler(int signo) {
|
||||
if (signo == SIGINT) {
|
||||
// make sure all logs are flushed
|
||||
logger->info("Interrupted by user\n");
|
||||
logger->flush();
|
||||
_exit(130);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
|
||||
// ctrl+C handling
|
||||
{
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = sigint_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
#elif defined (_WIN32)
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
SetConsoleCP(CP_UTF8);
|
||||
SetConsoleOutputCP(CP_UTF8);
|
||||
_setmode(_fileno(stdin), _O_WTEXT); // wide character input mode
|
||||
#endif
|
||||
}
|
||||
|
||||
if (argc <= 1) {
|
||||
std::cout << "Usage: " << argv[0] << " <mcp_server1> <mcp_server2>..." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<std::string> mcp_servers;
|
||||
for (int i = 1; i < argc; i++) {
|
||||
mcp_servers.emplace_back(argv[i]);
|
||||
}
|
||||
|
||||
MCPAgent agent = MCPAgent(
|
||||
mcp_servers
|
||||
);
|
||||
|
||||
while (true) {
|
||||
if (agent.current_step == agent.max_steps) {
|
||||
std::cout << "Automatically paused after " << agent.max_steps << " steps." << std::endl;
|
||||
std::cout << "Enter your prompt (enter an empty line to resume or 'exit' to quit): ";
|
||||
agent.reset(false);
|
||||
} else {
|
||||
std::cout << "Enter your prompt (or 'exit' to quit): ";
|
||||
}
|
||||
|
||||
std::string prompt;
|
||||
readline_utf8(prompt, false);
|
||||
if (prompt == "exit") {
|
||||
logger->info("Goodbye!");
|
||||
break;
|
||||
}
|
||||
|
||||
logger->info("Processing your request: " + prompt);
|
||||
agent.run(prompt);
|
||||
}
|
||||
}
|
|
@ -20,49 +20,14 @@ using namespace humanus;
|
|||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
static void sigint_handler(int signo) {
|
||||
if (signo == SIGINT) {
|
||||
// make sure all logs are flushed
|
||||
logger->info("Interrupted by user\n");
|
||||
exit(0);
|
||||
logger->flush();
|
||||
_exit(130);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
static bool readline_utf8(std::string & line, bool multiline_input) {
|
||||
#if defined(_WIN32)
|
||||
std::wstring wline;
|
||||
if (!std::getline(std::wcin, wline)) {
|
||||
// Input stream is bad or EOF received
|
||||
line.clear();
|
||||
GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
|
||||
return false;
|
||||
}
|
||||
|
||||
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
|
||||
line.resize(size_needed);
|
||||
WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
|
||||
#else
|
||||
if (!std::getline(std::cin, line)) {
|
||||
// Input stream is bad or EOF received
|
||||
line.clear();
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
if (!line.empty()) {
|
||||
char last = line.back();
|
||||
if (last == '/') { // Always return control on '/' symbol
|
||||
line.pop_back();
|
||||
return false;
|
||||
}
|
||||
if (last == '\\') { // '\\' changes the default action
|
||||
line.pop_back();
|
||||
multiline_input = !multiline_input;
|
||||
}
|
||||
}
|
||||
line += '\n';
|
||||
|
||||
// By default, continue input if multiline_input is set
|
||||
return multiline_input;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
// ctrl+C handling
|
||||
|
@ -104,23 +69,22 @@ int main() {
|
|||
std::cout << "Automatically paused after " << agent_ptr->current_step << " steps." << std::endl;
|
||||
std::cout << "Enter your prompt (enter an empty line to resume or 'exit' to quit): ";
|
||||
agent_ptr->reset(false);
|
||||
} else if (agent_ptr->state != AgentState::IDLE) {
|
||||
std::cout << "Enter your prompt (enter an empty line to retry or 'exit' to quit): ";
|
||||
agent_ptr->reset(false);
|
||||
} else {
|
||||
std::cout << "Enter your prompt (or 'exit' to quit): ";
|
||||
}
|
||||
|
||||
if (agent_ptr->state != AgentState::IDLE) {
|
||||
break;
|
||||
}
|
||||
|
||||
std::string prompt;
|
||||
readline_utf8(prompt, false);
|
||||
if (prompt == "exit" || prompt == "exit\n") {
|
||||
if (prompt == "exit") {
|
||||
logger->info("Goodbye!");
|
||||
break;
|
||||
}
|
||||
|
||||
std::cout << "Processing your request..." << std::endl;
|
||||
logger->info("Processing your request: " + prompt);
|
||||
auto result = flow->execute(prompt);
|
||||
std::cout << result << std::endl;
|
||||
logger->info("🌟 " + agent_ptr->name + "'s summary: " + result);
|
||||
}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
set(target humanus_cli_plan_mem0)
|
||||
|
||||
add_executable(${target} humanus_plan_mem0.cpp)
|
||||
|
||||
# 链接到核心库
|
||||
target_link_libraries(${target} PRIVATE humanus)
|
||||
|
||||
# 设置输出目录
|
||||
set_target_properties(${target}
|
||||
PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
|
||||
)
|
|
@ -1,148 +0,0 @@
|
|||
#include "agent/humanus.h"
|
||||
#include "logger.h"
|
||||
#include "prompt.h"
|
||||
#include "flow/flow_factory.h"
|
||||
#include "memory/mem0/base.h"
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#include <signal.h>
|
||||
#include <unistd.h>
|
||||
#elif defined (_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <signal.h>
|
||||
#endif
|
||||
|
||||
using namespace humanus;
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
static void sigint_handler(int signo) {
|
||||
if (signo == SIGINT) {
|
||||
logger->info("Interrupted by user\n");
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
static bool readline_utf8(std::string & line, bool multiline_input) {
|
||||
#if defined(_WIN32)
|
||||
std::wstring wline;
|
||||
if (!std::getline(std::wcin, wline)) {
|
||||
// Input stream is bad or EOF received
|
||||
line.clear();
|
||||
GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
|
||||
return false;
|
||||
}
|
||||
|
||||
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
|
||||
line.resize(size_needed);
|
||||
WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
|
||||
#else
|
||||
if (!std::getline(std::cin, line)) {
|
||||
// Input stream is bad or EOF received
|
||||
line.clear();
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
if (!line.empty()) {
|
||||
char last = line.back();
|
||||
if (last == '/') { // Always return control on '/' symbol
|
||||
line.pop_back();
|
||||
return false;
|
||||
}
|
||||
if (last == '\\') { // '\\' changes the default action
|
||||
line.pop_back();
|
||||
multiline_input = !multiline_input;
|
||||
}
|
||||
}
|
||||
line += '\n';
|
||||
|
||||
// By default, continue input if multiline_input is set
|
||||
return multiline_input;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
// ctrl+C handling
|
||||
{
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = sigint_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
#elif defined (_WIN32)
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
SetConsoleCP(CP_UTF8);
|
||||
SetConsoleOutputCP(CP_UTF8);
|
||||
_setmode(_fileno(stdin), _O_WTEXT); // wide character input mode
|
||||
#endif
|
||||
}
|
||||
|
||||
auto memory = std::make_shared<mem0::Memory>(mem0::MemoryConfig());
|
||||
|
||||
std::shared_ptr<BaseAgent> agent_ptr = std::make_shared<Humanus>(
|
||||
ToolCollection( // Add general-purpose tools to the tool collection
|
||||
{
|
||||
std::make_shared<PythonExecute>(),
|
||||
std::make_shared<Puppeteer>(), // for web browsing
|
||||
std::make_shared<Filesystem>(),
|
||||
std::make_shared<Terminate>()
|
||||
}
|
||||
),
|
||||
"auto",
|
||||
std::set<std::string>{"terminate"},
|
||||
"humanus_mem0",
|
||||
"A versatile agent that can solve various tasks using multiple tools",
|
||||
prompt::humanus::SYSTEM_PROMPT,
|
||||
prompt::humanus::NEXT_STEP_PROMPT,
|
||||
nullptr,
|
||||
memory
|
||||
);
|
||||
|
||||
std::map<std::string, std::shared_ptr<BaseAgent>> agents;
|
||||
agents["default"] = agent_ptr;
|
||||
|
||||
auto flow = FlowFactory::create_flow(
|
||||
FlowType::PLANNING,
|
||||
nullptr, // llm
|
||||
nullptr, // planning_tool
|
||||
std::vector<std::string>{}, // executor_keys
|
||||
"", // active_plan_id
|
||||
agents, // agents
|
||||
std::vector<std::shared_ptr<BaseTool>>{}, // tools
|
||||
"default" // primary_agent_key
|
||||
);
|
||||
|
||||
while (true) {
|
||||
if (agent_ptr->current_step == agent_ptr->max_steps) {
|
||||
std::cout << "Automatically paused after " << agent_ptr->current_step << " steps." << std::endl;
|
||||
std::cout << "Enter your prompt (enter an empty line to resume or 'exit' to quit): ";
|
||||
agent_ptr->reset(false);
|
||||
} else {
|
||||
std::cout << "Enter your prompt (or 'exit' to quit): ";
|
||||
}
|
||||
|
||||
if (agent_ptr->state != AgentState::IDLE) {
|
||||
break;
|
||||
}
|
||||
|
||||
std::string prompt;
|
||||
readline_utf8(prompt, false);
|
||||
if (prompt == "exit" || prompt == "exit\n") {
|
||||
logger->info("Goodbye!");
|
||||
break;
|
||||
}
|
||||
|
||||
std::cout << "Processing your request..." << std::endl;
|
||||
memory->current_request = prompt;
|
||||
auto result = flow->execute(prompt);
|
||||
std::cout << result << std::endl;
|
||||
}
|
||||
}
|
|
@ -63,16 +63,16 @@ std::string PlanningFlow::execute(const std::string& input) {
|
|||
|
||||
// Refactor memory
|
||||
std::string prefix_sum = _summarize_plan(executor->memory->get_messages(step_result));
|
||||
executor->reset(true); // TODO: More fine-grained memory reset?
|
||||
executor->reset(false);
|
||||
executor->update_memory("assistant", prefix_sum);
|
||||
if (!input.empty()) {
|
||||
executor->update_memory("user", "Continue to accomplish the task: " + input);
|
||||
}
|
||||
|
||||
result += step_info.value("type", "Step " + std::to_string(current_step_index)) + ":\n" + prefix_sum + "\n\n";
|
||||
result += "##" + step_info.value("type", "Step " + std::to_string(current_step_index)) + ":\n" + prefix_sum + "\n\n";
|
||||
}
|
||||
|
||||
reset(true); // Clear memory and state for next plan
|
||||
reset(true); // Clear short-termmemory and state for next plan
|
||||
|
||||
return result;
|
||||
} catch (const std::exception& e) {
|
||||
|
@ -89,13 +89,19 @@ void PlanningFlow::_create_initial_plan(const std::string& request) {
|
|||
std::string system_prompt = "You are a planning assistant. Your task is to create a detailed plan with clear steps.";
|
||||
|
||||
// Create a user message with the request
|
||||
Message user_message = Message::user_message(
|
||||
"Create a detailed plan to accomplish this task: " + request
|
||||
);
|
||||
std::string user_prompt = "Please provide a detailed plan to accomplish this task: " + request + "\n\n";
|
||||
user_prompt += "**Note**: The following executors will be used to accomplish the plan.\n\n";
|
||||
for (const auto& [key, agent] : agents) {
|
||||
auto tool_call_agent = std::dynamic_pointer_cast<ToolCallAgent>(agent);
|
||||
if (tool_call_agent) {
|
||||
user_prompt += "Available tools for executor `" + key + "`:\n";
|
||||
user_prompt += tool_call_agent->available_tools.to_params().dump(2) + "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Call LLM with PlanningTool
|
||||
auto response = llm->ask_tool(
|
||||
{user_message},
|
||||
{Message::user_message(user_prompt)},
|
||||
system_prompt,
|
||||
"", // No next_step_prompt for initial plan creation
|
||||
json::array({planning_tool->to_param()}),
|
||||
|
@ -236,7 +242,7 @@ std::string PlanningFlow::_execute_step(const std::shared_ptr<BaseAgent>& execut
|
|||
step_prompt += plan_status.dump(2);
|
||||
step_prompt += "\n\nYOUR CURRENT TASK:\n";
|
||||
step_prompt += "You are now working on step " + std::to_string(current_step_index) + ": \"" + step_text + "\"\n";
|
||||
step_prompt += "Please execute this step using the appropriate tools. When you're done, provide a summary of what you accomplished.";
|
||||
step_prompt += "Please execute this step using the appropriate tools. When you're done, provide a summary of what you accomplished and call `terminate` to trigger the next step.";
|
||||
|
||||
// Use agent.run() to execute the step
|
||||
try {
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "base.h"
|
||||
#include "agent/base.h"
|
||||
#include "agent/toolcall.h"
|
||||
#include "llm.h"
|
||||
#include "logger.h"
|
||||
#include "schema.h"
|
||||
|
|
|
@ -5,23 +5,15 @@
|
|||
#include "prompt.h"
|
||||
#include "logger.h"
|
||||
#include "toml.hpp"
|
||||
#include "utils.h"
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <mutex>
|
||||
#include <filesystem>
|
||||
#include <memory>
|
||||
|
||||
namespace humanus {
|
||||
|
||||
// Get project root directory
|
||||
static std::filesystem::path get_project_root() {
|
||||
return std::filesystem::path(__FILE__).parent_path().parent_path();
|
||||
}
|
||||
|
||||
static const std::filesystem::path PROJECT_ROOT = get_project_root();
|
||||
|
||||
struct LLMConfig {
|
||||
std::string model;
|
||||
std::string api_key;
|
||||
|
@ -159,7 +151,7 @@ struct ToolParser {
|
|||
};
|
||||
|
||||
// Read tool configuration from config_mcp.toml
|
||||
struct MCPToolConfig {
|
||||
struct MCPServerConfig {
|
||||
std::string type;
|
||||
std::string host;
|
||||
int port;
|
||||
|
@ -168,7 +160,7 @@ struct MCPToolConfig {
|
|||
std::vector<std::string> args;
|
||||
json env_vars = json::object();
|
||||
|
||||
static MCPToolConfig load_from_toml(const toml::table& tool_table);
|
||||
static MCPServerConfig load_from_toml(const toml::table& tool_table);
|
||||
};
|
||||
|
||||
enum class EmbeddingType {
|
||||
|
@ -201,19 +193,15 @@ struct VectorStoreConfig {
|
|||
Metric metric = Metric::L2;
|
||||
};
|
||||
|
||||
namespace mem0 {
|
||||
|
||||
struct MemoryConfig {
|
||||
// Base config
|
||||
int max_messages = 16; // Short-term memory capacity
|
||||
int retrieval_limit = 8; // Number of results to retrive from long-term memory
|
||||
int max_messages = 16; // Short-term memory capacity
|
||||
int max_tokens_context = 32768; // Maximum number of tokens in short-term memory
|
||||
int retrieval_limit = 32; // Number of results to retrive from long-term memory
|
||||
|
||||
// Prompt config
|
||||
std::string fact_extraction_prompt = prompt::mem0::FACT_EXTRACTION_PROMPT;
|
||||
std::string update_memory_prompt = prompt::mem0::UPDATE_MEMORY_PROMPT;
|
||||
|
||||
// Database config
|
||||
// std::string history_db_path = ":memory:";
|
||||
std::string fact_extraction_prompt = prompt::FACT_EXTRACTION_PROMPT;
|
||||
std::string update_memory_prompt = prompt::UPDATE_MEMORY_PROMPT;
|
||||
|
||||
// EmbeddingModel config
|
||||
std::shared_ptr<EmbeddingModelConfig> embedding_model_config = nullptr;
|
||||
|
@ -222,15 +210,13 @@ struct MemoryConfig {
|
|||
std::shared_ptr<VectorStoreConfig> vector_store_config = nullptr;
|
||||
FilterFunc filter = nullptr; // Filter to apply to search results
|
||||
|
||||
// Optional: LLM config
|
||||
// LLM config
|
||||
std::shared_ptr<LLMConfig> llm_config = nullptr;
|
||||
};
|
||||
|
||||
} // namespace mem0
|
||||
|
||||
struct AppConfig {
|
||||
std::unordered_map<std::string, LLMConfig> llm;
|
||||
std::unordered_map<std::string, MCPToolConfig> mcp_tool;
|
||||
std::unordered_map<std::string, MCPServerConfig> mcp_server;
|
||||
std::unordered_map<std::string, ToolParser> tool_parser;
|
||||
std::unordered_map<std::string, EmbeddingModelConfig> embedding_model;
|
||||
std::unordered_map<std::string, VectorStoreConfig> vector_store;
|
||||
|
@ -239,13 +225,12 @@ struct AppConfig {
|
|||
class Config {
|
||||
private:
|
||||
static Config* _instance;
|
||||
static std::mutex _mutex;
|
||||
bool _initialized = false;
|
||||
AppConfig _config;
|
||||
|
||||
Config() {
|
||||
_load_initial_llm_config();
|
||||
_load_initial_mcp_tool_config();
|
||||
_load_initial_mcp_server_config();
|
||||
_load_initial_embedding_model_config();
|
||||
_load_initial_vector_store_config();
|
||||
_initialized = true;
|
||||
|
@ -263,7 +248,7 @@ private:
|
|||
throw std::runtime_error("LLM Config file not found");
|
||||
}
|
||||
|
||||
static std::filesystem::path _get_mcp_tool_config_path() {
|
||||
static std::filesystem::path _get_mcp_server_config_path() {
|
||||
auto root = PROJECT_ROOT;
|
||||
auto config_path = root / "config" / "config_mcp.toml";
|
||||
if (std::filesystem::exists(config_path)) {
|
||||
|
@ -292,7 +277,7 @@ private:
|
|||
|
||||
void _load_initial_llm_config();
|
||||
|
||||
void _load_initial_mcp_tool_config();
|
||||
void _load_initial_mcp_server_config();
|
||||
|
||||
void _load_initial_embedding_model_config();
|
||||
|
||||
|
@ -305,10 +290,7 @@ public:
|
|||
*/
|
||||
static Config& get_instance() {
|
||||
if (_instance == nullptr) {
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
if (_instance == nullptr) {
|
||||
_instance = new Config();
|
||||
}
|
||||
_instance = new Config();
|
||||
}
|
||||
return *_instance;
|
||||
}
|
||||
|
@ -325,8 +307,8 @@ public:
|
|||
* @brief Get the MCP tool settings
|
||||
* @return The MCP tool settings map
|
||||
*/
|
||||
const std::unordered_map<std::string, MCPToolConfig>& mcp_tool() const {
|
||||
return _config.mcp_tool;
|
||||
const std::unordered_map<std::string, MCPServerConfig>& mcp_server() const {
|
||||
return _config.mcp_server;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -25,20 +25,27 @@ private:
|
|||
|
||||
std::shared_ptr<ToolParser> tool_parser_;
|
||||
|
||||
size_t total_prompt_tokens_;
|
||||
size_t total_completion_tokens_;
|
||||
|
||||
public:
|
||||
// Constructor
|
||||
LLM(const std::string& config_name, const std::shared_ptr<LLMConfig>& config = nullptr, const std::shared_ptr<ToolParser>& tool_parser = nullptr) : llm_config_(config), tool_parser_(tool_parser) {
|
||||
if (!llm_config_->oai_tool_support && !tool_parser_) {
|
||||
if (Config::get_instance().tool_parser().find(config_name) == Config::get_instance().tool_parser().end()) {
|
||||
throw std::invalid_argument("Tool helper config not found: " + config_name);
|
||||
logger->warn("Tool helper config not found: " + config_name + ", falling back to default tool helper config.");
|
||||
tool_parser_ = std::make_shared<ToolParser>(Config::get_instance().tool_parser().at("default"));
|
||||
} else {
|
||||
tool_parser_ = std::make_shared<ToolParser>(Config::get_instance().tool_parser().at(config_name));
|
||||
}
|
||||
tool_parser_ = std::make_shared<ToolParser>(Config::get_instance().tool_parser().at(config_name));
|
||||
}
|
||||
client_ = std::make_unique<httplib::Client>(llm_config_->base_url);
|
||||
client_->set_default_headers({
|
||||
{"Authorization", "Bearer " + llm_config_->api_key}
|
||||
});
|
||||
client_->set_read_timeout(llm_config_->timeout);
|
||||
total_prompt_tokens_ = 0;
|
||||
total_completion_tokens_ = 0;
|
||||
}
|
||||
|
||||
// Get the singleton instance
|
||||
|
@ -47,9 +54,11 @@ public:
|
|||
auto llm_config_ = llm_config;
|
||||
if (!llm_config_) {
|
||||
if (Config::get_instance().llm().find(config_name) == Config::get_instance().llm().end()) {
|
||||
throw std::invalid_argument("LLM config not found: " + config_name);
|
||||
logger->warn("LLM config not found: " + config_name + ", falling back to default LLM config.");
|
||||
llm_config_ = std::make_shared<LLMConfig>(Config::get_instance().llm().at("default"));
|
||||
} else {
|
||||
llm_config_ = std::make_shared<LLMConfig>(Config::get_instance().llm().at(config_name));
|
||||
}
|
||||
llm_config_ = std::make_shared<LLMConfig>(Config::get_instance().llm().at(config_name));
|
||||
}
|
||||
instances_[config_name] = std::make_shared<LLM>(config_name, llm_config_);
|
||||
}
|
||||
|
@ -110,6 +119,14 @@ public:
|
|||
const std::string& tool_choice = "auto",
|
||||
int max_retries = 3
|
||||
);
|
||||
|
||||
size_t total_prompt_tokens() const {
|
||||
return total_prompt_tokens_;
|
||||
}
|
||||
|
||||
size_t total_completion_tokens() const {
|
||||
return total_completion_tokens_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace humanus
|
||||
|
|
|
@ -26,10 +26,8 @@ extern const char* NEXT_STEP_PROMPT;
|
|||
extern const char* TOOL_HINT_TEMPLATE;
|
||||
} // namespace toolcall
|
||||
|
||||
namespace mem0 {
|
||||
extern const char* FACT_EXTRACTION_PROMPT;
|
||||
extern const char* UPDATE_MEMORY_PROMPT;
|
||||
} // namespace mem0
|
||||
|
||||
} // namespace prompt
|
||||
|
||||
|
|
|
@ -2,6 +2,9 @@
|
|||
#define HUMANUS_SCHEMA_H
|
||||
|
||||
#include "mcp_message.h"
|
||||
#include "utils.h"
|
||||
#include "tokenizer/utils.h"
|
||||
#include "tokenizer/bpe.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
@ -80,9 +83,15 @@ struct Message {
|
|||
std::string name;
|
||||
std::string tool_call_id;
|
||||
std::vector<ToolCall> tool_calls;
|
||||
int num_tokens;
|
||||
|
||||
// TODO: configure the tokenizer
|
||||
inline static const std::shared_ptr<BaseTokenizer> tokenizer = std::make_shared<BPETokenizer>(PROJECT_ROOT / "tokenizer" / "cl100k_base.tiktoken"); // use cl100k_base to roughly count tokens
|
||||
|
||||
Message(const std::string& role, const json& content, const std::string& name = "", const std::string& tool_call_id = "", const std::vector<ToolCall> tool_calls = {})
|
||||
: role(role), content(content), name(name), tool_call_id(tool_call_id), tool_calls(tool_calls) {}
|
||||
: role(role), content(content), name(name), tool_call_id(tool_call_id), tool_calls(tool_calls) {
|
||||
num_tokens = num_tokens_from_messages(*tokenizer, to_json());
|
||||
}
|
||||
|
||||
std::vector<Message> operator+(const Message& other) const {
|
||||
return {*this, other};
|
||||
|
@ -156,8 +165,6 @@ struct Message {
|
|||
}
|
||||
};
|
||||
|
||||
namespace mem0 {
|
||||
|
||||
struct MemoryItem {
|
||||
size_t id; // The unique identifier for the text data
|
||||
std::string memory; // The memory deduced from the text data
|
||||
|
@ -181,8 +188,6 @@ struct MemoryItem {
|
|||
|
||||
typedef std::function<bool(const MemoryItem&)> FilterFunc;
|
||||
|
||||
}
|
||||
|
||||
} // namespace humanus
|
||||
|
||||
#endif // HUMANUS_SCHEMA_H
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
#ifndef HUMANUS_UTILS_H
|
||||
#define HUMANUS_UTILS_H
|
||||
|
||||
#include <filesystem>
|
||||
#include <iostream>
|
||||
|
||||
#if defined (_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
namespace humanus {
|
||||
|
||||
// Get project root directory
|
||||
inline std::filesystem::path get_project_root() {
|
||||
return std::filesystem::path(__FILE__).parent_path().parent_path();
|
||||
}
|
||||
|
||||
extern const std::filesystem::path PROJECT_ROOT;
|
||||
|
||||
// return the last index of character that can form a valid string
|
||||
// if the last character is potentially cut in half, return the index before the cut
|
||||
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
|
||||
size_t validate_utf8(const std::string& text);
|
||||
|
||||
bool readline_utf8(std::string & line, bool multiline_input = false);
|
||||
|
||||
} // namespace humanus
|
||||
|
||||
#endif
|
400
memory/base.h
400
memory/base.h
|
@ -2,6 +2,15 @@
|
|||
#define HUMANUS_MEMORY_BASE_H
|
||||
|
||||
#include "schema.h"
|
||||
#include "memory/base.h"
|
||||
#include "vector_store/base.h"
|
||||
#include "embedding_model/base.h"
|
||||
#include "schema.h"
|
||||
#include "prompt.h"
|
||||
#include "httplib.h"
|
||||
#include "llm.h"
|
||||
#include "utils.h"
|
||||
#include "tool/fact_extract.h"
|
||||
#include <deque>
|
||||
#include <vector>
|
||||
|
||||
|
@ -9,17 +18,22 @@ namespace humanus {
|
|||
|
||||
struct BaseMemory {
|
||||
std::deque<Message> messages;
|
||||
std::string current_request;
|
||||
|
||||
// Add a message to the memory
|
||||
virtual void add_message(const Message& message) {
|
||||
virtual bool add_message(const Message& message) {
|
||||
messages.push_back(message);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Add multiple messages to the memory
|
||||
void add_messages(const std::vector<Message>& messages) {
|
||||
virtual bool add_messages(const std::vector<Message>& messages) {
|
||||
for (const auto& message : messages) {
|
||||
add_message(message);
|
||||
if (!add_message(message)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Clear all messages
|
||||
|
@ -46,16 +60,388 @@ struct BaseMemory {
|
|||
};
|
||||
|
||||
struct Memory : BaseMemory {
|
||||
MemoryConfig config;
|
||||
|
||||
std::string fact_extraction_prompt;
|
||||
std::string update_memory_prompt;
|
||||
int max_messages;
|
||||
int max_tokens_context;
|
||||
int retrieval_limit;
|
||||
FilterFunc filter;
|
||||
|
||||
Memory(int max_messages = 30) : max_messages(max_messages) {}
|
||||
std::shared_ptr<EmbeddingModel> embedding_model;
|
||||
std::shared_ptr<VectorStore> vector_store;
|
||||
std::shared_ptr<LLM> llm;
|
||||
|
||||
void add_message(const Message& message) override {
|
||||
std::shared_ptr<FactExtract> fact_extract_tool;
|
||||
|
||||
bool retrieval_enabled;
|
||||
|
||||
int num_tokens_context;
|
||||
|
||||
Memory(const MemoryConfig& config) : config(config) {
|
||||
fact_extraction_prompt = config.fact_extraction_prompt;
|
||||
update_memory_prompt = config.update_memory_prompt;
|
||||
max_messages = config.max_messages;
|
||||
retrieval_limit = config.retrieval_limit;
|
||||
filter = config.filter;
|
||||
|
||||
size_t pos = fact_extraction_prompt.find("{current_date}");
|
||||
if (pos != std::string::npos) {
|
||||
// %Y-%d-%m
|
||||
auto current_date = std::chrono::system_clock::now();
|
||||
auto in_time_t = std::chrono::system_clock::to_time_t(current_date);
|
||||
std::stringstream ss;
|
||||
std::tm tm_info = *std::localtime(&in_time_t);
|
||||
ss << std::put_time(&tm_info, "%Y-%m-%d");
|
||||
std::string formatted_date = ss.str(); // YYYY-MM-DD
|
||||
fact_extraction_prompt.replace(pos, 14, formatted_date);
|
||||
}
|
||||
|
||||
try {
|
||||
embedding_model = EmbeddingModel::get_instance("default", config.embedding_model_config);
|
||||
vector_store = VectorStore::get_instance("default", config.vector_store_config);
|
||||
llm = LLM::get_instance("memory", config.llm_config);
|
||||
|
||||
logger->info("🔥 Memory is warming up...");
|
||||
auto test_response = llm->ask(
|
||||
{Message::user_message("Hello")}
|
||||
);
|
||||
auto test_embedding = embedding_model->embed(test_response, EmbeddingType::ADD);
|
||||
vector_store->insert(test_embedding, 0);
|
||||
vector_store->remove(0);
|
||||
logger->info("📒 Memory is ready!");
|
||||
|
||||
retrieval_enabled = true;
|
||||
} catch (const std::exception& e) {
|
||||
logger->warn("Error in initializing memory: " + std::string(e.what()) + ", fallback to default FIFO memory");
|
||||
embedding_model = nullptr;
|
||||
vector_store = nullptr;
|
||||
llm = nullptr;
|
||||
retrieval_enabled = false;
|
||||
}
|
||||
|
||||
fact_extract_tool = std::make_shared<FactExtract>();
|
||||
}
|
||||
|
||||
bool add_message(const Message& message) override {
|
||||
if (message.num_tokens > config.max_tokens_context) {
|
||||
logger->warn("Message is too long, skipping"); // TODO: use content_provider to handle this
|
||||
return false;
|
||||
}
|
||||
messages.push_back(message);
|
||||
while (!messages.empty() && (messages.size() > max_messages || messages.front().role == "assistant" || messages.front().role == "tool")) {
|
||||
// Ensure the first message is always a user or system message
|
||||
num_tokens_context += message.num_tokens;
|
||||
std::vector<Message> messages_to_memory;
|
||||
while (messages.size() > max_messages || num_tokens_context > config.max_tokens_context) {
|
||||
messages_to_memory.push_back(messages.front());
|
||||
num_tokens_context -= messages.front().num_tokens;
|
||||
messages.pop_front();
|
||||
}
|
||||
if (!messages.empty()) { // Ensure the first message is always a user or system message
|
||||
if (messages.front().role == "assistant") {
|
||||
messages.push_front(Message::user_message("Current request: " + current_request + "\n\nDue to limited memory, some previous messages are not shown."));
|
||||
num_tokens_context += messages.front().num_tokens;
|
||||
} else if (messages.front().role == "tool") {
|
||||
messages_to_memory.push_back(messages.front());
|
||||
num_tokens_context -= messages.front().num_tokens;
|
||||
messages.pop_front();
|
||||
}
|
||||
}
|
||||
if (retrieval_enabled && !messages_to_memory.empty()) {
|
||||
if (llm->enable_vision()) {
|
||||
for (auto& m : messages_to_memory) {
|
||||
m = parse_vision_message(m, llm, llm->vision_details());
|
||||
}
|
||||
} else {
|
||||
for (auto& m : messages_to_memory) {
|
||||
m = parse_vision_message(m);
|
||||
}
|
||||
}
|
||||
_add_to_vector_store(messages_to_memory);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<Message> get_messages(const std::string& query = "") const override {
|
||||
std::vector<Message> messages_with_memory;
|
||||
|
||||
if (retrieval_enabled && !query.empty()) {
|
||||
auto embeddings = embedding_model->embed(
|
||||
query.size() > 8192 ? query.substr(0, validate_utf8(query.substr(0, 8192))) : query, // TODO: split to chunks instead of truncating
|
||||
EmbeddingType::SEARCH
|
||||
);
|
||||
std::vector<MemoryItem> memories;
|
||||
|
||||
// Check if vectore store is available
|
||||
if (vector_store) {
|
||||
memories = vector_store->search(embeddings, retrieval_limit, filter);
|
||||
}
|
||||
|
||||
if (!memories.empty()) {
|
||||
sort(memories.begin(), memories.end(), [](const MemoryItem& a, const MemoryItem& b) {
|
||||
return a.updated_at < b.updated_at;
|
||||
});
|
||||
|
||||
std::string memory_prompt;
|
||||
for (const auto& memory_item : memories) {
|
||||
memory_prompt += "<memory>" + memory_item.memory + "</memory>";
|
||||
}
|
||||
|
||||
messages_with_memory.push_back(Message::user_message(memory_prompt));
|
||||
|
||||
logger->info("📤 Total retreived memories: " + std::to_string(memories.size()));
|
||||
}
|
||||
}
|
||||
|
||||
messages_with_memory.insert(messages_with_memory.end(), messages.begin(), messages.end());
|
||||
|
||||
return messages_with_memory;
|
||||
}
|
||||
|
||||
void clear() override {
|
||||
if (messages.empty()) {
|
||||
return;
|
||||
}
|
||||
if (retrieval_enabled) {
|
||||
std::vector<Message> messages_to_memory(messages.begin(), messages.end());
|
||||
_add_to_vector_store(messages_to_memory);
|
||||
}
|
||||
messages.clear();
|
||||
}
|
||||
|
||||
void _add_to_vector_store(const std::vector<Message>& messages) {
|
||||
// Check if vector store is available
|
||||
if (!vector_store) {
|
||||
logger->warn("Vector store is not initialized, skipping memory operation");
|
||||
return;
|
||||
}
|
||||
|
||||
std::string parsed_message;
|
||||
|
||||
for (const auto& message : messages) {
|
||||
parsed_message += message.role + ": " + (message.content.is_string() ? message.content.get<std::string>() : message.content.dump()) + "\n";
|
||||
|
||||
for (const auto& tool_call : message.tool_calls) {
|
||||
parsed_message += "<tool_call>" + tool_call.to_json().dump() + "</tool_call>\n";
|
||||
}
|
||||
}
|
||||
|
||||
std::string system_prompt = fact_extraction_prompt;
|
||||
|
||||
size_t pos = system_prompt.find("{current_request}");
|
||||
if (pos != std::string::npos) {
|
||||
system_prompt = system_prompt.replace(pos, 17, current_request);
|
||||
}
|
||||
|
||||
std::string user_prompt = "<input>" + parsed_message + "</input>";
|
||||
|
||||
Message user_message = Message::user_message(user_prompt);
|
||||
|
||||
json response = llm->ask_tool(
|
||||
{user_message},
|
||||
system_prompt,
|
||||
"",
|
||||
json::array({fact_extract_tool->to_param()}),
|
||||
"required"
|
||||
);
|
||||
|
||||
std::vector<std::string> new_facts; // ["fact1", "fact2", "fact3"]
|
||||
|
||||
try {
|
||||
auto tool_calls = ToolCall::from_json_list(response["tool_calls"]);
|
||||
for (const auto& tool_call : tool_calls) {
|
||||
if (tool_call.function.name != "fact_extract") { // might be other tools because of hallucinations (e.g. wrongly responsed to user message)
|
||||
continue;
|
||||
}
|
||||
// Parse arguments
|
||||
json args = tool_call.function.arguments;
|
||||
|
||||
if (args.is_string()) {
|
||||
args = json::parse(args.get<std::string>());
|
||||
}
|
||||
|
||||
auto facts = fact_extract_tool->execute(args).output.get<std::vector<std::string>>();
|
||||
if (!facts.empty()) {
|
||||
new_facts.insert(new_facts.end(), facts.begin(), facts.end());
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->warn("Error in new_facts: " + std::string(e.what()));
|
||||
}
|
||||
|
||||
if (new_facts.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
logger->info("📫 New facts to remember: " + json(new_facts).dump());
|
||||
|
||||
std::vector<json> old_memories;
|
||||
std::map<std::string, std::vector<float>> new_message_embeddings;
|
||||
|
||||
for (const auto& fact : new_facts) {
|
||||
auto message_embedding = embedding_model->embed(fact, EmbeddingType::ADD);
|
||||
new_message_embeddings[fact] = message_embedding;
|
||||
auto existing_memories = vector_store->search(
|
||||
message_embedding,
|
||||
5
|
||||
);
|
||||
for (const auto& memory : existing_memories) {
|
||||
old_memories.push_back({
|
||||
{"id", memory.id},
|
||||
{"text", memory.memory}
|
||||
});
|
||||
}
|
||||
}
|
||||
// sort and unique by id
|
||||
std::sort(old_memories.begin(), old_memories.end(), [](const json& a, const json& b) {
|
||||
return a["id"] < b["id"];
|
||||
});
|
||||
old_memories.resize(std::unique(old_memories.begin(), old_memories.end(), [](const json& a, const json& b) {
|
||||
return a["id"] == b["id"];
|
||||
}) - old_memories.begin());
|
||||
logger->info("📒 Existing memories about new facts: " + std::to_string(old_memories.size()));
|
||||
|
||||
// mapping UUIDs with integers for handling ID hallucinations
|
||||
std::vector<size_t> temp_id_mapping;
|
||||
for (size_t idx = 0; idx < old_memories.size(); ++idx) {
|
||||
temp_id_mapping.push_back(old_memories[idx]["id"].get<size_t>());
|
||||
old_memories[idx]["id"] = idx;
|
||||
}
|
||||
|
||||
std::string function_calling_prompt = get_update_memory_messages(old_memories, new_facts, update_memory_prompt);
|
||||
|
||||
std::string new_memories_with_actions_str;
|
||||
json new_memories_with_actions = json::array();
|
||||
|
||||
try {
|
||||
new_memories_with_actions_str = llm->ask(
|
||||
{Message::user_message(function_calling_prompt)}
|
||||
);
|
||||
new_memories_with_actions_str = remove_code_blocks(new_memories_with_actions_str);
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Error in parsing new_memories_with_actions: " + std::string(e.what()));
|
||||
}
|
||||
|
||||
try {
|
||||
new_memories_with_actions = json::parse(new_memories_with_actions_str);
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Invalid JSON response: " + std::string(e.what()));
|
||||
}
|
||||
|
||||
try {
|
||||
for (const auto& resp : new_memories_with_actions["memory"]) {
|
||||
logger->debug("Processing memory: " + resp.dump(2));
|
||||
try {
|
||||
if (!resp.contains("text")) {
|
||||
logger->warn("Skipping memory entry because of empty `text` field.");
|
||||
continue;
|
||||
}
|
||||
std::string event = resp.value("event", "NONE");
|
||||
size_t memory_id;
|
||||
try {
|
||||
if (event != "ADD") {
|
||||
memory_id = temp_id_mapping.at(resp["id"].get<size_t>());
|
||||
} else {
|
||||
memory_id = get_uuid_64();
|
||||
}
|
||||
} catch (...) {
|
||||
memory_id = get_uuid_64();
|
||||
}
|
||||
if (event == "ADD") {
|
||||
_create_memory(
|
||||
memory_id,
|
||||
resp["text"], // data
|
||||
new_message_embeddings // existing_embeddings
|
||||
);
|
||||
} else if (event == "UPDATE") {
|
||||
_update_memory(
|
||||
memory_id,
|
||||
resp["text"], // data
|
||||
new_message_embeddings // existing_embeddings
|
||||
);
|
||||
} else if (event == "DELETE") {
|
||||
_delete_memory(memory_id);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
|
||||
}
|
||||
}
|
||||
|
||||
void _create_memory(const size_t& memory_id, const std::string& data, const std::map<std::string, std::vector<float>>& existing_embeddings) {
|
||||
if (!vector_store) {
|
||||
logger->warn("Vector store is not initialized, skipping create memory");
|
||||
return;
|
||||
}
|
||||
|
||||
logger->info("🆕 Creating memory: " + data);
|
||||
|
||||
std::vector<float> embedding;
|
||||
if (existing_embeddings.find(data) != existing_embeddings.end()) {
|
||||
embedding = existing_embeddings.at(data);
|
||||
} else {
|
||||
embedding = embedding_model->embed(data, EmbeddingType::ADD);
|
||||
}
|
||||
|
||||
MemoryItem metadata{
|
||||
memory_id,
|
||||
data,
|
||||
httplib::detail::MD5(data)
|
||||
};
|
||||
|
||||
vector_store->insert(
|
||||
embedding,
|
||||
memory_id,
|
||||
metadata
|
||||
);
|
||||
}
|
||||
|
||||
void _update_memory(const size_t& memory_id, const std::string& data, const std::map<std::string, std::vector<float>>& existing_embeddings) {
|
||||
if (!vector_store) {
|
||||
logger->warn("Vector store is not initialized, skipping update memory");
|
||||
return;
|
||||
}
|
||||
|
||||
MemoryItem existing_memory;
|
||||
|
||||
try {
|
||||
existing_memory = vector_store->get(memory_id);
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Error fetching existing memory: " + std::string(e.what()));
|
||||
return;
|
||||
}
|
||||
|
||||
logger->info("🆕 Updating memory: (old) " + existing_memory.memory + " (new) " + data);
|
||||
|
||||
std::vector<float> embedding;
|
||||
if (existing_embeddings.find(data) != existing_embeddings.end()) {
|
||||
embedding = existing_embeddings.at(data);
|
||||
} else {
|
||||
embedding = embedding_model->embed(data, EmbeddingType::ADD);
|
||||
}
|
||||
|
||||
existing_memory.memory = data;
|
||||
existing_memory.hash = httplib::detail::MD5(data);
|
||||
existing_memory.updated_at = std::chrono::system_clock::now().time_since_epoch().count();
|
||||
|
||||
vector_store->update(
|
||||
memory_id,
|
||||
embedding,
|
||||
existing_memory
|
||||
);
|
||||
}
|
||||
|
||||
void _delete_memory(const size_t& memory_id) {
|
||||
if (!vector_store) {
|
||||
logger->warn("Vector store is not initialized, skipping delete memory");
|
||||
return;
|
||||
}
|
||||
|
||||
logger->info("❌ Deleting memory: " + std::to_string(memory_id));
|
||||
vector_store->remove(memory_id);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#include "base.h"
|
||||
#include "oai.h"
|
||||
|
||||
namespace humanus::mem0 {
|
||||
namespace humanus {
|
||||
|
||||
std::unordered_map<std::string, std::shared_ptr<EmbeddingModel>> EmbeddingModel::instances_;
|
||||
|
||||
|
@ -10,9 +10,11 @@ std::shared_ptr<EmbeddingModel> EmbeddingModel::get_instance(const std::string&
|
|||
auto config_ = config;
|
||||
if (!config_) {
|
||||
if (Config::get_instance().embedding_model().find(config_name) == Config::get_instance().embedding_model().end()) {
|
||||
throw std::invalid_argument("Embedding model config not found: " + config_name);
|
||||
logger->warn("Embedding model config not found: " + config_name + ", falling back to default config");
|
||||
config_ = std::make_shared<EmbeddingModelConfig>(Config::get_instance().embedding_model().at("default"));
|
||||
} else {
|
||||
config_ = std::make_shared<EmbeddingModelConfig>(Config::get_instance().embedding_model().at(config_name));
|
||||
}
|
||||
config_ = std::make_shared<EmbeddingModelConfig>(Config::get_instance().embedding_model().at(config_name));
|
||||
}
|
||||
|
||||
if (config_->provider == "oai") {
|
||||
|
@ -24,4 +26,4 @@ std::shared_ptr<EmbeddingModel> EmbeddingModel::get_instance(const std::string&
|
|||
return instances_[config_name];
|
||||
}
|
||||
|
||||
} // namespace humanus::mem0
|
||||
} // namespace humanus
|
|
@ -1,5 +1,5 @@
|
|||
#ifndef HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_BASE_H
|
||||
#define HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_BASE_H
|
||||
#ifndef HUMANUS_MEMORY_EMBEDDING_MODEL_BASE_H
|
||||
#define HUMANUS_MEMORY_EMBEDDING_MODEL_BASE_H
|
||||
|
||||
#include "httplib.h"
|
||||
#include "logger.h"
|
||||
|
@ -7,7 +7,7 @@
|
|||
#include <unordered_map>
|
||||
#include <memory>
|
||||
|
||||
namespace humanus::mem0 {
|
||||
namespace humanus {
|
||||
|
||||
class EmbeddingModel {
|
||||
private:
|
||||
|
@ -28,6 +28,6 @@ public:
|
|||
virtual std::vector<float> embed(const std::string& text, EmbeddingType type) = 0;
|
||||
};
|
||||
|
||||
} // namespace humanus::mem0
|
||||
} // namespace humanus
|
||||
|
||||
#endif // HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_BASE_H
|
||||
#endif // HUMANUS_MEMORY_EMBEDDING_MODEL_BASE_H
|
|
@ -1,6 +1,6 @@
|
|||
#include "oai.h"
|
||||
|
||||
namespace humanus::mem0 {
|
||||
namespace humanus {
|
||||
|
||||
std::vector<float> OAIEmbeddingModel::embed(const std::string& text, EmbeddingType /* type */) {
|
||||
json body = {
|
||||
|
@ -18,16 +18,16 @@ std::vector<float> OAIEmbeddingModel::embed(const std::string& text, EmbeddingTy
|
|||
auto res = client_->Post(config_->endpoint, body_str, "application/json");
|
||||
|
||||
if (!res) {
|
||||
logger->error("Failed to send request: " + httplib::to_string(res.error()));
|
||||
logger->error(std::string(__func__) + ": Failed to send request: " + httplib::to_string(res.error()));
|
||||
} else if (res->status == 200) {
|
||||
try {
|
||||
json json_data = json::parse(res->body);
|
||||
return json_data["data"][0]["embedding"].get<std::vector<float>>();
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Failed to parse response: " + std::string(e.what()));
|
||||
logger->error(std::string(__func__) + ": Failed to parse response: " + std::string(e.what()));
|
||||
}
|
||||
} else {
|
||||
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
|
||||
logger->error(std::string(__func__) + ": Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
|
||||
}
|
||||
|
||||
retry++;
|
||||
|
@ -42,7 +42,20 @@ std::vector<float> OAIEmbeddingModel::embed(const std::string& text, EmbeddingTy
|
|||
logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(config_->max_retries));
|
||||
}
|
||||
|
||||
// If the logger has a file sink, log the request body
|
||||
if (logger->sinks().size() > 1) {
|
||||
auto file_sink = std::dynamic_pointer_cast<spdlog::sinks::basic_file_sink_mt>(logger->sinks()[1]);
|
||||
if (file_sink) {
|
||||
file_sink->log(spdlog::details::log_msg(
|
||||
spdlog::source_loc{},
|
||||
logger->name(),
|
||||
spdlog::level::debug,
|
||||
"Failed to get response from embedding model. Full request body: " + body_str
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
throw std::runtime_error("Failed to get embedding from: " + config_->base_url + " " + config_->model);
|
||||
}
|
||||
|
||||
} // namespace humanus::mem0
|
||||
} // namespace humanus
|
|
@ -1,9 +1,9 @@
|
|||
#ifndef HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_OAI_H
|
||||
#define HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_OAI_H
|
||||
#ifndef HUMANUS_MEMORY_EMBEDDING_MODEL_OAI_H
|
||||
#define HUMANUS_MEMORY_EMBEDDING_MODEL_OAI_H
|
||||
|
||||
#include "base.h"
|
||||
|
||||
namespace humanus::mem0 {
|
||||
namespace humanus {
|
||||
|
||||
class OAIEmbeddingModel : public EmbeddingModel {
|
||||
private:
|
||||
|
@ -20,6 +20,6 @@ public:
|
|||
std::vector<float> embed(const std::string& text, EmbeddingType type) override;
|
||||
};
|
||||
|
||||
} // namespace humanus::mem0
|
||||
} // namespace humanus
|
||||
|
||||
#endif // HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_OAI_H
|
||||
#endif // HUMANUS_MEMORY_EMBEDDING_MODEL_OAI_H
|
|
@ -1,368 +0,0 @@
|
|||
#ifndef HUMANUS_MEMORY_MEM0_H
|
||||
#define HUMANUS_MEMORY_MEM0_H
|
||||
|
||||
#include "memory/base.h"
|
||||
#include "vector_store/base.h"
|
||||
#include "embedding_model/base.h"
|
||||
#include "schema.h"
|
||||
#include "prompt.h"
|
||||
#include "httplib.h"
|
||||
#include "llm.h"
|
||||
#include "utils.h"
|
||||
#include "tool/fact_extract.h"
|
||||
|
||||
namespace humanus::mem0 {
|
||||
|
||||
struct Memory : BaseMemory {
|
||||
MemoryConfig config;
|
||||
|
||||
std::string current_request;
|
||||
|
||||
std::string fact_extraction_prompt;
|
||||
std::string update_memory_prompt;
|
||||
int max_messages;
|
||||
int retrieval_limit;
|
||||
FilterFunc filter;
|
||||
|
||||
std::shared_ptr<EmbeddingModel> embedding_model;
|
||||
std::shared_ptr<VectorStore> vector_store;
|
||||
std::shared_ptr<LLM> llm;
|
||||
// std::shared_ptr<SQLiteManager> db;
|
||||
|
||||
std::shared_ptr<FactExtract> fact_extract_tool;
|
||||
|
||||
Memory(const MemoryConfig& config) : config(config) {
|
||||
fact_extraction_prompt = config.fact_extraction_prompt;
|
||||
update_memory_prompt = config.update_memory_prompt;
|
||||
max_messages = config.max_messages;
|
||||
retrieval_limit = config.retrieval_limit;
|
||||
filter = config.filter;
|
||||
|
||||
size_t pos = fact_extraction_prompt.find("{current_date}");
|
||||
if (pos != std::string::npos) {
|
||||
// %Y-%d-%m
|
||||
auto current_date = std::chrono::system_clock::now();
|
||||
auto in_time_t = std::chrono::system_clock::to_time_t(current_date);
|
||||
std::stringstream ss;
|
||||
std::tm tm_info = *std::localtime(&in_time_t);
|
||||
ss << std::put_time(&tm_info, "%Y-%m-%d");
|
||||
std::string formatted_date = ss.str(); // YYYY-MM-DD
|
||||
fact_extraction_prompt.replace(pos, 14, formatted_date);
|
||||
}
|
||||
|
||||
embedding_model = EmbeddingModel::get_instance("default", config.embedding_model_config);
|
||||
vector_store = VectorStore::get_instance("default", config.vector_store_config);
|
||||
|
||||
llm = LLM::get_instance("default", config.llm_config);
|
||||
// db = std::make_shared<SQLiteManager>(config.history_db_path);
|
||||
|
||||
fact_extract_tool = std::make_shared<FactExtract>();
|
||||
}
|
||||
|
||||
void add_message(const Message& message) override {
|
||||
messages.push_back(message);
|
||||
// Message message_to_memory = message;
|
||||
// if (llm->enable_vision()) {
|
||||
// message_to_memory = parse_vision_message(message_to_memory, llm, llm->vision_details());
|
||||
// } else {
|
||||
// message_to_memory = parse_vision_message(message_to_memory);
|
||||
// }
|
||||
// _add_to_vector_store({message_to_memory});
|
||||
// while (!messages.empty() && (messages.size() > max_messages || messages.front().role == "assistant" || messages.front().role == "tool")) {
|
||||
// // Ensure the first message is always a user or system message
|
||||
// messages.pop_front();
|
||||
// }
|
||||
std::vector<Message> messages_to_memory;
|
||||
while (!messages.empty() && (messages.size() > max_messages || messages.front().role == "assistant" || messages.front().role == "tool")) {
|
||||
// Ensure the first message is always a user or system message
|
||||
messages_to_memory.push_back(messages.front());
|
||||
messages.pop_front();
|
||||
}
|
||||
if (!messages_to_memory.empty()) {
|
||||
if (llm->enable_vision()) {
|
||||
for (auto& m : messages_to_memory) {
|
||||
m = parse_vision_message(m, llm, llm->vision_details());
|
||||
}
|
||||
} else {
|
||||
for (auto& m : messages_to_memory) {
|
||||
m = parse_vision_message(m);
|
||||
}
|
||||
}
|
||||
_add_to_vector_store(messages_to_memory);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Message> get_messages(const std::string& query = "") const override {
|
||||
std::vector<Message> messages_with_memory;
|
||||
|
||||
if (!query.empty()) {
|
||||
auto embeddings = embedding_model->embed(query, EmbeddingType::SEARCH);
|
||||
std::vector<MemoryItem> memories;
|
||||
|
||||
// 检查vector_store是否已初始化
|
||||
if (vector_store) {
|
||||
memories = vector_store->search(embeddings, retrieval_limit, filter);
|
||||
}
|
||||
|
||||
if (!memories.empty()) {
|
||||
sort(memories.begin(), memories.end(), [](const MemoryItem& a, const MemoryItem& b) {
|
||||
return a.updated_at < b.updated_at;
|
||||
});
|
||||
|
||||
std::string memory_prompt;
|
||||
for (const auto& memory_item : memories) {
|
||||
memory_prompt += "<memory>" + memory_item.memory + "</memory>";
|
||||
}
|
||||
|
||||
messages_with_memory.push_back(Message::user_message(memory_prompt));
|
||||
|
||||
logger->info("📤 Total retreived memories: " + std::to_string(memories.size()));
|
||||
}
|
||||
}
|
||||
|
||||
messages_with_memory.insert(messages_with_memory.end(), messages.begin(), messages.end());
|
||||
|
||||
return messages_with_memory;
|
||||
}
|
||||
|
||||
void clear() override {
|
||||
if (messages.empty()) {
|
||||
return;
|
||||
}
|
||||
std::vector<Message> messages_to_memory(messages.begin(), messages.end());
|
||||
_add_to_vector_store(messages_to_memory);
|
||||
messages.clear();
|
||||
}
|
||||
|
||||
void _add_to_vector_store(const std::vector<Message>& messages) {
|
||||
// 检查vector_store是否已初始化
|
||||
if (!vector_store) {
|
||||
logger->warn("Vector store is not initialized, skipping memory operation");
|
||||
return;
|
||||
}
|
||||
|
||||
std::string parsed_message;
|
||||
|
||||
for (const auto& message : messages) {
|
||||
parsed_message += message.role + ": " + (message.content.is_string() ? message.content.get<std::string>() : message.content.dump()) + "\n";
|
||||
|
||||
for (const auto& tool_call : message.tool_calls) {
|
||||
parsed_message += "<tool_call>" + tool_call.to_json().dump() + "</tool_call>\n";
|
||||
}
|
||||
}
|
||||
|
||||
std::string system_prompt = fact_extraction_prompt;
|
||||
|
||||
size_t pos = system_prompt.find("{current_request}");
|
||||
if (pos != std::string::npos) {
|
||||
system_prompt = system_prompt.replace(pos, 17, current_request);
|
||||
}
|
||||
|
||||
std::string user_prompt = "Input:\n" + parsed_message;
|
||||
|
||||
Message user_message = Message::user_message(user_prompt);
|
||||
|
||||
json response = llm->ask_tool(
|
||||
{user_message},
|
||||
system_prompt,
|
||||
"",
|
||||
json::array({fact_extract_tool->to_param()}),
|
||||
"required"
|
||||
);
|
||||
|
||||
std::vector<std::string> new_facts; // ["fact1", "fact2", "fact3"]
|
||||
|
||||
try {
|
||||
auto tool_calls = ToolCall::from_json_list(response["tool_calls"]);
|
||||
for (const auto& tool_call : tool_calls) {
|
||||
// Parse arguments
|
||||
json args = tool_call.function.arguments;
|
||||
|
||||
if (args.is_string()) {
|
||||
args = json::parse(args.get<std::string>());
|
||||
}
|
||||
|
||||
auto facts = fact_extract_tool->execute(args).output.get<std::vector<std::string>>();
|
||||
new_facts.insert(new_facts.end(), facts.begin(), facts.end());
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Error in new_facts: " + std::string(e.what()));
|
||||
}
|
||||
|
||||
if (new_facts.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
logger->info("📫 New facts to remember: " + json(new_facts).dump());
|
||||
|
||||
std::vector<json> old_memories;
|
||||
std::map<std::string, std::vector<float>> new_message_embeddings;
|
||||
|
||||
for (const auto& fact : new_facts) {
|
||||
auto message_embedding = embedding_model->embed(fact, EmbeddingType::ADD);
|
||||
new_message_embeddings[fact] = message_embedding;
|
||||
auto existing_memories = vector_store->search(
|
||||
message_embedding,
|
||||
5
|
||||
);
|
||||
for (const auto& memory : existing_memories) {
|
||||
old_memories.push_back({
|
||||
{"id", memory.id},
|
||||
{"text", memory.memory}
|
||||
});
|
||||
}
|
||||
}
|
||||
// sort and unique by id
|
||||
std::sort(old_memories.begin(), old_memories.end(), [](const json& a, const json& b) {
|
||||
return a["id"] < b["id"];
|
||||
});
|
||||
old_memories.resize(std::unique(old_memories.begin(), old_memories.end(), [](const json& a, const json& b) {
|
||||
return a["id"] == b["id"];
|
||||
}) - old_memories.begin());
|
||||
logger->info("🧠 Existing memories about new facts: " + std::to_string(old_memories.size()));
|
||||
|
||||
// mapping UUIDs with integers for handling ID hallucinations
|
||||
std::vector<size_t> temp_id_mapping;
|
||||
for (size_t idx = 0; idx < old_memories.size(); ++idx) {
|
||||
temp_id_mapping.push_back(old_memories[idx]["id"].get<size_t>());
|
||||
old_memories[idx]["id"] = idx;
|
||||
}
|
||||
|
||||
std::string function_calling_prompt = get_update_memory_messages(old_memories, new_facts, update_memory_prompt);
|
||||
|
||||
std::string new_memories_with_actions_str;
|
||||
json new_memories_with_actions = json::array();
|
||||
|
||||
try {
|
||||
new_memories_with_actions_str = llm->ask(
|
||||
{Message::user_message(function_calling_prompt)}
|
||||
);
|
||||
new_memories_with_actions_str = remove_code_blocks(new_memories_with_actions_str);
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
|
||||
}
|
||||
|
||||
try {
|
||||
new_memories_with_actions = json::parse(new_memories_with_actions_str);
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Invalid JSON response: " + std::string(e.what()));
|
||||
}
|
||||
|
||||
try {
|
||||
for (const auto& resp : new_memories_with_actions["memory"]) {
|
||||
logger->debug("Processing memory: " + resp.dump(2));
|
||||
try {
|
||||
if (!resp.contains("text")) {
|
||||
logger->warn("Skipping memory entry because of empty `text` field.");
|
||||
continue;
|
||||
}
|
||||
std::string event = resp.value("event", "NONE");
|
||||
size_t memory_id;
|
||||
try {
|
||||
if (event != "ADD") {
|
||||
memory_id = temp_id_mapping.at(resp["id"].get<size_t>());
|
||||
} else {
|
||||
memory_id = get_uuid_64();
|
||||
}
|
||||
} catch (...) {
|
||||
memory_id = get_uuid_64();
|
||||
}
|
||||
if (event == "ADD") {
|
||||
_create_memory(
|
||||
memory_id,
|
||||
resp["text"], // data
|
||||
new_message_embeddings // existing_embeddings
|
||||
);
|
||||
} else if (event == "UPDATE") {
|
||||
_update_memory(
|
||||
memory_id,
|
||||
resp["text"], // data
|
||||
new_message_embeddings // existing_embeddings
|
||||
);
|
||||
} else if (event == "DELETE") {
|
||||
_delete_memory(memory_id);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
|
||||
}
|
||||
}
|
||||
|
||||
void _create_memory(const size_t& memory_id, const std::string& data, const std::map<std::string, std::vector<float>>& existing_embeddings) {
|
||||
if (!vector_store) {
|
||||
logger->warn("Vector store is not initialized, skipping create memory");
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<float> embedding;
|
||||
if (existing_embeddings.find(data) != existing_embeddings.end()) {
|
||||
embedding = existing_embeddings.at(data);
|
||||
} else {
|
||||
embedding = embedding_model->embed(data, EmbeddingType::ADD);
|
||||
}
|
||||
|
||||
MemoryItem metadata{
|
||||
memory_id,
|
||||
data,
|
||||
httplib::detail::MD5(data)
|
||||
};
|
||||
|
||||
vector_store->insert(
|
||||
embedding,
|
||||
memory_id,
|
||||
metadata
|
||||
);
|
||||
}
|
||||
|
||||
void _update_memory(const size_t& memory_id, const std::string& data, const std::map<std::string, std::vector<float>>& existing_embeddings) {
|
||||
if (!vector_store) {
|
||||
logger->warn("Vector store is not initialized, skipping update memory");
|
||||
return;
|
||||
}
|
||||
|
||||
logger->info("Updating memory with " + data);
|
||||
|
||||
MemoryItem existing_memory;
|
||||
|
||||
try {
|
||||
existing_memory = vector_store->get(memory_id);
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Error fetching existing memory: " + std::string(e.what()));
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<float> embedding;
|
||||
if (existing_embeddings.find(data) != existing_embeddings.end()) {
|
||||
embedding = existing_embeddings.at(data);
|
||||
} else {
|
||||
embedding = embedding_model->embed(data, EmbeddingType::ADD);
|
||||
}
|
||||
|
||||
existing_memory.memory = data;
|
||||
existing_memory.hash = httplib::detail::MD5(data);
|
||||
existing_memory.updated_at = std::chrono::system_clock::now().time_since_epoch().count();
|
||||
|
||||
vector_store->update(
|
||||
memory_id,
|
||||
embedding,
|
||||
existing_memory
|
||||
);
|
||||
}
|
||||
|
||||
void _delete_memory(const size_t& memory_id) {
|
||||
if (!vector_store) {
|
||||
logger->warn("Vector store is not initialized, skipping delete memory");
|
||||
return;
|
||||
}
|
||||
|
||||
logger->info("Deleting memory: " + std::to_string(memory_id));
|
||||
vector_store->delete_vector(memory_id);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace humanus::mem0
|
||||
|
||||
#endif // HUMANUS_MEMORY_MEM0_H
|
|
@ -1,146 +0,0 @@
|
|||
#ifndef HUMANUS_MEMORY_MEM0_STORAGE_H
|
||||
#define HUMANUS_MEMORY_MEM0_STORAGE_H
|
||||
|
||||
#include <sqlite3.h>
|
||||
#include <mutex>
|
||||
|
||||
namespace humanus::mem0 {
|
||||
|
||||
struct SQLiteManager {
|
||||
std::shared_ptr<sqlite3> db;
|
||||
std::mutex mutex;
|
||||
|
||||
SQLiteManager(const std::string& db_path) {
|
||||
int rc = sqlite3_open(db_path.c_str(), &db);
|
||||
if (rc) {
|
||||
throw std::runtime_error("Failed to open database: " + std::string(sqlite3_errmsg(db)));
|
||||
}
|
||||
_migrate_history_table();
|
||||
_create_history_table();
|
||||
}
|
||||
|
||||
void _migrate_history_table() {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
|
||||
char* errmsg = nullptr;
|
||||
sqlite3_stmt* stmt = nullptr;
|
||||
|
||||
// 检查历史表是否存在
|
||||
int rc = sqlite3_prepare_v2(db.get(), "SELECT name FROM sqlite_master WHERE type='table' AND name='history'", -1, &stmt, nullptr);
|
||||
if (rc != SQLITE_OK) {
|
||||
throw std::runtime_error("Failed to prepare statement: " + std::string(sqlite3_errmsg(db.get())));
|
||||
}
|
||||
|
||||
bool table_exists = false;
|
||||
if (sqlite3_step(stmt) == SQLITE_ROW) {
|
||||
table_exists = true;
|
||||
}
|
||||
sqlite3_finalize(stmt);
|
||||
|
||||
if (table_exists) {
|
||||
// 获取当前表结构
|
||||
std::map<std::string, std::string> current_schema;
|
||||
rc = sqlite3_prepare_v2(db.get(), "PRAGMA table_info(history)", -1, &stmt, nullptr);
|
||||
if (rc != SQLITE_OK) {
|
||||
throw std::runtime_error("Failed to prepare statement: " + std::string(sqlite3_errmsg(db.get())));
|
||||
}
|
||||
|
||||
while (sqlite3_step(stmt) == SQLITE_ROW) {
|
||||
std::string column_name = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
|
||||
std::string column_type = reinterpret_cast<const char*>(sqlite3_column_text(stmt, 2));
|
||||
current_schema[column_name] = column_type;
|
||||
}
|
||||
sqlite3_finalize(stmt);
|
||||
|
||||
// 定义预期表结构
|
||||
std::map<std::string, std::string> expected_schema = {
|
||||
{"id", "TEXT"},
|
||||
{"memory_id", "TEXT"},
|
||||
{"old_memory", "TEXT"},
|
||||
{"new_memory", "TEXT"},
|
||||
{"new_value", "TEXT"},
|
||||
{"event", "TEXT"},
|
||||
{"created_at", "DATETIME"},
|
||||
{"updated_at", "DATETIME"},
|
||||
{"is_deleted", "INTEGER"}
|
||||
};
|
||||
|
||||
// 检查表结构是否一致
|
||||
if (current_schema != expected_schema) {
|
||||
// 重命名旧表
|
||||
rc = sqlite3_exec(db.get(), "ALTER TABLE history RENAME TO old_history", nullptr, nullptr, &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
std::string error = errmsg ? errmsg : "Unknown error";
|
||||
sqlite3_free(errmsg);
|
||||
throw std::runtime_error("Failed to rename table: " + error);
|
||||
}
|
||||
|
||||
// 创建新表
|
||||
rc = sqlite3_exec(db.get(),
|
||||
"CREATE TABLE IF NOT EXISTS history ("
|
||||
"id TEXT PRIMARY KEY,"
|
||||
"memory_id TEXT,"
|
||||
"old_memory TEXT,"
|
||||
"new_memory TEXT,"
|
||||
"new_value TEXT,"
|
||||
"event TEXT,"
|
||||
"created_at DATETIME,"
|
||||
"updated_at DATETIME,"
|
||||
"is_deleted INTEGER"
|
||||
")", nullptr, nullptr, &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
std::string error = errmsg ? errmsg : "Unknown error";
|
||||
sqlite3_free(errmsg);
|
||||
throw std::runtime_error("Failed to create table: " + error);
|
||||
}
|
||||
|
||||
// 复制数据
|
||||
rc = sqlite3_exec(db.get(),
|
||||
"INSERT INTO history (id, memory_id, old_memory, new_memory, new_value, event, created_at, updated_at, is_deleted) "
|
||||
"SELECT id, memory_id, prev_value, new_value, new_value, event, timestamp, timestamp, is_deleted "
|
||||
"FROM old_history", nullptr, nullptr, &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
std::string error = errmsg ? errmsg : "Unknown error";
|
||||
sqlite3_free(errmsg);
|
||||
throw std::runtime_error("Failed to copy data: " + error);
|
||||
}
|
||||
|
||||
// 删除旧表
|
||||
rc = sqlite3_exec(db.get(), "DROP TABLE old_history", nullptr, nullptr, &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
std::string error = errmsg ? errmsg : "Unknown error";
|
||||
sqlite3_free(errmsg);
|
||||
throw std::runtime_error("Failed to drop old table: " + error);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void _create_history_table() {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
|
||||
char* errmsg = nullptr;
|
||||
int rc = sqlite3_exec(db.get(),
|
||||
"CREATE TABLE IF NOT EXISTS history ("
|
||||
"id TEXT PRIMARY KEY,"
|
||||
"memory_id TEXT,"
|
||||
"old_memory TEXT,"
|
||||
"new_memory TEXT,"
|
||||
"new_value TEXT,"
|
||||
"event TEXT,"
|
||||
"created_at DATETIME,"
|
||||
"updated_at DATETIME,"
|
||||
"is_deleted INTEGER"
|
||||
")", nullptr, nullptr, &errmsg);
|
||||
|
||||
if (rc != SQLITE_OK) {
|
||||
std::string error = errmsg ? errmsg : "Unknown error";
|
||||
sqlite3_free(errmsg);
|
||||
throw std::runtime_error("Failed to create history table: " + error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace humanus::mem0
|
||||
|
||||
#endif // HUMANUS_MEMORY_MEM0_STORAGE_H
|
|
@ -0,0 +1,36 @@
|
|||
#include "utils.h"
|
||||
|
||||
namespace humanus {
|
||||
|
||||
std::string get_update_memory_messages(const json& old_memories, const json& new_facts, const std::string& update_memory_prompt) {
|
||||
std::stringstream ss;
|
||||
ss << update_memory_prompt << "\n\n";
|
||||
ss << "Below is the current content of my memory which I have collected till now. You have to update it in the following format only:\n\n";
|
||||
ss << old_memories.dump(2) + "\n\n";
|
||||
ss << "The new retrieved facts are mentioned below. You have to analyze the new retrieved facts and determine whether these facts should be added, updated, or deleted in the memory.\n\n";
|
||||
ss << new_facts.dump(2) + "\n\n";
|
||||
ss << "You must return your response in the following JSON structure only:\n\n";
|
||||
ss << R"json({
|
||||
"memory" : [
|
||||
{
|
||||
"id" : <interger ID of the memory>, # Use existing ID for updates/deletes, or new ID for additions
|
||||
"text" : "<Content of the memory>", # Content of the memory
|
||||
"event" : "<Operation to be performed>", # Must be "ADD", "UPDATE", "DELETE", or "NONE"
|
||||
"old_memory" : "<Old memory content>" # Required only if the event is "UPDATE"
|
||||
},
|
||||
...
|
||||
]
|
||||
})json" << "\n\n";
|
||||
ss << "Follow the instruction mentioned below:\n"
|
||||
<< "- Do not return anything from the custom few shot prompts provided above.\n"
|
||||
<< "- If the current memory is empty, then you have to add the new retrieved facts to the memory.\n"
|
||||
<< "- You should return the updated memory in only JSON format as shown below. The memory key should be the same if no changes are made.\n"
|
||||
<< "- If there is an addition, generate a new key and add the new memory corresponding to it.\n"
|
||||
<< "- If there is a deletion, the memory key-value pair should be removed from the memory.\n"
|
||||
<< "- If there is an update, the ID key should remain the same and only the value needs to be updated.\n"
|
||||
<< "\n";
|
||||
ss << "Do not return anything except the JSON format.\n";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
}
|
|
@ -1,10 +1,11 @@
|
|||
#ifndef HUMANUS_MEMORY_MEM0_UTILS_H
|
||||
#define HUMANUS_MEMORY_MEM0_UTILS_H
|
||||
#ifndef HUMANUS_MEMORY_UTILS_H
|
||||
#define HUMANUS_MEMORY_UTILS_H
|
||||
|
||||
#include "../include/utils.h"
|
||||
#include "schema.h"
|
||||
#include "llm.h"
|
||||
|
||||
namespace humanus::mem0 {
|
||||
namespace humanus {
|
||||
|
||||
// Removes enclosing code block markers ```[language] and ``` from a given string.
|
||||
//
|
||||
|
@ -12,7 +13,7 @@ namespace humanus::mem0 {
|
|||
// - The function uses a regex pattern to match code blocks that may start with ``` followed by an optional language tag (letters or numbers) and end with ```.
|
||||
// - If a code block is detected, it returns only the inner content, stripping out the markers.
|
||||
// - If no code block markers are found, the original content is returned as-is.
|
||||
std::string remove_code_blocks(const std::string& text) {
|
||||
inline std::string remove_code_blocks(const std::string& text) {
|
||||
static const std::regex pattern(R"(^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$)");
|
||||
std::smatch match;
|
||||
if (std::regex_search(text, match, pattern)) {
|
||||
|
@ -21,7 +22,7 @@ std::string remove_code_blocks(const std::string& text) {
|
|||
return text;
|
||||
}
|
||||
|
||||
static size_t get_uuid_64() {
|
||||
inline size_t get_uuid_64() {
|
||||
const std::string chars = "0123456789abcdef";
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
|
@ -44,40 +45,11 @@ static size_t get_uuid_64() {
|
|||
return uuid_int;
|
||||
}
|
||||
|
||||
std::string get_update_memory_messages(const json& old_memories, const json& new_facts, const std::string& update_memory_prompt) {
|
||||
std::stringstream ss;
|
||||
ss << update_memory_prompt << "\n\n";
|
||||
ss << "Below is the current content of my memory which I have collected till now. You have to update it in the following format only:\n\n";
|
||||
ss << old_memories.dump(2) + "\n\n";
|
||||
ss << "The new retrieved facts are mentioned below. You have to analyze the new retrieved facts and determine whether these facts should be added, updated, or deleted in the memory.\n\n";
|
||||
ss << new_facts.dump(2) + "\n\n";
|
||||
ss << "You must return your response in the following JSON structure only:\n\n";
|
||||
ss << R"json({
|
||||
"memory" : [
|
||||
{
|
||||
"id" : <interger ID of the memory>, # Use existing ID for updates/deletes, or new ID for additions
|
||||
"text" : "<Content of the memory>", # Content of the memory
|
||||
"event" : "<Operation to be performed>", # Must be "ADD", "UPDATE", "DELETE", or "NONE"
|
||||
"old_memory" : "<Old memory content>" # Required only if the event is "UPDATE"
|
||||
},
|
||||
...
|
||||
]
|
||||
})json" << "\n\n";
|
||||
ss << "Follow the instruction mentioned below:\n"
|
||||
<< "- Do not return anything from the custom few shot prompts provided above.\n"
|
||||
<< "- If the current memory is empty, then you have to add the new retrieved facts to the memory.\n"
|
||||
<< "- You should return the updated memory in only JSON format as shown below. The memory key should be the same if no changes are made.\n"
|
||||
<< "- If there is an addition, generate a new key and add the new memory corresponding to it.\n"
|
||||
<< "- If there is a deletion, the memory key-value pair should be removed from the memory.\n"
|
||||
<< "- If there is an update, the ID key should remain the same and only the value needs to be updated.\n"
|
||||
<< "\n";
|
||||
ss << "Do not return anything except the JSON format.\n";
|
||||
return ss.str();
|
||||
}
|
||||
std::string get_update_memory_messages(const json& old_memories, const json& new_facts, const std::string& update_memory_prompt);
|
||||
|
||||
// Get the description of the image
|
||||
// image_url should be like: data:{mime_type};base64,{base64_data}
|
||||
std::string get_image_description(const std::string& image_url, const std::shared_ptr<LLM>& llm, const std::string& vision_details) {
|
||||
inline std::string get_image_description(const std::string& image_url, const std::shared_ptr<LLM>& llm, const std::string& vision_details) {
|
||||
if (!llm) {
|
||||
return "Here is an image failed to get description due to missing LLM instance.";
|
||||
}
|
||||
|
@ -101,7 +73,7 @@ std::string get_image_description(const std::string& image_url, const std::share
|
|||
}
|
||||
|
||||
// Parse the vision messages from the messages
|
||||
Message parse_vision_message(const Message& message, const std::shared_ptr<LLM>& llm = nullptr, const std::string& vision_details = "auto") {
|
||||
inline Message parse_vision_message(const Message& message, const std::shared_ptr<LLM>& llm = nullptr, const std::string& vision_details = "auto") {
|
||||
Message returned_message = message;
|
||||
|
||||
if (returned_message.content.is_array()) {
|
||||
|
@ -122,4 +94,4 @@ Message parse_vision_message(const Message& message, const std::shared_ptr<LLM>&
|
|||
|
||||
}
|
||||
|
||||
#endif // HUMANUS_MEMORY_MEM0_UTILS_H
|
||||
#endif // HUMANUS_MEMORY_UTILS_H
|
|
@ -1,7 +1,7 @@
|
|||
#include "base.h"
|
||||
#include "hnswlib.h"
|
||||
|
||||
namespace humanus::mem0 {
|
||||
namespace humanus {
|
||||
|
||||
std::unordered_map<std::string, std::shared_ptr<VectorStore>> VectorStore::instances_;
|
||||
|
||||
|
@ -10,9 +10,11 @@ std::shared_ptr<VectorStore> VectorStore::get_instance(const std::string& config
|
|||
auto config_ = config;
|
||||
if (!config_) {
|
||||
if (Config::get_instance().vector_store().find(config_name) == Config::get_instance().vector_store().end()) {
|
||||
throw std::invalid_argument("Vector store config not found: " + config_name);
|
||||
logger->warn("Vector store config not found: " + config_name + ", falling back to default config");
|
||||
config_ = std::make_shared<VectorStoreConfig>(Config::get_instance().vector_store().at("default"));
|
||||
} else {
|
||||
config_ = std::make_shared<VectorStoreConfig>(Config::get_instance().vector_store().at(config_name));
|
||||
}
|
||||
config_ = std::make_shared<VectorStoreConfig>(Config::get_instance().vector_store().at(config_name));
|
||||
}
|
||||
|
||||
if (config_->provider == "hnswlib") {
|
||||
|
@ -24,4 +26,4 @@ std::shared_ptr<VectorStore> VectorStore::get_instance(const std::string& config
|
|||
return instances_[config_name];
|
||||
}
|
||||
|
||||
} // namespace humanus::mem0
|
||||
} // namespace humanus
|
|
@ -1,12 +1,12 @@
|
|||
#ifndef HUMANUS_MEMORY_MEM0_VECTOR_STORE_BASE_H
|
||||
#define HUMANUS_MEMORY_MEM0_VECTOR_STORE_BASE_H
|
||||
#ifndef HUMANUS_MEMORY_VECTOR_STORE_BASE_H
|
||||
#define HUMANUS_MEMORY_VECTOR_STORE_BASE_H
|
||||
|
||||
#include "config.h"
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
|
||||
namespace humanus::mem0 {
|
||||
namespace humanus {
|
||||
|
||||
class VectorStore {
|
||||
private:
|
||||
|
@ -34,7 +34,7 @@ public:
|
|||
*/
|
||||
virtual void insert(const std::vector<float>& vector,
|
||||
const size_t vector_id,
|
||||
const MemoryItem& metadata) = 0;
|
||||
const MemoryItem& metadata = MemoryItem()) = 0;
|
||||
|
||||
/**
|
||||
* @brief 搜索相似向量
|
||||
|
@ -51,7 +51,7 @@ public:
|
|||
* @brief 通过ID删除向量
|
||||
* @param vector_id 向量ID
|
||||
*/
|
||||
virtual void delete_vector(size_t vector_id) = 0;
|
||||
virtual void remove(size_t vector_id) = 0;
|
||||
|
||||
/**
|
||||
* @brief 更新向量及其负载
|
||||
|
@ -68,6 +68,13 @@ public:
|
|||
*/
|
||||
virtual MemoryItem get(size_t vector_id) = 0;
|
||||
|
||||
/**
|
||||
* @brief Set metadata for a vector
|
||||
* @param vector_id Vector ID
|
||||
* @param metadata New metadata
|
||||
*/
|
||||
virtual void set(size_t vector_id, const MemoryItem& metadata) = 0;
|
||||
|
||||
/**
|
||||
* @brief 列出所有记忆
|
||||
* @param limit 可选的结果数量限制
|
||||
|
@ -77,7 +84,7 @@ public:
|
|||
virtual std::vector<MemoryItem> list(size_t limit = 0, const FilterFunc& filter = nullptr) = 0;
|
||||
};
|
||||
|
||||
} // namespace humanus::mem0
|
||||
} // namespace humanus
|
||||
|
||||
|
||||
#endif // HUMANUS_MEMORY_MEM0_VECTOR_STORE_BASE_H
|
||||
#endif // HUMANUS_MEMORY_VECTOR_STORE_BASE_H
|
|
@ -3,7 +3,7 @@
|
|||
#include <map>
|
||||
#include <chrono>
|
||||
|
||||
namespace humanus::mem0 {
|
||||
namespace humanus {
|
||||
|
||||
void HNSWLibVectorStore::reset() {
|
||||
if (hnsw) {
|
||||
|
@ -13,7 +13,8 @@ void HNSWLibVectorStore::reset() {
|
|||
space.reset();
|
||||
}
|
||||
|
||||
metadata_store.clear();
|
||||
cache_map.clear();
|
||||
metadata_list.clear();
|
||||
|
||||
if (config_->metric == VectorStoreConfig::Metric::L2) {
|
||||
space = std::make_shared<hnswlib::L2Space>(config_->dim);
|
||||
|
@ -27,9 +28,12 @@ void HNSWLibVectorStore::reset() {
|
|||
}
|
||||
|
||||
void HNSWLibVectorStore::insert(const std::vector<float>& vector, const size_t vector_id, const MemoryItem& metadata) {
|
||||
if (cache_map.size() >= config_->max_elements) {
|
||||
remove(metadata_list.back().id);
|
||||
}
|
||||
|
||||
hnsw->addPoint(vector.data(), vector_id);
|
||||
|
||||
// 存储元数据
|
||||
auto now = std::chrono::system_clock::now().time_since_epoch().count();
|
||||
MemoryItem _metadata = metadata;
|
||||
if (_metadata.created_at < 0) {
|
||||
|
@ -39,7 +43,7 @@ void HNSWLibVectorStore::insert(const std::vector<float>& vector, const size_t v
|
|||
_metadata.updated_at = now;
|
||||
}
|
||||
|
||||
metadata_store[vector_id] = _metadata;
|
||||
set(vector_id, _metadata);
|
||||
}
|
||||
|
||||
std::vector<MemoryItem> HNSWLibVectorStore::search(const std::vector<float>& query, size_t limit, const FilterFunc& filter) {
|
||||
|
@ -52,8 +56,8 @@ std::vector<MemoryItem> HNSWLibVectorStore::search(const std::vector<float>& que
|
|||
|
||||
results.pop();
|
||||
|
||||
if (metadata_store.find(id) != metadata_store.end()) {
|
||||
MemoryItem item = metadata_store[id];
|
||||
if (cache_map.find(id) != cache_map.end()) {
|
||||
MemoryItem item = *cache_map[id];
|
||||
item.score = distance;
|
||||
memory_items.push_back(item);
|
||||
}
|
||||
|
@ -62,9 +66,13 @@ std::vector<MemoryItem> HNSWLibVectorStore::search(const std::vector<float>& que
|
|||
return memory_items;
|
||||
}
|
||||
|
||||
void HNSWLibVectorStore::delete_vector(size_t vector_id) {
|
||||
void HNSWLibVectorStore::remove(size_t vector_id) {
|
||||
hnsw->markDelete(vector_id);
|
||||
metadata_store.erase(vector_id);
|
||||
auto it = cache_map.find(vector_id);
|
||||
if (it != cache_map.end()) {
|
||||
metadata_list.erase(it->second);
|
||||
cache_map.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
void HNSWLibVectorStore::update(size_t vector_id, const std::vector<float>& vector, const MemoryItem& metadata) {
|
||||
|
@ -78,8 +86,8 @@ void HNSWLibVectorStore::update(size_t vector_id, const std::vector<float>& vect
|
|||
MemoryItem new_metadata = metadata;
|
||||
new_metadata.id = vector_id; // Make sure the id is the same as the vector id
|
||||
auto now = std::chrono::system_clock::now().time_since_epoch().count();
|
||||
if (metadata_store.find(vector_id) != metadata_store.end()) {
|
||||
MemoryItem old_metadata = metadata_store[vector_id];
|
||||
if (cache_map.find(vector_id) != cache_map.end()) {
|
||||
MemoryItem old_metadata = *cache_map[vector_id];
|
||||
if (new_metadata.hash == old_metadata.hash) {
|
||||
new_metadata.created_at = old_metadata.created_at;
|
||||
} else {
|
||||
|
@ -90,12 +98,34 @@ void HNSWLibVectorStore::update(size_t vector_id, const std::vector<float>& vect
|
|||
new_metadata.created_at = now;
|
||||
}
|
||||
new_metadata.updated_at = now;
|
||||
metadata_store[vector_id] = new_metadata;
|
||||
set(vector_id, new_metadata);
|
||||
}
|
||||
}
|
||||
|
||||
MemoryItem HNSWLibVectorStore::get(size_t vector_id) {
|
||||
return metadata_store.at(vector_id);
|
||||
auto it = cache_map.find(vector_id);
|
||||
if (it != cache_map.end()) {
|
||||
metadata_list.splice(metadata_list.begin(), metadata_list, it->second);
|
||||
return *it->second;
|
||||
}
|
||||
throw std::out_of_range("Vector id " + std::to_string(vector_id) + " not found in cache");
|
||||
}
|
||||
|
||||
void HNSWLibVectorStore::set(size_t vector_id, const MemoryItem& metadata) {
|
||||
auto it = cache_map.find(vector_id);
|
||||
if (it != cache_map.end()) { // update existing metadata
|
||||
*it->second = metadata;
|
||||
metadata_list.splice(metadata_list.begin(), metadata_list, it->second);
|
||||
} else { // insert new metadata
|
||||
if (cache_map.size() >= config_->max_elements) { // cache full
|
||||
auto last = metadata_list.back();
|
||||
cache_map.erase(last.id);
|
||||
metadata_list.pop_back();
|
||||
}
|
||||
|
||||
metadata_list.emplace_front(metadata);
|
||||
cache_map[vector_id] = metadata_list.begin();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<MemoryItem> HNSWLibVectorStore::list(size_t limit, const FilterFunc& filter) {
|
||||
|
@ -118,4 +148,5 @@ std::vector<MemoryItem> HNSWLibVectorStore::list(size_t limit, const FilterFunc&
|
|||
|
||||
return result;
|
||||
}
|
||||
|
||||
};
|
|
@ -1,16 +1,18 @@
|
|||
#ifndef HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H
|
||||
#define HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H
|
||||
#ifndef HUMANUS_MEMORY_VECTOR_STORE_HNSWLIB_H
|
||||
#define HUMANUS_MEMORY_VECTOR_STORE_HNSWLIB_H
|
||||
|
||||
#include "base.h"
|
||||
#include "hnswlib/hnswlib.h"
|
||||
#include <list>
|
||||
|
||||
namespace humanus::mem0 {
|
||||
namespace humanus {
|
||||
|
||||
class HNSWLibVectorStore : public VectorStore {
|
||||
private:
|
||||
std::shared_ptr<hnswlib::HierarchicalNSW<float>> hnsw;
|
||||
std::shared_ptr<hnswlib::SpaceInterface<float>> space; // 保持space对象的引用以确保其生命周期
|
||||
std::unordered_map<size_t, MemoryItem> metadata_store; // 存储向量的元数据
|
||||
std::unordered_map<size_t, std::list<MemoryItem>::iterator> cache_map; // LRU cache
|
||||
std::list<MemoryItem> metadata_list; // 存储向量的元数据
|
||||
|
||||
public:
|
||||
HNSWLibVectorStore(const std::shared_ptr<VectorStoreConfig>& config) : VectorStore(config) {
|
||||
|
@ -19,16 +21,18 @@ public:
|
|||
|
||||
void reset() override;
|
||||
|
||||
void insert(const std::vector<float>& vector, const size_t vector_id, const MemoryItem& metadata) override;
|
||||
void insert(const std::vector<float>& vector, const size_t vector_id, const MemoryItem& metadata = MemoryItem()) override;
|
||||
|
||||
std::vector<MemoryItem> search(const std::vector<float>& query, size_t limit, const FilterFunc& filter = nullptr) override;
|
||||
|
||||
void delete_vector(size_t vector_id) override;
|
||||
void remove(size_t vector_id) override;
|
||||
|
||||
void update(size_t vector_id, const std::vector<float>& vector = std::vector<float>(), const MemoryItem& metadata = MemoryItem()) override;
|
||||
|
||||
MemoryItem get(size_t vector_id) override;
|
||||
|
||||
void set(size_t vector_id, const MemoryItem& metadata) override;
|
||||
|
||||
std::vector<MemoryItem> list(size_t limit, const FilterFunc& filter = nullptr) override;
|
||||
};
|
||||
|
||||
|
@ -56,4 +60,4 @@ public:
|
|||
|
||||
}
|
||||
|
||||
#endif // HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H
|
||||
#endif // HUMANUS_MEMORY_VECTOR_STORE_HNSWLIB_H
|
|
@ -15,10 +15,15 @@ add_library(server STATIC ${SERVER_SOURCES})
|
|||
# 链接依赖库
|
||||
target_link_libraries(server PRIVATE mcp)
|
||||
|
||||
find_package(Python3 REQUIRED)
|
||||
|
||||
find_package(Python3 COMPONENTS Development)
|
||||
if(Python3_FOUND)
|
||||
target_link_libraries(server PRIVATE ${Python3_LIBRARIES})
|
||||
message(STATUS "Python3 found: ${Python3_VERSION}")
|
||||
message(STATUS "Python3 include directory: ${Python3_INCLUDE_DIRS}")
|
||||
message(STATUS "Python3 libraries: ${Python3_LIBRARIES}")
|
||||
include_directories(${Python3_INCLUDE_DIRS})
|
||||
add_compile_definitions(PYTHON_FOUND)
|
||||
else()
|
||||
message(WARNING "Python3 development libraries not found. Python interpreter will not be available.")
|
||||
endif()
|
||||
|
||||
# 包含目录
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
### How to run the server
|
||||
|
||||
```bash
|
||||
./build/bin/mcp_server <port> # default port is 8818
|
||||
```
|
||||
|
||||
### SWitch Python Environment
|
||||
|
||||
```bash
|
||||
rm -rf build
|
||||
|
||||
#example
|
||||
cmake -DPython3_ROOT_DIR=/opt/anaconda3/envs/pytorch \
|
||||
-DPython3_INCLUDE_DIR=/opt/anaconda3/envs/pytorch/include/python3.9 \
|
||||
-DPython3_LIBRARY=/opt/anaconda3/envs/pytorch/lib/libpython3.9.dylib \
|
||||
-B build
|
||||
|
||||
# replace with your own python environment path
|
||||
cmake -DPython3_ROOT_DIR=/path/to/your/python/environment \
|
||||
-DPython3_INCLUDE_DIR=/path/to/your/python/environment/include/python<version> \
|
||||
-DPython3_LIBRARY=/path/to/your/python/environment/lib/libpython<version>.dylib \
|
||||
-B build
|
||||
```
|
|
@ -18,9 +18,21 @@
|
|||
// Import Python execution tool
|
||||
extern void register_python_execute_tool(mcp::server& server);
|
||||
|
||||
int main() {
|
||||
int main(int argc, char* argv[]) {
|
||||
int port;
|
||||
if (argc == 2) {
|
||||
try {
|
||||
port = std::stoi(argv[1]);
|
||||
} catch (...) {
|
||||
std::cerr << "Invalid port number: " << argv[1] << std::endl;
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
port = 8818;
|
||||
}
|
||||
|
||||
// Create and configure server
|
||||
mcp::server server("localhost", 8818);
|
||||
mcp::server server("localhost", port);
|
||||
server.set_server_info("HumanusMCPServer", "0.0.1");
|
||||
|
||||
// Set server capabilities
|
||||
|
@ -33,7 +45,7 @@ int main() {
|
|||
register_python_execute_tool(server);
|
||||
|
||||
// Start server
|
||||
std::cout << "Starting Humanus MCP server at localhost:8818..." << std::endl;
|
||||
std::cout << "Starting Humanus MCP server at localhost:" << port << "..." << std::endl;
|
||||
std::cout << "Press Ctrl+C to stop server" << std::endl;
|
||||
server.start(true); // Blocking mode
|
||||
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
namespace humanus {
|
||||
|
||||
MCPToolConfig MCPToolConfig::load_from_toml(const toml::table &tool_table) {
|
||||
MCPToolConfig config;
|
||||
MCPServerConfig MCPServerConfig::load_from_toml(const toml::table &tool_table) {
|
||||
MCPServerConfig config;
|
||||
|
||||
try {
|
||||
// Read type
|
||||
|
@ -73,7 +73,6 @@ MCPToolConfig MCPToolConfig::load_from_toml(const toml::table &tool_table) {
|
|||
|
||||
// Initialize static members
|
||||
Config* Config::_instance = nullptr;
|
||||
std::mutex Config::_mutex;
|
||||
|
||||
void Config::_load_initial_llm_config() {
|
||||
try {
|
||||
|
@ -168,9 +167,9 @@ void Config::_load_initial_llm_config() {
|
|||
}
|
||||
}
|
||||
|
||||
void Config::_load_initial_mcp_tool_config() {
|
||||
void Config::_load_initial_mcp_server_config() {
|
||||
try {
|
||||
auto config_path = _get_mcp_tool_config_path();
|
||||
auto config_path = _get_mcp_server_config_path();
|
||||
logger->info("Loading MCP tool config file from: " + config_path.string());
|
||||
|
||||
const auto& data = toml::parse_file(config_path.string());
|
||||
|
@ -179,13 +178,7 @@ void Config::_load_initial_mcp_tool_config() {
|
|||
for (const auto& [key, value] : data) {
|
||||
const auto& tool_table = *value.as_table();
|
||||
|
||||
_config.mcp_tool[std::string(key.str())] = MCPToolConfig::load_from_toml(tool_table);
|
||||
}
|
||||
|
||||
if (_config.mcp_tool.empty()) {
|
||||
throw std::runtime_error("No MCP tool configuration found");
|
||||
} else if (_config.mcp_tool.find("default") == _config.mcp_tool.end()) {
|
||||
_config.mcp_tool["default"] = _config.mcp_tool.begin()->second;
|
||||
_config.mcp_server[std::string(key.str())] = MCPServerConfig::load_from_toml(tool_table);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
logger->warn("Failed to load MCP tool configuration: " + std::string(e.what()));
|
||||
|
|
29
src/llm.cpp
29
src/llm.cpp
|
@ -142,16 +142,18 @@ std::string LLM::ask(
|
|||
auto res = client_->Post(llm_config_->endpoint, body_str, "application/json");
|
||||
|
||||
if (!res) {
|
||||
logger->error("Failed to send request: " + httplib::to_string(res.error()));
|
||||
logger->error(std::string(__func__) + ": Failed to send request: " + httplib::to_string(res.error()));
|
||||
} else if (res->status == 200) {
|
||||
try {
|
||||
json json_data = json::parse(res->body);
|
||||
total_prompt_tokens_ += json_data["usage"]["prompt_tokens"].get<size_t>();
|
||||
total_completion_tokens_ += json_data["usage"]["completion_tokens"].get<size_t>();
|
||||
return json_data["choices"][0]["message"]["content"].get<std::string>();
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Failed to parse response: " + std::string(e.what()));
|
||||
logger->error(std::string(__func__) + ": Failed to parse response: " + std::string(e.what()));
|
||||
}
|
||||
} else {
|
||||
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
|
||||
logger->error(std::string(__func__) + ": Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
|
||||
}
|
||||
|
||||
retry++;
|
||||
|
@ -166,6 +168,19 @@ std::string LLM::ask(
|
|||
logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries));
|
||||
}
|
||||
|
||||
// If the logger has a file sink, log the request body
|
||||
if (logger->sinks().size() > 1) {
|
||||
auto file_sink = std::dynamic_pointer_cast<spdlog::sinks::basic_file_sink_mt>(logger->sinks()[1]);
|
||||
if (file_sink) {
|
||||
file_sink->log(spdlog::details::log_msg(
|
||||
spdlog::source_loc{},
|
||||
logger->name(),
|
||||
spdlog::level::debug,
|
||||
"Failed to get response from LLM. Full request body: " + body_str
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
throw std::runtime_error("Failed to get response from LLM");
|
||||
}
|
||||
|
||||
|
@ -266,7 +281,7 @@ json LLM::ask_tool(
|
|||
auto res = client_->Post(llm_config_->endpoint, body_str, "application/json");
|
||||
|
||||
if (!res) {
|
||||
logger->error("Failed to send request: " + httplib::to_string(res.error()));
|
||||
logger->error(std::string(__func__) + ": Failed to send request: " + httplib::to_string(res.error()));
|
||||
} else if (res->status == 200) {
|
||||
try {
|
||||
json json_data = json::parse(res->body);
|
||||
|
@ -274,12 +289,14 @@ json LLM::ask_tool(
|
|||
if (!llm_config_->oai_tool_support && message["content"].is_string()) {
|
||||
message = tool_parser_->parse(message["content"].get<std::string>());
|
||||
}
|
||||
total_prompt_tokens_ += json_data["usage"]["prompt_tokens"].get<size_t>();
|
||||
total_completion_tokens_ += json_data["usage"]["completion_tokens"].get<size_t>();
|
||||
return message;
|
||||
} catch (const std::exception& e) {
|
||||
logger->error("Failed to parse response: " + std::string(e.what()));
|
||||
logger->error(std::string(__func__) + ": Failed to parse response: " + std::string(e.what()));
|
||||
}
|
||||
} else {
|
||||
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
|
||||
logger->error(std::string(__func__) + ": Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
|
||||
}
|
||||
|
||||
retry++;
|
||||
|
|
|
@ -6,34 +6,40 @@ namespace prompt {
|
|||
|
||||
namespace humanus {
|
||||
const char* SYSTEM_PROMPT = "\
|
||||
You are Humanus, an all-capable AI assistant, aimed at solving any task presented by the user. You have various tools at your disposal that you can call upon to efficiently complete complex requests. Whether it's programming, information retrieval, file processing you can handle it all.";
|
||||
|
||||
const char* NEXT_STEP_PROMPT = R"(You can interact with the computer using python_execute, save important content and information files through filesystem, open browsers and retrieve information with puppeteer.
|
||||
You are Humanus, an all-capable AI assistant, aimed at solving any task presented by the user. You have various tools at your disposal that you can call upon to efficiently complete complex requests. Whether it's programming, information retrieval, file processing or web browsingyou can handle it all.";
|
||||
|
||||
const char* NEXT_STEP_PROMPT = R"(You can interact with the computer using python_execute, save important content and information files through filesystem, open browsers and retrieve information with playwright.
|
||||
- python_execute: Execute Python code to interact with the computer system, data processing, automation tasks, etc.
|
||||
|
||||
- filesystem: Read/write files locally, such as txt, py, html, etc. Create/list/delete directories, move files/directories, search for files and get file metadata.
|
||||
- playwright: Interact with web pages, take screenshots, generate test code, web scraps the page and execute JavaScript in a real browser environment. Note: Most of the time you need to observer the page before executing other actions.
|
||||
|
||||
Based on user needs, proactively select the most appropriate tool or combination of tools. For complex tasks, you can break down the problem and use different tools step by step to solve it.
|
||||
|
||||
After using each tool, clearly explain the execution results and suggest the next steps.
|
||||
|
||||
Unless required by user, you should always at most use one tool at a time, observe the result and then choose the next tool or action.
|
||||
|
||||
Detect the language of the user input and respond in the same language for thoughts.
|
||||
|
||||
Basically the user will not reply to you, you should make decisions and determine whether current step is finished. If you finish the current step, call `terminate`.)";
|
||||
Remember the following:
|
||||
- Today's date is {current_date}.
|
||||
- Refer to current request to determine what to do: {current_request}
|
||||
- Based on user needs, proactively select the most appropriate tool or combination of tools. For complex tasks, you can break down the problem and use different tools step by step to solve it.
|
||||
- After using each tool, clearly explain the execution results and suggest the next steps.
|
||||
- Unless required by user, you should always at most use one tool at a time, observe the result and then choose the next tool or action.
|
||||
- Detect the language of the user input and respond in the same language for thoughts.
|
||||
- Basically the user will not reply to you, you should make decisions and determine whether current step is finished. If you want to stop interaction, call `terminate`.)";
|
||||
} // namespace humanus
|
||||
|
||||
namespace toolcall {
|
||||
const char* SYSTEM_PROMPT = "You are an agent that can execute tool calls";
|
||||
const char* SYSTEM_PROMPT = "You are a helpful assistant that can execute tool calls to help users with their task";
|
||||
|
||||
const char* NEXT_STEP_PROMPT = "If you want to stop interaction, use `terminate` tool/function call.";
|
||||
const char* NEXT_STEP_PROMPT = R"(You can interact with the computer using provided tools.
|
||||
|
||||
Remember the following:
|
||||
- Today's date is {current_date}.
|
||||
- Refer to current request to determine what to do: {current_request}
|
||||
- Based on user needs, proactively select the most appropriate tool or combination of tools. For complex tasks, you can break down the problem and use different tools step by step to solve it.
|
||||
- After using each tool, clearly explain the execution results and suggest the next steps.
|
||||
- Unless required by user, you should always at most use one tool at a time, observe the result and then choose the next tool or action.
|
||||
- Detect the language of the user input and respond in the same language for thoughts.
|
||||
- Basically the user will not reply to you, you should make decisions and determine whether current step is finished. If you want to stop interaction, call `terminate`.)";
|
||||
|
||||
const char* TOOL_HINT_TEMPLATE = "Available tools:\n{tool_list}\n\nFor each tool call, return a json object with tool name and arguments within {tool_start}{tool_end} XML tags:\n{tool_start}\n{\"name\": <tool-name>, \"arguments\": <args-json-object>}\n{tool_end}";
|
||||
} // namespace toolcall
|
||||
|
||||
namespace mem0 {
|
||||
const char* FACT_EXTRACTION_PROMPT = R"(You are a Personal Information Organizer, specialized in accurately storing facts, user memories, and preferences. Your primary role is to extract relevant pieces of information from conversations and organize them into distinct, manageable facts. This allows for easy retrieval and personalization in future interactions. Below are the types of information you need to focus on and the detailed instructions on how to handle the input data.
|
||||
|
||||
Types of Information to Remember:
|
||||
|
@ -58,6 +64,8 @@ Remember the following:
|
|||
|
||||
Following is a message parsed from previous interactions. You have to extract the relevant facts and preferences about the user and some accomplished tasks about the assistant.
|
||||
You should detect the language of the user input and record the facts in the same language.
|
||||
|
||||
Below is the data to extract in XML tags <input> and </input>:
|
||||
)";
|
||||
|
||||
const char* UPDATE_MEMORY_PROMPT = R"(You are a smart memory manager which controls the memory of a system.
|
||||
|
@ -208,8 +216,6 @@ Please note to return the IDs in the output from the input IDs only and do not g
|
|||
}
|
||||
)";
|
||||
|
||||
} // namespace mem0
|
||||
|
||||
} // namespace prompt
|
||||
|
||||
} // namespace humanus
|
|
@ -0,0 +1,70 @@
|
|||
#include "utils.h"
|
||||
|
||||
namespace humanus {
|
||||
|
||||
const std::filesystem::path PROJECT_ROOT = get_project_root();
|
||||
|
||||
size_t validate_utf8(const std::string& text) {
|
||||
size_t len = text.size();
|
||||
if (len == 0) return 0;
|
||||
|
||||
// Check the last few bytes to see if a multi-byte character is cut off
|
||||
for (size_t i = 1; i <= 4 && i <= len; ++i) {
|
||||
unsigned char c = text[len - i];
|
||||
// Check for start of a multi-byte sequence from the end
|
||||
if ((c & 0xE0) == 0xC0) {
|
||||
// 2-byte character start: 110xxxxx
|
||||
// Needs at least 2 bytes
|
||||
if (i < 2) return len - i;
|
||||
} else if ((c & 0xF0) == 0xE0) {
|
||||
// 3-byte character start: 1110xxxx
|
||||
// Needs at least 3 bytes
|
||||
if (i < 3) return len - i;
|
||||
} else if ((c & 0xF8) == 0xF0) {
|
||||
// 4-byte character start: 11110xxx
|
||||
// Needs at least 4 bytes
|
||||
if (i < 4) return len - i;
|
||||
}
|
||||
}
|
||||
|
||||
// If no cut-off multi-byte character is found, return full length
|
||||
return len;
|
||||
}
|
||||
|
||||
bool readline_utf8(std::string & line, bool multiline_input) {
|
||||
#if defined(_WIN32)
|
||||
std::wstring wline;
|
||||
if (!std::getline(std::wcin, wline)) {
|
||||
// Input stream is bad or EOF received
|
||||
line.clear();
|
||||
GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
|
||||
return false;
|
||||
}
|
||||
|
||||
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
|
||||
line.resize(size_needed);
|
||||
WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
|
||||
#else
|
||||
if (!std::getline(std::cin, line)) {
|
||||
// Input stream is bad or EOF received
|
||||
line.clear();
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
if (!line.empty()) {
|
||||
char last = line.back();
|
||||
if (last == '/') { // Always return control on '/' symbol
|
||||
line.pop_back();
|
||||
return false;
|
||||
}
|
||||
if (last == '\\') { // '\\' changes the default action
|
||||
line.pop_back();
|
||||
multiline_input = !multiline_input;
|
||||
}
|
||||
}
|
||||
|
||||
// By default, continue input if multiline_input is set
|
||||
return multiline_input;
|
||||
}
|
||||
|
||||
} // namespace humanus
|
|
@ -0,0 +1,11 @@
|
|||
cmake_minimum_required(VERSION 3.10)
|
||||
project(humanus_tests)
|
||||
|
||||
# 添加测试可执行文件
|
||||
add_executable(test_bpe test_bpe.cpp)
|
||||
|
||||
# 链接到主库
|
||||
target_link_libraries(test_bpe PRIVATE humanus)
|
||||
|
||||
# 添加包含路径,确保可以找到需要的头文件
|
||||
target_include_directories(test_bpe PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
|
@ -0,0 +1,159 @@
|
|||
#include "../tokenizer/bpe.h"
|
||||
#include "../tokenizer/utils.h"
|
||||
#include "../mcp/common/json.hpp"
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <stdexcept>
|
||||
|
||||
using namespace humanus;
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
#define TEST_FAILED(func, ...) std::cout << func << " \033[31mfailed\033[0m " << ("\0", ##__VA_ARGS__) << std::endl;
|
||||
#define TEST_PASSED(func, ...) std::cout << func << " \033[32mpassed\033[0m " << ("\0", ##__VA_ARGS__) << std::endl;
|
||||
|
||||
void test_encode_decode(const BPETokenizer& tokenizer) {
|
||||
std::string test_text = "Hello, world! 你好,世界!";
|
||||
|
||||
auto tokens = tokenizer.encode(test_text);
|
||||
|
||||
std::string decoded = tokenizer.decode(tokens);
|
||||
|
||||
if (decoded != test_text) {
|
||||
TEST_FAILED(__func__, "Expected " + test_text + ", got " + decoded);
|
||||
return;
|
||||
}
|
||||
|
||||
test_text = "お誕生日おめでとう";
|
||||
tokens = tokenizer.encode(test_text);
|
||||
decoded = tokenizer.decode(tokens);
|
||||
if (decoded != test_text) {
|
||||
TEST_FAILED(__func__, "Expected " + test_text + ", got " + decoded);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<size_t> expected_tokens{33334, 45918, 243, 21990, 9080, 33334, 62004, 16556, 78699};
|
||||
if (tokens.size() != expected_tokens.size()) {
|
||||
TEST_FAILED(__func__, "Expected " + std::to_string(expected_tokens.size()) + " tokens, got " + std::to_string(tokens.size()));
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
if (tokens[i] != expected_tokens[i]) {
|
||||
TEST_FAILED(__func__, "Expected " + std::to_string(expected_tokens[i]) + ", got " + std::to_string(tokens[i]));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_PASSED(__func__);
|
||||
}
|
||||
|
||||
void test_num_tokens_from_messages(const BPETokenizer& tokenizer) {
|
||||
json example_messages = json::parse(R"json([
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful, pattern-following assistant that translates corporate jargon into plain English."
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"name": "example_user",
|
||||
"content": "New synergies will help drive top-line growth."
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"name": "example_assistant",
|
||||
"content": "Things working well together will increase revenue."
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"name": "example_user",
|
||||
"content": "Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage."
|
||||
},
|
||||
{
|
||||
"role": "system",
|
||||
"name": "example_assistant",
|
||||
"content": "Let's talk later when we're less busy about how to do better."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "This late pivot means we don't have time to boil the ocean for the client deliverable."
|
||||
}
|
||||
])json");
|
||||
|
||||
size_t num_tokens = num_tokens_from_messages(tokenizer, example_messages);
|
||||
|
||||
if (num_tokens != 129) {
|
||||
TEST_FAILED(__func__, "Expected 129, got " + std::to_string(num_tokens));
|
||||
} else {
|
||||
TEST_PASSED(__func__);
|
||||
}
|
||||
}
|
||||
|
||||
void test_num_tokens_for_tools(const BPETokenizer& tokenizer) {
|
||||
json tools = json::parse(R"json([
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit of temperature to return",
|
||||
"enum": [
|
||||
"celsius",
|
||||
"fahrenheit"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
])json");
|
||||
|
||||
json example_messages = json::parse(R"json([
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that can answer to questions about the weather."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in San Francisco?"
|
||||
}
|
||||
])json");
|
||||
|
||||
size_t num_tokens = num_tokens_for_tools(tokenizer, tools, example_messages);
|
||||
|
||||
if (num_tokens != 105) {
|
||||
TEST_FAILED(__func__, "Expected 105, got " + std::to_string(num_tokens));
|
||||
} else {
|
||||
TEST_PASSED(__func__);
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
try {
|
||||
auto tokenizer = BPETokenizer::load_from_tiktoken("tokenizer/cl100k_base.tiktoken");
|
||||
|
||||
test_encode_decode(*tokenizer);
|
||||
|
||||
test_num_tokens_from_messages(*tokenizer);
|
||||
|
||||
test_num_tokens_for_tools(*tokenizer);
|
||||
|
||||
return 0;
|
||||
} catch (const std::exception& e) {
|
||||
TEST_FAILED("test_bpe", "Error: " + std::string(e.what()));
|
||||
return 1;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
cmake_minimum_required(VERSION 3.10)
|
||||
project(humanus_tokenizer)
|
||||
|
||||
# 复制测试数据到构建目录
|
||||
file(COPY cl100k_base.tiktoken DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
|
|
@ -0,0 +1,17 @@
|
|||
#ifndef HUMANUS_TOKENIZER_BASE_H
|
||||
#define HUMANUS_TOKENIZER_BASE_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
namespace humanus {
|
||||
|
||||
class BaseTokenizer {
|
||||
public:
|
||||
virtual std::vector<size_t> encode(const std::string& text) const = 0;
|
||||
virtual std::string decode(const std::vector<size_t>& tokens) const = 0;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif // HUMANUS_TOKENIZER_BASE_H
|
|
@ -0,0 +1,245 @@
|
|||
#ifndef HUMANUS_TOKENIZER_BPE_H
|
||||
#define HUMANUS_TOKENIZER_BPE_H
|
||||
|
||||
#include "base.h"
|
||||
#include "../mcp/common/base64.hpp"
|
||||
#include <unordered_map>
|
||||
#include <queue>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
|
||||
namespace humanus {
|
||||
|
||||
/**
|
||||
* @brief BPE (Byte Pair Encoding) Tokenizer实现
|
||||
*
|
||||
* 这个类实现了基于BPE (Byte Pair Encoding)算法的分词器,
|
||||
* 使用tiktoken格式的词汇表和合并规则文件。
|
||||
* 使用优先队列进行高效的BPE合并操作。
|
||||
*/
|
||||
class BPETokenizer : public BaseTokenizer {
|
||||
private:
|
||||
// 辅助结构,用于哈希pair
|
||||
struct PairHash {
|
||||
template <class T1, class T2>
|
||||
std::size_t operator() (const std::pair<T1, T2>& pair) const {
|
||||
return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
|
||||
}
|
||||
};
|
||||
|
||||
// 字节对到对应token ID的映射
|
||||
std::unordered_map<std::string, size_t> encoder;
|
||||
// token ID到字节对的映射
|
||||
std::unordered_map<size_t, std::string> decoder;
|
||||
|
||||
// 合并优先级映射,优先级越小越优先合并
|
||||
std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash> merge_ranks;
|
||||
|
||||
/**
|
||||
* @brief 解码base64编码的字符串
|
||||
* @param encoded base64编码的字符串
|
||||
* @return 解码后的字符串
|
||||
*/
|
||||
std::string base64_decode(const std::string& encoded) const {
|
||||
if (encoded.empty()) return "";
|
||||
return base64::decode(encoded);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 用于优先队列的比较函数
|
||||
*
|
||||
* 根据merge_ranks中的优先级比较两个token对。
|
||||
* 优先级较低的值(即较小的rank值)具有更高的优先级。
|
||||
*/
|
||||
struct MergeComparator {
|
||||
const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& ranks;
|
||||
|
||||
MergeComparator(const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& r)
|
||||
: ranks(r) {}
|
||||
|
||||
bool operator()(const std::pair<std::pair<std::string, std::string>, size_t>& a,
|
||||
const std::pair<std::pair<std::string, std::string>, size_t>& b) const {
|
||||
// 首先按照merge_ranks比较,如果不存在则使用最大值
|
||||
size_t rank_a = ranks.count(a.first) ? ranks.at(a.first) : std::numeric_limits<size_t>::max();
|
||||
size_t rank_b = ranks.count(b.first) ? ranks.at(b.first) : std::numeric_limits<size_t>::max();
|
||||
|
||||
// 优先队列是最大堆,所以我们需要反向比较(较小的rank优先级更高)
|
||||
return rank_a > rank_b;
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief 从tiktoken格式文件构造BPE tokenizer
|
||||
* @param tokenizer_path tiktoken格式词汇表文件的路径
|
||||
*
|
||||
* 文件格式:每行包含一个base64编码的token和对应的token ID
|
||||
* 例如: "IQ== 0",其中"IQ=="是base64编码的token,0是对应的ID
|
||||
*/
|
||||
BPETokenizer(const std::string& tokenizer_path) {
|
||||
std::ifstream file(tokenizer_path);
|
||||
if (!file.is_open()) {
|
||||
throw std::runtime_error("无法打开tokenizer文件: " + tokenizer_path);
|
||||
}
|
||||
|
||||
std::string line;
|
||||
while (std::getline(file, line)) {
|
||||
std::istringstream iss(line);
|
||||
std::string token_base64;
|
||||
size_t rank;
|
||||
|
||||
if (iss >> token_base64 >> rank) {
|
||||
std::string token = base64_decode(token_base64);
|
||||
|
||||
// 存储token和其ID的映射关系
|
||||
encoder[token] = rank;
|
||||
decoder[rank] = token;
|
||||
}
|
||||
}
|
||||
|
||||
// 构建merge_ranks
|
||||
build_merge_ranks();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 构建合并优先级映射
|
||||
*
|
||||
* 利用词汇表中的tokens推断可能的合并规则。
|
||||
* 对于长度大于1的tokens,尝试所有可能的分割,如果分割后的两个部分也在词汇表中,
|
||||
* 则假设这是一个有效的合并规则。
|
||||
*/
|
||||
void build_merge_ranks() {
|
||||
// 对于tiktoken格式,我们可以利用编码中长度>1的token来构建合并规则
|
||||
for (const auto& [token, id] : encoder) {
|
||||
if (token.length() <= 1) continue;
|
||||
|
||||
// 尝试所有可能的分割点
|
||||
for (size_t i = 1; i < token.length(); ++i) {
|
||||
std::string first = token.substr(0, i);
|
||||
std::string second = token.substr(i);
|
||||
|
||||
// 如果两个部分都在词汇表中,假设这是一个有效的合并规则
|
||||
if (encoder.count(first) && encoder.count(second)) {
|
||||
// 使用ID作为优先级 - 较小的ID表示更高的优先级
|
||||
merge_ranks[{first, second}] = id;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 设置合并优先级
|
||||
* @param ranks 新的合并优先级映射
|
||||
*/
|
||||
void set_merge_ranks(const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& ranks) {
|
||||
merge_ranks = ranks;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 执行BPE编码
|
||||
* @param text 要编码的文本
|
||||
* @return 编码后的token IDs
|
||||
*
|
||||
* 该方法使用BPE算法对输入文本进行编码。
|
||||
* 1. 首先将文本分解为单个字节
|
||||
* 2. 使用优先队列根据merge_ranks中的优先级对相邻token进行合并
|
||||
* 3. 将最终的tokens转换为对应的IDs
|
||||
*/
|
||||
std::vector<size_t> encode(const std::string& text) const override {
|
||||
if (text.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// 将文本分解为单个字符tokens
|
||||
std::vector<std::string> tokens;
|
||||
for (unsigned char c : text) {
|
||||
tokens.push_back(std::string(1, c));
|
||||
}
|
||||
|
||||
// 使用优先队列执行BPE合并
|
||||
while (tokens.size() > 1) {
|
||||
// 构建优先队列,用于选择最高优先级的合并对
|
||||
using MergePair = std::pair<std::pair<std::string, std::string>, size_t>;
|
||||
MergeComparator comparator(merge_ranks);
|
||||
std::priority_queue<MergePair, std::vector<MergePair>, MergeComparator> merge_candidates(comparator);
|
||||
|
||||
// 查找所有可能的合并对
|
||||
for (size_t i = 0; i < tokens.size() - 1; ++i) {
|
||||
std::pair<std::string, std::string> pair = {tokens[i], tokens[i+1]};
|
||||
if (merge_ranks.count(pair)) {
|
||||
merge_candidates.push({pair, i});
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有可合并的对,退出循环
|
||||
if (merge_candidates.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// 执行优先级最高的合并(优先队列中排在最前面的)
|
||||
auto top_merge = merge_candidates.top();
|
||||
auto pair = top_merge.first; // 要合并的token对
|
||||
size_t pos = top_merge.second; // 要合并的位置
|
||||
|
||||
// 合并token
|
||||
std::string merged_token = pair.first + pair.second;
|
||||
tokens[pos] = merged_token;
|
||||
tokens.erase(tokens.begin() + pos + 1);
|
||||
}
|
||||
|
||||
// 将tokens转换为IDs
|
||||
std::vector<size_t> ids;
|
||||
ids.reserve(tokens.size());
|
||||
for (const auto& token : tokens) {
|
||||
auto it = encoder.find(token);
|
||||
if (it != encoder.end()) {
|
||||
ids.push_back(it->second);
|
||||
}
|
||||
// 未知token将被跳过
|
||||
}
|
||||
|
||||
return ids;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 解码BPE tokens
|
||||
* @param tokens 要解码的token IDs
|
||||
* @return 解码后的文本
|
||||
*
|
||||
* 该方法将编码后的token IDs转换回原始文本。
|
||||
* 简单地将每个ID对应的token字符串连接起来。
|
||||
*/
|
||||
std::string decode(const std::vector<size_t>& tokens) const override {
|
||||
std::string result;
|
||||
result.reserve(tokens.size() * 2); // 预估大小,避免频繁重新分配
|
||||
|
||||
for (size_t id : tokens) {
|
||||
auto it = decoder.find(id);
|
||||
if (it != decoder.end()) {
|
||||
result += it->second;
|
||||
}
|
||||
// 未知ID将被跳过
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief 从tiktoken文件加载BPE词汇和合并规则
|
||||
* @param file_path tiktoken格式文件的路径
|
||||
* @return 共享指针指向创建的BPE tokenizer
|
||||
*/
|
||||
static std::shared_ptr<BPETokenizer> load_from_tiktoken(const std::string& file_path) {
|
||||
return std::make_shared<BPETokenizer>(file_path);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace humanus
|
||||
|
||||
#endif // HUMANUS_TOKENIZER_BPE_H
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,99 @@
|
|||
#include "utils.h"
|
||||
|
||||
namespace humanus {
|
||||
|
||||
int num_tokens_from_messages(const BaseTokenizer& tokenizer, const json& messages) {
|
||||
// TODO: configure the magic number
|
||||
static const int tokens_per_message = 3;
|
||||
static const int tokens_per_name = 1;
|
||||
static const int tokens_per_image = 1024;
|
||||
|
||||
int num_tokens = 0;
|
||||
|
||||
json messages_;
|
||||
|
||||
if (messages.is_object()) {
|
||||
messages_ = json::array({messages});
|
||||
} else {
|
||||
messages_ = messages;
|
||||
}
|
||||
|
||||
for (const auto& message : messages_) {
|
||||
num_tokens += tokens_per_message;
|
||||
for (const auto& [key, value] : message.items()) {
|
||||
if (value.is_string()) {
|
||||
num_tokens += tokenizer.encode(value.get<std::string>()).size();
|
||||
} else if (value.is_array()) {
|
||||
for (const auto& item : value) {
|
||||
if (item.contains("text")) {
|
||||
num_tokens += tokenizer.encode(item.at("text").get<std::string>()).size();
|
||||
} else if (item.contains("image_url")) {
|
||||
num_tokens += tokens_per_image;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (key == "name") {
|
||||
num_tokens += tokens_per_name;
|
||||
}
|
||||
}
|
||||
}
|
||||
num_tokens += 3; // every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens;
|
||||
}
|
||||
|
||||
int num_tokens_for_tools(const BaseTokenizer& tokenizer, const json& tools, const json& messages) {
|
||||
// TODO: configure the magic number
|
||||
static const int tool_init = 10;
|
||||
static const int prop_init = 3;
|
||||
static const int prop_key = 3;
|
||||
static const int enum_init = 3;
|
||||
static const int enum_item = 3;
|
||||
static const int tool_end = 12;
|
||||
|
||||
int tool_token_count = 0;
|
||||
if (!tools.empty()) {
|
||||
for (const auto& tool : tools) {
|
||||
tool_token_count += tool_init; // Add tokens for start of each tool
|
||||
auto function = tool["function"];
|
||||
auto f_name = function["name"].get<std::string>();
|
||||
auto f_desc = function["description"].get<std::string>();
|
||||
if (f_desc.back() == '.') {
|
||||
f_desc.pop_back();
|
||||
}
|
||||
auto line = f_name + ":" + f_desc;
|
||||
tool_token_count += tokenizer.encode(line).size();
|
||||
|
||||
if (function["parameters"].contains("properties")) {
|
||||
tool_token_count += prop_init; // Add tokens for start of each property
|
||||
for (const auto& [key, value] : function["parameters"]["properties"].items()) {
|
||||
tool_token_count += prop_key; // Add tokens for each set property
|
||||
auto p_name = key;
|
||||
auto p_type = value["type"].get<std::string>();
|
||||
auto p_desc = value["description"].get<std::string>();
|
||||
|
||||
if (value.contains("enum")) {
|
||||
tool_token_count += enum_init; // Add tokens if property has enum list
|
||||
for (const auto& item : value["enum"]) {
|
||||
tool_token_count += enum_item;
|
||||
tool_token_count += tokenizer.encode(item.get<std::string>()).size();
|
||||
}
|
||||
}
|
||||
|
||||
if (p_desc.back() == '.') {
|
||||
p_desc.pop_back();
|
||||
}
|
||||
auto line = p_name + ":" + p_type + ":" + p_desc;
|
||||
tool_token_count += tokenizer.encode(line).size();
|
||||
}
|
||||
}
|
||||
}
|
||||
tool_token_count += tool_end;
|
||||
}
|
||||
|
||||
auto messages_token_count = num_tokens_from_messages(tokenizer, messages);
|
||||
auto total_token_count = tool_token_count + messages_token_count;
|
||||
|
||||
return total_token_count;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
#ifndef HUMANUS_TOKENIZER_UTILS_H
|
||||
#define HUMANUS_TOKENIZER_UTILS_H
|
||||
|
||||
#include "base.h"
|
||||
#include "../mcp/common/json.hpp"
|
||||
|
||||
namespace humanus {
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
/**
|
||||
* @brief Roughly count the number of tokens in a message (https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb)
|
||||
* @param tokenizer The tokenizer to use
|
||||
* @param messages The messages to count (object or array)
|
||||
* @return The number of tokens in the messages
|
||||
*/
|
||||
int num_tokens_from_messages(const BaseTokenizer& tokenizer, const json& messages);
|
||||
|
||||
/**
|
||||
* @brief Roughly count the number of tokens in a message (https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb)
|
||||
* @param tokenizer The tokenizer to use
|
||||
* @param tools The tools to count (array)
|
||||
* @param messages The messages to count (object or array)
|
||||
* @return The number of tokens in the messages
|
||||
*/
|
||||
int num_tokens_for_tools(const BaseTokenizer& tokenizer, const json& tools, const json& messages);
|
||||
|
||||
} // namespace humanus
|
||||
|
||||
#endif // HUMANUS_TOKENIZER_UTILS_H
|
28
tool/base.h
28
tool/base.h
|
@ -2,7 +2,6 @@
|
|||
#define HUMANUS_TOOL_BASE_H
|
||||
|
||||
#include "schema.h"
|
||||
#include "agent/base.h"
|
||||
#include "config.h"
|
||||
#include "mcp_stdio_client.h"
|
||||
#include "mcp_sse_client.h"
|
||||
|
@ -100,13 +99,19 @@ struct BaseTool {
|
|||
|
||||
// Execute the tool with given parameters.
|
||||
struct BaseMCPTool : BaseTool {
|
||||
std::unique_ptr<mcp::client> _client;
|
||||
std::shared_ptr<mcp::client> _client;
|
||||
|
||||
BaseMCPTool(const std::string& name, const std::string& description, const json& parameters, const std::shared_ptr<mcp::client>& client)
|
||||
: BaseTool(name, description, parameters), _client(client) {}
|
||||
|
||||
BaseMCPTool(const std::string& name, const std::string& description, const json& parameters)
|
||||
: BaseTool(name, description, parameters) {
|
||||
: BaseTool(name, description, parameters), _client(create_client(name)) {}
|
||||
|
||||
static std::shared_ptr<mcp::client> create_client(const std::string& server_name) {
|
||||
std::shared_ptr<mcp::client> _client;
|
||||
try {
|
||||
// Load tool configuration from config file
|
||||
auto _config = Config::get_instance().mcp_tool().at(name);
|
||||
auto _config = Config::get_instance().mcp_server().at(server_name);
|
||||
|
||||
if (_config.type == "stdio") {
|
||||
std::string command = _config.command;
|
||||
|
@ -115,21 +120,22 @@ struct BaseMCPTool : BaseTool {
|
|||
command += " " + arg;
|
||||
}
|
||||
}
|
||||
_client = std::make_unique<mcp::stdio_client>(command, _config.env_vars);
|
||||
_client = std::make_shared<mcp::stdio_client>(command, _config.env_vars);
|
||||
} else if (_config.type == "sse") {
|
||||
if (!_config.host.empty() && _config.port > 0) {
|
||||
_client = std::make_unique<mcp::sse_client>(_config.host, _config.port);
|
||||
_client = std::make_shared<mcp::sse_client>(_config.host, _config.port);
|
||||
} else if (!_config.url.empty()) {
|
||||
_client = std::make_unique<mcp::sse_client>(_config.url, "/sse");
|
||||
_client = std::make_shared<mcp::sse_client>(_config.url, "/sse");
|
||||
} else {
|
||||
throw std::runtime_error("MCP SSE configuration missing host or port or url");
|
||||
}
|
||||
}
|
||||
|
||||
_client->initialize(name + "_client", "0.0.1");
|
||||
_client->initialize(server_name + "_client", "0.0.1");
|
||||
} catch (const std::exception& e) {
|
||||
throw std::runtime_error("Failed to initialize MCP tool client for `" + name + "`: " + std::string(e.what()));
|
||||
throw std::runtime_error("Failed to initialize MCP tool client for `" + server_name + "`: " + std::string(e.what()));
|
||||
}
|
||||
return _client;
|
||||
}
|
||||
|
||||
ToolResult execute(const json& arguments) override {
|
||||
|
@ -151,10 +157,6 @@ struct BaseMCPTool : BaseTool {
|
|||
}
|
||||
};
|
||||
|
||||
struct AgentAware : BaseTool {
|
||||
std::shared_ptr<BaseAgent> agent = nullptr;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif // HUMANUS_TOOL_BASE_H
|
||||
|
|
|
@ -0,0 +1,302 @@
|
|||
#ifndef HUMANUS_TOOL_CONTENT_PROVIDER_H
|
||||
#define HUMANUS_TOOL_CONTENT_PROVIDER_H
|
||||
|
||||
#include "base.h"
|
||||
#include "utils.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
namespace humanus {
|
||||
|
||||
struct ContentProvider : BaseTool {
|
||||
inline static const std::string name_ = "content_provider";
|
||||
inline static const std::string description_ = "Use this tool to save temporary content for later use. For example, you can save a large code file (like HTML) and read it by chunks later.";
|
||||
inline static const json parameters_ = json::parse(R"json(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"description": "The operation to perform: `write` to save content, `read` to retrieve content",
|
||||
"enum": ["write", "read"]
|
||||
},
|
||||
"content": {
|
||||
"type": "array",
|
||||
"description": "The content to store. Required when operation is `write`. Format: [{`type`: `text`, `text`: `content`}, {`type`: `image`, `image_url`: {`url`: `image_url`}}]",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["text", "image"]
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text content. Required when type is `text`."
|
||||
},
|
||||
"image_url": {
|
||||
"type": "object",
|
||||
"description": "Image URL information. Required when type is `image`.",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL of the image"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"cursor": {
|
||||
"type": "string",
|
||||
"description": "The cursor position for reading content. Required when operation is `read`. Use `start` for the beginning or the cursor returned from a previous read."
|
||||
},
|
||||
"max_chunk_size": {
|
||||
"type": "integer",
|
||||
"description": "Maximum size in characters for each text chunk. Default is 4000.",
|
||||
"default": 4000
|
||||
}
|
||||
},
|
||||
"required": ["operation"]
|
||||
}
|
||||
)json");
|
||||
|
||||
inline static std::map<std::string, std::vector<json>> content_store_;
|
||||
inline static size_t MAX_STORE_ID = 100;
|
||||
inline static size_t current_id_ = 0;
|
||||
|
||||
ContentProvider() : BaseTool(name_, description_, parameters_) {}
|
||||
|
||||
// 将文本分割成合适大小的块
|
||||
std::vector<json> split_text_into_chunks(const std::string& text, int max_chunk_size) {
|
||||
std::vector<json> chunks;
|
||||
|
||||
// 如果文本为空,返回空数组
|
||||
if (text.empty()) {
|
||||
return chunks;
|
||||
}
|
||||
|
||||
size_t text_length = text.length();
|
||||
size_t offset = 0;
|
||||
|
||||
while (offset < text_length) {
|
||||
// 首先确定最大可能的块大小
|
||||
size_t raw_chunk_size = std::min(static_cast<size_t>(max_chunk_size), text_length - offset);
|
||||
|
||||
// 使用 validate_utf8 确保不会截断 UTF-8 字符
|
||||
std::string potential_chunk = text.substr(offset, raw_chunk_size);
|
||||
size_t valid_utf8_length = validate_utf8(potential_chunk);
|
||||
|
||||
// 调整为有效的 UTF-8 字符边界
|
||||
size_t chunk_size = valid_utf8_length;
|
||||
|
||||
// 如果不是在文本的结尾,并且我们没有因为 UTF-8 截断而减小块大小,
|
||||
// 尝试在空格、换行或标点处分割,以获得更自然的分隔点
|
||||
if (offset + chunk_size < text_length && chunk_size == raw_chunk_size) {
|
||||
size_t break_pos = offset + chunk_size;
|
||||
|
||||
// 向后寻找一个合适的分割点
|
||||
size_t min_pos = offset + valid_utf8_length / 2; // 不要搜索太远,至少保留一半的有效内容
|
||||
while (break_pos > min_pos &&
|
||||
text[break_pos] != ' ' &&
|
||||
text[break_pos] != '\n' &&
|
||||
text[break_pos] != '.' &&
|
||||
text[break_pos] != ',' &&
|
||||
text[break_pos] != ';' &&
|
||||
text[break_pos] != ':' &&
|
||||
text[break_pos] != '!' &&
|
||||
text[break_pos] != '?') {
|
||||
break_pos--;
|
||||
}
|
||||
|
||||
// 如果找到了合适的分割点且不是原始位置
|
||||
if (break_pos > min_pos) {
|
||||
// 向前移动到分隔符后面的位置
|
||||
break_pos++;
|
||||
|
||||
// 检查新的分割点是否会导致 UTF-8 截断
|
||||
std::string new_chunk = text.substr(offset, break_pos - offset);
|
||||
size_t new_valid_length = validate_utf8(new_chunk);
|
||||
|
||||
if (new_valid_length == new_chunk.size()) {
|
||||
// 只有在不会截断 UTF-8 字符的情况下使用新的分割点
|
||||
chunk_size = break_pos - offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建一个文本块
|
||||
json chunk;
|
||||
chunk["type"] = "text";
|
||||
chunk["text"] = text.substr(offset, chunk_size);
|
||||
chunks.push_back(chunk);
|
||||
|
||||
offset += chunk_size;
|
||||
}
|
||||
|
||||
return chunks;
|
||||
}
|
||||
|
||||
// 处理写入操作
|
||||
ToolResult handle_write(const json& args) {
|
||||
int max_chunk_size = args.value("max_chunk_size", 4000);
|
||||
|
||||
if (!args.contains("content") || !args["content"].is_array()) {
|
||||
return ToolError("`content` is required and must be an array");
|
||||
}
|
||||
|
||||
std::vector<json> processed_content;
|
||||
|
||||
// 处理内容,分割大型文本
|
||||
for (const auto& item : args["content"]) {
|
||||
if (!item.contains("type")) {
|
||||
return ToolError("Each content item must have a `type` field");
|
||||
}
|
||||
|
||||
std::string type = item["type"];
|
||||
|
||||
if (type == "text") {
|
||||
if (!item.contains("text") || !item["text"].is_string()) {
|
||||
return ToolError("Text items must have a `text` field with string value");
|
||||
}
|
||||
|
||||
std::string text = item["text"];
|
||||
auto chunks = split_text_into_chunks(text, max_chunk_size);
|
||||
processed_content.insert(processed_content.end(), chunks.begin(), chunks.end());
|
||||
} else if (type == "image") {
|
||||
if (!item.contains("image_url") || !item["image_url"].is_object() ||
|
||||
!item["image_url"].contains("url") || !item["image_url"]["url"].is_string()) {
|
||||
return ToolError("Image items must have an `image_url` field with a `url` property");
|
||||
}
|
||||
|
||||
// 图像保持为一个整体
|
||||
processed_content.push_back(item);
|
||||
} else {
|
||||
return ToolError("Unsupported content type: " + type);
|
||||
}
|
||||
}
|
||||
|
||||
// 生成一个唯一的存储ID
|
||||
std::string store_id = "content_" + std::to_string(current_id_);
|
||||
current_id_ = (current_id_ + 1) % MAX_STORE_ID;
|
||||
|
||||
// 存储处理后的内容
|
||||
content_store_[store_id] = processed_content;
|
||||
|
||||
// 返回存储ID和内容项数
|
||||
json result;
|
||||
result["store_id"] = store_id;
|
||||
result["total_items"] = processed_content.size();
|
||||
|
||||
return ToolResult(result);
|
||||
}
|
||||
|
||||
// 处理读取操作
|
||||
ToolResult handle_read(const json& args) {
|
||||
if (!args.contains("cursor") || !args["cursor"].is_string()) {
|
||||
return ToolError("`cursor` is required for read operations");
|
||||
}
|
||||
|
||||
std::string cursor = args["cursor"];
|
||||
|
||||
if (cursor == "start") {
|
||||
// 列出所有可用的存储ID
|
||||
json available_stores = json::array();
|
||||
for (const auto& [id, content] : content_store_) {
|
||||
json store_info;
|
||||
store_info["store_id"] = id;
|
||||
store_info["total_items"] = content.size();
|
||||
available_stores.push_back(store_info);
|
||||
}
|
||||
|
||||
if (available_stores.empty()) {
|
||||
return ToolResult("No content available. Use `write` operation to store content first.");
|
||||
}
|
||||
|
||||
json result;
|
||||
result["available_stores"] = available_stores;
|
||||
result["next_cursor"] = "select_store";
|
||||
|
||||
return ToolResult(result);
|
||||
} else if (cursor == "select_store") {
|
||||
// 用户需要选择一个存储ID
|
||||
return ToolError("Please provide a store_id as cursor in format `store_id:content_X`");
|
||||
} else if (cursor.find("store_id:") == 0) {
|
||||
// 用户选择了一个存储ID
|
||||
std::string store_id = cursor.substr(9); // 移除 "store_id:" 前缀
|
||||
|
||||
if (content_store_.find(store_id) == content_store_.end()) {
|
||||
return ToolError("Store ID `" + store_id + "` not found");
|
||||
}
|
||||
|
||||
// 返回该存储的第一个内容项
|
||||
json result = content_store_[store_id][0];
|
||||
|
||||
// 添加导航信息
|
||||
if (content_store_[store_id].size() > 1) {
|
||||
result["next_cursor"] = store_id + ":1";
|
||||
result["remaining_items"] = content_store_[store_id].size() - 1;
|
||||
} else {
|
||||
result["next_cursor"] = "end";
|
||||
result["remaining_items"] = 0;
|
||||
}
|
||||
|
||||
return ToolResult(result);
|
||||
} else if (cursor.find(":") != std::string::npos) {
|
||||
// 用户正在浏览特定存储内的内容
|
||||
size_t delimiter_pos = cursor.find(":");
|
||||
std::string store_id = cursor.substr(0, delimiter_pos);
|
||||
size_t index = std::stoul(cursor.substr(delimiter_pos + 1));
|
||||
|
||||
if (content_store_.find(store_id) == content_store_.end()) {
|
||||
return ToolError("Store ID `" + store_id + "` not found");
|
||||
}
|
||||
|
||||
if (index >= content_store_[store_id].size()) {
|
||||
return ToolError("Index out of range");
|
||||
}
|
||||
|
||||
// 返回请求的内容项
|
||||
json result = content_store_[store_id][index];
|
||||
|
||||
// 添加导航信息
|
||||
if (index + 1 < content_store_[store_id].size()) {
|
||||
result["next_cursor"] = store_id + ":" + std::to_string(index + 1);
|
||||
result["remaining_items"] = content_store_[store_id].size() - index - 1;
|
||||
} else {
|
||||
result["next_cursor"] = "end";
|
||||
result["remaining_items"] = 0;
|
||||
}
|
||||
|
||||
return ToolResult(result);
|
||||
} else if (cursor == "end") {
|
||||
return ToolResult("You have reached the end of the content.");
|
||||
} else {
|
||||
return ToolError("Invalid cursor format");
|
||||
}
|
||||
}
|
||||
|
||||
ToolResult execute(const json& args) override {
|
||||
try {
|
||||
if (!args.contains("operation")) {
|
||||
return ToolError("`operation` is required");
|
||||
}
|
||||
|
||||
std::string operation = args["operation"];
|
||||
|
||||
if (operation == "write") {
|
||||
return handle_write(args);
|
||||
} else if (operation == "read") {
|
||||
return handle_read(args);
|
||||
} else {
|
||||
return ToolError("Unknown operation `" + operation + "`. Please use `write` or `read`");
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
return ToolError(std::string(e.what()));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace humanus
|
||||
|
||||
#endif // HUMANUS_TOOL_CONTENT_PROVIDER_H
|
|
@ -1,6 +1,8 @@
|
|||
#ifndef HUMANUS_TOOL_FACT_EXTRACT_H
|
||||
#define HUMANUS_TOOL_FACT_EXTRACT_H
|
||||
|
||||
#include "tool/base.h"
|
||||
|
||||
namespace humanus {
|
||||
|
||||
struct FactExtract : BaseTool {
|
||||
|
|
|
@ -0,0 +1,253 @@
|
|||
#ifndef HUMANUS_TOOL_PLAYWRIGHT_H
|
||||
#define HUMANUS_TOOL_PLAYWRIGHT_H
|
||||
|
||||
#include "base.h"
|
||||
|
||||
namespace humanus {
|
||||
|
||||
struct Playwright : BaseMCPTool {
|
||||
inline static const std::string name_ = "playwright";
|
||||
inline static const std::string description_ = "Interact with web pages, take screenshots, generate test code, web scraps the page and execute JavaScript in a real browser environment. Note: Most of the time you need to observer the page before executing other actions.";
|
||||
inline static const json parameters_ = json::parse(R"json({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"navigate",
|
||||
"screenshot",
|
||||
"click",
|
||||
"iframe_click",
|
||||
"fill",
|
||||
"select",
|
||||
"hover",
|
||||
"evaluate",
|
||||
"console_logs",
|
||||
"close",
|
||||
"get",
|
||||
"post",
|
||||
"put",
|
||||
"patch",
|
||||
"delete",
|
||||
"expect_response",
|
||||
"assert_response",
|
||||
"custom_user_agent",
|
||||
"get_visible_text",
|
||||
"get_visible_html",
|
||||
"go_back",
|
||||
"go_forward",
|
||||
"drag",
|
||||
"press_key",
|
||||
"save_as_pdf"
|
||||
],
|
||||
"description": "Specify the command to perform on the web page using Playwright."
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL to navigate to, or to perform HTTP operations on. **Required by**: `navigate`, `get`, `post`, `put`, `patch`, `delete`, `expect_response`."
|
||||
},
|
||||
"selector": {
|
||||
"type": "string",
|
||||
"description": "CSS selector for the element to interact with. Note: Use JS to determine available selectors first. **Required by**: `click`, `iframe_click`, `fill`, `select`, `hover`, `drag`, `press_key`."
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name for the screenshot or file operations. **Required by**: `screenshot`."
|
||||
},
|
||||
"browserType": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"chromium",
|
||||
"firefox",
|
||||
"webkit"
|
||||
],
|
||||
"description": "Browser type to use. Defaults to chromium. **Used by**: `navigate`."
|
||||
},
|
||||
"width": {
|
||||
"type": "number",
|
||||
"description": "Viewport width in pixels. Defaults to 1280. **Used by**: `navigate`, `screenshot`."
|
||||
},
|
||||
"height": {
|
||||
"type": "number",
|
||||
"description": "Viewport height in pixels. Defaults to 720. **Used by**: `navigate`, `screenshot`."
|
||||
},
|
||||
"timeout": {
|
||||
"type": "number",
|
||||
"description": "Navigation or operation timeout in milliseconds. **Used by**: `navigate`."
|
||||
},
|
||||
"waitUntil": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"load",
|
||||
"domcontentloaded",
|
||||
"networkidle",
|
||||
"commit"
|
||||
],
|
||||
"description": "Navigation wait condition. **Used by**: `navigate`."
|
||||
},
|
||||
"headless": {
|
||||
"type": "boolean",
|
||||
"description": "Run browser in headless mode. Defaults to false. **Used by**: `navigate`."
|
||||
},
|
||||
"fullPage": {
|
||||
"type": "boolean",
|
||||
"description": "Capture the entire page. Defaults to false. **Used by**: `screenshot`."
|
||||
},
|
||||
"savePng": {
|
||||
"type": "boolean",
|
||||
"description": "Save the screenshot as a PNG file. Defaults to false. **Used by**: `screenshot`."
|
||||
},
|
||||
"storeBase64": {
|
||||
"type": "boolean",
|
||||
"description": "Store screenshot in base64 format. Defaults to true. **Used by**: `screenshot`."
|
||||
},
|
||||
"downloadsDir": {
|
||||
"type": "string",
|
||||
"description": "Path to save the file. Defaults to user's Downloads folder. **Used by**: `screenshot`."
|
||||
},
|
||||
"iframeSelector": {
|
||||
"type": "string",
|
||||
"description": "CSS selector for the iframe containing the element to click. **Required by**: `iframe_click`."
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Value to fill in an input or select in a dropdown. **Required by**: `fill`, `select`."
|
||||
},
|
||||
"sourceSelector": {
|
||||
"type": "string",
|
||||
"description": "CSS selector for the source element to drag. **Required by**: `drag`."
|
||||
},
|
||||
"targetSelector": {
|
||||
"type": "string",
|
||||
"description": "CSS selector for the target location to drag to. **Required by**: `drag`."
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Key to press on the keyboard. **Required by**: `press_key`."
|
||||
},
|
||||
"outputPath": {
|
||||
"type": "string",
|
||||
"description": "Directory path where the PDF will be saved. **Required by**: `save_as_pdf`."
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name of the PDF file. Defaults to `page.pdf`. **Used by**: `save_as_pdf`."
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "Page format, e.g., 'A4', 'Letter'. **Used by**: `save_as_pdf`."
|
||||
},
|
||||
"printBackground": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to print background graphics. **Used by**: `save_as_pdf`."
|
||||
},
|
||||
"margin": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"top": {
|
||||
"type": "string"
|
||||
},
|
||||
"right": {
|
||||
"type": "string"
|
||||
},
|
||||
"bottom": {
|
||||
"type": "string"
|
||||
},
|
||||
"left": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"description": "Margins of the page. **Used by**: `save_as_pdf`."
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
})json");
|
||||
|
||||
inline static std::set<std::string> allowed_commands = {
|
||||
"navigate",
|
||||
"screenshot",
|
||||
"click",
|
||||
"iframe_click",
|
||||
"fill",
|
||||
"select",
|
||||
"hover",
|
||||
"evaluate",
|
||||
"console_logs",
|
||||
"close",
|
||||
"get",
|
||||
"post",
|
||||
"put",
|
||||
"patch",
|
||||
"delete",
|
||||
"expect_response",
|
||||
"assert_response",
|
||||
"custom_user_agent",
|
||||
"get_visible_text",
|
||||
"get_visible_html",
|
||||
"go_back",
|
||||
"go_forward",
|
||||
"drag",
|
||||
"press_key",
|
||||
"save_as_pdf"
|
||||
};
|
||||
|
||||
Playwright() : BaseMCPTool(name_, description_, parameters_) {}
|
||||
|
||||
ToolResult execute(const json& args) override {
|
||||
try {
|
||||
if (!_client) {
|
||||
return ToolError("Failed to initialize playwright client");
|
||||
}
|
||||
|
||||
std::string command;
|
||||
if (args.contains("command")) {
|
||||
if (args["command"].is_string()) {
|
||||
command = args["command"].get<std::string>();
|
||||
} else {
|
||||
return ToolError("Invalid command format");
|
||||
}
|
||||
} else {
|
||||
return ToolError("'command' is required");
|
||||
}
|
||||
|
||||
if (allowed_commands.find(command) == allowed_commands.end()) {
|
||||
return ToolError("Unknown command '" + command + "'. Please use one of the following commands: " +
|
||||
std::accumulate(allowed_commands.begin(), allowed_commands.end(), std::string(),
|
||||
[](const std::string& a, const std::string& b) {
|
||||
return a + (a.empty() ? "" : ", ") + b;
|
||||
}));
|
||||
}
|
||||
|
||||
json result = _client->call_tool("playwright_" + command, args);
|
||||
|
||||
if (result["content"].is_array()) {
|
||||
for (size_t i = 0; i < result["content"].size(); i++) {
|
||||
if (result["content"][i]["type"] == "image") {
|
||||
std::string data = result["content"][i]["data"].get<std::string>();
|
||||
std::string mimeType = result["content"][i].value("mimeType", "image/png");
|
||||
// Convert to OAI-compatible image_url format
|
||||
result["content"][i] = {
|
||||
{"type", "image_url"},
|
||||
{"image_url", {{"url", "data:" + mimeType + ";base64," + data}}}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool is_error = result.value("isError", false);
|
||||
|
||||
// Return different ToolResult based on whether there is an error
|
||||
if (is_error) {
|
||||
return ToolError(result.value("content", json::array()));
|
||||
} else {
|
||||
return ToolResult(result.value("content", json::array()));
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
return ToolError(std::string(e.what()));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace humanus
|
||||
|
||||
#endif // HUMANUS_TOOL_PLAYWRIGHT_H
|
|
@ -103,7 +103,7 @@ struct Puppeteer : BaseMCPTool {
|
|||
if (result["content"][i]["type"] == "image") {
|
||||
std::string data = result["content"][i]["data"].get<std::string>();
|
||||
std::string mimeType = result["content"][i].value("mimeType", "image/png");
|
||||
// Convert to OAI-complatible image_url format
|
||||
// Convert to OAI-compatible image_url format
|
||||
result["content"][i] = {
|
||||
{"type", "image_url"},
|
||||
{"image_url", {{"url", "data:" + mimeType + ";base64," + data}}}
|
||||
|
|
|
@ -14,7 +14,7 @@ struct PythonExecute : BaseMCPTool {
|
|||
{"properties", {
|
||||
{"code", {
|
||||
{"type", "string"},
|
||||
{"description", "The Python code to execute."}
|
||||
{"description", "The Python code to execute. Note: Use absolute file paths if code will read/write files."}
|
||||
}},
|
||||
{"timeout", {
|
||||
{"type", "number"},
|
||||
|
|
|
@ -54,6 +54,14 @@ struct ToolCollection {
|
|||
tools_map[tool->name] = tool;
|
||||
}
|
||||
|
||||
void add_mcp_tools(const std::string& mcp_server_name) {
|
||||
auto client = BaseMCPTool::create_client(mcp_server_name);
|
||||
auto tool_list = client->get_tools();
|
||||
for (auto tool : tool_list) {
|
||||
add_tool(std::make_shared<BaseMCPTool>(tool.name, tool.description, tool.parameters_schema, client));
|
||||
}
|
||||
}
|
||||
|
||||
void add_tools(const std::vector<std::shared_ptr<BaseTool>>& tools) {
|
||||
for (auto tool : tools) {
|
||||
add_tool(tool);
|
||||
|
|
Loading…
Reference in New Issue