#ifndef HUMANUS_MEMORY_MEM0_H #define HUMANUS_MEMORY_MEM0_H #include "memory/base.h" #include "vector_store/base.h" #include "embedding_model/base.h" #include "schema.h" #include "prompt.h" #include "httplib.h" #include "llm.h" #include "utils.h" namespace humanus::mem0 { struct Memory : BaseMemory { MemoryConfig config; std::string fact_extraction_prompt; std::string update_memory_prompt; int max_messages; int limit; std::string filters; std::shared_ptr embedding_model; std::shared_ptr vector_store; std::shared_ptr llm; // std::shared_ptr db; Memory(const MemoryConfig& config) : config(config) { fact_extraction_prompt = config.fact_extraction_prompt; update_memory_prompt = config.update_memory_prompt; max_messages = config.max_messages; limit = config.limit; filters = config.filters; embedding_model = EmbeddingModel::get_instance("mem0_" + std::to_string(reinterpret_cast(this)), config.embedding_model_config); vector_store = VectorStore::get_instance("mem0_" + std::to_string(reinterpret_cast(this)), config.vector_store_config); llm = LLM::get_instance("mem0_" + std::to_string(reinterpret_cast(this)), config.llm_config); // db = std::make_shared(config.history_db_path); } void add_message(const Message& message) override { while (!messages.empty() && (messages.size() > max_messages || messages.begin()->role == "assistant" || messages.begin()->role == "tool")) { // Ensure the first message is always a user or system message Message front_message = *messages.begin(); messages.erase(messages.begin()); if (config.llm_config->enable_vision) { front_message = parse_vision_message(front_message, llm, config.llm_config->vision_details); } else { front_message = parse_vision_message(front_message); } _add_to_vector_store(front_message); } messages.push_back(message); } std::vector get_messages(const std::string& query = "") const override { auto embeddings = embedding_model->embed(query, EmbeddingType::SEARCH); std::vector memories; // 检查vector_store是否已初始化 if (vector_store) { memories = vector_store->search(embeddings, limit, filters); } std::string memory_prompt; for (const auto& memory_item : memories) { memory_prompt += "" + memory_item.memory + ""; } std::vector messages_with_memory{Message::user_message(memory_prompt)}; messages_with_memory.insert(messages_with_memory.end(), messages.begin(), messages.end()); return messages_with_memory; } void _add_to_vector_store(const Message& message) { // 检查vector_store是否已初始化 if (!vector_store) { logger->warn("Vector store is not initialized, skipping memory operation"); return; } std::string parsed_message = message.role + ": " + (message.content.is_string() ? message.content.get() : message.content.dump()); for (const auto& tool_call : message.tool_calls) { parsed_message += "" + tool_call.to_json().dump() + ""; } std::string system_prompt = fact_extraction_prompt; std::string user_prompt = "Input:\n" + parsed_message; Message user_message = Message::user_message(user_prompt); std::string response = llm->ask( {user_message}, system_prompt ); json new_retrieved_facts; // ["fact1", "fact2", "fact3"] try { // response = remove_code_blocks(response); new_retrieved_facts = json::parse(response)["facts"]; } catch (const std::exception& e) { logger->error("Error in new_retrieved_facts: " + std::string(e.what())); } std::vector retrieved_old_memory; std::map> new_message_embeddings; for (const auto& fact : new_retrieved_facts) { auto message_embedding = embedding_model->embed(fact, EmbeddingType::ADD); new_message_embeddings[fact] = message_embedding; auto existing_memories = vector_store->search( message_embedding, 5 ); for (const auto& memory : existing_memories) { retrieved_old_memory.push_back({ {"id", memory.id}, {"text", memory.metadata["data"]} }); } } // sort and unique by id std::sort(retrieved_old_memory.begin(), retrieved_old_memory.end(), [](const json& a, const json& b) { return a["id"] < b["id"]; }); retrieved_old_memory.resize(std::unique(retrieved_old_memory.begin(), retrieved_old_memory.end(), [](const json& a, const json& b) { return a["id"] == b["id"]; }) - retrieved_old_memory.begin()); logger->info("Total existing memories: " + std::to_string(retrieved_old_memory.size())); // mapping UUIDs with integers for handling UUID hallucinations std::vector temp_uuid_mapping; for (size_t idx = 0; idx < retrieved_old_memory.size(); ++idx) { temp_uuid_mapping.push_back(retrieved_old_memory[idx]["id"].get()); retrieved_old_memory[idx]["id"] = idx; } std::string function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts, fact_extraction_prompt, update_memory_prompt); std::string new_memories_with_actions_str; json new_memories_with_actions = json::array(); try { new_memories_with_actions_str = llm->ask( {Message::user_message(function_calling_prompt)} ); new_memories_with_actions = json::parse(new_memories_with_actions_str); } catch (const std::exception& e) { logger->error("Error in new_memories_with_actions: " + std::string(e.what())); } try { // new_memories_with_actions_str = remove_code_blocks(new_memories_with_actions_str); new_memories_with_actions = json::parse(new_memories_with_actions_str); } catch (const std::exception& e) { logger->error("Invalid JSON response: " + std::string(e.what())); } try { for (const auto& resp : new_memories_with_actions.value("memory", json::array())) { logger->info("Processing memory: " + resp.dump(2)); try { if (!resp.contains("text")) { logger->info("Skipping memory entry because of empty `text` field."); continue; } std::string event = resp.value("event", "NONE"); size_t memory_id = resp.contains("id") ? temp_uuid_mapping[resp["id"].get()] : uuid(); if (event == "ADD") { _create_memory( memory_id, resp["text"], // data new_message_embeddings // existing_embeddings ); } else if (event == "UPDATE") { _update_memory( memory_id, resp["text"], // data new_message_embeddings // existing_embeddings ); } else if (event == "DELETE") { _delete_memory(memory_id); } else if (event == "NONE") { logger->info("NOOP for Memory."); } } catch (const std::exception& e) { logger->error("Error in new_memories_with_actions: " + std::string(e.what())); } } } catch (const std::exception& e) { logger->error("Error in new_memories_with_actions: " + std::string(e.what())); } } void _create_memory(const size_t& memory_id, const std::string& data, const std::map>& existing_embeddings) { if (!vector_store) { logger->warn("Vector store is not initialized, skipping create memory"); return; } std::vector embedding; if (existing_embeddings.find(data) != existing_embeddings.end()) { embedding = existing_embeddings.at(data); } else { embedding = embedding_model->embed(data, EmbeddingType::ADD); } auto created_at = std::chrono::system_clock::now(); json metadata = { {"data", data}, {"hash", httplib::detail::MD5(data)}, {"created_at", std::chrono::system_clock::now().time_since_epoch().count()} }; vector_store->insert( embedding, memory_id, metadata ); } void _update_memory(const size_t& memory_id, const std::string& data, const std::map>& existing_embeddings) { if (!vector_store) { logger->warn("Vector store is not initialized, skipping update memory"); return; } logger->info("Updating memory with " + data); MemoryItem existing_memory; try { existing_memory = vector_store->get(memory_id); } catch (const std::exception& e) { logger->error("Error fetching existing memory: " + std::string(e.what())); return; } std::vector embedding; if (existing_embeddings.find(data) != existing_embeddings.end()) { embedding = existing_embeddings.at(data); } else { embedding = embedding_model->embed(data, EmbeddingType::ADD); } json metadata = existing_memory.metadata; metadata["data"] = data; metadata["hash"] = httplib::detail::MD5(data); metadata["updated_at"] = std::chrono::system_clock::now().time_since_epoch().count(); vector_store->update( memory_id, embedding, metadata ); } void _delete_memory(const size_t& memory_id) { if (!vector_store) { logger->warn("Vector store is not initialized, skipping delete memory"); return; } logger->info("Deleting memory: " + std::to_string(memory_id)); vector_store->delete_vector(memory_id); } }; } // namespace humanus::mem0 #endif // HUMANUS_MEMORY_MEM0_H