#ifndef HUMANUS_LLM_H #define HUMANUS_LLM_H #include "config.h" #include "logger.h" #include "schema.h" #include "mcp/common/httplib.h" #include #include #include #include #include #include #include namespace humanus { class LLM { private: static std::map> _instances; std::string model; std::string api_key; int max_tokens; double temperature; std::unique_ptr client_ = nullptr; int max_retries = 3; LLMSettings llm_config_; // 私有构造函数,防止直接创建实例 LLM(const std::string& config_name, const LLMSettings llm_config) : llm_config_(llm_config) { model = llm_config.model; api_key = llm_config.api_key; max_tokens = llm_config.max_tokens; temperature = llm_config.temperature; client_ = std::make_unique(llm_config.base_url); } public: // 单例模式获取实例 static std::shared_ptr get_instance(const std::string& config_name = "default", const LLMSettings llm_config = LLMSettings()) { if (_instances.find(config_name) == _instances.end()) { _instances[config_name] = std::make_shared(config_name, llm_config); } return _instances[config_name]; } /** * @brief 格式化消息列表为LLM可接受的格式 * @param messages Message对象消息列表 * @return 格式化后的消息列表 * @throws std::invalid_argument 如果消息格式无效或缺少必要字段 * @throws std::runtime_error 如果消息类型不支持 */ static std::vector format_messages(const std::vector& messages) { std::vector formatted_messages; for (const auto& message : messages) { formatted_messages.push_back(message.to_json()); } for (const auto& message : formatted_messages) { if (message["role"] != "user" && message["role"] != "assistant" && message["role"] != "system" && message["role"] != "tool") { throw std::invalid_argument("Invalid role: " + message["role"]); } if (!message.contains("content") && !message.contains("tool_calls")) { throw std::invalid_argument("Message must contain either 'content' or 'tool_calls'"); } } return formatted_messages; } /** * @brief 格式化消息列表为LLM可接受的格式 * @param messages json对象消息列表 * @return 格式化后的消息列表 * @throws std::invalid_argument 如果消息格式无效或缺少必要字段 * @throws std::runtime_error 如果消息类型不支持 */ static std::vector format_messages(const std::vector& messages) { std::vector formatted_messages; for (const auto& message : messages) { if (!message.contains("role")) { throw std::invalid_argument("消息缺少必要字段: role"); } formatted_messages.push_back(message); } for (const auto& message : formatted_messages) { if (message["role"] != "user" && message["role"] != "assistant" && message["role"] != "system" && message["role"] != "tool") { throw std::invalid_argument("Invalid role: " + message["role"]); } if (!message.contains("content") && !message.contains("tool_calls")) { throw std::invalid_argument("Message must contain either 'content' or 'tool_calls'"); } } return formatted_messages; } /** * @brief 向LLM发送请求并获取回复 * @param messages 对话消息列表 * @param system_msgs 可选的系统消息 * @param max_retries 最大重试次数 * @return 生成的assitant content * @throws std::invalid_argument 如果消息无效或回复为空 * @throws std::runtime_error 如果API调用失败 */ std::string ask( const std::vector& messages, const std::vector& system_msgs = {}, int max_retries = 3 ) { std::vector formatted_messages; if (!system_msgs.empty()) { auto system_formatted_messages = format_messages(system_msgs); formatted_messages.insert(formatted_messages.end(), system_formatted_messages.begin(), system_formatted_messages.end()); } auto _formatted_messages = format_messages(messages); formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end()); json body = { {"model", model}, {"messages", formatted_messages}, {"temperature", temperature}, {"max_tokens", max_tokens} }; std::string body_str = body.dump(); httplib::Headers headers = { {"Authorization", "Bearer " + api_key} }; int retry = 0; while (retry <= max_retries) { // send request auto res = client_->Post(llm_config_.end_point, headers, body_str, "application/json"); if (!res) { logger->error("Failed to send request: " + httplib::to_string(res.error())); } else if (res->status == 200) { try { json json_data = json::parse(res->body); return json_data["choices"][0]["message"]["content"].get(); } catch (const std::exception & e) { logger->error("Failed to parse response: " + std::string(e.what())); } } else { logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body); } retry++; // wait for a while before retrying std::this_thread::sleep_for(std::chrono::milliseconds(500)); logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries)); } throw std::runtime_error("Failed to get response from LLM"); } /** * @brief 使用工具功能向LLM发送请求 * @param messages 对话消息列表 * @param system_msgs 可选的系统消息 * @param timeout 请求超时时间(秒) * @param tools 工具列表 * @param tool_choice 工具选择策略 * @return 生成的assistant message (content, tool_calls) * @throws std::invalid_argument 如果工具、工具选择或消息无效 * @throws std::runtime_error 如果API调用失败 */ json ask_tool( const std::vector& messages, const std::vector& system_msgs = {}, const std::vector tools = {}, const std::string& tool_choice = "auto", int timeout = 60 ) { if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") { throw std::invalid_argument("Invalid tool_choice: " + tool_choice); } std::vector formatted_messages; if (!system_msgs.empty()) { auto system_formatted_messages = format_messages(system_msgs); formatted_messages.insert(formatted_messages.end(), system_formatted_messages.begin(), system_formatted_messages.end()); } auto _formatted_messages = format_messages(messages); formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end()); if (!tools.empty()) { for (const json& tool : tools) { if (!tool.contains("type")) { throw std::invalid_argument("Tool must contain 'type' field"); } } } json body = { {"model", model}, {"messages", formatted_messages}, {"temperature", temperature}, {"max_tokens", max_tokens}, {"tools", tools}, {"tool_choice", tool_choice} }; client_->set_read_timeout(timeout); std::string body_str = body.dump(); httplib::Headers headers = { {"Authorization", "Bearer " + api_key} }; int retry = 0; while (retry <= max_retries) { // send request auto res = client_->Post(llm_config_.end_point, headers, body_str, "application/json"); if (!res) { logger->error("Failed to send request: " + httplib::to_string(res.error())); } else if (res->status == 200) { try { json json_data = json::parse(res->body); return json_data["choices"][0]["message"]; } catch (const std::exception & e) { logger->error("Failed to parse response: " + std::string(e.what())); } } else { logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body); } retry++; // wait for a while before retrying std::this_thread::sleep_for(std::chrono::milliseconds(500)); logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries)); } throw std::runtime_error("Failed to get response from LLM"); } }; } // namespace humanus #endif // HUMANUS_LLM_H