116 lines
3.3 KiB
C
116 lines
3.3 KiB
C
|
#ifndef HUMANUS_MEMORY_MEM0_H
|
||
|
#define HUMANUS_MEMORY_MEM0_H
|
||
|
|
||
|
#include "memory/base.h"
|
||
|
#include "storage.h"
|
||
|
#include "vector_store.h"
|
||
|
#include "prompt.h"
|
||
|
|
||
|
namespace humanus {
|
||
|
|
||
|
namespace mem0 {
|
||
|
|
||
|
struct Config {
|
||
|
// Prompt config
|
||
|
std::string fact_extraction_prompt;
|
||
|
std::string update_memory_prompt;
|
||
|
|
||
|
// Database config
|
||
|
// std::string history_db_path = ":memory:";
|
||
|
|
||
|
// Embedder config
|
||
|
EmbedderConfig embedder_config;
|
||
|
|
||
|
// Vector store config
|
||
|
VectorStoreConfig vector_store_config;
|
||
|
|
||
|
// Optional: LLM config
|
||
|
LLMConfig llm_config;
|
||
|
};
|
||
|
|
||
|
struct Memory : BaseMemory {
|
||
|
Config config;
|
||
|
std::string fact_extraction_prompt;
|
||
|
std::string update_memory_prompt;
|
||
|
|
||
|
std::shared_ptr<Embedder> embedder;
|
||
|
std::shared_ptr<VectorStore> vector_store;
|
||
|
std::shared_ptr<LLM> llm;
|
||
|
// std::shared_ptr<SQLiteManager> db;
|
||
|
|
||
|
Memory(const Config& config) : config(config) {
|
||
|
fact_extraction_prompt = config.fact_extraction_prompt;
|
||
|
update_memory_prompt = config.update_memory_prompt;
|
||
|
|
||
|
embedder = std::make_shared<Embedder>(config.embedder_config);
|
||
|
vector_store = std::make_shared<VectorStore>(config.vector_store_config);
|
||
|
llm = std::make_shared<LLM>(config.llm_config);
|
||
|
// db = std::make_shared<SQLiteManager>(config.history_db_path);
|
||
|
}
|
||
|
|
||
|
void add_message(const Message& message) override {
|
||
|
if (config.llm_config.enable_vision) {
|
||
|
message = parse_vision_messages(message, llm, config.llm_config.vision_details);
|
||
|
} else {
|
||
|
message = parse_vision_messages(message);
|
||
|
}
|
||
|
|
||
|
_add_to_vector_store(message);
|
||
|
}
|
||
|
|
||
|
void _add_to_vector_store(const Message& message) {
|
||
|
std::string parsed_message = parse_message(message);
|
||
|
|
||
|
std::string system_prompt;
|
||
|
std::string user_prompt = "Input:\n" + parsed_message;
|
||
|
|
||
|
if (!fact_extraction_prompt.empty()) {
|
||
|
system_prompt = fact_extraction_prompt;
|
||
|
} else {
|
||
|
system_prompt = FACT_EXTRACTION_PROMPT;
|
||
|
}
|
||
|
|
||
|
Message user_message = Message::user_message(user_prompt);
|
||
|
|
||
|
std::string response = llm->ask(
|
||
|
{user_message},
|
||
|
system_prompt
|
||
|
);
|
||
|
|
||
|
std::vector<json> new_retrieved_facts;
|
||
|
|
||
|
try {
|
||
|
response = remove_code_blocks(response);
|
||
|
new_retrieved_facts = json::parse(response)["facts"].get<std::vector<json>>();
|
||
|
} catch (const std::exception& e) {
|
||
|
LOG_ERROR("Error in new_retrieved_facts: " + std::string(e.what()));
|
||
|
}
|
||
|
|
||
|
std::vector<json> retrieved_old_memory;
|
||
|
std::map<std::string, std::vector<float>> new_message_embeddings;
|
||
|
|
||
|
for (const auto& fact : new_retrieved_facts) {
|
||
|
auto message_embedding = embedder->embed(fact);
|
||
|
new_message_embeddings[fact] = message_embedding;
|
||
|
auto existing_memories = vector_store->search(
|
||
|
message_embedding,
|
||
|
5,
|
||
|
filters
|
||
|
)
|
||
|
for (const auto& memory : existing_memories) {
|
||
|
retrieved_old_memory.push_back({
|
||
|
{"id", memory.id},
|
||
|
{"text", memory.payload["data"]}
|
||
|
});
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
}
|
||
|
};
|
||
|
|
||
|
} // namespace mem0
|
||
|
|
||
|
} // namespace humanus
|
||
|
|
||
|
#endif // HUMANUS_MEMORY_MEM0_H
|