29 lines
1.2 KiB
C++
29 lines
1.2 KiB
C++
#include "base.h"
|
|
#include "oai.h"
|
|
|
|
namespace humanus {
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<EmbeddingModel>> EmbeddingModel::instances_;
|
|
|
|
std::shared_ptr<EmbeddingModel> EmbeddingModel::get_instance(const std::string& config_name, const std::shared_ptr<EmbeddingModelConfig>& config) {
|
|
if (instances_.find(config_name) == instances_.end()) {
|
|
auto config_ = config;
|
|
if (!config_) {
|
|
if (Config::get_instance().embedding_model().find(config_name) == Config::get_instance().embedding_model().end()) {
|
|
logger->warn("Embedding model config not found: " + config_name + ", falling back to default config");
|
|
config_ = std::make_shared<EmbeddingModelConfig>(Config::get_instance().embedding_model().at("default"));
|
|
} else {
|
|
config_ = std::make_shared<EmbeddingModelConfig>(Config::get_instance().embedding_model().at(config_name));
|
|
}
|
|
}
|
|
|
|
if (config_->provider == "oai") {
|
|
instances_[config_name] = std::make_shared<OAIEmbeddingModel>(config_);
|
|
} else {
|
|
throw std::invalid_argument("Unsupported embedding model provider: " + config_->provider);
|
|
}
|
|
}
|
|
return instances_[config_name];
|
|
}
|
|
|
|
} // namespace humanus
|