humanus.cpp/config.cpp

221 lines
9.2 KiB
C++
Raw Normal View History

2025-03-16 17:17:01 +08:00
#include "config.h"
#include "logger.h"
#include "toml.hpp"
#include <iostream>
#include <filesystem>
namespace humanus {
2025-03-17 16:35:11 +08:00
// Initialize static members
2025-03-16 17:17:01 +08:00
Config* Config::_instance = nullptr;
std::mutex Config::_mutex;
void Config::_load_initial_llm_config() {
2025-03-16 17:17:01 +08:00
try {
auto config_path = _get_llm_config_path();
std::cout << "Loading LLM config file from: " << config_path.string() << std::endl;
2025-03-16 17:17:01 +08:00
2025-03-16 22:56:03 +08:00
const auto& data = toml::parse_file(config_path.string());
2025-03-19 18:44:54 +08:00
// Load LLM configuration
2025-03-20 16:03:26 +08:00
for (const auto& [key, value] : data) {
const auto& llm_table = *value.as_table();
2025-03-19 18:44:54 +08:00
2025-03-20 16:03:26 +08:00
LLMConfig llm_config;
2025-03-16 22:56:03 +08:00
2025-03-20 16:03:26 +08:00
if (llm_table.contains("model") && llm_table["model"].is_string()) {
llm_config.model = llm_table["model"].as_string()->get();
}
2025-03-19 18:44:54 +08:00
2025-03-20 16:03:26 +08:00
if (llm_table.contains("api_key") && llm_table["api_key"].is_string()) {
llm_config.api_key = llm_table["api_key"].as_string()->get();
}
2025-03-16 22:56:03 +08:00
2025-03-20 16:03:26 +08:00
if (llm_table.contains("base_url") && llm_table["base_url"].is_string()) {
llm_config.base_url = llm_table["base_url"].as_string()->get();
}
2025-03-19 18:44:54 +08:00
if (llm_table.contains("endpoint") && llm_table["endpoint"].is_string()) {
llm_config.endpoint = llm_table["endpoint"].as_string()->get();
}
if (llm_table.contains("vision_details") && llm_table["vision_details"].is_string()) {
llm_config.vision_details = llm_table["vision_details"].as_string()->get();
2025-03-20 16:03:26 +08:00
}
if (llm_table.contains("max_tokens") && llm_table["max_tokens"].is_integer()) {
llm_config.max_tokens = llm_table["max_tokens"].as_integer()->get();
}
2025-03-19 18:44:54 +08:00
2025-03-20 16:03:26 +08:00
if (llm_table.contains("timeout") && llm_table["timeout"].is_integer()) {
llm_config.timeout = llm_table["timeout"].as_integer()->get();
}
2025-03-19 18:44:54 +08:00
2025-03-20 16:03:26 +08:00
if (llm_table.contains("temperature") && llm_table["temperature"].is_floating_point()) {
llm_config.temperature = llm_table["temperature"].as_floating_point()->get();
2025-03-19 18:44:54 +08:00
}
if (llm_table.contains("enable_vision") && llm_table["enable_vision"].is_boolean()) {
llm_config.enable_vision = llm_table["enable_vision"].as_boolean()->get();
}
2025-03-20 16:03:26 +08:00
if (llm_table.contains("oai_tool_support") && llm_table["oai_tool_support"].is_boolean()) {
llm_config.oai_tool_support = llm_table["oai_tool_support"].as_boolean()->get();
2025-03-19 18:44:54 +08:00
}
2025-03-20 16:03:26 +08:00
_config.llm[std::string(key.str())] = llm_config;
if (!llm_config.oai_tool_support) {
// Load tool helper configuration
2025-03-23 14:35:54 +08:00
ToolParser tool_parser;
if (llm_table.contains("tool_parser") && llm_table["tool_parser"].is_table()) {
const auto& tool_parser_table = *llm_table["tool_parser"].as_table();
if (tool_parser_table.contains("tool_start")) {
tool_parser.tool_start = tool_parser_table["tool_start"].as_string()->get();
2025-03-20 16:03:26 +08:00
}
2025-03-23 14:35:54 +08:00
if (tool_parser_table.contains("tool_end")) {
tool_parser.tool_end = tool_parser_table["tool_end"].as_string()->get();
2025-03-20 16:03:26 +08:00
}
2025-03-23 14:35:54 +08:00
if (tool_parser_table.contains("tool_hint_template")) {
tool_parser.tool_hint_template = tool_parser_table["tool_hint_template"].as_string()->get();
2025-03-20 16:03:26 +08:00
}
}
2025-03-23 14:35:54 +08:00
_config.tool_parser[std::string(key.str())] = tool_parser;
2025-03-19 18:44:54 +08:00
}
}
2025-03-20 16:03:26 +08:00
if (_config.llm.empty()) {
throw std::runtime_error("No LLM configuration found");
} else if (_config.llm.find("default") == _config.llm.end()) {
_config.llm["default"] = _config.llm.begin()->second;
}
2025-03-23 14:35:54 +08:00
if (_config.tool_parser.find("default") == _config.tool_parser.end()) {
_config.tool_parser["default"] = ToolParser();
2025-03-20 16:03:26 +08:00
}
2025-03-16 17:17:01 +08:00
} catch (const std::exception& e) {
2025-03-19 18:44:54 +08:00
std::cerr << "Loading config file failed: " << e.what() << std::endl;
2025-03-17 16:35:11 +08:00
// Set default configuration
2025-03-20 16:03:26 +08:00
_config.llm["default"] = LLMConfig();
2025-03-23 14:35:54 +08:00
_config.tool_parser["default"] = ToolParser();
2025-03-16 17:17:01 +08:00
}
}
void Config::_load_initial_embedding_model_config() {
try {
auto config_path = _get_embedding_model_config_path();
std::cout << "Loading embedding model config file from: " << config_path.string() << std::endl;
const auto& data = toml::parse_file(config_path.string());
// Load embedding model configuration
for (const auto& [key, value] : data) {
const auto& embd_table = *value.as_table();
EmbeddingModelConfig embd_config;
if (embd_table.contains("provider") && embd_table["provider"].is_string()) {
embd_config.provider = embd_table["provider"].as_string()->get();
}
if (embd_table.contains("base_url") && embd_table["base_url"].is_string()) {
embd_config.base_url = embd_table["base_url"].as_string()->get();
}
if (embd_table.contains("endpoint") && embd_table["endpoint"].is_string()) {
embd_config.endpoint = embd_table["endpoint"].as_string()->get();
}
if (embd_table.contains("model") && embd_table["model"].is_string()) {
embd_config.model = embd_table["model"].as_string()->get();
}
if (embd_table.contains("api_key") && embd_table["api_key"].is_string()) {
embd_config.api_key = embd_table["api_key"].as_string()->get();
}
if (embd_table.contains("embedding_dims") && embd_table["embedding_dims"].is_integer()) {
embd_config.embedding_dims = embd_table["embedding_dims"].as_integer()->get();
}
if (embd_table.contains("max_retries") && embd_table["max_retries"].is_integer()) {
embd_config.max_retries = embd_table["max_retries"].as_integer()->get();
}
_config.embedding_model[std::string(key.str())] = embd_config;
}
if (_config.embedding_model.empty()) {
throw std::runtime_error("No embedding model configuration found");
} else if (_config.embedding_model.find("default") == _config.embedding_model.end()) {
_config.embedding_model["default"] = _config.embedding_model.begin()->second;
}
} catch (const std::exception& e) {
std::cerr << "Loading embedding model config file failed: " << e.what() << std::endl;
// Set default configuration
_config.embedding_model["default"] = EmbeddingModelConfig();
}
}
void Config::_load_initial_vector_store_config() {
try {
auto config_path = _get_vector_store_config_path();
std::cout << "Loading vector store config file from: " << config_path.string() << std::endl;
const auto& data = toml::parse_file(config_path.string());
// Load vector store configuration
for (const auto& [key, value] : data) {
const auto& vs_table = *value.as_table();
VectorStoreConfig vs_config;
if (vs_table.contains("provider") && vs_table["provider"].is_string()) {
vs_config.provider = vs_table["provider"].as_string()->get();
}
if (vs_table.contains("dim") && vs_table["dim"].is_integer()) {
vs_config.dim = vs_table["dim"].as_integer()->get();
}
if (vs_table.contains("max_elements") && vs_table["max_elements"].is_integer()) {
vs_config.max_elements = vs_table["max_elements"].as_integer()->get();
}
if (vs_table.contains("M") && vs_table["M"].is_integer()) {
vs_config.M = vs_table["M"].as_integer()->get();
}
if (vs_table.contains("ef_construction") && vs_table["ef_construction"].is_integer()) {
vs_config.ef_construction = vs_table["ef_construction"].as_integer()->get();
}
if (vs_table.contains("metric") && vs_table["metric"].is_string()) {
const auto& metric_str = vs_table["metric"].as_string()->get();
if (metric_str == "L2") {
vs_config.metric = VectorStoreConfig::Metric::L2;
} else if (metric_str == "IP") {
vs_config.metric = VectorStoreConfig::Metric::IP;
} else {
throw std::runtime_error("Invalid metric: " + metric_str);
}
}
_config.vector_store[std::string(key.str())] = vs_config;
}
if (_config.vector_store.empty()) {
throw std::runtime_error("No vector store configuration found");
} else if (_config.vector_store.find("default") == _config.vector_store.end()) {
_config.vector_store["default"] = _config.vector_store.begin()->second;
}
} catch (const std::exception& e) {
std::cerr << "Loading vector store config file failed: " << e.what() << std::endl;
// Set default configuration
_config.vector_store["default"] = VectorStoreConfig();
}
}
2025-03-16 17:17:01 +08:00
} // namespace humanus