humanus.cpp/memory/mem0/vector_store/hnswlib.h

125 lines
3.9 KiB
C
Raw Normal View History

2025-03-23 14:35:54 +08:00
#ifndef HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H
#define HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H
#include "hnswlib/hnswlib.h"
namespace humanus {
namespace mem0 {
struct HNSWLIBVectorStore {
VectorStoreConfig config;
std::shared_ptr<hnswlib::HierarchicalNSW<float>> hnsw;
HNSWLIBVectorStore(const VectorStoreConfig& config) : config(config) {
reset();
}
void reset() {
if (hnsw) {
hnsw.reset();
}
if (config.metric == Metric::L2) {
hnswlib::L2Space space(config.dim);
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>(&space, config.max_elements, config.M, config.ef_construction);
} else if (config.metric == Metric::IP) {
hnswlib::InnerProductSpace space(config.dim);
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>(&space, config.max_elements, config.M, config.ef_construction);
} else {
throw std::invalid_argument("Unsupported metric: " + std::to_string(static_cast<int>(config.metric)));
}
}
/**
* @brief
* @param vectors
* @param payloads
* @param ids ID
* @return ID
*/
std::vector<size_t> insert(const std::vector<std::vector<float>>& vectors,
const std::vector<std::string>& payloads = {},
const std::vector<size_t>& ids = {}) {
std::vector<size_t> result_ids;
for (size_t i = 0; i < vectors.size(); i++) {
size_t id = ids.size() > i ? ids[i] : hnsw->cur_element_count;
hnsw->addPoint(vectors[i].data(), id);
result_ids.push_back(id);
}
return result_ids;
}
/**
* @brief
* @param query
* @param limit
* @param filters
* @return ID
*/
std::vector<std::pair<size_t, std::vector<float>>> search(const std::vector<float>& query,
int limit = 5,
const std::string& filters = "") {
return hnsw->searchKnn(query.data(), limit);
}
/**
* @brief ID
* @param vector_id ID
*/
void delete_vector(size_t vector_id) {
hnsw->markDelete(vector_id);
}
/**
* @brief
* @param vector_id ID
* @param vector
* @param payload
*/
void update(size_t vector_id,
const std::vector<float>* vector = nullptr,
const std::string* payload = nullptr) {
if (vector) {
hnsw->markDelete(vector_id);
hnsw->addPoint(vector->data(), vector_id);
}
}
/**
* @brief ID
* @param vector_id ID
* @return
*/
std::vector<float> get(size_t vector_id) {
std::vector<float> result(config.dimension);
hnsw->getDataByLabel(vector_id, result.data());
return result;
}
/**
* @brief
* @param filters
* @param limit
* @return ID
*/
std::vector<size_t> list(const std::string& filters = "", int limit = 0) {
std::vector<size_t> result;
size_t count = hnsw->cur_element_count;
for (size_t i = 0; i < count; i++) {
if (!hnsw->isMarkedDeleted(i)) {
result.push_back(i);
if (limit > 0 && result.size() >= static_cast<size_t>(limit)) {
break;
}
}
}
return result;
}
};
}
}
#endif // HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H