humanus.cpp/memory/mem0/base.h

280 lines
11 KiB
C++

#ifndef HUMANUS_MEMORY_MEM0_H
#define HUMANUS_MEMORY_MEM0_H
#include "memory/base.h"
#include "vector_store/base.h"
#include "embedding_model/base.h"
#include "schema.h"
#include "prompt.h"
#include "httplib.h"
#include "llm.h"
#include "utils.h"
namespace humanus::mem0 {
struct Memory : BaseMemory {
MemoryConfig config;
std::string fact_extraction_prompt;
std::string update_memory_prompt;
int max_messages;
int limit;
std::string filters;
std::shared_ptr<EmbeddingModel> embedding_model;
std::shared_ptr<VectorStore> vector_store;
std::shared_ptr<LLM> llm;
// std::shared_ptr<SQLiteManager> db;
Memory(const MemoryConfig& config) : config(config) {
fact_extraction_prompt = config.fact_extraction_prompt;
update_memory_prompt = config.update_memory_prompt;
max_messages = config.max_messages;
limit = config.limit;
filters = config.filters;
embedding_model = EmbeddingModel::get_instance("mem0_" + std::to_string(reinterpret_cast<uintptr_t>(this)), config.embedding_model_config);
vector_store = VectorStore::get_instance("mem0_" + std::to_string(reinterpret_cast<uintptr_t>(this)), config.vector_store_config);
llm = LLM::get_instance("mem0_" + std::to_string(reinterpret_cast<uintptr_t>(this)), config.llm_config);
// db = std::make_shared<SQLiteManager>(config.history_db_path);
}
void add_message(const Message& message) override {
while (!messages.empty() && (messages.size() > max_messages || messages.begin()->role == "assistant" || messages.begin()->role == "tool")) {
// Ensure the first message is always a user or system message
Message front_message = *messages.begin();
messages.erase(messages.begin());
if (config.llm_config->enable_vision) {
front_message = parse_vision_message(front_message, llm, config.llm_config->vision_details);
} else {
front_message = parse_vision_message(front_message);
}
_add_to_vector_store(front_message);
}
messages.push_back(message);
}
std::vector<Message> get_messages(const std::string& query = "") const override {
auto embeddings = embedding_model->embed(query, EmbeddingType::SEARCH);
std::vector<MemoryItem> memories;
// 检查vector_store是否已初始化
if (vector_store) {
memories = vector_store->search(embeddings, limit, filters);
}
std::string memory_prompt;
for (const auto& memory_item : memories) {
memory_prompt += "<memory>" + memory_item.memory + "</memory>";
}
std::vector<Message> messages_with_memory{Message::user_message(memory_prompt)};
messages_with_memory.insert(messages_with_memory.end(), messages.begin(), messages.end());
return messages_with_memory;
}
void _add_to_vector_store(const Message& message) {
// 检查vector_store是否已初始化
if (!vector_store) {
logger->warn("Vector store is not initialized, skipping memory operation");
return;
}
std::string parsed_message = message.role + ": " + (message.content.is_string() ? message.content.get<std::string>() : message.content.dump());
for (const auto& tool_call : message.tool_calls) {
parsed_message += "<tool_call>" + tool_call.to_json().dump() + "</tool_call>";
}
std::string system_prompt = fact_extraction_prompt;
std::string user_prompt = "Input:\n" + parsed_message;
Message user_message = Message::user_message(user_prompt);
std::string response = llm->ask(
{user_message},
system_prompt
);
json new_retrieved_facts; // ["fact1", "fact2", "fact3"]
try {
// response = remove_code_blocks(response);
new_retrieved_facts = json::parse(response)["facts"];
} catch (const std::exception& e) {
logger->error("Error in new_retrieved_facts: " + std::string(e.what()));
}
std::vector<json> retrieved_old_memory;
std::map<std::string, std::vector<float>> new_message_embeddings;
for (const auto& fact : new_retrieved_facts) {
auto message_embedding = embedding_model->embed(fact, EmbeddingType::ADD);
new_message_embeddings[fact] = message_embedding;
auto existing_memories = vector_store->search(
message_embedding,
5
);
for (const auto& memory : existing_memories) {
retrieved_old_memory.push_back({
{"id", memory.id},
{"text", memory.metadata["data"]}
});
}
}
// sort and unique by id
std::sort(retrieved_old_memory.begin(), retrieved_old_memory.end(), [](const json& a, const json& b) {
return a["id"] < b["id"];
});
retrieved_old_memory.resize(std::unique(retrieved_old_memory.begin(), retrieved_old_memory.end(), [](const json& a, const json& b) {
return a["id"] == b["id"];
}) - retrieved_old_memory.begin());
logger->info("Total existing memories: " + std::to_string(retrieved_old_memory.size()));
// mapping UUIDs with integers for handling UUID hallucinations
std::vector<size_t> temp_uuid_mapping;
for (size_t idx = 0; idx < retrieved_old_memory.size(); ++idx) {
temp_uuid_mapping.push_back(retrieved_old_memory[idx]["id"].get<size_t>());
retrieved_old_memory[idx]["id"] = idx;
}
std::string function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts, fact_extraction_prompt, update_memory_prompt);
std::string new_memories_with_actions_str;
json new_memories_with_actions = json::array();
try {
new_memories_with_actions_str = llm->ask(
{Message::user_message(function_calling_prompt)}
);
new_memories_with_actions = json::parse(new_memories_with_actions_str);
} catch (const std::exception& e) {
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
}
try {
// new_memories_with_actions_str = remove_code_blocks(new_memories_with_actions_str);
new_memories_with_actions = json::parse(new_memories_with_actions_str);
} catch (const std::exception& e) {
logger->error("Invalid JSON response: " + std::string(e.what()));
}
try {
for (const auto& resp : new_memories_with_actions.value("memory", json::array())) {
logger->info("Processing memory: " + resp.dump(2));
try {
if (!resp.contains("text")) {
logger->info("Skipping memory entry because of empty `text` field.");
continue;
}
std::string event = resp.value("event", "NONE");
size_t memory_id = resp.contains("id") ? temp_uuid_mapping[resp["id"].get<size_t>()] : uuid();
if (event == "ADD") {
_create_memory(
memory_id,
resp["text"], // data
new_message_embeddings // existing_embeddings
);
} else if (event == "UPDATE") {
_update_memory(
memory_id,
resp["text"], // data
new_message_embeddings // existing_embeddings
);
} else if (event == "DELETE") {
_delete_memory(memory_id);
} else if (event == "NONE") {
logger->info("NOOP for Memory.");
}
} catch (const std::exception& e) {
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
}
}
} catch (const std::exception& e) {
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
}
}
void _create_memory(const size_t& memory_id, const std::string& data, const std::map<std::string, std::vector<float>>& existing_embeddings) {
if (!vector_store) {
logger->warn("Vector store is not initialized, skipping create memory");
return;
}
std::vector<float> embedding;
if (existing_embeddings.find(data) != existing_embeddings.end()) {
embedding = existing_embeddings.at(data);
} else {
embedding = embedding_model->embed(data, EmbeddingType::ADD);
}
auto created_at = std::chrono::system_clock::now();
json metadata = {
{"data", data},
{"hash", httplib::detail::MD5(data)},
{"created_at", std::chrono::system_clock::now().time_since_epoch().count()}
};
vector_store->insert(
embedding,
memory_id,
metadata
);
}
void _update_memory(const size_t& memory_id, const std::string& data, const std::map<std::string, std::vector<float>>& existing_embeddings) {
if (!vector_store) {
logger->warn("Vector store is not initialized, skipping update memory");
return;
}
logger->info("Updating memory with " + data);
MemoryItem existing_memory;
try {
existing_memory = vector_store->get(memory_id);
} catch (const std::exception& e) {
logger->error("Error fetching existing memory: " + std::string(e.what()));
return;
}
std::vector<float> embedding;
if (existing_embeddings.find(data) != existing_embeddings.end()) {
embedding = existing_embeddings.at(data);
} else {
embedding = embedding_model->embed(data, EmbeddingType::ADD);
}
json metadata = existing_memory.metadata;
metadata["data"] = data;
metadata["hash"] = httplib::detail::MD5(data);
metadata["updated_at"] = std::chrono::system_clock::now().time_since_epoch().count();
vector_store->update(
memory_id,
embedding,
metadata
);
}
void _delete_memory(const size_t& memory_id) {
if (!vector_store) {
logger->warn("Vector store is not initialized, skipping delete memory");
return;
}
logger->info("Deleting memory: " + std::to_string(memory_id));
vector_store->delete_vector(memory_id);
}
};
} // namespace humanus::mem0
#endif // HUMANUS_MEMORY_MEM0_H