humanus.cpp/memory/mem0/vector_store/hnswlib.cpp

166 lines
5.4 KiB
C++
Raw Normal View History

#include "hnswlib/hnswlib.h"
#include "hnswlib.h"
#include <map>
#include <chrono>
namespace humanus::mem0 {
void HNSWLibVectorStore::reset() {
if (hnsw) {
hnsw.reset();
}
if (space) {
space.reset();
}
metadata_store.clear();
if (config_->metric == VectorStoreConfig::Metric::L2) {
space = std::make_shared<hnswlib::L2Space>(config_->dim);
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>>(space.get(), config_->max_elements, config_->M, config_->ef_construction);
} else if (config_->metric == VectorStoreConfig::Metric::IP) {
space = std::make_shared<hnswlib::InnerProductSpace>(config_->dim);
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>>(space.get(), config_->max_elements, config_->M, config_->ef_construction);
} else {
throw std::invalid_argument("Unsupported metric: " + std::to_string(static_cast<int>(config_->metric)));
}
}
void HNSWLibVectorStore::insert(const std::vector<float>& vector,
const size_t vector_id,
const json& metadata) {
hnsw->addPoint(vector.data(), vector_id);
// 存储元数据
auto now = std::chrono::system_clock::now().time_since_epoch().count();
json _metadata = metadata;
if (!_metadata.contains("created_at")) {
_metadata["created_at"] = now;
}
if (!_metadata.contains("updated_at")) {
_metadata["updated_at"] = now;
}
metadata_store[vector_id] = _metadata;
}
std::vector<MemoryItem> HNSWLibVectorStore::search(const std::vector<float>& query,
int limit = 5,
const std::string& filters = "") {
auto results = hnsw->searchKnn(query.data(), limit);
std::vector<MemoryItem> memory_items;
while (!results.empty()) {
const auto& [id, distance] = results.top();
results.pop();
if (metadata_store.find(id) != metadata_store.end()) {
MemoryItem item;
item.id = id;
if (metadata_store[id].contains("data")) {
item.memory = metadata_store[id]["data"];
}
if (metadata_store[id].contains("hash")) {
item.hash = metadata_store[id]["hash"];
}
item.metadata = metadata_store[id];
item.score = distance;
memory_items.push_back(item);
}
}
return memory_items;
}
void HNSWLibVectorStore::delete_vector(size_t vector_id) {
hnsw->markDelete(vector_id);
metadata_store.erase(vector_id);
}
void HNSWLibVectorStore::update(size_t vector_id,
const std::vector<float> vector = std::vector<float>(),
const json& metadata = json::object()) {
// 检查向量是否需要更新
if (!vector.empty()) {
hnsw->markDelete(vector_id);
hnsw->addPoint(vector.data(), vector_id);
}
// 更新元数据
if (metadata_store.find(vector_id) != metadata_store.end()) {
auto now = std::chrono::system_clock::now().time_since_epoch().count();
// 合并现有元数据和新元数据
for (auto& [key, value] : metadata.items()) {
metadata_store[vector_id][key] = value;
}
// 更新时间戳
metadata_store[vector_id]["updated_at"] = now;
} else if (!metadata.empty()) {
// 如果元数据不存在但提供了新的元数据,则创建新条目
auto now = std::chrono::system_clock::now().time_since_epoch().count();
json new_metadata = metadata;
if (!new_metadata.contains("created_at")) {
new_metadata["created_at"] = now;
}
new_metadata["updated_at"] = now;
metadata_store[vector_id] = new_metadata;
}
}
MemoryItem HNSWLibVectorStore::get(size_t vector_id) {
MemoryItem item;
item.id = vector_id;
// 获取向量数据
std::vector<float> vector_data = hnsw->getDataByLabel<float>(vector_id);
// 获取元数据
if (metadata_store.find(vector_id) != metadata_store.end()) {
if (metadata_store[vector_id].contains("data")) {
item.memory = metadata_store[vector_id]["data"];
}
if (metadata_store[vector_id].contains("hash")) {
item.hash = metadata_store[vector_id]["hash"];
}
item.metadata = metadata_store[vector_id];
}
return item;
}
std::vector<MemoryItem> HNSWLibVectorStore::list(const std::string& filters = "", int limit = 0) {
std::vector<MemoryItem> result;
size_t count = hnsw->cur_element_count;
for (size_t i = 0; i < count; i++) {
if (!hnsw->isMarkedDeleted(i)) {
// 如果有过滤条件,检查元数据是否匹配
if (!filters.empty() && metadata_store.find(i) != metadata_store.end()) {
// 简单的字符串匹配过滤,可以根据需要扩展
json metadata_json = metadata_store[i];
std::string metadata_str = metadata_json.dump();
if (metadata_str.find(filters) == std::string::npos) {
continue;
}
}
result.emplace_back(get(i));
if (limit > 0 && result.size() >= static_cast<size_t>(limit)) {
break;
}
}
}
return result;
}
};