125 lines
3.9 KiB
C++
125 lines
3.9 KiB
C++
#ifndef HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H
|
|
#define HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H
|
|
|
|
#include "hnswlib/hnswlib.h"
|
|
|
|
namespace humanus {
|
|
|
|
namespace mem0 {
|
|
|
|
struct HNSWLIBVectorStore {
|
|
VectorStoreConfig config;
|
|
std::shared_ptr<hnswlib::HierarchicalNSW<float>> hnsw;
|
|
|
|
HNSWLIBVectorStore(const VectorStoreConfig& config) : config(config) {
|
|
reset();
|
|
}
|
|
|
|
void reset() {
|
|
if (hnsw) {
|
|
hnsw.reset();
|
|
}
|
|
if (config.metric == Metric::L2) {
|
|
hnswlib::L2Space space(config.dim);
|
|
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>(&space, config.max_elements, config.M, config.ef_construction);
|
|
} else if (config.metric == Metric::IP) {
|
|
hnswlib::InnerProductSpace space(config.dim);
|
|
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>(&space, config.max_elements, config.M, config.ef_construction);
|
|
} else {
|
|
throw std::invalid_argument("Unsupported metric: " + std::to_string(static_cast<int>(config.metric)));
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief 插入向量到集合中
|
|
* @param vectors 向量数据
|
|
* @param payloads 可选的负载数据
|
|
* @param ids 可选的ID列表
|
|
* @return 插入的向量ID列表
|
|
*/
|
|
std::vector<size_t> insert(const std::vector<std::vector<float>>& vectors,
|
|
const std::vector<std::string>& payloads = {},
|
|
const std::vector<size_t>& ids = {}) {
|
|
std::vector<size_t> result_ids;
|
|
for (size_t i = 0; i < vectors.size(); i++) {
|
|
size_t id = ids.size() > i ? ids[i] : hnsw->cur_element_count;
|
|
hnsw->addPoint(vectors[i].data(), id);
|
|
result_ids.push_back(id);
|
|
}
|
|
return result_ids;
|
|
}
|
|
|
|
/**
|
|
* @brief 搜索相似向量
|
|
* @param query 查询向量
|
|
* @param limit 返回结果数量限制
|
|
* @param filters 可选的过滤条件
|
|
* @return 相似向量的ID和距离
|
|
*/
|
|
std::vector<std::pair<size_t, std::vector<float>>> search(const std::vector<float>& query,
|
|
int limit = 5,
|
|
const std::string& filters = "") {
|
|
return hnsw->searchKnn(query.data(), limit);
|
|
}
|
|
|
|
/**
|
|
* @brief 通过ID删除向量
|
|
* @param vector_id 向量ID
|
|
*/
|
|
void delete_vector(size_t vector_id) {
|
|
hnsw->markDelete(vector_id);
|
|
}
|
|
|
|
/**
|
|
* @brief 更新向量及其负载
|
|
* @param vector_id 向量ID
|
|
* @param vector 可选的新向量数据
|
|
* @param payload 可选的新负载数据
|
|
*/
|
|
void update(size_t vector_id,
|
|
const std::vector<float>* vector = nullptr,
|
|
const std::string* payload = nullptr) {
|
|
if (vector) {
|
|
hnsw->markDelete(vector_id);
|
|
hnsw->addPoint(vector->data(), vector_id);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief 通过ID获取向量
|
|
* @param vector_id 向量ID
|
|
* @return 向量数据
|
|
*/
|
|
std::vector<float> get(size_t vector_id) {
|
|
std::vector<float> result(config.dimension);
|
|
hnsw->getDataByLabel(vector_id, result.data());
|
|
return result;
|
|
}
|
|
|
|
/**
|
|
* @brief 列出所有记忆
|
|
* @param filters 可选的过滤条件
|
|
* @param limit 可选的结果数量限制
|
|
* @return 记忆ID列表
|
|
*/
|
|
std::vector<size_t> list(const std::string& filters = "", int limit = 0) {
|
|
std::vector<size_t> result;
|
|
size_t count = hnsw->cur_element_count;
|
|
for (size_t i = 0; i < count; i++) {
|
|
if (!hnsw->isMarkedDeleted(i)) {
|
|
result.push_back(i);
|
|
if (limit > 0 && result.size() >= static_cast<size_t>(limit)) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
};
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif // HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H
|