#include "hnswlib/hnswlib.h" #include "hnswlib.h" #include #include namespace humanus::mem0 { void HNSWLibVectorStore::reset() { if (hnsw) { hnsw.reset(); } if (space) { space.reset(); } metadata_store.clear(); if (config_->metric == VectorStoreConfig::Metric::L2) { space = std::make_shared(config_->dim); hnsw = std::make_shared>(space.get(), config_->max_elements, config_->M, config_->ef_construction); } else if (config_->metric == VectorStoreConfig::Metric::IP) { space = std::make_shared(config_->dim); hnsw = std::make_shared>(space.get(), config_->max_elements, config_->M, config_->ef_construction); } else { throw std::invalid_argument("Unsupported metric: " + std::to_string(static_cast(config_->metric))); } } void HNSWLibVectorStore::insert(const std::vector& vector, const size_t vector_id, const json& metadata) { hnsw->addPoint(vector.data(), vector_id); // 存储元数据 auto now = std::chrono::system_clock::now().time_since_epoch().count(); json _metadata = metadata; if (!_metadata.contains("created_at")) { _metadata["created_at"] = now; } if (!_metadata.contains("updated_at")) { _metadata["updated_at"] = now; } metadata_store[vector_id] = _metadata; } std::vector HNSWLibVectorStore::search(const std::vector& query, int limit = 5, const std::string& filters = "") { auto results = hnsw->searchKnn(query.data(), limit); std::vector memory_items; while (!results.empty()) { const auto& [id, distance] = results.top(); results.pop(); if (metadata_store.find(id) != metadata_store.end()) { MemoryItem item; item.id = id; if (metadata_store[id].contains("data")) { item.memory = metadata_store[id]["data"]; } if (metadata_store[id].contains("hash")) { item.hash = metadata_store[id]["hash"]; } item.metadata = metadata_store[id]; item.score = distance; memory_items.push_back(item); } } return memory_items; } void HNSWLibVectorStore::delete_vector(size_t vector_id) { hnsw->markDelete(vector_id); metadata_store.erase(vector_id); } void HNSWLibVectorStore::update(size_t vector_id, const std::vector vector = std::vector(), const json& metadata = json::object()) { // 检查向量是否需要更新 if (!vector.empty()) { hnsw->markDelete(vector_id); hnsw->addPoint(vector.data(), vector_id); } // 更新元数据 if (metadata_store.find(vector_id) != metadata_store.end()) { auto now = std::chrono::system_clock::now().time_since_epoch().count(); // 合并现有元数据和新元数据 for (auto& [key, value] : metadata.items()) { metadata_store[vector_id][key] = value; } // 更新时间戳 metadata_store[vector_id]["updated_at"] = now; } else if (!metadata.empty()) { // 如果元数据不存在但提供了新的元数据,则创建新条目 auto now = std::chrono::system_clock::now().time_since_epoch().count(); json new_metadata = metadata; if (!new_metadata.contains("created_at")) { new_metadata["created_at"] = now; } new_metadata["updated_at"] = now; metadata_store[vector_id] = new_metadata; } } MemoryItem HNSWLibVectorStore::get(size_t vector_id) { MemoryItem item; item.id = vector_id; // 获取向量数据 std::vector vector_data = hnsw->getDataByLabel(vector_id); // 获取元数据 if (metadata_store.find(vector_id) != metadata_store.end()) { if (metadata_store[vector_id].contains("data")) { item.memory = metadata_store[vector_id]["data"]; } if (metadata_store[vector_id].contains("hash")) { item.hash = metadata_store[vector_id]["hash"]; } item.metadata = metadata_store[vector_id]; } return item; } std::vector HNSWLibVectorStore::list(const std::string& filters = "", int limit = 0) { std::vector result; size_t count = hnsw->cur_element_count; for (size_t i = 0; i < count; i++) { if (!hnsw->isMarkedDeleted(i)) { // 如果有过滤条件,检查元数据是否匹配 if (!filters.empty() && metadata_store.find(i) != metadata_store.end()) { // 简单的字符串匹配过滤,可以根据需要扩展 json metadata_json = metadata_store[i]; std::string metadata_str = metadata_json.dump(); if (metadata_str.find(filters) == std::string::npos) { continue; } } result.emplace_back(get(i)); if (limit > 0 && result.size() >= static_cast(limit)) { break; } } } return result; } };