add mem0 (compilation success w/o test)

main
hkr04 2025-03-26 00:38:43 +08:00
parent 47cc44f717
commit 4258a71d7a
26 changed files with 1230 additions and 309 deletions

View File

@ -84,6 +84,10 @@ file(GLOB FLOW_SOURCES
file(GLOB MEMORY_SOURCES file(GLOB MEMORY_SOURCES
"memory/*.cpp" "memory/*.cpp"
"memory/*.cc" "memory/*.cc"
"memory/*/*.cpp"
"memory/*/*.cc"
"memory/*/*/*.cpp"
"memory/*/*/*.cc"
) )
# humanus # humanus

View File

@ -10,10 +10,10 @@ namespace humanus {
Config* Config::_instance = nullptr; Config* Config::_instance = nullptr;
std::mutex Config::_mutex; std::mutex Config::_mutex;
void Config::_load_initial_config() { void Config::_load_initial_llm_config() {
try { try {
auto config_path = _get_config_path(); auto config_path = _get_llm_config_path();
std::cout << "Loading config file from: " << config_path.string() << std::endl; std::cout << "Loading LLM config file from: " << config_path.string() << std::endl;
const auto& data = toml::parse_file(config_path.string()); const auto& data = toml::parse_file(config_path.string());
@ -35,8 +35,12 @@ void Config::_load_initial_config() {
llm_config.base_url = llm_table["base_url"].as_string()->get(); llm_config.base_url = llm_table["base_url"].as_string()->get();
} }
if (llm_table.contains("end_point") && llm_table["end_point"].is_string()) { if (llm_table.contains("endpoint") && llm_table["endpoint"].is_string()) {
llm_config.end_point = llm_table["end_point"].as_string()->get(); llm_config.endpoint = llm_table["endpoint"].as_string()->get();
}
if (llm_table.contains("vision_details") && llm_table["vision_details"].is_string()) {
llm_config.vision_details = llm_table["vision_details"].as_string()->get();
} }
if (llm_table.contains("max_tokens") && llm_table["max_tokens"].is_integer()) { if (llm_table.contains("max_tokens") && llm_table["max_tokens"].is_integer()) {
@ -51,6 +55,10 @@ void Config::_load_initial_config() {
llm_config.temperature = llm_table["temperature"].as_floating_point()->get(); llm_config.temperature = llm_table["temperature"].as_floating_point()->get();
} }
if (llm_table.contains("enable_vision") && llm_table["enable_vision"].is_boolean()) {
llm_config.enable_vision = llm_table["enable_vision"].as_boolean()->get();
}
if (llm_table.contains("oai_tool_support") && llm_table["oai_tool_support"].is_boolean()) { if (llm_table.contains("oai_tool_support") && llm_table["oai_tool_support"].is_boolean()) {
llm_config.oai_tool_support = llm_table["oai_tool_support"].as_boolean()->get(); llm_config.oai_tool_support = llm_table["oai_tool_support"].as_boolean()->get();
} }
@ -95,4 +103,119 @@ void Config::_load_initial_config() {
} }
} }
void Config::_load_initial_embedding_model_config() {
try {
auto config_path = _get_embedding_model_config_path();
std::cout << "Loading embedding model config file from: " << config_path.string() << std::endl;
const auto& data = toml::parse_file(config_path.string());
// Load embedding model configuration
for (const auto& [key, value] : data) {
const auto& embd_table = *value.as_table();
EmbeddingModelConfig embd_config;
if (embd_table.contains("provider") && embd_table["provider"].is_string()) {
embd_config.provider = embd_table["provider"].as_string()->get();
}
if (embd_table.contains("base_url") && embd_table["base_url"].is_string()) {
embd_config.base_url = embd_table["base_url"].as_string()->get();
}
if (embd_table.contains("endpoint") && embd_table["endpoint"].is_string()) {
embd_config.endpoint = embd_table["endpoint"].as_string()->get();
}
if (embd_table.contains("model") && embd_table["model"].is_string()) {
embd_config.model = embd_table["model"].as_string()->get();
}
if (embd_table.contains("api_key") && embd_table["api_key"].is_string()) {
embd_config.api_key = embd_table["api_key"].as_string()->get();
}
if (embd_table.contains("embedding_dims") && embd_table["embedding_dims"].is_integer()) {
embd_config.embedding_dims = embd_table["embedding_dims"].as_integer()->get();
}
if (embd_table.contains("max_retries") && embd_table["max_retries"].is_integer()) {
embd_config.max_retries = embd_table["max_retries"].as_integer()->get();
}
_config.embedding_model[std::string(key.str())] = embd_config;
}
if (_config.embedding_model.empty()) {
throw std::runtime_error("No embedding model configuration found");
} else if (_config.embedding_model.find("default") == _config.embedding_model.end()) {
_config.embedding_model["default"] = _config.embedding_model.begin()->second;
}
} catch (const std::exception& e) {
std::cerr << "Loading embedding model config file failed: " << e.what() << std::endl;
// Set default configuration
_config.embedding_model["default"] = EmbeddingModelConfig();
}
}
void Config::_load_initial_vector_store_config() {
try {
auto config_path = _get_vector_store_config_path();
std::cout << "Loading vector store config file from: " << config_path.string() << std::endl;
const auto& data = toml::parse_file(config_path.string());
// Load vector store configuration
for (const auto& [key, value] : data) {
const auto& vs_table = *value.as_table();
VectorStoreConfig vs_config;
if (vs_table.contains("provider") && vs_table["provider"].is_string()) {
vs_config.provider = vs_table["provider"].as_string()->get();
}
if (vs_table.contains("dim") && vs_table["dim"].is_integer()) {
vs_config.dim = vs_table["dim"].as_integer()->get();
}
if (vs_table.contains("max_elements") && vs_table["max_elements"].is_integer()) {
vs_config.max_elements = vs_table["max_elements"].as_integer()->get();
}
if (vs_table.contains("M") && vs_table["M"].is_integer()) {
vs_config.M = vs_table["M"].as_integer()->get();
}
if (vs_table.contains("ef_construction") && vs_table["ef_construction"].is_integer()) {
vs_config.ef_construction = vs_table["ef_construction"].as_integer()->get();
}
if (vs_table.contains("metric") && vs_table["metric"].is_string()) {
const auto& metric_str = vs_table["metric"].as_string()->get();
if (metric_str == "L2") {
vs_config.metric = VectorStoreConfig::Metric::L2;
} else if (metric_str == "IP") {
vs_config.metric = VectorStoreConfig::Metric::IP;
} else {
throw std::runtime_error("Invalid metric: " + metric_str);
}
}
_config.vector_store[std::string(key.str())] = vs_config;
}
if (_config.vector_store.empty()) {
throw std::runtime_error("No vector store configuration found");
} else if (_config.vector_store.find("default") == _config.vector_store.end()) {
_config.vector_store["default"] = _config.vector_store.begin()->second;
}
} catch (const std::exception& e) {
std::cerr << "Loading vector store config file failed: " << e.what() << std::endl;
// Set default configuration
_config.vector_store["default"] = VectorStoreConfig();
}
}
} // namespace humanus } // namespace humanus

143
config.h
View File

@ -25,30 +25,34 @@ struct LLMConfig {
std::string model; std::string model;
std::string api_key; std::string api_key;
std::string base_url; std::string base_url;
std::string end_point; std::string endpoint;
std::string vision_details;
int max_tokens; int max_tokens;
int timeout; int timeout;
double temperature; double temperature;
bool enable_vision;
bool oai_tool_support; bool oai_tool_support;
LLMConfig( LLMConfig(
const std::string& model = "deepseek-chat", const std::string& model = "deepseek-chat",
const std::string& api_key = "sk-", const std::string& api_key = "sk-",
const std::string& base_url = "https://api.deepseek.com", const std::string& base_url = "https://api.deepseek.com",
const std::string& end_point = "/v1/chat/completions", const std::string& endpoint = "/v1/chat/completions",
const std::string& vision_details = "auto",
int max_tokens = 4096, int max_tokens = 4096,
int timeout = 120, int timeout = 120,
double temperature = 1.0, double temperature = 1.0,
bool enable_vision = false,
bool oai_tool_support = true bool oai_tool_support = true
) : model(model), api_key(api_key), base_url(base_url), end_point(end_point), ) : model(model), api_key(api_key), base_url(base_url), endpoint(endpoint), vision_details(vision_details),
max_tokens(max_tokens), timeout(timeout), temperature(temperature), oai_tool_support(oai_tool_support) {} max_tokens(max_tokens), timeout(timeout), temperature(temperature), enable_vision(enable_vision), oai_tool_support(oai_tool_support) {}
json to_json() const { json to_json() const {
json j; json j;
j["model"] = model; j["model"] = model;
j["api_key"] = api_key; j["api_key"] = api_key;
j["base_url"] = base_url; j["base_url"] = base_url;
j["end_point"] = end_point; j["endpoint"] = endpoint;
j["max_tokens"] = max_tokens; j["max_tokens"] = max_tokens;
j["temperature"] = temperature; j["temperature"] = temperature;
return j; return j;
@ -153,9 +157,76 @@ struct ToolParser {
} }
}; };
enum class EmbeddingType {
ADD = 0,
SEARCH = 1,
UPDATE = 2
};
struct EmbeddingModelConfig {
std::string provider = "oai";
std::string base_url = "http://localhost:8080";
std::string endpoint = "/v1/embeddings";
std::string model = "nomic-embed-text-v1.5.f16.gguf";
std::string api_key = "";
int embedding_dims = 768;
int max_retries = 3;
};
struct VectorStoreConfig {
std::string provider = "hnswlib";
int dim = 16; // Dimension of the elements
int max_elements = 10000; // Maximum number of elements, should be known beforehand
int M = 16; // Tightly connected with internal dimensionality of the data
// strongly affects the memory consumption
int ef_construction = 200; // Controls index search speed/build speed tradeoff
enum class Metric {
L2,
IP
};
Metric metric = Metric::L2;
};
namespace mem0 {
struct MemoryConfig {
// Base config
int max_messages = 5; // Short-term memory capacity
int limit = 5; // Number of results to retrive from long-term memory
std::string filters = ""; // Filters to apply to search results
// Prompt config
std::string fact_extraction_prompt = prompt::mem0::FACT_EXTRACTION_PROMPT;
std::string update_memory_prompt = prompt::mem0::UPDATE_MEMORY_PROMPT;
// Database config
// std::string history_db_path = ":memory:";
// EmbeddingModel config
std::shared_ptr<EmbeddingModelConfig> embedding_model_config = nullptr;
// Vector store config
std::shared_ptr<VectorStoreConfig> vector_store_config = nullptr;
// Optional: LLM config
std::shared_ptr<LLMConfig> llm_config = nullptr;
};
struct MemoryItem {
size_t id; // The unique identifier for the text data
std::string memory; // The memory deduced from the text data
std::string hash; // The hash of the memory
json metadata; // Any additional metadata associated with the memory, like 'created_at' or 'updated_at'
float score; // The score associated with the text data, used for ranking and sorting
};
} // namespace mem0
struct AppConfig { struct AppConfig {
std::map<std::string, LLMConfig> llm; std::unordered_map<std::string, LLMConfig> llm;
std::map<std::string, ToolParser> tool_parser; std::unordered_map<std::string, ToolParser> tool_parser;
std::unordered_map<std::string, EmbeddingModelConfig> embedding_model;
std::unordered_map<std::string, VectorStoreConfig> vector_store;
}; };
class Config { class Config {
@ -166,7 +237,9 @@ private:
AppConfig _config; AppConfig _config;
Config() { Config() {
_load_initial_config(); _load_initial_llm_config();
_load_initial_embedding_model_config();
_load_initial_vector_store_config();
_initialized = true; _initialized = true;
} }
@ -177,19 +250,47 @@ private:
* @brief Get the config path * @brief Get the config path
* @return The config path * @return The config path
*/ */
static std::filesystem::path _get_config_path() { static std::filesystem::path _get_llm_config_path() {
auto root = PROJECT_ROOT; auto root = PROJECT_ROOT;
auto config_path = root / "config" / "config_llm.toml"; auto config_path = root / "config" / "config_llm.toml";
if (std::filesystem::exists(config_path)) { if (std::filesystem::exists(config_path)) {
return config_path; return config_path;
} }
throw std::runtime_error("Config file not found"); throw std::runtime_error("LLM Config file not found");
}
static std::filesystem::path _get_embedding_model_config_path() {
auto root = PROJECT_ROOT;
auto config_path = root / "config" / "config_embd.toml";
if (std::filesystem::exists(config_path)) {
return config_path;
}
throw std::runtime_error("Embedding Model Config file not found");
}
static std::filesystem::path _get_vector_store_config_path() {
auto root = PROJECT_ROOT;
auto config_path = root / "config" / "config_vec.toml";
if (std::filesystem::exists(config_path)) {
return config_path;
}
throw std::runtime_error("Vector Store Config file not found");
} }
/** /**
* @brief Load the initial config * @brief Load the initial config
*/ */
void _load_initial_config(); void _load_initial_llm_config();
/**
* @brief Load the initial embedding model config
*/
void _load_initial_embedding_model_config();
/**
* @brief Load the initial vector store config
*/
void _load_initial_vector_store_config();
public: public:
/** /**
@ -210,7 +311,7 @@ public:
* @brief Get the LLM settings * @brief Get the LLM settings
* @return The LLM settings map * @return The LLM settings map
*/ */
const std::map<std::string, LLMConfig>& llm() const { const std::unordered_map<std::string, LLMConfig>& llm() const {
return _config.llm; return _config.llm;
} }
@ -218,10 +319,26 @@ public:
* @brief Get the tool helpers * @brief Get the tool helpers
* @return The tool helpers map * @return The tool helpers map
*/ */
const std::map<std::string, ToolParser>& tool_parser() const { const std::unordered_map<std::string, ToolParser>& tool_parser() const {
return _config.tool_parser; return _config.tool_parser;
} }
/**
* @brief Get the embedding model settings
* @return The embedding model settings map
*/
const std::unordered_map<std::string, EmbeddingModelConfig>& embedding_model() const {
return _config.embedding_model;
}
/**
* @brief Get the vector store settings
* @return The vector store settings map
*/
const std::unordered_map<std::string, VectorStoreConfig>& vector_store() const {
return _config.vector_store;
}
/** /**
* @brief Get the app config * @brief Get the app config
* @return The app config * @return The app config

View File

@ -0,0 +1,8 @@
[default]
provider = "oai"
base_url = "http://localhost:8080"
endpoint = "/v1/embeddings"
model = "nomic-embed-text-v1.5.f16.gguf"
api_key = ""
embeddings_dim = 768
max_retries = 3

View File

@ -1,7 +1,7 @@
[default] [default]
model = "deepseek-reasoner" model = "deepseek-reasoner"
base_url = "https://api.deepseek.com" base_url = "https://api.deepseek.com"
end_point = "/v1/chat/completions" endpoint = "/v1/chat/completions"
api_key = "sk-93c5bfcb920c4a8aa345791d429b8536" api_key = "sk-93c5bfcb920c4a8aa345791d429b8536"
max_tokens = 8192 max_tokens = 8192
oai_tool_support = false oai_tool_support = false

View File

@ -1,28 +1,28 @@
[llm] [llm]
model = "anthropic/claude-3.7-sonnet" model = "anthropic/claude-3.7-sonnet"
base_url = "https://openrouter.ai" base_url = "https://openrouter.ai"
end_point = "/api/v1/chat/completions" endpoint = "/api/v1/chat/completions"
api_key = "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" api_key = "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"
max_tokens = 8192 max_tokens = 8192
[llm] [llm]
model = "deepseek-chat" model = "deepseek-chat"
base_url = "https://api.deepseek.com" base_url = "https://api.deepseek.com"
end_point = "/v1/chat/completions" endpoint = "/v1/chat/completions"
api_key = "sk-93c5bfcb920c4a8aa345791d429b8536" api_key = "sk-93c5bfcb920c4a8aa345791d429b8536"
max_tokens = 8192 max_tokens = 8192
[llm] [llm]
model = "qwen-max" model = "qwen-max"
base_url = "https://dashscope.aliyuncs.com" base_url = "https://dashscope.aliyuncs.com"
end_point = "/compatible-mode/v1/chat/completions" endpoint = "/compatible-mode/v1/chat/completions"
api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600" api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600"
max_tokens = 8192 max_tokens = 8192
[llm] [llm]
model = "deepseek-reasoner" model = "deepseek-reasoner"
base_url = "https://api.deepseek.com" base_url = "https://api.deepseek.com"
end_point = "/v1/chat/completions" endpoint = "/v1/chat/completions"
api_key = "sk-93c5bfcb920c4a8aa345791d429b8536" api_key = "sk-93c5bfcb920c4a8aa345791d429b8536"
max_tokens = 8192 max_tokens = 8192
oai_tool_support = false oai_tool_support = false

View File

@ -0,0 +1,8 @@
[default]
provider = "hnswlib"
dim = 768 # Dimension of the elements
max_elements = 100 # Maximum number of elements, should be known beforehand
M = 16 # Tightly connected with internal dimensionality of the data
# strongly affects the memory consumption
ef_construction = 200 # Controls index search speed/build speed tradeoff
metric = "L2" # Distance metric to use, can be L2 or IP

View File

@ -0,0 +1,12 @@
set(target humanus_cli_plan_mem0)
add_executable(${target} humanus_plan_mem0.cpp)
#
target_link_libraries(${target} PRIVATE humanus)
#
set_target_properties(${target}
PROPERTIES
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
)

View File

@ -0,0 +1,145 @@
#include "agent/humanus.h"
#include "logger.h"
#include "prompt.h"
#include "flow/flow_factory.h"
#include "memory/mem0/base.h"
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <signal.h>
#endif
using namespace humanus;
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void sigint_handler(int signo) {
if (signo == SIGINT) {
logger->info("Interrupted by user\n");
exit(0);
}
}
#endif
static bool readline_utf8(std::string & line, bool multiline_input) {
#if defined(_WIN32)
std::wstring wline;
if (!std::getline(std::wcin, wline)) {
// Input stream is bad or EOF received
line.clear();
GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
return false;
}
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
line.resize(size_needed);
WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
#else
if (!std::getline(std::cin, line)) {
// Input stream is bad or EOF received
line.clear();
return false;
}
#endif
if (!line.empty()) {
char last = line.back();
if (last == '/') { // Always return control on '/' symbol
line.pop_back();
return false;
}
if (last == '\\') { // '\\' changes the default action
line.pop_back();
multiline_input = !multiline_input;
}
}
line += '\n';
// By default, continue input if multiline_input is set
return multiline_input;
}
int main() {
// ctrl+C handling
{
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = sigint_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
SetConsoleCP(CP_UTF8);
SetConsoleOutputCP(CP_UTF8);
_setmode(_fileno(stdin), _O_WTEXT); // wide character input mode
#endif
}
std::shared_ptr<BaseAgent> agent_ptr = std::make_shared<Humanus>(
ToolCollection( // Add general-purpose tools to the tool collection
{
std::make_shared<PythonExecute>(),
std::make_shared<Puppeteer>(), // for web browsing
std::make_shared<Filesystem>(),
std::make_shared<Terminate>()
}
),
"auto",
std::set<std::string>{"terminate"},
"humanus_mem0",
"A versatile agent that can solve various tasks using multiple tools",
prompt::humanus::SYSTEM_PROMPT,
prompt::humanus::NEXT_STEP_PROMPT,
nullptr,
std::make_shared<mem0::Memory>(mem0::MemoryConfig())
);
std::map<std::string, std::shared_ptr<BaseAgent>> agents;
agents["default"] = agent_ptr;
auto flow = FlowFactory::create_flow(
FlowType::PLANNING,
nullptr, // llm
nullptr, // planning_tool
std::vector<std::string>{}, // executor_keys
"", // active_plan_id
agents, // agents
std::vector<std::shared_ptr<BaseTool>>{}, // tools
"default" // primary_agent_key
);
while (true) {
if (agent_ptr->current_step == agent_ptr->max_steps) {
std::cout << "Automatically paused after " << agent_ptr->current_step << " steps." << std::endl;
std::cout << "Enter your prompt (enter an empty line to resume or 'exit' to quit): ";
agent_ptr->reset(false);
} else {
std::cout << "Enter your prompt (or 'exit' to quit): ";
}
if (agent_ptr->state != AgentState::IDLE) {
break;
}
std::string prompt;
readline_utf8(prompt, false);
if (prompt == "exit" || prompt == "exit\n") {
logger->info("Goodbye!");
break;
}
std::cout << "Processing your request..." << std::endl;
auto result = flow->execute(prompt);
std::cout << result << std::endl;
}
}

View File

@ -2,5 +2,5 @@
namespace humanus { namespace humanus {
// 定义静态成员变量 // 定义静态成员变量
std::map<std::string, std::shared_ptr<LLM>> LLM::_instances; std::unordered_map<std::string, std::shared_ptr<LLM>> LLM::instances_;
} }

27
llm.h
View File

@ -17,7 +17,7 @@ namespace humanus {
class LLM { class LLM {
private: private:
static std::map<std::string, std::shared_ptr<LLM>> _instances; static std::unordered_map<std::string, std::shared_ptr<LLM>> instances_;
std::unique_ptr<httplib::Client> client_; std::unique_ptr<httplib::Client> client_;
@ -27,13 +27,7 @@ private:
public: public:
// Constructor // Constructor
LLM(const std::string& config_name, const std::shared_ptr<LLMConfig>& llm_config = nullptr, const std::shared_ptr<ToolParser>& tool_parser = nullptr) : llm_config_(llm_config), tool_parser_(tool_parser) { LLM(const std::string& config_name, const std::shared_ptr<LLMConfig>& config = nullptr, const std::shared_ptr<ToolParser>& tool_parser = nullptr) : llm_config_(config), tool_parser_(tool_parser) {
if (!llm_config_) {
if (Config::get_instance().llm().find(config_name) == Config::get_instance().llm().end()) {
throw std::invalid_argument("LLM config not found: " + config_name);
}
llm_config_ = std::make_shared<LLMConfig>(Config::get_instance().llm().at(config_name));
}
if (!llm_config_->oai_tool_support && !tool_parser_) { if (!llm_config_->oai_tool_support && !tool_parser_) {
if (Config::get_instance().tool_parser().find(config_name) == Config::get_instance().tool_parser().end()) { if (Config::get_instance().tool_parser().find(config_name) == Config::get_instance().tool_parser().end()) {
throw std::invalid_argument("Tool helper config not found: " + config_name); throw std::invalid_argument("Tool helper config not found: " + config_name);
@ -49,10 +43,17 @@ public:
// Get the singleton instance // Get the singleton instance
static std::shared_ptr<LLM> get_instance(const std::string& config_name = "default", const std::shared_ptr<LLMConfig>& llm_config = nullptr) { static std::shared_ptr<LLM> get_instance(const std::string& config_name = "default", const std::shared_ptr<LLMConfig>& llm_config = nullptr) {
if (_instances.find(config_name) == _instances.end()) { if (instances_.find(config_name) == instances_.end()) {
_instances[config_name] = std::make_shared<LLM>(config_name, llm_config); auto llm_config_ = llm_config;
if (!llm_config_) {
if (Config::get_instance().llm().find(config_name) == Config::get_instance().llm().end()) {
throw std::invalid_argument("LLM config not found: " + config_name);
} }
return _instances[config_name]; llm_config_ = std::make_shared<LLMConfig>(Config::get_instance().llm().at(config_name));
}
instances_[config_name] = std::make_shared<LLM>(config_name, llm_config_);
}
return instances_[config_name];
} }
/** /**
@ -194,7 +195,7 @@ public:
while (retry <= max_retries) { while (retry <= max_retries) {
// send request // send request
auto res = client_->Post(llm_config_->end_point, body_str, "application/json"); auto res = client_->Post(llm_config_->endpoint, body_str, "application/json");
if (!res) { if (!res) {
logger->error("Failed to send request: " + httplib::to_string(res.error())); logger->error("Failed to send request: " + httplib::to_string(res.error()));
@ -325,7 +326,7 @@ public:
while (retry <= max_retries) { while (retry <= max_retries) {
// send request // send request
auto res = client_->Post(llm_config_->end_point, body_str, "application/json"); auto res = client_->Post(llm_config_->endpoint, body_str, "application/json");
if (!res) { if (!res) {
logger->error("Failed to send request: " + httplib::to_string(res.error())); logger->error("Failed to send request: " + httplib::to_string(res.error()));

View File

@ -25,7 +25,7 @@ struct BaseMemory {
messages.clear(); messages.clear();
} }
virtual std::vector<Message> get_messages() const { virtual std::vector<Message> get_messages(const std::string& query = "") const {
return messages; return messages;
} }

280
memory/mem0/base.h 100644
View File

@ -0,0 +1,280 @@
#ifndef HUMANUS_MEMORY_MEM0_H
#define HUMANUS_MEMORY_MEM0_H
#include "memory/base.h"
#include "vector_store/base.h"
#include "embedding_model/base.h"
#include "schema.h"
#include "prompt.h"
#include "httplib.h"
#include "llm.h"
#include "utils.h"
namespace humanus::mem0 {
struct Memory : BaseMemory {
MemoryConfig config;
std::string fact_extraction_prompt;
std::string update_memory_prompt;
int max_messages;
int limit;
std::string filters;
std::shared_ptr<EmbeddingModel> embedding_model;
std::shared_ptr<VectorStore> vector_store;
std::shared_ptr<LLM> llm;
// std::shared_ptr<SQLiteManager> db;
Memory(const MemoryConfig& config) : config(config) {
fact_extraction_prompt = config.fact_extraction_prompt;
update_memory_prompt = config.update_memory_prompt;
max_messages = config.max_messages;
limit = config.limit;
filters = config.filters;
embedding_model = EmbeddingModel::get_instance("mem0_" + std::to_string(reinterpret_cast<uintptr_t>(this)), config.embedding_model_config);
vector_store = VectorStore::get_instance("mem0_" + std::to_string(reinterpret_cast<uintptr_t>(this)), config.vector_store_config);
llm = LLM::get_instance("mem0_" + std::to_string(reinterpret_cast<uintptr_t>(this)), config.llm_config);
// db = std::make_shared<SQLiteManager>(config.history_db_path);
}
void add_message(const Message& message) override {
while (!messages.empty() && (messages.size() > max_messages || messages.begin()->role == "assistant" || messages.begin()->role == "tool")) {
// Ensure the first message is always a user or system message
Message front_message = *messages.begin();
messages.erase(messages.begin());
if (config.llm_config->enable_vision) {
front_message = parse_vision_message(front_message, llm, config.llm_config->vision_details);
} else {
front_message = parse_vision_message(front_message);
}
_add_to_vector_store(front_message);
}
messages.push_back(message);
}
std::vector<Message> get_messages(const std::string& query = "") const override {
auto embeddings = embedding_model->embed(query, EmbeddingType::SEARCH);
std::vector<MemoryItem> memories;
// 检查vector_store是否已初始化
if (vector_store) {
memories = vector_store->search(embeddings, limit, filters);
}
std::string memory_prompt;
for (const auto& memory_item : memories) {
memory_prompt += "<memory>" + memory_item.memory + "</memory>";
}
std::vector<Message> messages_with_memory{Message::user_message(memory_prompt)};
messages_with_memory.insert(messages_with_memory.end(), messages.begin(), messages.end());
return messages_with_memory;
}
void _add_to_vector_store(const Message& message) {
// 检查vector_store是否已初始化
if (!vector_store) {
logger->warn("Vector store is not initialized, skipping memory operation");
return;
}
std::string parsed_message = message.role + ": " + (message.content.is_string() ? message.content.get<std::string>() : message.content.dump());
for (const auto& tool_call : message.tool_calls) {
parsed_message += "<tool_call>" + tool_call.to_json().dump() + "</tool_call>";
}
std::string system_prompt = fact_extraction_prompt;
std::string user_prompt = "Input:\n" + parsed_message;
Message user_message = Message::user_message(user_prompt);
std::string response = llm->ask(
{user_message},
system_prompt
);
json new_retrieved_facts; // ["fact1", "fact2", "fact3"]
try {
// response = remove_code_blocks(response);
new_retrieved_facts = json::parse(response)["facts"];
} catch (const std::exception& e) {
logger->error("Error in new_retrieved_facts: " + std::string(e.what()));
}
std::vector<json> retrieved_old_memory;
std::map<std::string, std::vector<float>> new_message_embeddings;
for (const auto& fact : new_retrieved_facts) {
auto message_embedding = embedding_model->embed(fact, EmbeddingType::ADD);
new_message_embeddings[fact] = message_embedding;
auto existing_memories = vector_store->search(
message_embedding,
5
);
for (const auto& memory : existing_memories) {
retrieved_old_memory.push_back({
{"id", memory.id},
{"text", memory.metadata["data"]}
});
}
}
// sort and unique by id
std::sort(retrieved_old_memory.begin(), retrieved_old_memory.end(), [](const json& a, const json& b) {
return a["id"] < b["id"];
});
retrieved_old_memory.resize(std::unique(retrieved_old_memory.begin(), retrieved_old_memory.end(), [](const json& a, const json& b) {
return a["id"] == b["id"];
}) - retrieved_old_memory.begin());
logger->info("Total existing memories: " + std::to_string(retrieved_old_memory.size()));
// mapping UUIDs with integers for handling UUID hallucinations
std::vector<size_t> temp_uuid_mapping;
for (size_t idx = 0; idx < retrieved_old_memory.size(); ++idx) {
temp_uuid_mapping.push_back(retrieved_old_memory[idx]["id"].get<size_t>());
retrieved_old_memory[idx]["id"] = idx;
}
std::string function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts, fact_extraction_prompt, update_memory_prompt);
std::string new_memories_with_actions_str;
json new_memories_with_actions = json::array();
try {
new_memories_with_actions_str = llm->ask(
{Message::user_message(function_calling_prompt)}
);
new_memories_with_actions = json::parse(new_memories_with_actions_str);
} catch (const std::exception& e) {
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
}
try {
// new_memories_with_actions_str = remove_code_blocks(new_memories_with_actions_str);
new_memories_with_actions = json::parse(new_memories_with_actions_str);
} catch (const std::exception& e) {
logger->error("Invalid JSON response: " + std::string(e.what()));
}
try {
for (const auto& resp : new_memories_with_actions.value("memory", json::array())) {
logger->info("Processing memory: " + resp.dump(2));
try {
if (!resp.contains("text")) {
logger->info("Skipping memory entry because of empty `text` field.");
continue;
}
std::string event = resp.value("event", "NONE");
size_t memory_id = resp.contains("id") ? temp_uuid_mapping[resp["id"].get<size_t>()] : uuid();
if (event == "ADD") {
_create_memory(
memory_id,
resp["text"], // data
new_message_embeddings // existing_embeddings
);
} else if (event == "UPDATE") {
_update_memory(
memory_id,
resp["text"], // data
new_message_embeddings // existing_embeddings
);
} else if (event == "DELETE") {
_delete_memory(memory_id);
} else if (event == "NONE") {
logger->info("NOOP for Memory.");
}
} catch (const std::exception& e) {
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
}
}
} catch (const std::exception& e) {
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
}
}
void _create_memory(const size_t& memory_id, const std::string& data, const std::map<std::string, std::vector<float>>& existing_embeddings) {
if (!vector_store) {
logger->warn("Vector store is not initialized, skipping create memory");
return;
}
std::vector<float> embedding;
if (existing_embeddings.find(data) != existing_embeddings.end()) {
embedding = existing_embeddings.at(data);
} else {
embedding = embedding_model->embed(data, EmbeddingType::ADD);
}
auto created_at = std::chrono::system_clock::now();
json metadata = {
{"data", data},
{"hash", httplib::detail::MD5(data)},
{"created_at", std::chrono::system_clock::now().time_since_epoch().count()}
};
vector_store->insert(
embedding,
memory_id,
metadata
);
}
void _update_memory(const size_t& memory_id, const std::string& data, const std::map<std::string, std::vector<float>>& existing_embeddings) {
if (!vector_store) {
logger->warn("Vector store is not initialized, skipping update memory");
return;
}
logger->info("Updating memory with " + data);
MemoryItem existing_memory;
try {
existing_memory = vector_store->get(memory_id);
} catch (const std::exception& e) {
logger->error("Error fetching existing memory: " + std::string(e.what()));
return;
}
std::vector<float> embedding;
if (existing_embeddings.find(data) != existing_embeddings.end()) {
embedding = existing_embeddings.at(data);
} else {
embedding = embedding_model->embed(data, EmbeddingType::ADD);
}
json metadata = existing_memory.metadata;
metadata["data"] = data;
metadata["hash"] = httplib::detail::MD5(data);
metadata["updated_at"] = std::chrono::system_clock::now().time_since_epoch().count();
vector_store->update(
memory_id,
embedding,
metadata
);
}
void _delete_memory(const size_t& memory_id) {
if (!vector_store) {
logger->warn("Vector store is not initialized, skipping delete memory");
return;
}
logger->info("Deleting memory: " + std::to_string(memory_id));
vector_store->delete_vector(memory_id);
}
};
} // namespace humanus::mem0
#endif // HUMANUS_MEMORY_MEM0_H

View File

@ -0,0 +1,27 @@
#include "base.h"
#include "oai.h"
namespace humanus::mem0 {
std::unordered_map<std::string, std::shared_ptr<EmbeddingModel>> EmbeddingModel::instances_;
std::shared_ptr<EmbeddingModel> EmbeddingModel::get_instance(const std::string& config_name, const std::shared_ptr<EmbeddingModelConfig>& config) {
if (instances_.find(config_name) == instances_.end()) {
auto config_ = config;
if (!config_) {
if (Config::get_instance().embedding_model().find(config_name) == Config::get_instance().embedding_model().end()) {
throw std::invalid_argument("Embedding model config not found: " + config_name);
}
config_ = std::make_shared<EmbeddingModelConfig>(Config::get_instance().embedding_model().at(config_name));
}
if (config_->provider == "oai") {
instances_[config_name] = std::make_shared<OAIEmbeddingModel>(config_);
} else {
throw std::invalid_argument("Unsupported embedding model provider: " + config_->provider);
}
}
return instances_[config_name];
}
} // namespace humanus::mem0

View File

@ -0,0 +1,33 @@
#ifndef HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_BASE_H
#define HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_BASE_H
#include "httplib.h"
#include "logger.h"
#include <vector>
#include <unordered_map>
#include <memory>
namespace humanus::mem0 {
class EmbeddingModel {
private:
static std::unordered_map<std::string, std::shared_ptr<EmbeddingModel>> instances_;
protected:
std::shared_ptr<EmbeddingModelConfig> config_;
// Constructor
EmbeddingModel(const std::shared_ptr<EmbeddingModelConfig>& config) : config_(config) {}
public:
// Get the singleton instance
static std::shared_ptr<EmbeddingModel> get_instance(const std::string& config_name = "default", const std::shared_ptr<EmbeddingModelConfig>& config = nullptr);
virtual ~EmbeddingModel() = default;
virtual std::vector<float> embed(const std::string& text, EmbeddingType type) = 0;
};
} // namespace humanus::mem0
#endif // HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_BASE_H

View File

@ -0,0 +1,48 @@
#include "oai.h"
namespace humanus::mem0 {
std::vector<float> OAIEmbeddingModel::embed(const std::string& text, EmbeddingType /* type */) {
json body = {
{"model", config_->model},
{"input", text},
{"encoding_format", "float"}
};
std::string body_str = body.dump();
int retry = 0;
while (retry <= config_->max_retries) {
// send request
auto res = client_->Post(config_->endpoint, body_str, "application/json");
if (!res) {
logger->error("Failed to send request: " + httplib::to_string(res.error()));
} else if (res->status == 200) {
try {
json json_data = json::parse(res->body);
return json_data["data"][0]["embedding"].get<std::vector<float>>();
} catch (const std::exception& e) {
logger->error("Failed to parse response: " + std::string(e.what()));
}
} else {
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
}
retry++;
if (retry > config_->max_retries) {
break;
}
// wait for a while before retrying
std::this_thread::sleep_for(std::chrono::milliseconds(500));
logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(config_->max_retries));
}
throw std::runtime_error("Failed to get embedding from: " + config_->base_url + " " + config_->model);
}
} // namespace humanus::mem0

View File

@ -0,0 +1,25 @@
#ifndef HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_OAI_H
#define HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_OAI_H
#include "base.h"
namespace humanus::mem0 {
class OAIEmbeddingModel : public EmbeddingModel {
private:
std::unique_ptr<httplib::Client> client_;
public:
OAIEmbeddingModel(const std::shared_ptr<EmbeddingModelConfig>& config) : EmbeddingModel(config) {
client_ = std::make_unique<httplib::Client>(config_->base_url);
client_->set_default_headers({
{"Authorization", "Bearer " + config_->api_key}
});
}
std::vector<float> embed(const std::string& text, EmbeddingType type) override;
};
} // namespace humanus::mem0
#endif // HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_OAI_H

View File

@ -1,116 +0,0 @@
#ifndef HUMANUS_MEMORY_MEM0_H
#define HUMANUS_MEMORY_MEM0_H
#include "memory/base.h"
#include "storage.h"
#include "vector_store.h"
#include "prompt.h"
namespace humanus {
namespace mem0 {
struct Config {
// Prompt config
std::string fact_extraction_prompt;
std::string update_memory_prompt;
// Database config
// std::string history_db_path = ":memory:";
// Embedder config
EmbedderConfig embedder_config;
// Vector store config
VectorStoreConfig vector_store_config;
// Optional: LLM config
LLMConfig llm_config;
};
struct Memory : BaseMemory {
Config config;
std::string fact_extraction_prompt;
std::string update_memory_prompt;
std::shared_ptr<Embedder> embedder;
std::shared_ptr<VectorStore> vector_store;
std::shared_ptr<LLM> llm;
// std::shared_ptr<SQLiteManager> db;
Memory(const Config& config) : config(config) {
fact_extraction_prompt = config.fact_extraction_prompt;
update_memory_prompt = config.update_memory_prompt;
embedder = std::make_shared<Embedder>(config.embedder_config);
vector_store = std::make_shared<VectorStore>(config.vector_store_config);
llm = std::make_shared<LLM>(config.llm_config);
// db = std::make_shared<SQLiteManager>(config.history_db_path);
}
void add_message(const Message& message) override {
if (config.llm_config.enable_vision) {
message = parse_vision_messages(message, llm, config.llm_config.vision_details);
} else {
message = parse_vision_messages(message);
}
_add_to_vector_store(message);
}
void _add_to_vector_store(const Message& message) {
std::string parsed_message = parse_message(message);
std::string system_prompt;
std::string user_prompt = "Input:\n" + parsed_message;
if (!fact_extraction_prompt.empty()) {
system_prompt = fact_extraction_prompt;
} else {
system_prompt = FACT_EXTRACTION_PROMPT;
}
Message user_message = Message::user_message(user_prompt);
std::string response = llm->ask(
{user_message},
system_prompt
);
std::vector<json> new_retrieved_facts;
try {
response = remove_code_blocks(response);
new_retrieved_facts = json::parse(response)["facts"].get<std::vector<json>>();
} catch (const std::exception& e) {
LOG_ERROR("Error in new_retrieved_facts: " + std::string(e.what()));
}
std::vector<json> retrieved_old_memory;
std::map<std::string, std::vector<float>> new_message_embeddings;
for (const auto& fact : new_retrieved_facts) {
auto message_embedding = embedder->embed(fact);
new_message_embeddings[fact] = message_embedding;
auto existing_memories = vector_store->search(
message_embedding,
5,
filters
)
for (const auto& memory : existing_memories) {
retrieved_old_memory.push_back({
{"id", memory.id},
{"text", memory.payload["data"]}
});
}
}
}
};
} // namespace mem0
} // namespace humanus
#endif // HUMANUS_MEMORY_MEM0_H

View File

@ -2,10 +2,9 @@
#define HUMANUS_MEMORY_MEM0_STORAGE_H #define HUMANUS_MEMORY_MEM0_STORAGE_H
#include <sqlite3.h> #include <sqlite3.h>
#include <mutex>
namespace humanus { namespace humanus::mem0 {
namespace mem0 {
struct SQLiteManager { struct SQLiteManager {
std::shared_ptr<sqlite3> db; std::shared_ptr<sqlite3> db;
@ -17,7 +16,7 @@ struct SQLiteManager {
throw std::runtime_error("Failed to open database: " + std::string(sqlite3_errmsg(db))); throw std::runtime_error("Failed to open database: " + std::string(sqlite3_errmsg(db)));
} }
_migrate_history_table(); _migrate_history_table();
_create_history_table() _create_history_table();
} }
void _migrate_history_table() { void _migrate_history_table() {
@ -142,8 +141,6 @@ struct SQLiteManager {
} }
}; };
} // namespace mem0 } // namespace humanus::mem0
} // namespace humanus
#endif // HUMANUS_MEMORY_MEM0_STORAGE_H #endif // HUMANUS_MEMORY_MEM0_STORAGE_H

