humanus.cpp/memory/mem0/vector_store/hnswlib.cpp

122 lines
4.0 KiB
C++
Raw Normal View History

#include "hnswlib/hnswlib.h"
#include "hnswlib.h"
#include <map>
#include <chrono>
namespace humanus::mem0 {
void HNSWLibVectorStore::reset() {
if (hnsw) {
hnsw.reset();
}
if (space) {
space.reset();
}
metadata_store.clear();
if (config_->metric == VectorStoreConfig::Metric::L2) {
space = std::make_shared<hnswlib::L2Space>(config_->dim);
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>>(space.get(), config_->max_elements, config_->M, config_->ef_construction);
} else if (config_->metric == VectorStoreConfig::Metric::IP) {
space = std::make_shared<hnswlib::InnerProductSpace>(config_->dim);
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>>(space.get(), config_->max_elements, config_->M, config_->ef_construction);
} else {
2025-03-26 19:28:02 +08:00
throw std::invalid_argument("Unsupported metric: " + std::to_string(static_cast<size_t>(config_->metric)));
}
}
2025-03-26 19:28:02 +08:00
void HNSWLibVectorStore::insert(const std::vector<float>& vector, const size_t vector_id, const MemoryItem& metadata) {
hnsw->addPoint(vector.data(), vector_id);
// 存储元数据
auto now = std::chrono::system_clock::now().time_since_epoch().count();
2025-03-26 19:28:02 +08:00
MemoryItem _metadata = metadata;
if (_metadata.created_at < 0) {
_metadata.created_at = now;
}
2025-03-26 19:28:02 +08:00
if (_metadata.updated_at < 0) {
_metadata.updated_at = now;
}
metadata_store[vector_id] = _metadata;
}
2025-03-26 19:28:02 +08:00
std::vector<MemoryItem> HNSWLibVectorStore::search(const std::vector<float>& query, size_t limit, const FilterFunc& filter) {
auto filte_wrapper = filter ? std::make_unique<HNSWLibFilterFunctorWrapper>(*this, filter) : nullptr;
auto results = hnsw->searchKnn(query.data(), limit, filte_wrapper.get());
std::vector<MemoryItem> memory_items;
while (!results.empty()) {
2025-03-26 19:28:02 +08:00
const auto& [distance, id] = results.top();
results.pop();
if (metadata_store.find(id) != metadata_store.end()) {
2025-03-26 19:28:02 +08:00
MemoryItem item = metadata_store[id];
item.score = distance;
memory_items.push_back(item);
}
}
return memory_items;
}
void HNSWLibVectorStore::delete_vector(size_t vector_id) {
hnsw->markDelete(vector_id);
metadata_store.erase(vector_id);
}
2025-03-26 19:28:02 +08:00
void HNSWLibVectorStore::update(size_t vector_id, const std::vector<float>& vector, const MemoryItem& metadata) {
// 检查向量是否需要更新
if (!vector.empty()) {
hnsw->markDelete(vector_id);
hnsw->addPoint(vector.data(), vector_id);
}
2025-03-26 19:28:02 +08:00
if (!metadata.empty()) {
MemoryItem new_metadata = metadata;
new_metadata.id = vector_id; // Make sure the id is the same as the vector id
auto now = std::chrono::system_clock::now().time_since_epoch().count();
2025-03-26 19:28:02 +08:00
if (metadata_store.find(vector_id) != metadata_store.end()) {
MemoryItem old_metadata = metadata_store[vector_id];
if (new_metadata.hash == old_metadata.hash) {
new_metadata.created_at = old_metadata.created_at;
} else {
new_metadata.created_at = now;
}
}
2025-03-26 19:28:02 +08:00
if (new_metadata.created_at < 0) {
new_metadata.created_at = now;
}
2025-03-26 19:28:02 +08:00
new_metadata.updated_at = now;
metadata_store[vector_id] = new_metadata;
}
}
MemoryItem HNSWLibVectorStore::get(size_t vector_id) {
2025-03-26 19:28:02 +08:00
return metadata_store.at(vector_id);
}
2025-03-26 19:28:02 +08:00
std::vector<MemoryItem> HNSWLibVectorStore::list(size_t limit, const FilterFunc& filter) {
std::vector<MemoryItem> result;
size_t count = hnsw->cur_element_count;
for (size_t i = 0; i < count; i++) {
if (!hnsw->isMarkedDeleted(i)) {
// 如果有过滤条件,检查元数据是否匹配
2025-03-26 19:28:02 +08:00
auto memory_item = get(i);
if (filter && !filter(memory_item)) {
continue;
}
2025-03-26 19:28:02 +08:00
result.emplace_back(memory_item);
if (limit > 0 && result.size() >= limit) {
break;
}
}
}
return result;
}
};