humanus.cpp/memory/mem0/mem0.h

116 lines
3.3 KiB
C
Raw Normal View History

2025-03-23 14:35:54 +08:00
#ifndef HUMANUS_MEMORY_MEM0_H
#define HUMANUS_MEMORY_MEM0_H
#include "memory/base.h"
#include "storage.h"
#include "vector_store.h"
#include "prompt.h"
namespace humanus {
namespace mem0 {
struct Config {
// Prompt config
std::string fact_extraction_prompt;
std::string update_memory_prompt;
// Database config
// std::string history_db_path = ":memory:";
// Embedder config
EmbedderConfig embedder_config;
// Vector store config
VectorStoreConfig vector_store_config;
// Optional: LLM config
LLMConfig llm_config;
};
struct Memory : BaseMemory {
Config config;
std::string fact_extraction_prompt;
std::string update_memory_prompt;
std::shared_ptr<Embedder> embedder;
std::shared_ptr<VectorStore> vector_store;
std::shared_ptr<LLM> llm;
// std::shared_ptr<SQLiteManager> db;
Memory(const Config& config) : config(config) {
fact_extraction_prompt = config.fact_extraction_prompt;
update_memory_prompt = config.update_memory_prompt;
embedder = std::make_shared<Embedder>(config.embedder_config);
vector_store = std::make_shared<VectorStore>(config.vector_store_config);
llm = std::make_shared<LLM>(config.llm_config);
// db = std::make_shared<SQLiteManager>(config.history_db_path);
}
void add_message(const Message& message) override {
if (config.llm_config.enable_vision) {
message = parse_vision_messages(message, llm, config.llm_config.vision_details);
} else {
message = parse_vision_messages(message);
}
_add_to_vector_store(message);
}
void _add_to_vector_store(const Message& message) {
std::string parsed_message = parse_message(message);
std::string system_prompt;
std::string user_prompt = "Input:\n" + parsed_message;
if (!fact_extraction_prompt.empty()) {
system_prompt = fact_extraction_prompt;
} else {
system_prompt = FACT_EXTRACTION_PROMPT;
}
Message user_message = Message::user_message(user_prompt);
std::string response = llm->ask(
{user_message},
system_prompt
);
std::vector<json> new_retrieved_facts;
try {
response = remove_code_blocks(response);
new_retrieved_facts = json::parse(response)["facts"].get<std::vector<json>>();
} catch (const std::exception& e) {
LOG_ERROR("Error in new_retrieved_facts: " + std::string(e.what()));
}
std::vector<json> retrieved_old_memory;
std::map<std::string, std::vector<float>> new_message_embeddings;
for (const auto& fact : new_retrieved_facts) {
auto message_embedding = embedder->embed(fact);
new_message_embeddings[fact] = message_embedding;
auto existing_memories = vector_store->search(
message_embedding,
5,
filters
)
for (const auto& memory : existing_memories) {
retrieved_old_memory.push_back({
{"id", memory.id},
{"text", memory.payload["data"]}
});
}
}
}
};
} // namespace mem0
} // namespace humanus
#endif // HUMANUS_MEMORY_MEM0_H