110
memory/mem0/utils.h 100644
View File

@ -0,0 +1,110 @@
#ifndef HUMANUS_MEMORY_MEM0_UTILS_H
#define HUMANUS_MEMORY_MEM0_UTILS_H
#include "schema.h"
#include "llm.h"
namespace humanus::mem0 {
static size_t uuid() {
const std::string chars = "0123456789abcdef";
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(0, chars.size() - 1);
unsigned long long int uuid_int = 0;
for (int i = 0; i < 16; ++i) {
uuid_int = (uuid_int << 4) | dis(gen);
}
// RFC 4122 variant
uuid_int &= ~(0xc000ULL << 48);
uuid_int |= 0x8000ULL << 48;
// version 4, random UUID
int version = 4;
uuid_int &= ~(0xfULL << 12);
uuid_int |= static_cast<unsigned long long>(version) << 12;
return uuid_int;
}
std::string get_update_memory_messages(const json& retrieved_old_memory, const json& new_retrieved_facts, const std::string fact_extraction_prompt, const std::string& update_memory_prompt) {
std::stringstream ss;
ss << fact_extraction_prompt << "\n\n";
ss << "Below is the current content of my memory which I have collected till now. You have to update it in the following format only:\n\n";
ss << "```" + retrieved_old_memory.dump(2) + "```\n\n";
ss << "The new retrieved facts are mentioned in the triple backticks. You have to analyze the new retrieved facts and determine whether these facts should be added, updated, or deleted in the memory.\n\n";
ss << "```" + new_retrieved_facts.dump(2) + "```\n\n";
ss << "You must return your response in the following JSON structure only:\n\n";
ss << R"json({
"memory" : [
{
"id" : "<ID of the memory>", # Use existing ID for updates/deletes, or new ID for additions
"text" : "<Content of the memory>", # Content of the memory
"event" : "<Operation to be performed>", # Must be "ADD", "UPDATE", "DELETE", or "NONE"
"old_memory" : "<Old memory content>" # Required only if the event is "UPDATE"
},
...
]
})json" << "\n\n";
ss << "Follow the instruction mentioned below:\n"
<< "- Do not return anything from the custom few shot prompts provided above.\n"
<< "- If the current memory is empty, then you have to add the new retrieved facts to the memory.\n"
<< "- You should return the updated memory in only JSON format as shown below. The memory key should be the same if no changes are made.\n"
<< "- If there is an addition, generate a new key and add the new memory corresponding to it.\n"
<< "- If there is a deletion, the memory key-value pair should be removed from the memory.\n"
<< "- If there is an update, the ID key should remain the same and only the value needs to be updated.\n"
<< "\n";
ss << "Do not return anything except the JSON format.\n";
return ss.str();
}
// Get the description of the image
// image_url should be like: data:{mime_type};base64,{base64_data}
std::string get_image_description(const std::string& image_url, const std::shared_ptr<LLM>& llm, const std::string& vision_details) {
if (!llm) {
return "Here is an image failed to get description due to missing LLM instance.";
}
json content = json::array({
{
{"type", "text"},
{"text", "A user is providing an image. Provide a high level description of the image and do not include any additional text."}
},
{
{"type", "image_url"},
{"image_url", {
{"url", image_url},
{"detail", vision_details}
}}
}
});
return llm->ask(
{Message::user_message(content)}
);
}
// Parse the vision messages from the messages
Message parse_vision_message(const Message& message, const std::shared_ptr<LLM>& llm = nullptr, const std::string& vision_details = "auto") {
Message returned_message = message;
if (returned_message.content.is_array()) {
// Multiple image URLs in content
for (auto& content_item : returned_message.content) {
if (content_item["type"] == "image_url") {
auto description = get_image_description(content_item["image_url"]["url"], llm, vision_details);
content_item = description;
}
}
} else if (returned_message.content.is_object() && returned_message.content["type"] == "image_url") {
auto image_url = returned_message.content["image_url"]["url"];
returned_message.content = get_image_description(image_url, llm, vision_details);
}
return returned_message;
}
}
#endif // HUMANUS_MEMORY_MEM0_UTILS_H

