diff --git a/CMakeLists.txt b/CMakeLists.txt index d1156c3..99bc6ca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,6 +84,10 @@ file(GLOB FLOW_SOURCES file(GLOB MEMORY_SOURCES "memory/*.cpp" "memory/*.cc" + "memory/*/*.cpp" + "memory/*/*.cc" + "memory/*/*/*.cpp" + "memory/*/*/*.cc" ) # 创建humanus核心库,包含所有共享组件 diff --git a/config.cpp b/config.cpp index f70ff87..8fb092e 100644 --- a/config.cpp +++ b/config.cpp @@ -10,10 +10,10 @@ namespace humanus { Config* Config::_instance = nullptr; std::mutex Config::_mutex; -void Config::_load_initial_config() { +void Config::_load_initial_llm_config() { try { - auto config_path = _get_config_path(); - std::cout << "Loading config file from: " << config_path.string() << std::endl; + auto config_path = _get_llm_config_path(); + std::cout << "Loading LLM config file from: " << config_path.string() << std::endl; const auto& data = toml::parse_file(config_path.string()); @@ -35,8 +35,12 @@ void Config::_load_initial_config() { llm_config.base_url = llm_table["base_url"].as_string()->get(); } - if (llm_table.contains("end_point") && llm_table["end_point"].is_string()) { - llm_config.end_point = llm_table["end_point"].as_string()->get(); + if (llm_table.contains("endpoint") && llm_table["endpoint"].is_string()) { + llm_config.endpoint = llm_table["endpoint"].as_string()->get(); + } + + if (llm_table.contains("vision_details") && llm_table["vision_details"].is_string()) { + llm_config.vision_details = llm_table["vision_details"].as_string()->get(); } if (llm_table.contains("max_tokens") && llm_table["max_tokens"].is_integer()) { @@ -51,6 +55,10 @@ void Config::_load_initial_config() { llm_config.temperature = llm_table["temperature"].as_floating_point()->get(); } + if (llm_table.contains("enable_vision") && llm_table["enable_vision"].is_boolean()) { + llm_config.enable_vision = llm_table["enable_vision"].as_boolean()->get(); + } + if (llm_table.contains("oai_tool_support") && llm_table["oai_tool_support"].is_boolean()) { llm_config.oai_tool_support = llm_table["oai_tool_support"].as_boolean()->get(); } @@ -95,4 +103,119 @@ void Config::_load_initial_config() { } } +void Config::_load_initial_embedding_model_config() { + try { + auto config_path = _get_embedding_model_config_path(); + std::cout << "Loading embedding model config file from: " << config_path.string() << std::endl; + + const auto& data = toml::parse_file(config_path.string()); + + // Load embedding model configuration + for (const auto& [key, value] : data) { + const auto& embd_table = *value.as_table(); + + EmbeddingModelConfig embd_config; + + if (embd_table.contains("provider") && embd_table["provider"].is_string()) { + embd_config.provider = embd_table["provider"].as_string()->get(); + } + + if (embd_table.contains("base_url") && embd_table["base_url"].is_string()) { + embd_config.base_url = embd_table["base_url"].as_string()->get(); + } + + if (embd_table.contains("endpoint") && embd_table["endpoint"].is_string()) { + embd_config.endpoint = embd_table["endpoint"].as_string()->get(); + } + + if (embd_table.contains("model") && embd_table["model"].is_string()) { + embd_config.model = embd_table["model"].as_string()->get(); + } + + if (embd_table.contains("api_key") && embd_table["api_key"].is_string()) { + embd_config.api_key = embd_table["api_key"].as_string()->get(); + } + + if (embd_table.contains("embedding_dims") && embd_table["embedding_dims"].is_integer()) { + embd_config.embedding_dims = embd_table["embedding_dims"].as_integer()->get(); + } + + if (embd_table.contains("max_retries") && embd_table["max_retries"].is_integer()) { + embd_config.max_retries = embd_table["max_retries"].as_integer()->get(); + } + + _config.embedding_model[std::string(key.str())] = embd_config; + } + + if (_config.embedding_model.empty()) { + throw std::runtime_error("No embedding model configuration found"); + } else if (_config.embedding_model.find("default") == _config.embedding_model.end()) { + _config.embedding_model["default"] = _config.embedding_model.begin()->second; + } + } catch (const std::exception& e) { + std::cerr << "Loading embedding model config file failed: " << e.what() << std::endl; + // Set default configuration + _config.embedding_model["default"] = EmbeddingModelConfig(); + } +} + +void Config::_load_initial_vector_store_config() { + try { + auto config_path = _get_vector_store_config_path(); + std::cout << "Loading vector store config file from: " << config_path.string() << std::endl; + + const auto& data = toml::parse_file(config_path.string()); + + // Load vector store configuration + for (const auto& [key, value] : data) { + const auto& vs_table = *value.as_table(); + + VectorStoreConfig vs_config; + + if (vs_table.contains("provider") && vs_table["provider"].is_string()) { + vs_config.provider = vs_table["provider"].as_string()->get(); + } + + if (vs_table.contains("dim") && vs_table["dim"].is_integer()) { + vs_config.dim = vs_table["dim"].as_integer()->get(); + } + + if (vs_table.contains("max_elements") && vs_table["max_elements"].is_integer()) { + vs_config.max_elements = vs_table["max_elements"].as_integer()->get(); + } + + if (vs_table.contains("M") && vs_table["M"].is_integer()) { + vs_config.M = vs_table["M"].as_integer()->get(); + } + + if (vs_table.contains("ef_construction") && vs_table["ef_construction"].is_integer()) { + vs_config.ef_construction = vs_table["ef_construction"].as_integer()->get(); + } + + if (vs_table.contains("metric") && vs_table["metric"].is_string()) { + const auto& metric_str = vs_table["metric"].as_string()->get(); + if (metric_str == "L2") { + vs_config.metric = VectorStoreConfig::Metric::L2; + } else if (metric_str == "IP") { + vs_config.metric = VectorStoreConfig::Metric::IP; + } else { + throw std::runtime_error("Invalid metric: " + metric_str); + } + } + + _config.vector_store[std::string(key.str())] = vs_config; + } + + if (_config.vector_store.empty()) { + throw std::runtime_error("No vector store configuration found"); + } else if (_config.vector_store.find("default") == _config.vector_store.end()) { + _config.vector_store["default"] = _config.vector_store.begin()->second; + } + } catch (const std::exception& e) { + std::cerr << "Loading vector store config file failed: " << e.what() << std::endl; + // Set default configuration + _config.vector_store["default"] = VectorStoreConfig(); + } +} + } // namespace humanus \ No newline at end of file diff --git a/config.h b/config.h index 43f83f0..c174b7c 100644 --- a/config.h +++ b/config.h @@ -25,30 +25,34 @@ struct LLMConfig { std::string model; std::string api_key; std::string base_url; - std::string end_point; + std::string endpoint; + std::string vision_details; int max_tokens; int timeout; double temperature; + bool enable_vision; bool oai_tool_support; LLMConfig( const std::string& model = "deepseek-chat", const std::string& api_key = "sk-", const std::string& base_url = "https://api.deepseek.com", - const std::string& end_point = "/v1/chat/completions", + const std::string& endpoint = "/v1/chat/completions", + const std::string& vision_details = "auto", int max_tokens = 4096, int timeout = 120, double temperature = 1.0, + bool enable_vision = false, bool oai_tool_support = true - ) : model(model), api_key(api_key), base_url(base_url), end_point(end_point), - max_tokens(max_tokens), timeout(timeout), temperature(temperature), oai_tool_support(oai_tool_support) {} + ) : model(model), api_key(api_key), base_url(base_url), endpoint(endpoint), vision_details(vision_details), + max_tokens(max_tokens), timeout(timeout), temperature(temperature), enable_vision(enable_vision), oai_tool_support(oai_tool_support) {} json to_json() const { json j; j["model"] = model; j["api_key"] = api_key; j["base_url"] = base_url; - j["end_point"] = end_point; + j["endpoint"] = endpoint; j["max_tokens"] = max_tokens; j["temperature"] = temperature; return j; @@ -153,9 +157,76 @@ struct ToolParser { } }; +enum class EmbeddingType { + ADD = 0, + SEARCH = 1, + UPDATE = 2 +}; + +struct EmbeddingModelConfig { + std::string provider = "oai"; + std::string base_url = "http://localhost:8080"; + std::string endpoint = "/v1/embeddings"; + std::string model = "nomic-embed-text-v1.5.f16.gguf"; + std::string api_key = ""; + int embedding_dims = 768; + int max_retries = 3; +}; + +struct VectorStoreConfig { + std::string provider = "hnswlib"; + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + enum class Metric { + L2, + IP + }; + Metric metric = Metric::L2; +}; + +namespace mem0 { + +struct MemoryConfig { + // Base config + int max_messages = 5; // Short-term memory capacity + int limit = 5; // Number of results to retrive from long-term memory + std::string filters = ""; // Filters to apply to search results + + // Prompt config + std::string fact_extraction_prompt = prompt::mem0::FACT_EXTRACTION_PROMPT; + std::string update_memory_prompt = prompt::mem0::UPDATE_MEMORY_PROMPT; + + // Database config + // std::string history_db_path = ":memory:"; + + // EmbeddingModel config + std::shared_ptr embedding_model_config = nullptr; + + // Vector store config + std::shared_ptr vector_store_config = nullptr; + + // Optional: LLM config + std::shared_ptr llm_config = nullptr; +}; + +struct MemoryItem { + size_t id; // The unique identifier for the text data + std::string memory; // The memory deduced from the text data + std::string hash; // The hash of the memory + json metadata; // Any additional metadata associated with the memory, like 'created_at' or 'updated_at' + float score; // The score associated with the text data, used for ranking and sorting +}; + +} // namespace mem0 + struct AppConfig { - std::map llm; - std::map tool_parser; + std::unordered_map llm; + std::unordered_map tool_parser; + std::unordered_map embedding_model; + std::unordered_map vector_store; }; class Config { @@ -166,7 +237,9 @@ private: AppConfig _config; Config() { - _load_initial_config(); + _load_initial_llm_config(); + _load_initial_embedding_model_config(); + _load_initial_vector_store_config(); _initialized = true; } @@ -177,19 +250,47 @@ private: * @brief Get the config path * @return The config path */ - static std::filesystem::path _get_config_path() { + static std::filesystem::path _get_llm_config_path() { auto root = PROJECT_ROOT; auto config_path = root / "config" / "config_llm.toml"; if (std::filesystem::exists(config_path)) { return config_path; } - throw std::runtime_error("Config file not found"); + throw std::runtime_error("LLM Config file not found"); + } + + static std::filesystem::path _get_embedding_model_config_path() { + auto root = PROJECT_ROOT; + auto config_path = root / "config" / "config_embd.toml"; + if (std::filesystem::exists(config_path)) { + return config_path; + } + throw std::runtime_error("Embedding Model Config file not found"); + } + + static std::filesystem::path _get_vector_store_config_path() { + auto root = PROJECT_ROOT; + auto config_path = root / "config" / "config_vec.toml"; + if (std::filesystem::exists(config_path)) { + return config_path; + } + throw std::runtime_error("Vector Store Config file not found"); } /** * @brief Load the initial config */ - void _load_initial_config(); + void _load_initial_llm_config(); + + /** + * @brief Load the initial embedding model config + */ + void _load_initial_embedding_model_config(); + + /** + * @brief Load the initial vector store config + */ + void _load_initial_vector_store_config(); public: /** @@ -210,7 +311,7 @@ public: * @brief Get the LLM settings * @return The LLM settings map */ - const std::map& llm() const { + const std::unordered_map& llm() const { return _config.llm; } @@ -218,9 +319,25 @@ public: * @brief Get the tool helpers * @return The tool helpers map */ - const std::map& tool_parser() const { + const std::unordered_map& tool_parser() const { return _config.tool_parser; } + + /** + * @brief Get the embedding model settings + * @return The embedding model settings map + */ + const std::unordered_map& embedding_model() const { + return _config.embedding_model; + } + + /** + * @brief Get the vector store settings + * @return The vector store settings map + */ + const std::unordered_map& vector_store() const { + return _config.vector_store; + } /** * @brief Get the app config diff --git a/config/config_embd.toml b/config/config_embd.toml new file mode 100644 index 0000000..34f2d4e --- /dev/null +++ b/config/config_embd.toml @@ -0,0 +1,8 @@ +[default] +provider = "oai" +base_url = "http://localhost:8080" +endpoint = "/v1/embeddings" +model = "nomic-embed-text-v1.5.f16.gguf" +api_key = "" +embeddings_dim = 768 +max_retries = 3 \ No newline at end of file diff --git a/config/config_llm.toml b/config/config_llm.toml index 0a6b838..acc76dc 100644 --- a/config/config_llm.toml +++ b/config/config_llm.toml @@ -1,7 +1,7 @@ [default] model = "deepseek-reasoner" base_url = "https://api.deepseek.com" -end_point = "/v1/chat/completions" +endpoint = "/v1/chat/completions" api_key = "sk-93c5bfcb920c4a8aa345791d429b8536" max_tokens = 8192 oai_tool_support = false diff --git a/config/config_llm.toml.bak b/config/config_llm.toml.bak index 7e495c2..0a99bdf 100644 --- a/config/config_llm.toml.bak +++ b/config/config_llm.toml.bak @@ -1,28 +1,28 @@ [llm] model = "anthropic/claude-3.7-sonnet" base_url = "https://openrouter.ai" -end_point = "/api/v1/chat/completions" +endpoint = "/api/v1/chat/completions" api_key = "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" max_tokens = 8192 [llm] model = "deepseek-chat" base_url = "https://api.deepseek.com" -end_point = "/v1/chat/completions" +endpoint = "/v1/chat/completions" api_key = "sk-93c5bfcb920c4a8aa345791d429b8536" max_tokens = 8192 [llm] model = "qwen-max" base_url = "https://dashscope.aliyuncs.com" -end_point = "/compatible-mode/v1/chat/completions" +endpoint = "/compatible-mode/v1/chat/completions" api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600" max_tokens = 8192 [llm] model = "deepseek-reasoner" base_url = "https://api.deepseek.com" -end_point = "/v1/chat/completions" +endpoint = "/v1/chat/completions" api_key = "sk-93c5bfcb920c4a8aa345791d429b8536" max_tokens = 8192 oai_tool_support = false \ No newline at end of file diff --git a/config/config_vec.toml b/config/config_vec.toml new file mode 100644 index 0000000..e6ed389 --- /dev/null +++ b/config/config_vec.toml @@ -0,0 +1,8 @@ +[default] +provider = "hnswlib" +dim = 768 # Dimension of the elements +max_elements = 100 # Maximum number of elements, should be known beforehand +M = 16 # Tightly connected with internal dimensionality of the data + # strongly affects the memory consumption +ef_construction = 200 # Controls index search speed/build speed tradeoff +metric = "L2" # Distance metric to use, can be L2 or IP \ No newline at end of file diff --git a/examples/plan_mem0/CMakeLists.txt b/examples/plan_mem0/CMakeLists.txt new file mode 100644 index 0000000..52c8c15 --- /dev/null +++ b/examples/plan_mem0/CMakeLists.txt @@ -0,0 +1,12 @@ +set(target humanus_cli_plan_mem0) + +add_executable(${target} humanus_plan_mem0.cpp) + +# 链接到核心库 +target_link_libraries(${target} PRIVATE humanus) + +# 设置输出目录 +set_target_properties(${target} + PROPERTIES + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" +) \ No newline at end of file diff --git a/examples/plan_mem0/humanus_plan_mem0.cpp b/examples/plan_mem0/humanus_plan_mem0.cpp new file mode 100644 index 0000000..430da4b --- /dev/null +++ b/examples/plan_mem0/humanus_plan_mem0.cpp @@ -0,0 +1,145 @@ +#include "agent/humanus.h" +#include "logger.h" +#include "prompt.h" +#include "flow/flow_factory.h" +#include "memory/mem0/base.h" + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#include +#include +#elif defined (_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include +#endif + +using namespace humanus; + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) +static void sigint_handler(int signo) { + if (signo == SIGINT) { + logger->info("Interrupted by user\n"); + exit(0); + } +} +#endif + +static bool readline_utf8(std::string & line, bool multiline_input) { +#if defined(_WIN32) + std::wstring wline; + if (!std::getline(std::wcin, wline)) { + // Input stream is bad or EOF received + line.clear(); + GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0); + return false; + } + + int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL); + line.resize(size_needed); + WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL); +#else + if (!std::getline(std::cin, line)) { + // Input stream is bad or EOF received + line.clear(); + return false; + } +#endif + if (!line.empty()) { + char last = line.back(); + if (last == '/') { // Always return control on '/' symbol + line.pop_back(); + return false; + } + if (last == '\\') { // '\\' changes the default action + line.pop_back(); + multiline_input = !multiline_input; + } + } + line += '\n'; + + // By default, continue input if multiline_input is set + return multiline_input; +} + +int main() { + + // ctrl+C handling + { +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = sigint_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); + SetConsoleCP(CP_UTF8); + SetConsoleOutputCP(CP_UTF8); + _setmode(_fileno(stdin), _O_WTEXT); // wide character input mode +#endif + } + + std::shared_ptr agent_ptr = std::make_shared( + ToolCollection( // Add general-purpose tools to the tool collection + { + std::make_shared(), + std::make_shared(), // for web browsing + std::make_shared(), + std::make_shared() + } + ), + "auto", + std::set{"terminate"}, + "humanus_mem0", + "A versatile agent that can solve various tasks using multiple tools", + prompt::humanus::SYSTEM_PROMPT, + prompt::humanus::NEXT_STEP_PROMPT, + nullptr, + std::make_shared(mem0::MemoryConfig()) + ); + + std::map> agents; + agents["default"] = agent_ptr; + + auto flow = FlowFactory::create_flow( + FlowType::PLANNING, + nullptr, // llm + nullptr, // planning_tool + std::vector{}, // executor_keys + "", // active_plan_id + agents, // agents + std::vector>{}, // tools + "default" // primary_agent_key + ); + + while (true) { + if (agent_ptr->current_step == agent_ptr->max_steps) { + std::cout << "Automatically paused after " << agent_ptr->current_step << " steps." << std::endl; + std::cout << "Enter your prompt (enter an empty line to resume or 'exit' to quit): "; + agent_ptr->reset(false); + } else { + std::cout << "Enter your prompt (or 'exit' to quit): "; + } + + if (agent_ptr->state != AgentState::IDLE) { + break; + } + + std::string prompt; + readline_utf8(prompt, false); + if (prompt == "exit" || prompt == "exit\n") { + logger->info("Goodbye!"); + break; + } + + std::cout << "Processing your request..." << std::endl; + auto result = flow->execute(prompt); + std::cout << result << std::endl; + } +} \ No newline at end of file diff --git a/llm.cpp b/llm.cpp index df423af..34ff12a 100644 --- a/llm.cpp +++ b/llm.cpp @@ -2,5 +2,5 @@ namespace humanus { // 定义静态成员变量 - std::map> LLM::_instances; + std::unordered_map> LLM::instances_; } \ No newline at end of file diff --git a/llm.h b/llm.h index db6e955..3db3f56 100644 --- a/llm.h +++ b/llm.h @@ -17,7 +17,7 @@ namespace humanus { class LLM { private: - static std::map> _instances; + static std::unordered_map> instances_; std::unique_ptr client_; @@ -27,13 +27,7 @@ private: public: // Constructor - LLM(const std::string& config_name, const std::shared_ptr& llm_config = nullptr, const std::shared_ptr& tool_parser = nullptr) : llm_config_(llm_config), tool_parser_(tool_parser) { - if (!llm_config_) { - if (Config::get_instance().llm().find(config_name) == Config::get_instance().llm().end()) { - throw std::invalid_argument("LLM config not found: " + config_name); - } - llm_config_ = std::make_shared(Config::get_instance().llm().at(config_name)); - } + LLM(const std::string& config_name, const std::shared_ptr& config = nullptr, const std::shared_ptr& tool_parser = nullptr) : llm_config_(config), tool_parser_(tool_parser) { if (!llm_config_->oai_tool_support && !tool_parser_) { if (Config::get_instance().tool_parser().find(config_name) == Config::get_instance().tool_parser().end()) { throw std::invalid_argument("Tool helper config not found: " + config_name); @@ -46,13 +40,20 @@ public: }); client_->set_read_timeout(llm_config_->timeout); } - + // Get the singleton instance static std::shared_ptr get_instance(const std::string& config_name = "default", const std::shared_ptr& llm_config = nullptr) { - if (_instances.find(config_name) == _instances.end()) { - _instances[config_name] = std::make_shared(config_name, llm_config); + if (instances_.find(config_name) == instances_.end()) { + auto llm_config_ = llm_config; + if (!llm_config_) { + if (Config::get_instance().llm().find(config_name) == Config::get_instance().llm().end()) { + throw std::invalid_argument("LLM config not found: " + config_name); + } + llm_config_ = std::make_shared(Config::get_instance().llm().at(config_name)); + } + instances_[config_name] = std::make_shared(config_name, llm_config_); } - return _instances[config_name]; + return instances_[config_name]; } /** @@ -194,7 +195,7 @@ public: while (retry <= max_retries) { // send request - auto res = client_->Post(llm_config_->end_point, body_str, "application/json"); + auto res = client_->Post(llm_config_->endpoint, body_str, "application/json"); if (!res) { logger->error("Failed to send request: " + httplib::to_string(res.error())); @@ -325,7 +326,7 @@ public: while (retry <= max_retries) { // send request - auto res = client_->Post(llm_config_->end_point, body_str, "application/json"); + auto res = client_->Post(llm_config_->endpoint, body_str, "application/json"); if (!res) { logger->error("Failed to send request: " + httplib::to_string(res.error())); diff --git a/memory/base.h b/memory/base.h index f11d72a..88f373e 100644 --- a/memory/base.h +++ b/memory/base.h @@ -25,7 +25,7 @@ struct BaseMemory { messages.clear(); } - virtual std::vector get_messages() const { + virtual std::vector get_messages(const std::string& query = "") const { return messages; } diff --git a/memory/mem0/base.h b/memory/mem0/base.h new file mode 100644 index 0000000..835f390 --- /dev/null +++ b/memory/mem0/base.h @@ -0,0 +1,280 @@ +#ifndef HUMANUS_MEMORY_MEM0_H +#define HUMANUS_MEMORY_MEM0_H + +#include "memory/base.h" +#include "vector_store/base.h" +#include "embedding_model/base.h" +#include "schema.h" +#include "prompt.h" +#include "httplib.h" +#include "llm.h" +#include "utils.h" + +namespace humanus::mem0 { + +struct Memory : BaseMemory { + MemoryConfig config; + + std::string fact_extraction_prompt; + std::string update_memory_prompt; + int max_messages; + int limit; + std::string filters; + + std::shared_ptr embedding_model; + std::shared_ptr vector_store; + std::shared_ptr llm; + // std::shared_ptr db; + + Memory(const MemoryConfig& config) : config(config) { + fact_extraction_prompt = config.fact_extraction_prompt; + update_memory_prompt = config.update_memory_prompt; + max_messages = config.max_messages; + limit = config.limit; + filters = config.filters; + + embedding_model = EmbeddingModel::get_instance("mem0_" + std::to_string(reinterpret_cast(this)), config.embedding_model_config); + vector_store = VectorStore::get_instance("mem0_" + std::to_string(reinterpret_cast(this)), config.vector_store_config); + + llm = LLM::get_instance("mem0_" + std::to_string(reinterpret_cast(this)), config.llm_config); + // db = std::make_shared(config.history_db_path); + } + + void add_message(const Message& message) override { + while (!messages.empty() && (messages.size() > max_messages || messages.begin()->role == "assistant" || messages.begin()->role == "tool")) { + // Ensure the first message is always a user or system message + Message front_message = *messages.begin(); + messages.erase(messages.begin()); + + if (config.llm_config->enable_vision) { + front_message = parse_vision_message(front_message, llm, config.llm_config->vision_details); + } else { + front_message = parse_vision_message(front_message); + } + + _add_to_vector_store(front_message); + } + + messages.push_back(message); + } + + std::vector get_messages(const std::string& query = "") const override { + auto embeddings = embedding_model->embed(query, EmbeddingType::SEARCH); + std::vector memories; + + // 检查vector_store是否已初始化 + if (vector_store) { + memories = vector_store->search(embeddings, limit, filters); + } + + std::string memory_prompt; + for (const auto& memory_item : memories) { + memory_prompt += "" + memory_item.memory + ""; + } + + std::vector messages_with_memory{Message::user_message(memory_prompt)}; + + messages_with_memory.insert(messages_with_memory.end(), messages.begin(), messages.end()); + + return messages_with_memory; + } + + void _add_to_vector_store(const Message& message) { + // 检查vector_store是否已初始化 + if (!vector_store) { + logger->warn("Vector store is not initialized, skipping memory operation"); + return; + } + + std::string parsed_message = message.role + ": " + (message.content.is_string() ? message.content.get() : message.content.dump()); + + for (const auto& tool_call : message.tool_calls) { + parsed_message += "" + tool_call.to_json().dump() + ""; + } + + std::string system_prompt = fact_extraction_prompt; + std::string user_prompt = "Input:\n" + parsed_message; + + Message user_message = Message::user_message(user_prompt); + + std::string response = llm->ask( + {user_message}, + system_prompt + ); + + json new_retrieved_facts; // ["fact1", "fact2", "fact3"] + + try { + // response = remove_code_blocks(response); + new_retrieved_facts = json::parse(response)["facts"]; + } catch (const std::exception& e) { + logger->error("Error in new_retrieved_facts: " + std::string(e.what())); + } + + std::vector retrieved_old_memory; + std::map> new_message_embeddings; + + for (const auto& fact : new_retrieved_facts) { + auto message_embedding = embedding_model->embed(fact, EmbeddingType::ADD); + new_message_embeddings[fact] = message_embedding; + auto existing_memories = vector_store->search( + message_embedding, + 5 + ); + for (const auto& memory : existing_memories) { + retrieved_old_memory.push_back({ + {"id", memory.id}, + {"text", memory.metadata["data"]} + }); + } + } + // sort and unique by id + std::sort(retrieved_old_memory.begin(), retrieved_old_memory.end(), [](const json& a, const json& b) { + return a["id"] < b["id"]; + }); + retrieved_old_memory.resize(std::unique(retrieved_old_memory.begin(), retrieved_old_memory.end(), [](const json& a, const json& b) { + return a["id"] == b["id"]; + }) - retrieved_old_memory.begin()); + logger->info("Total existing memories: " + std::to_string(retrieved_old_memory.size())); + + // mapping UUIDs with integers for handling UUID hallucinations + std::vector temp_uuid_mapping; + for (size_t idx = 0; idx < retrieved_old_memory.size(); ++idx) { + temp_uuid_mapping.push_back(retrieved_old_memory[idx]["id"].get()); + retrieved_old_memory[idx]["id"] = idx; + } + + std::string function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts, fact_extraction_prompt, update_memory_prompt); + + std::string new_memories_with_actions_str; + json new_memories_with_actions = json::array(); + + try { + new_memories_with_actions_str = llm->ask( + {Message::user_message(function_calling_prompt)} + ); + new_memories_with_actions = json::parse(new_memories_with_actions_str); + } catch (const std::exception& e) { + logger->error("Error in new_memories_with_actions: " + std::string(e.what())); + } + + try { + // new_memories_with_actions_str = remove_code_blocks(new_memories_with_actions_str); + new_memories_with_actions = json::parse(new_memories_with_actions_str); + } catch (const std::exception& e) { + logger->error("Invalid JSON response: " + std::string(e.what())); + } + + try { + for (const auto& resp : new_memories_with_actions.value("memory", json::array())) { + logger->info("Processing memory: " + resp.dump(2)); + try { + if (!resp.contains("text")) { + logger->info("Skipping memory entry because of empty `text` field."); + continue; + } + std::string event = resp.value("event", "NONE"); + size_t memory_id = resp.contains("id") ? temp_uuid_mapping[resp["id"].get()] : uuid(); + if (event == "ADD") { + _create_memory( + memory_id, + resp["text"], // data + new_message_embeddings // existing_embeddings + ); + } else if (event == "UPDATE") { + _update_memory( + memory_id, + resp["text"], // data + new_message_embeddings // existing_embeddings + ); + } else if (event == "DELETE") { + _delete_memory(memory_id); + } else if (event == "NONE") { + logger->info("NOOP for Memory."); + } + } catch (const std::exception& e) { + logger->error("Error in new_memories_with_actions: " + std::string(e.what())); + } + } + } catch (const std::exception& e) { + logger->error("Error in new_memories_with_actions: " + std::string(e.what())); + } + } + + void _create_memory(const size_t& memory_id, const std::string& data, const std::map>& existing_embeddings) { + if (!vector_store) { + logger->warn("Vector store is not initialized, skipping create memory"); + return; + } + + std::vector embedding; + if (existing_embeddings.find(data) != existing_embeddings.end()) { + embedding = existing_embeddings.at(data); + } else { + embedding = embedding_model->embed(data, EmbeddingType::ADD); + } + + auto created_at = std::chrono::system_clock::now(); + json metadata = { + {"data", data}, + {"hash", httplib::detail::MD5(data)}, + {"created_at", std::chrono::system_clock::now().time_since_epoch().count()} + }; + + vector_store->insert( + embedding, + memory_id, + metadata + ); + } + + void _update_memory(const size_t& memory_id, const std::string& data, const std::map>& existing_embeddings) { + if (!vector_store) { + logger->warn("Vector store is not initialized, skipping update memory"); + return; + } + + logger->info("Updating memory with " + data); + + MemoryItem existing_memory; + + try { + existing_memory = vector_store->get(memory_id); + } catch (const std::exception& e) { + logger->error("Error fetching existing memory: " + std::string(e.what())); + return; + } + + std::vector embedding; + if (existing_embeddings.find(data) != existing_embeddings.end()) { + embedding = existing_embeddings.at(data); + } else { + embedding = embedding_model->embed(data, EmbeddingType::ADD); + } + + json metadata = existing_memory.metadata; + metadata["data"] = data; + metadata["hash"] = httplib::detail::MD5(data); + metadata["updated_at"] = std::chrono::system_clock::now().time_since_epoch().count(); + + vector_store->update( + memory_id, + embedding, + metadata + ); + } + + void _delete_memory(const size_t& memory_id) { + if (!vector_store) { + logger->warn("Vector store is not initialized, skipping delete memory"); + return; + } + + logger->info("Deleting memory: " + std::to_string(memory_id)); + vector_store->delete_vector(memory_id); + } +}; + +} // namespace humanus::mem0 + +#endif // HUMANUS_MEMORY_MEM0_H \ No newline at end of file diff --git a/memory/mem0/embedding_model/base.cpp b/memory/mem0/embedding_model/base.cpp new file mode 100644 index 0000000..818d5ad --- /dev/null +++ b/memory/mem0/embedding_model/base.cpp @@ -0,0 +1,27 @@ +#include "base.h" +#include "oai.h" + +namespace humanus::mem0 { + +std::unordered_map> EmbeddingModel::instances_; + +std::shared_ptr EmbeddingModel::get_instance(const std::string& config_name, const std::shared_ptr& config) { + if (instances_.find(config_name) == instances_.end()) { + auto config_ = config; + if (!config_) { + if (Config::get_instance().embedding_model().find(config_name) == Config::get_instance().embedding_model().end()) { + throw std::invalid_argument("Embedding model config not found: " + config_name); + } + config_ = std::make_shared(Config::get_instance().embedding_model().at(config_name)); + } + + if (config_->provider == "oai") { + instances_[config_name] = std::make_shared(config_); + } else { + throw std::invalid_argument("Unsupported embedding model provider: " + config_->provider); + } + } + return instances_[config_name]; +} + +} // namespace humanus::mem0 \ No newline at end of file diff --git a/memory/mem0/embedding_model/base.h b/memory/mem0/embedding_model/base.h new file mode 100644 index 0000000..2f464f7 --- /dev/null +++ b/memory/mem0/embedding_model/base.h @@ -0,0 +1,33 @@ +#ifndef HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_BASE_H +#define HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_BASE_H + +#include "httplib.h" +#include "logger.h" +#include +#include +#include + +namespace humanus::mem0 { + +class EmbeddingModel { +private: + static std::unordered_map> instances_; + +protected: + std::shared_ptr config_; + + // Constructor + EmbeddingModel(const std::shared_ptr& config) : config_(config) {} + +public: + // Get the singleton instance + static std::shared_ptr get_instance(const std::string& config_name = "default", const std::shared_ptr& config = nullptr); + + virtual ~EmbeddingModel() = default; + + virtual std::vector embed(const std::string& text, EmbeddingType type) = 0; +}; + +} // namespace humanus::mem0 + +#endif // HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_BASE_H \ No newline at end of file diff --git a/memory/mem0/embedding_model/oai.cpp b/memory/mem0/embedding_model/oai.cpp new file mode 100644 index 0000000..da5261a --- /dev/null +++ b/memory/mem0/embedding_model/oai.cpp @@ -0,0 +1,48 @@ +#include "oai.h" + +namespace humanus::mem0 { + +std::vector OAIEmbeddingModel::embed(const std::string& text, EmbeddingType /* type */) { + json body = { + {"model", config_->model}, + {"input", text}, + {"encoding_format", "float"} + }; + + std::string body_str = body.dump(); + + int retry = 0; + + while (retry <= config_->max_retries) { + // send request + auto res = client_->Post(config_->endpoint, body_str, "application/json"); + + if (!res) { + logger->error("Failed to send request: " + httplib::to_string(res.error())); + } else if (res->status == 200) { + try { + json json_data = json::parse(res->body); + return json_data["data"][0]["embedding"].get>(); + } catch (const std::exception& e) { + logger->error("Failed to parse response: " + std::string(e.what())); + } + } else { + logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body); + } + + retry++; + + if (retry > config_->max_retries) { + break; + } + + // wait for a while before retrying + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + + logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(config_->max_retries)); + } + + throw std::runtime_error("Failed to get embedding from: " + config_->base_url + " " + config_->model); +} + +} // namespace humanus::mem0 \ No newline at end of file diff --git a/memory/mem0/embedding_model/oai.h b/memory/mem0/embedding_model/oai.h new file mode 100644 index 0000000..71fee6d --- /dev/null +++ b/memory/mem0/embedding_model/oai.h @@ -0,0 +1,25 @@ +#ifndef HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_OAI_H +#define HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_OAI_H + +#include "base.h" + +namespace humanus::mem0 { + +class OAIEmbeddingModel : public EmbeddingModel { +private: + std::unique_ptr client_; + +public: + OAIEmbeddingModel(const std::shared_ptr& config) : EmbeddingModel(config) { + client_ = std::make_unique(config_->base_url); + client_->set_default_headers({ + {"Authorization", "Bearer " + config_->api_key} + }); + } + + std::vector embed(const std::string& text, EmbeddingType type) override; +}; + +} // namespace humanus::mem0 + +#endif // HUMANUS_MEMORY_MEM0_EMBEDDING_MODEL_OAI_H \ No newline at end of file diff --git a/memory/mem0/mem0.h b/memory/mem0/mem0.h deleted file mode 100644 index 1f39c03..0000000 --- a/memory/mem0/mem0.h +++ /dev/null @@ -1,116 +0,0 @@ -#ifndef HUMANUS_MEMORY_MEM0_H -#define HUMANUS_MEMORY_MEM0_H - -#include "memory/base.h" -#include "storage.h" -#include "vector_store.h" -#include "prompt.h" - -namespace humanus { - -namespace mem0 { - -struct Config { - // Prompt config - std::string fact_extraction_prompt; - std::string update_memory_prompt; - - // Database config - // std::string history_db_path = ":memory:"; - - // Embedder config - EmbedderConfig embedder_config; - - // Vector store config - VectorStoreConfig vector_store_config; - - // Optional: LLM config - LLMConfig llm_config; -}; - -struct Memory : BaseMemory { - Config config; - std::string fact_extraction_prompt; - std::string update_memory_prompt; - - std::shared_ptr embedder; - std::shared_ptr vector_store; - std::shared_ptr llm; - // std::shared_ptr db; - - Memory(const Config& config) : config(config) { - fact_extraction_prompt = config.fact_extraction_prompt; - update_memory_prompt = config.update_memory_prompt; - - embedder = std::make_shared(config.embedder_config); - vector_store = std::make_shared(config.vector_store_config); - llm = std::make_shared(config.llm_config); - // db = std::make_shared(config.history_db_path); - } - - void add_message(const Message& message) override { - if (config.llm_config.enable_vision) { - message = parse_vision_messages(message, llm, config.llm_config.vision_details); - } else { - message = parse_vision_messages(message); - } - - _add_to_vector_store(message); - } - - void _add_to_vector_store(const Message& message) { - std::string parsed_message = parse_message(message); - - std::string system_prompt; - std::string user_prompt = "Input:\n" + parsed_message; - - if (!fact_extraction_prompt.empty()) { - system_prompt = fact_extraction_prompt; - } else { - system_prompt = FACT_EXTRACTION_PROMPT; - } - - Message user_message = Message::user_message(user_prompt); - - std::string response = llm->ask( - {user_message}, - system_prompt - ); - - std::vector new_retrieved_facts; - - try { - response = remove_code_blocks(response); - new_retrieved_facts = json::parse(response)["facts"].get>(); - } catch (const std::exception& e) { - LOG_ERROR("Error in new_retrieved_facts: " + std::string(e.what())); - } - - std::vector retrieved_old_memory; - std::map> new_message_embeddings; - - for (const auto& fact : new_retrieved_facts) { - auto message_embedding = embedder->embed(fact); - new_message_embeddings[fact] = message_embedding; - auto existing_memories = vector_store->search( - message_embedding, - 5, - filters - ) - for (const auto& memory : existing_memories) { - retrieved_old_memory.push_back({ - {"id", memory.id}, - {"text", memory.payload["data"]} - }); - } - } - - - } -}; - -} // namespace mem0 - -} // namespace humanus - -#endif // HUMANUS_MEMORY_MEM0_H \ No newline at end of file diff --git a/memory/mem0/storage.h b/memory/mem0/storage.h index 94e44ab..b302d6e 100644 --- a/memory/mem0/storage.h +++ b/memory/mem0/storage.h @@ -2,10 +2,9 @@ #define HUMANUS_MEMORY_MEM0_STORAGE_H #include +#include -namespace humanus { - -namespace mem0 { +namespace humanus::mem0 { struct SQLiteManager { std::shared_ptr db; @@ -17,7 +16,7 @@ struct SQLiteManager { throw std::runtime_error("Failed to open database: " + std::string(sqlite3_errmsg(db))); } _migrate_history_table(); - _create_history_table() + _create_history_table(); } void _migrate_history_table() { @@ -142,8 +141,6 @@ struct SQLiteManager { } }; -} // namespace mem0 - -} // namespace humanus +} // namespace humanus::mem0 #endif // HUMANUS_MEMORY_MEM0_STORAGE_H diff --git a/memory/mem0/utils.h b/memory/mem0/utils.h new file mode 100644 index 0000000..e1308d0 --- /dev/null +++ b/memory/mem0/utils.h @@ -0,0 +1,110 @@ +#ifndef HUMANUS_MEMORY_MEM0_UTILS_H +#define HUMANUS_MEMORY_MEM0_UTILS_H + +#include "schema.h" +#include "llm.h" + +namespace humanus::mem0 { + +static size_t uuid() { + const std::string chars = "0123456789abcdef"; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, chars.size() - 1); + + unsigned long long int uuid_int = 0; + for (int i = 0; i < 16; ++i) { + uuid_int = (uuid_int << 4) | dis(gen); + } + + // RFC 4122 variant + uuid_int &= ~(0xc000ULL << 48); + uuid_int |= 0x8000ULL << 48; + + // version 4, random UUID + int version = 4; + uuid_int &= ~(0xfULL << 12); + uuid_int |= static_cast(version) << 12; + + return uuid_int; +} + +std::string get_update_memory_messages(const json& retrieved_old_memory, const json& new_retrieved_facts, const std::string fact_extraction_prompt, const std::string& update_memory_prompt) { + std::stringstream ss; + ss << fact_extraction_prompt << "\n\n"; + ss << "Below is the current content of my memory which I have collected till now. You have to update it in the following format only:\n\n"; + ss << "```" + retrieved_old_memory.dump(2) + "```\n\n"; + ss << "The new retrieved facts are mentioned in the triple backticks. You have to analyze the new retrieved facts and determine whether these facts should be added, updated, or deleted in the memory.\n\n"; + ss << "```" + new_retrieved_facts.dump(2) + "```\n\n"; + ss << "You must return your response in the following JSON structure only:\n\n"; + ss << R"json({ +"memory" : [ +{ + "id" : "", # Use existing ID for updates/deletes, or new ID for additions + "text" : "", # Content of the memory + "event" : "", # Must be "ADD", "UPDATE", "DELETE", or "NONE" + "old_memory" : "" # Required only if the event is "UPDATE" +}, +... +] + })json" << "\n\n"; + ss << "Follow the instruction mentioned below:\n" + << "- Do not return anything from the custom few shot prompts provided above.\n" + << "- If the current memory is empty, then you have to add the new retrieved facts to the memory.\n" + << "- You should return the updated memory in only JSON format as shown below. The memory key should be the same if no changes are made.\n" + << "- If there is an addition, generate a new key and add the new memory corresponding to it.\n" + << "- If there is a deletion, the memory key-value pair should be removed from the memory.\n" + << "- If there is an update, the ID key should remain the same and only the value needs to be updated.\n" + << "\n"; + ss << "Do not return anything except the JSON format.\n"; + return ss.str(); +} + +// Get the description of the image +// image_url should be like: data:{mime_type};base64,{base64_data} +std::string get_image_description(const std::string& image_url, const std::shared_ptr& llm, const std::string& vision_details) { + if (!llm) { + return "Here is an image failed to get description due to missing LLM instance."; + } + + json content = json::array({ + { + {"type", "text"}, + {"text", "A user is providing an image. Provide a high level description of the image and do not include any additional text."} + }, + { + {"type", "image_url"}, + {"image_url", { + {"url", image_url}, + {"detail", vision_details} + }} + } + }); + return llm->ask( + {Message::user_message(content)} + ); +} + +// Parse the vision messages from the messages +Message parse_vision_message(const Message& message, const std::shared_ptr& llm = nullptr, const std::string& vision_details = "auto") { + Message returned_message = message; + + if (returned_message.content.is_array()) { + // Multiple image URLs in content + for (auto& content_item : returned_message.content) { + if (content_item["type"] == "image_url") { + auto description = get_image_description(content_item["image_url"]["url"], llm, vision_details); + content_item = description; + } + } + } else if (returned_message.content.is_object() && returned_message.content["type"] == "image_url") { + auto image_url = returned_message.content["image_url"]["url"]; + returned_message.content = get_image_description(image_url, llm, vision_details); + } + + return returned_message; +} + +} + +#endif // HUMANUS_MEMORY_MEM0_UTILS_H \ No newline at end of file diff --git a/memory/mem0/vector_store/base.cpp b/memory/mem0/vector_store/base.cpp new file mode 100644 index 0000000..e936220 --- /dev/null +++ b/memory/mem0/vector_store/base.cpp @@ -0,0 +1,27 @@ +#include "base.h" +#include "hnswlib.h" + +namespace humanus::mem0 { + +std::unordered_map> VectorStore::instances_; + +std::shared_ptr VectorStore::get_instance(const std::string& config_name, const std::shared_ptr& config) { + if (instances_.find(config_name) == instances_.end()) { + auto config_ = config; + if (!config_) { + if (Config::get_instance().vector_store().find(config_name) == Config::get_instance().vector_store().end()) { + throw std::invalid_argument("Vector store config not found: " + config_name); + } + config_ = std::make_shared(Config::get_instance().vector_store().at(config_name)); + } + + if (config_->provider == "hnswlib") { + instances_[config_name] = std::make_shared(config_); + } else { + throw std::invalid_argument("Unsupported embedding model provider: " + config_->provider); + } + } + return instances_[config_name]; +} + +} // namespace humanus::mem0 \ No newline at end of file diff --git a/memory/mem0/vector_store/base.h b/memory/mem0/vector_store/base.h index f9b7e24..0f5e916 100644 --- a/memory/mem0/vector_store/base.h +++ b/memory/mem0/vector_store/base.h @@ -1,55 +1,51 @@ #ifndef HUMANUS_MEMORY_MEM0_VECTOR_STORE_BASE_H #define HUMANUS_MEMORY_MEM0_VECTOR_STORE_BASE_H -#include "hnswlib/hnswlib.h" +#include "config.h" +#include +#include +#include -namespace humanus { +namespace humanus::mem0 { -namespace mem0 { +class VectorStore { +private: + static std::unordered_map> instances_; -struct VectorStoreConfig { - int dim = 16; // Dimension of the elements - int max_elements = 10000; // Maximum number of elements, should be known beforehand - int M = 16; // Tightly connected with internal dimensionality of the data - // strongly affects the memory consumption - int ef_construction = 200; // Controls index search speed/build speed tradeoff - enum class Metric { - L2, - IP - }; - Metric metric = Metric::L2; -}; +protected: + std::shared_ptr config_; -struct VectorStoreBase { - VectorStoreConfig config; + // Constructor + VectorStore(const std::shared_ptr& config) : config_(config) {} - VectorStoreBase(const VectorStoreConfig& config) : config(config) { - reset(); - } +public: + // Get the singleton instance + static std::shared_ptr get_instance(const std::string& config_name = "default", const std::shared_ptr& config = nullptr); - virtual void reset() = 0; + virtual ~VectorStore() = default; + + virtual void reset() = 0; /** * @brief 插入向量到集合中 - * @param vectors 向量数据 - * @param payloads 可选的负载数据 - * @param ids 可选的ID列表 - * @return 插入的向量ID列表 + * @param vector 向量数据 + * @param vector_id 向量ID + * @param metadata 可选的元数据 */ - virtual std::vector insert(const std::vector>& vectors, - const std::vector& payloads = {}, - const std::vector& ids = {}) = 0; + virtual void insert(const std::vector& vector, + const size_t vector_id, + const json& metadata = json::object()) = 0; /** * @brief 搜索相似向量 * @param query 查询向量 * @param limit 返回结果数量限制 * @param filters 可选的过滤条件 - * @return 相似向量的ID和距离 + * @return 相似向量的MemoryItem列表 */ - std::vector>> search(const std::vector& query, - int limit = 5, - const std::string& filters = "") = 0; + virtual std::vector search(const std::vector& query, + int limit = 5, + const std::string& filters = "") = 0; /** * @brief 通过ID删除向量 @@ -61,18 +57,18 @@ struct VectorStoreBase { * @brief 更新向量及其负载 * @param vector_id 向量ID * @param vector 可选的新向量数据 - * @param payload 可选的新负载数据 + * @param metadata 可选的新负载数据 */ virtual void update(size_t vector_id, - const std::vector* vector = nullptr, - const std::string* payload = nullptr) = 0; + const std::vector vector = std::vector(), + const json& metadata = json::object()) = 0; /** * @brief 通过ID获取向量 * @param vector_id 向量ID * @return 向量数据 */ - virtual std::vector get(size_t vector_id) = 0; + virtual MemoryItem get(size_t vector_id) = 0; /** * @brief 列出所有记忆 @@ -80,12 +76,10 @@ struct VectorStoreBase { * @param limit 可选的结果数量限制 * @return 记忆ID列表 */ - virtual std::vector list(const std::string& filters = "", int limit = 0) = 0; + virtual std::vector list(const std::string& filters = "", int limit = 0) = 0; }; -} - -} +} // namespace humanus::mem0 #endif // HUMANUS_MEMORY_MEM0_VECTOR_STORE_BASE_H diff --git a/memory/mem0/vector_store/hnswlib.cpp b/memory/mem0/vector_store/hnswlib.cpp new file mode 100644 index 0000000..f15cda3 --- /dev/null +++ b/memory/mem0/vector_store/hnswlib.cpp @@ -0,0 +1,165 @@ +#include "hnswlib/hnswlib.h" +#include "hnswlib.h" +#include +#include + +namespace humanus::mem0 { + +void HNSWLibVectorStore::reset() { + if (hnsw) { + hnsw.reset(); + } + if (space) { + space.reset(); + } + + metadata_store.clear(); + + if (config_->metric == VectorStoreConfig::Metric::L2) { + space = std::make_shared(config_->dim); + hnsw = std::make_shared>(space.get(), config_->max_elements, config_->M, config_->ef_construction); + } else if (config_->metric == VectorStoreConfig::Metric::IP) { + space = std::make_shared(config_->dim); + hnsw = std::make_shared>(space.get(), config_->max_elements, config_->M, config_->ef_construction); + } else { + throw std::invalid_argument("Unsupported metric: " + std::to_string(static_cast(config_->metric))); + } +} + +void HNSWLibVectorStore::insert(const std::vector& vector, + const size_t vector_id, + const json& metadata) { + hnsw->addPoint(vector.data(), vector_id); + + // 存储元数据 + auto now = std::chrono::system_clock::now().time_since_epoch().count(); + json _metadata = metadata; + if (!_metadata.contains("created_at")) { + _metadata["created_at"] = now; + } + if (!_metadata.contains("updated_at")) { + _metadata["updated_at"] = now; + } + + metadata_store[vector_id] = _metadata; +} + +std::vector HNSWLibVectorStore::search(const std::vector& query, + int limit = 5, + const std::string& filters = "") { + auto results = hnsw->searchKnn(query.data(), limit); + std::vector memory_items; + + while (!results.empty()) { + const auto& [id, distance] = results.top(); + + results.pop(); + + if (metadata_store.find(id) != metadata_store.end()) { + MemoryItem item; + item.id = id; + + if (metadata_store[id].contains("data")) { + item.memory = metadata_store[id]["data"]; + } + + if (metadata_store[id].contains("hash")) { + item.hash = metadata_store[id]["hash"]; + } + + item.metadata = metadata_store[id]; + item.score = distance; + + memory_items.push_back(item); + } + } + + return memory_items; +} + +void HNSWLibVectorStore::delete_vector(size_t vector_id) { + hnsw->markDelete(vector_id); + metadata_store.erase(vector_id); +} + +void HNSWLibVectorStore::update(size_t vector_id, + const std::vector vector = std::vector(), + const json& metadata = json::object()) { + // 检查向量是否需要更新 + if (!vector.empty()) { + hnsw->markDelete(vector_id); + hnsw->addPoint(vector.data(), vector_id); + } + + // 更新元数据 + if (metadata_store.find(vector_id) != metadata_store.end()) { + auto now = std::chrono::system_clock::now().time_since_epoch().count(); + + // 合并现有元数据和新元数据 + for (auto& [key, value] : metadata.items()) { + metadata_store[vector_id][key] = value; + } + + // 更新时间戳 + metadata_store[vector_id]["updated_at"] = now; + } else if (!metadata.empty()) { + // 如果元数据不存在但提供了新的元数据,则创建新条目 + auto now = std::chrono::system_clock::now().time_since_epoch().count(); + json new_metadata = metadata; + if (!new_metadata.contains("created_at")) { + new_metadata["created_at"] = now; + } + new_metadata["updated_at"] = now; + metadata_store[vector_id] = new_metadata; + } +} + +MemoryItem HNSWLibVectorStore::get(size_t vector_id) { + MemoryItem item; + item.id = vector_id; + + // 获取向量数据 + std::vector vector_data = hnsw->getDataByLabel(vector_id); + + // 获取元数据 + if (metadata_store.find(vector_id) != metadata_store.end()) { + if (metadata_store[vector_id].contains("data")) { + item.memory = metadata_store[vector_id]["data"]; + } + + if (metadata_store[vector_id].contains("hash")) { + item.hash = metadata_store[vector_id]["hash"]; + } + + item.metadata = metadata_store[vector_id]; + } + + return item; +} + +std::vector HNSWLibVectorStore::list(const std::string& filters = "", int limit = 0) { + std::vector result; + size_t count = hnsw->cur_element_count; + + for (size_t i = 0; i < count; i++) { + if (!hnsw->isMarkedDeleted(i)) { + // 如果有过滤条件,检查元数据是否匹配 + if (!filters.empty() && metadata_store.find(i) != metadata_store.end()) { + // 简单的字符串匹配过滤,可以根据需要扩展 + json metadata_json = metadata_store[i]; + std::string metadata_str = metadata_json.dump(); + if (metadata_str.find(filters) == std::string::npos) { + continue; + } + } + + result.emplace_back(get(i)); + if (limit > 0 && result.size() >= static_cast(limit)) { + break; + } + } + } + + return result; +} +}; diff --git a/memory/mem0/vector_store/hnswlib.h b/memory/mem0/vector_store/hnswlib.h index 5b00f7e..c6822d2 100644 --- a/memory/mem0/vector_store/hnswlib.h +++ b/memory/mem0/vector_store/hnswlib.h @@ -1,124 +1,37 @@ #ifndef HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H #define HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H +#include "base.h" #include "hnswlib/hnswlib.h" -namespace humanus { +namespace humanus::mem0 { -namespace mem0 { - -struct HNSWLIBVectorStore { - VectorStoreConfig config; +class HNSWLibVectorStore : public VectorStore { +private: std::shared_ptr> hnsw; + std::shared_ptr> space; // 保持space对象的引用以确保其生命周期 + std::unordered_map metadata_store; // 存储向量的元数据 - HNSWLIBVectorStore(const VectorStoreConfig& config) : config(config) { +public: + HNSWLibVectorStore(const std::shared_ptr& config) : VectorStore(config) { reset(); } - void reset() { - if (hnsw) { - hnsw.reset(); - } - if (config.metric == Metric::L2) { - hnswlib::L2Space space(config.dim); - hnsw = std::make_shared(&space, config.max_elements, config.M, config.ef_construction); - } else if (config.metric == Metric::IP) { - hnswlib::InnerProductSpace space(config.dim); - hnsw = std::make_shared(&space, config.max_elements, config.M, config.ef_construction); - } else { - throw std::invalid_argument("Unsupported metric: " + std::to_string(static_cast(config.metric))); - } - } + void reset() override; - /** - * @brief 插入向量到集合中 - * @param vectors 向量数据 - * @param payloads 可选的负载数据 - * @param ids 可选的ID列表 - * @return 插入的向量ID列表 - */ - std::vector insert(const std::vector>& vectors, - const std::vector& payloads = {}, - const std::vector& ids = {}) { - std::vector result_ids; - for (size_t i = 0; i < vectors.size(); i++) { - size_t id = ids.size() > i ? ids[i] : hnsw->cur_element_count; - hnsw->addPoint(vectors[i].data(), id); - result_ids.push_back(id); - } - return result_ids; - } + void insert(const std::vector& vector, const size_t vector_id, const json& metadata) override; - /** - * @brief 搜索相似向量 - * @param query 查询向量 - * @param limit 返回结果数量限制 - * @param filters 可选的过滤条件 - * @return 相似向量的ID和距离 - */ - std::vector>> search(const std::vector& query, - int limit = 5, - const std::string& filters = "") { - return hnsw->searchKnn(query.data(), limit); - } + std::vector search(const std::vector& query, int limit, const std::string& filters) override; - /** - * @brief 通过ID删除向量 - * @param vector_id 向量ID - */ - void delete_vector(size_t vector_id) { - hnsw->markDelete(vector_id); - } + void delete_vector(size_t vector_id) override; - /** - * @brief 更新向量及其负载 - * @param vector_id 向量ID - * @param vector 可选的新向量数据 - * @param payload 可选的新负载数据 - */ - void update(size_t vector_id, - const std::vector* vector = nullptr, - const std::string* payload = nullptr) { - if (vector) { - hnsw->markDelete(vector_id); - hnsw->addPoint(vector->data(), vector_id); - } - } + void update(size_t vector_id, const std::vector vector, const json& metadata) override; - /** - * @brief 通过ID获取向量 - * @param vector_id 向量ID - * @return 向量数据 - */ - std::vector get(size_t vector_id) { - std::vector result(config.dimension); - hnsw->getDataByLabel(vector_id, result.data()); - return result; - } + MemoryItem get(size_t vector_id) override; - /** - * @brief 列出所有记忆 - * @param filters 可选的过滤条件 - * @param limit 可选的结果数量限制 - * @return 记忆ID列表 - */ - std::vector list(const std::string& filters = "", int limit = 0) { - std::vector result; - size_t count = hnsw->cur_element_count; - for (size_t i = 0; i < count; i++) { - if (!hnsw->isMarkedDeleted(i)) { - result.push_back(i); - if (limit > 0 && result.size() >= static_cast(limit)) { - break; - } - } - } - return result; - } + std::vector list(const std::string& filters, int limit) override; }; - + } -} - -#endif // HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H +#endif // HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H \ No newline at end of file diff --git a/prompt.cpp b/prompt.cpp index ae88fb9..1319e1b 100644 --- a/prompt.cpp +++ b/prompt.cpp @@ -124,7 +124,7 @@ Following is a conversation between the user and the assistant. You have to extr You should detect the language of the user input and record the facts in the same language. )"; -const char* DEFAULT_UPDATE_MEMORY_PROMPT = R"(You are a smart memory manager which controls the memory of a system. +const char* UPDATE_MEMORY_PROMPT = R"(You are a smart memory manager which controls the memory of a system. You can perform four operations: (1) add into the memory, (2) update the memory, (3) delete from the memory, and (4) no change. Based on the above four operations, the memory will change. diff --git a/prompt.h b/prompt.h index 3afd755..7730b38 100644 --- a/prompt.h +++ b/prompt.h @@ -26,13 +26,13 @@ extern const char* NEXT_STEP_PROMPT; extern const char* TOOL_HINT_TEMPLATE; } // namespace toolcall -} // namespace prompt - namespace mem0 { extern const char* FACT_EXTRACTION_PROMPT; extern const char* UPDATE_MEMORY_PROMPT; } // namespace mem0 +} // namespace prompt + } // namespace humanus #endif // HUMANUS_PROMPT_H