166 lines
5.4 KiB
C++
166 lines
5.4 KiB
C++
|
#include "hnswlib/hnswlib.h"
|
||
|
#include "hnswlib.h"
|
||
|
#include <map>
|
||
|
#include <chrono>
|
||
|
|
||
|
namespace humanus::mem0 {
|
||
|
|
||
|
void HNSWLibVectorStore::reset() {
|
||
|
if (hnsw) {
|
||
|
hnsw.reset();
|
||
|
}
|
||
|
if (space) {
|
||
|
space.reset();
|
||
|
}
|
||
|
|
||
|
metadata_store.clear();
|
||
|
|
||
|
if (config_->metric == VectorStoreConfig::Metric::L2) {
|
||
|
space = std::make_shared<hnswlib::L2Space>(config_->dim);
|
||
|
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>>(space.get(), config_->max_elements, config_->M, config_->ef_construction);
|
||
|
} else if (config_->metric == VectorStoreConfig::Metric::IP) {
|
||
|
space = std::make_shared<hnswlib::InnerProductSpace>(config_->dim);
|
||
|
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>>(space.get(), config_->max_elements, config_->M, config_->ef_construction);
|
||
|
} else {
|
||
|
throw std::invalid_argument("Unsupported metric: " + std::to_string(static_cast<int>(config_->metric)));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void HNSWLibVectorStore::insert(const std::vector<float>& vector,
|
||
|
const size_t vector_id,
|
||
|
const json& metadata) {
|
||
|
hnsw->addPoint(vector.data(), vector_id);
|
||
|
|
||
|
// 存储元数据
|
||
|
auto now = std::chrono::system_clock::now().time_since_epoch().count();
|
||
|
json _metadata = metadata;
|
||
|
if (!_metadata.contains("created_at")) {
|
||
|
_metadata["created_at"] = now;
|
||
|
}
|
||
|
if (!_metadata.contains("updated_at")) {
|
||
|
_metadata["updated_at"] = now;
|
||
|
}
|
||
|
|
||
|
metadata_store[vector_id] = _metadata;
|
||
|
}
|
||
|
|
||
|
std::vector<MemoryItem> HNSWLibVectorStore::search(const std::vector<float>& query,
|
||
|
int limit = 5,
|
||
|
const std::string& filters = "") {
|
||
|
auto results = hnsw->searchKnn(query.data(), limit);
|
||
|
std::vector<MemoryItem> memory_items;
|
||
|
|
||
|
while (!results.empty()) {
|
||
|
const auto& [id, distance] = results.top();
|
||
|
|
||
|
results.pop();
|
||
|
|
||
|
if (metadata_store.find(id) != metadata_store.end()) {
|
||
|
MemoryItem item;
|
||
|
item.id = id;
|
||
|
|
||
|
if (metadata_store[id].contains("data")) {
|
||
|
item.memory = metadata_store[id]["data"];
|
||
|
}
|
||
|
|
||
|
if (metadata_store[id].contains("hash")) {
|
||
|
item.hash = metadata_store[id]["hash"];
|
||
|
}
|
||
|
|
||
|
item.metadata = metadata_store[id];
|
||
|
item.score = distance;
|
||
|
|
||
|
memory_items.push_back(item);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return memory_items;
|
||
|
}
|
||
|
|
||
|
void HNSWLibVectorStore::delete_vector(size_t vector_id) {
|
||
|
hnsw->markDelete(vector_id);
|
||
|
metadata_store.erase(vector_id);
|
||
|
}
|
||
|
|
||
|
void HNSWLibVectorStore::update(size_t vector_id,
|
||
|
const std::vector<float> vector = std::vector<float>(),
|
||
|
const json& metadata = json::object()) {
|
||
|
// 检查向量是否需要更新
|
||
|
if (!vector.empty()) {
|
||
|
hnsw->markDelete(vector_id);
|
||
|
hnsw->addPoint(vector.data(), vector_id);
|
||
|
}
|
||
|
|
||
|
// 更新元数据
|
||
|
if (metadata_store.find(vector_id) != metadata_store.end()) {
|
||
|
auto now = std::chrono::system_clock::now().time_since_epoch().count();
|
||
|
|
||
|
// 合并现有元数据和新元数据
|
||
|
for (auto& [key, value] : metadata.items()) {
|
||
|
metadata_store[vector_id][key] = value;
|
||
|
}
|
||
|
|
||
|
// 更新时间戳
|
||
|
metadata_store[vector_id]["updated_at"] = now;
|
||
|
} else if (!metadata.empty()) {
|
||
|
// 如果元数据不存在但提供了新的元数据,则创建新条目
|
||
|
auto now = std::chrono::system_clock::now().time_since_epoch().count();
|
||
|
json new_metadata = metadata;
|
||
|
if (!new_metadata.contains("created_at")) {
|
||
|
new_metadata["created_at"] = now;
|
||
|
}
|
||
|
new_metadata["updated_at"] = now;
|
||
|
metadata_store[vector_id] = new_metadata;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
MemoryItem HNSWLibVectorStore::get(size_t vector_id) {
|
||
|
MemoryItem item;
|
||
|
item.id = vector_id;
|
||
|
|
||
|
// 获取向量数据
|
||
|
std::vector<float> vector_data = hnsw->getDataByLabel<float>(vector_id);
|
||
|
|
||
|
// 获取元数据
|
||
|
if (metadata_store.find(vector_id) != metadata_store.end()) {
|
||
|
if (metadata_store[vector_id].contains("data")) {
|
||
|
item.memory = metadata_store[vector_id]["data"];
|
||
|
}
|
||
|
|
||
|
if (metadata_store[vector_id].contains("hash")) {
|
||
|
item.hash = metadata_store[vector_id]["hash"];
|
||
|
}
|
||
|
|
||
|
item.metadata = metadata_store[vector_id];
|
||
|
}
|
||
|
|
||
|
return item;
|
||
|
}
|
||
|
|
||
|
std::vector<MemoryItem> HNSWLibVectorStore::list(const std::string& filters = "", int limit = 0) {
|
||
|
std::vector<MemoryItem> result;
|
||
|
size_t count = hnsw->cur_element_count;
|
||
|
|
||
|
for (size_t i = 0; i < count; i++) {
|
||
|
if (!hnsw->isMarkedDeleted(i)) {
|
||
|
// 如果有过滤条件,检查元数据是否匹配
|
||
|
if (!filters.empty() && metadata_store.find(i) != metadata_store.end()) {
|
||
|
// 简单的字符串匹配过滤,可以根据需要扩展
|
||
|
json metadata_json = metadata_store[i];
|
||
|
std::string metadata_str = metadata_json.dump();
|
||
|
if (metadata_str.find(filters) == std::string::npos) {
|
||
|
continue;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
result.emplace_back(get(i));
|
||
|
if (limit > 0 && result.size() >= static_cast<size_t>(limit)) {
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return result;
|
||
|
}
|
||
|
};
|