View File

@ -0,0 +1,27 @@
#include "base.h"
#include "hnswlib.h"
namespace humanus::mem0 {
std::unordered_map<std::string, std::shared_ptr<VectorStore>> VectorStore::instances_;
std::shared_ptr<VectorStore> VectorStore::get_instance(const std::string& config_name, const std::shared_ptr<VectorStoreConfig>& config) {
if (instances_.find(config_name) == instances_.end()) {
auto config_ = config;
if (!config_) {
if (Config::get_instance().vector_store().find(config_name) == Config::get_instance().vector_store().end()) {
throw std::invalid_argument("Vector store config not found: " + config_name);
}
config_ = std::make_shared<VectorStoreConfig>(Config::get_instance().vector_store().at(config_name));
}
if (config_->provider == "hnswlib") {
instances_[config_name] = std::make_shared<HNSWLibVectorStore>(config_);
} else {
throw std::invalid_argument("Unsupported embedding model provider: " + config_->provider);
}
}
return instances_[config_name];
}
} // namespace humanus::mem0

View File

@ -1,53 +1,49 @@
#ifndef HUMANUS_MEMORY_MEM0_VECTOR_STORE_BASE_H #ifndef HUMANUS_MEMORY_MEM0_VECTOR_STORE_BASE_H
#define HUMANUS_MEMORY_MEM0_VECTOR_STORE_BASE_H #define HUMANUS_MEMORY_MEM0_VECTOR_STORE_BASE_H
#include "hnswlib/hnswlib.h" #include "config.h"
#include <memory>
#include <unordered_map>
#include <string>
namespace humanus { namespace humanus::mem0 {
namespace mem0 { class VectorStore {
private:
static std::unordered_map<std::string, std::shared_ptr<VectorStore>> instances_;
struct VectorStoreConfig { protected:
int dim = 16; // Dimension of the elements std::shared_ptr<VectorStoreConfig> config_;
int max_elements = 10000; // Maximum number of elements, should be known beforehand
int M = 16; // Tightly connected with internal dimensionality of the data
// strongly affects the memory consumption
int ef_construction = 200; // Controls index search speed/build speed tradeoff
enum class Metric {
L2,
IP
};
Metric metric = Metric::L2;
};
struct VectorStoreBase { // Constructor
VectorStoreConfig config; VectorStore(const std::shared_ptr<VectorStoreConfig>& config) : config_(config) {}
VectorStoreBase(const VectorStoreConfig& config) : config(config) { public:
reset(); // Get the singleton instance
} static std::shared_ptr<VectorStore> get_instance(const std::string& config_name = "default", const std::shared_ptr<VectorStoreConfig>& config = nullptr);
virtual ~VectorStore() = default;
virtual void reset() = 0; virtual void reset() = 0;
/** /**
* @brief * @brief
* @param vectors * @param vector
* @param payloads * @param vector_id ID
* @param ids ID * @param metadata
* @return ID
*/ */
virtual std::vector<size_t> insert(const std::vector<std::vector<float>>& vectors, virtual void insert(const std::vector<float>& vector,
const std::vector<std::string>& payloads = {}, const size_t vector_id,
const std::vector<size_t>& ids = {}) = 0; const json& metadata = json::object()) = 0;
/** /**
* @brief * @brief
* @param query * @param query
* @param limit * @param limit
* @param filters * @param filters
* @return ID * @return MemoryItem
*/ */
std::vector<std::pair<size_t, std::vector<float>>> search(const std::vector<float>& query, virtual std::vector<MemoryItem> search(const std::vector<float>& query,
int limit = 5, int limit = 5,
const std::string& filters = "") = 0; const std::string& filters = "") = 0;
@ -61,18 +57,18 @@ struct VectorStoreBase {
* @brief * @brief
* @param vector_id ID * @param vector_id ID
* @param vector * @param vector
* @param payload * @param metadata
*/ */
virtual void update(size_t vector_id, virtual void update(size_t vector_id,
const std::vector<float>* vector = nullptr, const std::vector<float> vector = std::vector<float>(),
const std::string* payload = nullptr) = 0; const json& metadata = json::object()) = 0;
/** /**
* @brief ID * @brief ID
* @param vector_id ID * @param vector_id ID
* @return * @return
*/ */
virtual std::vector<float> get(size_t vector_id) = 0; virtual MemoryItem get(size_t vector_id) = 0;
/** /**
* @brief * @brief
@ -80,12 +76,10 @@ struct VectorStoreBase {
* @param limit * @param limit
* @return ID * @return ID
*/ */
virtual std::vector<size_t> list(const std::string& filters = "", int limit = 0) = 0; virtual std::vector<MemoryItem> list(const std::string& filters = "", int limit = 0) = 0;
}; };
} } // namespace humanus::mem0
}
#endif // HUMANUS_MEMORY_MEM0_VECTOR_STORE_BASE_H #endif // HUMANUS_MEMORY_MEM0_VECTOR_STORE_BASE_H

View File

@ -0,0 +1,165 @@
#include "hnswlib/hnswlib.h"
#include "hnswlib.h"
#include <map>
#include <chrono>
namespace humanus::mem0 {
void HNSWLibVectorStore::reset() {
if (hnsw) {
hnsw.reset();
}
if (space) {
space.reset();
}
metadata_store.clear();
if (config_->metric == VectorStoreConfig::Metric::L2) {
space = std::make_shared<hnswlib::L2Space>(config_->dim);
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>>(space.get(), config_->max_elements, config_->M, config_->ef_construction);
} else if (config_->metric == VectorStoreConfig::Metric::IP) {
space = std::make_shared<hnswlib::InnerProductSpace>(config_->dim);
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>>(space.get(), config_->max_elements, config_->M, config_->ef_construction);
} else {
throw std::invalid_argument("Unsupported metric: " + std::to_string(static_cast<int>(config_->metric)));
}
}
void HNSWLibVectorStore::insert(const std::vector<float>& vector,
const size_t vector_id,
const json& metadata) {
hnsw->addPoint(vector.data(), vector_id);
// 存储元数据
auto now = std::chrono::system_clock::now().time_since_epoch().count();
json _metadata = metadata;
if (!_metadata.contains("created_at")) {
_metadata["created_at"] = now;
}
if (!_metadata.contains("updated_at")) {
_metadata["updated_at"] = now;
}
metadata_store[vector_id] = _metadata;
}
std::vector<MemoryItem> HNSWLibVectorStore::search(const std::vector<float>& query,
int limit = 5,
const std::string& filters = "") {
auto results = hnsw->searchKnn(query.data(), limit);
std::vector<MemoryItem> memory_items;
while (!results.empty()) {
const auto& [id, distance] = results.top();
results.pop();
if (metadata_store.find(id) != metadata_store.end()) {
MemoryItem item;
item.id = id;
if (metadata_store[id].contains("data")) {
item.memory = metadata_store[id]["data"];
}
if (metadata_store[id].contains("hash")) {
item.hash = metadata_store[id]["hash"];
}
item.metadata = metadata_store[id];
item.score = distance;
memory_items.push_back(item);
}
}
return memory_items;
}
void HNSWLibVectorStore::delete_vector(size_t vector_id) {
hnsw->markDelete(vector_id);
metadata_store.erase(vector_id);
}
void HNSWLibVectorStore::update(size_t vector_id,
const std::vector<float> vector = std::vector<float>(),
const json& metadata = json::object()) {
// 检查向量是否需要更新
if (!vector.empty()) {
hnsw->markDelete(vector_id);
hnsw->addPoint(vector.data(), vector_id);
}
// 更新元数据
if (metadata_store.find(vector_id) != metadata_store.end()) {
auto now = std::chrono::system_clock::now().time_since_epoch().count();
// 合并现有元数据和新元数据
for (auto& [key, value] : metadata.items()) {
metadata_store[vector_id][key] = value;
}
// 更新时间戳
metadata_store[vector_id]["updated_at"] = now;
} else if (!metadata.empty()) {
// 如果元数据不存在但提供了新的元数据,则创建新条目
auto now = std::chrono::system_clock::now().time_since_epoch().count();
json new_metadata = metadata;
if (!new_metadata.contains("created_at")) {
new_metadata["created_at"] = now;
}
new_metadata["updated_at"] = now;
metadata_store[vector_id] = new_metadata;
}
}
MemoryItem HNSWLibVectorStore::get(size_t vector_id) {
MemoryItem item;
item.id = vector_id;
// 获取向量数据
std::vector<float> vector_data = hnsw->getDataByLabel<float>(vector_id);
// 获取元数据
if (metadata_store.find(vector_id) != metadata_store.end()) {
if (metadata_store[vector_id].contains("data")) {
item.memory = metadata_store[vector_id]["data"];
}
if (metadata_store[vector_id].contains("hash")) {
item.hash = metadata_store[vector_id]["hash"];
}
item.metadata = metadata_store[vector_id];
}
return item;
}
std::vector<MemoryItem> HNSWLibVectorStore::list(const std::string& filters = "", int limit = 0) {
std::vector<MemoryItem> result;
size_t count = hnsw->cur_element_count;
for (size_t i = 0; i < count; i++) {
if (!hnsw->isMarkedDeleted(i)) {
// 如果有过滤条件,检查元数据是否匹配
if (!filters.empty() && metadata_store.find(i) != metadata_store.end()) {
// 简单的字符串匹配过滤,可以根据需要扩展
json metadata_json = metadata_store[i];
std::string metadata_str = metadata_json.dump();
if (metadata_str.find(filters) == std::string::npos) {
continue;
}
}
result.emplace_back(get(i));
if (limit > 0 && result.size() >= static_cast<size_t>(limit)) {
break;
}
}
}
return result;
}
};

View File

@ -1,124 +1,37 @@
#ifndef HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H #ifndef HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H
#define HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H #define HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H
#include "base.h"
#include "hnswlib/hnswlib.h" #include "hnswlib/hnswlib.h"
namespace humanus { namespace humanus::mem0 {
namespace mem0 { class HNSWLibVectorStore : public VectorStore {
private:
struct HNSWLIBVectorStore {
VectorStoreConfig config;
std::shared_ptr<hnswlib::HierarchicalNSW<float>> hnsw; std::shared_ptr<hnswlib::HierarchicalNSW<float>> hnsw;
std::shared_ptr<hnswlib::SpaceInterface<float>> space; // 保持space对象的引用以确保其生命周期
std::unordered_map<size_t, json> metadata_store; // 存储向量的元数据
HNSWLIBVectorStore(const VectorStoreConfig& config) : config(config) { public:
HNSWLibVectorStore(const std::shared_ptr<VectorStoreConfig>& config) : VectorStore(config) {
reset(); reset();
} }
void reset() { void reset() override;
if (hnsw) {
hnsw.reset();
}
if (config.metric == Metric::L2) {
hnswlib::L2Space space(config.dim);
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>(&space, config.max_elements, config.M, config.ef_construction);
} else if (config.metric == Metric::IP) {
hnswlib::InnerProductSpace space(config.dim);
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>(&space, config.max_elements, config.M, config.ef_construction);
} else {
throw std::invalid_argument("Unsupported metric: " + std::to_string(static_cast<int>(config.metric)));
}
}
/** void insert(const std::vector<float>& vector, const size_t vector_id, const json& metadata) override;
* @brief
* @param vectors
* @param payloads
* @param ids ID
* @return ID
*/
std::vector<size_t> insert(const std::vector<std::vector<float>>& vectors,
const std::vector<std::string>& payloads = {},
const std::vector<size_t>& ids = {}) {
std::vector<size_t> result_ids;
for (size_t i = 0; i < vectors.size(); i++) {
size_t id = ids.size() > i ? ids[i] : hnsw->cur_element_count;
hnsw->addPoint(vectors[i].data(), id);
result_ids.push_back(id);
}
return result_ids;
}
/** std::vector<MemoryItem> search(const std::vector<float>& query, int limit, const std::string& filters) override;
* @brief
* @param query
* @param limit
* @param filters
* @return ID
*/
std::vector<std::pair<size_t, std::vector<float>>> search(const std::vector<float>& query,
int limit = 5,
const std::string& filters = "") {
return hnsw->searchKnn(query.data(), limit);
}
/** void delete_vector(size_t vector_id) override;
* @brief ID
* @param vector_id ID
*/
void delete_vector(size_t vector_id) {
hnsw->markDelete(vector_id);
}
/** void update(size_t vector_id, const std::vector<float> vector, const json& metadata) override;
* @brief
* @param vector_id ID
* @param vector
* @param payload
*/
void update(size_t vector_id,
const std::vector<float>* vector = nullptr,
const std::string* payload = nullptr) {
if (vector) {
hnsw->markDelete(vector_id);
hnsw->addPoint(vector->data(), vector_id);
}
}
/** MemoryItem get(size_t vector_id) override;
* @brief ID
* @param vector_id ID
* @return
*/
std::vector<float> get(size_t vector_id) {
std::vector<float> result(config.dimension);
hnsw->getDataByLabel(vector_id, result.data());
return result;
}
/** std::vector<MemoryItem> list(const std::string& filters, int limit) override;
* @brief
* @param filters
* @param limit
* @return ID
*/
std::vector<size_t> list(const std::string& filters = "", int limit = 0) {
std::vector<size_t> result;
size_t count = hnsw->cur_element_count;
for (size_t i = 0; i < count; i++) {
if (!hnsw->isMarkedDeleted(i)) {
result.push_back(i);
if (limit > 0 && result.size() >= static_cast<size_t>(limit)) {
break;
}
}
}
return result;
}
}; };
} }
}
#endif // HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H #endif // HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H

View File

@ -124,7 +124,7 @@ Following is a conversation between the user and the assistant. You have to extr
You should detect the language of the user input and record the facts in the same language. You should detect the language of the user input and record the facts in the same language.
)"; )";
const char* DEFAULT_UPDATE_MEMORY_PROMPT = R"(You are a smart memory manager which controls the memory of a system. const char* UPDATE_MEMORY_PROMPT = R"(You are a smart memory manager which controls the memory of a system.
You can perform four operations: (1) add into the memory, (2) update the memory, (3) delete from the memory, and (4) no change. You can perform four operations: (1) add into the memory, (2) update the memory, (3) delete from the memory, and (4) no change.
Based on the above four operations, the memory will change. Based on the above four operations, the memory will change.

View File

@ -26,13 +26,13 @@ extern const char* NEXT_STEP_PROMPT;
extern const char* TOOL_HINT_TEMPLATE; extern const char* TOOL_HINT_TEMPLATE;
} // namespace toolcall } // namespace toolcall
} // namespace prompt
namespace mem0 { namespace mem0 {
extern const char* FACT_EXTRACTION_PROMPT; extern const char* FACT_EXTRACTION_PROMPT;
extern const char* UPDATE_MEMORY_PROMPT; extern const char* UPDATE_MEMORY_PROMPT;
} // namespace mem0 } // namespace mem0
} // namespace prompt
} // namespace humanus } // namespace humanus
#endif // HUMANUS_PROMPT_H #endif // HUMANUS_PROMPT_H