#ifndef HUMANUS_LLM_H #define HUMANUS_LLM_H #include "config.h" #include "logger.h" #include "schema.h" #include "mcp/common/httplib.h" #include #include #include #include #include #include #include namespace humanus { class LLM { private: static std::map> _instances; std::unique_ptr client_; std::shared_ptr llm_config_; public: // Constructor LLM(const std::string& config_name, const std::shared_ptr& llm_config = nullptr) : llm_config_(llm_config) { if (!llm_config_) { if (Config::get_instance().llm().find(config_name) == Config::get_instance().llm().end()) { throw std::invalid_argument("Config not found: " + config_name); } llm_config_ = std::make_shared(Config::get_instance().llm().at(config_name)); } client_ = std::make_unique(llm_config_->base_url); client_->set_default_headers({ {"Authorization", "Bearer " + llm_config_->api_key} }); } // Get the singleton instance static std::shared_ptr get_instance(const std::string& config_name = "default", const std::shared_ptr& llm_config = nullptr) { if (_instances.find(config_name) == _instances.end()) { _instances[config_name] = std::make_shared(config_name, llm_config); } return _instances[config_name]; } /** * @brief Format the message list to the format that LLM can accept * @param messages Message object message list * @return The formatted message list * @throws std::invalid_argument If the message format is invalid or missing necessary fields * @throws std::runtime_error If the message type is not supported */ static std::vector format_messages(const std::vector& messages) { std::vector formatted_messages; for (const auto& message : messages) { formatted_messages.push_back(message.to_json()); } for (const auto& message : formatted_messages) { if (message["role"] != "user" && message["role"] != "assistant" && message["role"] != "system" && message["role"] != "tool") { throw std::invalid_argument("Invalid role: " + message["role"].get()); } if (message["content"].empty() && message["tool_calls"].empty()) { throw std::invalid_argument("Message must contain either 'content' or 'tool_calls'"); } } return formatted_messages; } /** * @brief Format the message list to the format that LLM can accept * @param messages json object message list * @return The formatted message list * @throws std::invalid_argument If the message format is invalid or missing necessary fields * @throws std::runtime_error If the message type is not supported */ static std::vector format_messages(const std::vector& messages) { std::vector formatted_messages; for (const auto& message : messages) { if (!message.contains("role")) { throw std::invalid_argument("Message missing necessary field: role"); } formatted_messages.push_back(message); } for (const auto& message : formatted_messages) { if (message["role"] != "user" && message["role"] != "assistant" && message["role"] != "system" && message["role"] != "tool") { throw std::invalid_argument("Invalid role: " + message["role"].get()); } if (message["content"].empty() && message["tool_calls"].empty()) { throw std::invalid_argument("Message must contain either 'content' or 'tool_calls'"); } } return formatted_messages; } /** * @brief Send a request to the LLM and get the reply * @param messages The conversation message list * @param system_msgs Optional system messages * @param max_retries The maximum number of retries * @return The generated assistant content * @throws std::invalid_argument If the message is invalid or the reply is empty * @throws std::runtime_error If the API call fails */ std::string ask( const std::vector& messages, const std::vector& system_msgs = {}, int max_retries = 3 ) { std::vector formatted_messages; if (!system_msgs.empty()) { auto system_formatted_messages = format_messages(system_msgs); formatted_messages.insert(formatted_messages.end(), system_formatted_messages.begin(), system_formatted_messages.end()); } auto _formatted_messages = format_messages(messages); formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end()); json body = { {"model", llm_config_->model}, {"messages", formatted_messages}, {"temperature", llm_config_->temperature}, {"max_tokens", llm_config_->max_tokens} }; std::string body_str = body.dump(); int retry = 0; while (retry <= max_retries) { // send request auto res = client_->Post(llm_config_->end_point, body_str, "application/json"); if (!res) { logger->error("Failed to send request: " + httplib::to_string(res.error())); } else if (res->status == 200) { try { json json_data = json::parse(res->body); return json_data["choices"][0]["message"]["content"].get(); } catch (const std::exception & e) { logger->error("Failed to parse response: " + std::string(e.what())); } } else { logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body); } retry++; if (retry > max_retries) { break; } // wait for a while before retrying std::this_thread::sleep_for(std::chrono::milliseconds(500)); logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries)); } throw std::runtime_error("Failed to get response from LLM"); } /** * @brief Send a request to the LLM with tool functions * @param messages The conversation message list * @param system_msgs Optional system messages * @param timeout The request timeout (seconds) * @param tools The tool list * @param tool_choice The tool choice strategy * @param max_retries The maximum number of retries * @return The generated assistant message (content, tool_calls) * @throws std::invalid_argument If the tool, tool choice or message is invalid * @throws std::runtime_error If the API call fails */ json ask_tool( const std::vector& messages, const std::vector& system_msgs = {}, const std::vector tools = {}, const std::string& tool_choice = "auto", int timeout = 60, int max_retries = 3 ) { if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") { throw std::invalid_argument("Invalid tool_choice: " + tool_choice); } std::vector formatted_messages; if (!system_msgs.empty()) { auto system_formatted_messages = format_messages(system_msgs); formatted_messages.insert(formatted_messages.end(), system_formatted_messages.begin(), system_formatted_messages.end()); } auto _formatted_messages = format_messages(messages); formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end()); if (!tools.empty()) { for (const json& tool : tools) { if (!tool.contains("type")) { throw std::invalid_argument("Tool must contain 'type' field"); } } } json body = { {"model", llm_config_->model}, {"messages", formatted_messages}, {"temperature", llm_config_->temperature}, {"max_tokens", llm_config_->max_tokens}, {"tools", tools}, {"tool_choice", tool_choice} }; client_->set_read_timeout(timeout); std::string body_str = body.dump(); int retry = 0; while (retry <= max_retries) { // send request auto res = client_->Post(llm_config_->end_point, body_str, "application/json"); if (!res) { logger->error("Failed to send request: " + httplib::to_string(res.error())); } else if (res->status == 200) { try { json json_data = json::parse(res->body); return json_data["choices"][0]["message"]; } catch (const std::exception & e) { logger->error("Failed to parse response: " + std::string(e.what())); } } else { logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body); } retry++; if (retry > max_retries) { break; } // wait for a while before retrying std::this_thread::sleep_for(std::chrono::milliseconds(500)); logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries)); } throw std::runtime_error("Failed to get response from LLM"); } }; } // namespace humanus #endif // HUMANUS_LLM_H