2025-03-16 17:17:01 +08:00
|
|
|
|
#ifndef HUMANUS_LLM_H
|
|
|
|
|
#define HUMANUS_LLM_H
|
|
|
|
|
|
|
|
|
|
#include "config.h"
|
|
|
|
|
#include "logger.h"
|
|
|
|
|
#include "schema.h"
|
|
|
|
|
#include "mcp/common/httplib.h"
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <stdexcept>
|
|
|
|
|
#include <future>
|
|
|
|
|
|
|
|
|
|
namespace humanus {
|
|
|
|
|
|
|
|
|
|
class LLM {
|
|
|
|
|
private:
|
|
|
|
|
static std::map<std::string, std::shared_ptr<LLM>> _instances;
|
|
|
|
|
|
2025-03-16 22:56:03 +08:00
|
|
|
|
std::unique_ptr<httplib::Client> client_;
|
2025-03-16 17:17:01 +08:00
|
|
|
|
|
|
|
|
|
LLMSettings llm_config_;
|
|
|
|
|
|
|
|
|
|
// 私有构造函数,防止直接创建实例
|
|
|
|
|
LLM(const std::string& config_name, const LLMSettings llm_config) : llm_config_(llm_config) {
|
|
|
|
|
client_ = std::make_unique<httplib::Client>(llm_config.base_url);
|
2025-03-16 22:56:03 +08:00
|
|
|
|
client_->set_default_headers({
|
|
|
|
|
{"Authorization", "Bearer " + llm_config_.api_key}
|
|
|
|
|
});
|
2025-03-16 17:17:01 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
// 单例模式获取实例
|
|
|
|
|
static std::shared_ptr<LLM> get_instance(const std::string& config_name = "default", const LLMSettings llm_config = LLMSettings()) {
|
|
|
|
|
if (_instances.find(config_name) == _instances.end()) {
|
|
|
|
|
_instances[config_name] = std::make_shared<LLM>(config_name, llm_config);
|
|
|
|
|
}
|
|
|
|
|
return _instances[config_name];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief 格式化消息列表为LLM可接受的格式
|
|
|
|
|
* @param messages Message对象消息列表
|
|
|
|
|
* @return 格式化后的消息列表
|
|
|
|
|
* @throws std::invalid_argument 如果消息格式无效或缺少必要字段
|
|
|
|
|
* @throws std::runtime_error 如果消息类型不支持
|
|
|
|
|
*/
|
|
|
|
|
static std::vector<json> format_messages(const std::vector<Message>& messages) {
|
|
|
|
|
std::vector<json> formatted_messages;
|
|
|
|
|
|
|
|
|
|
for (const auto& message : messages) {
|
|
|
|
|
formatted_messages.push_back(message.to_json());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto& message : formatted_messages) {
|
|
|
|
|
if (message["role"] != "user" && message["role"] != "assistant" && message["role"] != "system" && message["role"] != "tool") {
|
|
|
|
|
throw std::invalid_argument("Invalid role: " + message["role"]);
|
|
|
|
|
}
|
2025-03-16 22:56:03 +08:00
|
|
|
|
if (message["content"].empty() && message["tool_calls"].empty()) {
|
2025-03-16 17:17:01 +08:00
|
|
|
|
throw std::invalid_argument("Message must contain either 'content' or 'tool_calls'");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return formatted_messages;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief 格式化消息列表为LLM可接受的格式
|
|
|
|
|
* @param messages json对象消息列表
|
|
|
|
|
* @return 格式化后的消息列表
|
|
|
|
|
* @throws std::invalid_argument 如果消息格式无效或缺少必要字段
|
|
|
|
|
* @throws std::runtime_error 如果消息类型不支持
|
|
|
|
|
*/
|
|
|
|
|
static std::vector<json> format_messages(const std::vector<json>& messages) {
|
|
|
|
|
std::vector<json> formatted_messages;
|
|
|
|
|
|
|
|
|
|
for (const auto& message : messages) {
|
|
|
|
|
if (!message.contains("role")) {
|
|
|
|
|
throw std::invalid_argument("消息缺少必要字段: role");
|
|
|
|
|
}
|
|
|
|
|
formatted_messages.push_back(message);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto& message : formatted_messages) {
|
|
|
|
|
if (message["role"] != "user" && message["role"] != "assistant" && message["role"] != "system" && message["role"] != "tool") {
|
|
|
|
|
throw std::invalid_argument("Invalid role: " + message["role"]);
|
|
|
|
|
}
|
2025-03-16 22:56:03 +08:00
|
|
|
|
if (message["content"].empty() && message["tool_calls"].empty()) {
|
2025-03-16 17:17:01 +08:00
|
|
|
|
throw std::invalid_argument("Message must contain either 'content' or 'tool_calls'");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return formatted_messages;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief 向LLM发送请求并获取回复
|
|
|
|
|
* @param messages 对话消息列表
|
|
|
|
|
* @param system_msgs 可选的系统消息
|
|
|
|
|
* @param max_retries 最大重试次数
|
|
|
|
|
* @return 生成的assitant content
|
|
|
|
|
* @throws std::invalid_argument 如果消息无效或回复为空
|
|
|
|
|
* @throws std::runtime_error 如果API调用失败
|
|
|
|
|
*/
|
|
|
|
|
std::string ask(
|
|
|
|
|
const std::vector<Message>& messages,
|
|
|
|
|
const std::vector<Message>& system_msgs = {},
|
|
|
|
|
int max_retries = 3
|
|
|
|
|
) {
|
|
|
|
|
std::vector<json> formatted_messages;
|
|
|
|
|
|
|
|
|
|
if (!system_msgs.empty()) {
|
|
|
|
|
auto system_formatted_messages = format_messages(system_msgs);
|
|
|
|
|
formatted_messages.insert(formatted_messages.end(), system_formatted_messages.begin(), system_formatted_messages.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto _formatted_messages = format_messages(messages);
|
|
|
|
|
formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end());
|
|
|
|
|
|
|
|
|
|
json body = {
|
2025-03-16 22:56:03 +08:00
|
|
|
|
{"model", llm_config_.model},
|
2025-03-16 17:17:01 +08:00
|
|
|
|
{"messages", formatted_messages},
|
2025-03-16 22:56:03 +08:00
|
|
|
|
{"temperature", llm_config_.temperature},
|
|
|
|
|
{"max_tokens", llm_config_.max_tokens}
|
2025-03-16 17:17:01 +08:00
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::string body_str = body.dump();
|
|
|
|
|
|
|
|
|
|
int retry = 0;
|
|
|
|
|
|
|
|
|
|
while (retry <= max_retries) {
|
|
|
|
|
// send request
|
2025-03-16 22:56:03 +08:00
|
|
|
|
auto res = client_->Post(llm_config_.end_point, body_str, "application/json");
|
2025-03-16 17:17:01 +08:00
|
|
|
|
|
|
|
|
|
if (!res) {
|
|
|
|
|
logger->error("Failed to send request: " + httplib::to_string(res.error()));
|
|
|
|
|
} else if (res->status == 200) {
|
|
|
|
|
try {
|
|
|
|
|
json json_data = json::parse(res->body);
|
|
|
|
|
return json_data["choices"][0]["message"]["content"].get<std::string>();
|
|
|
|
|
} catch (const std::exception & e) {
|
|
|
|
|
logger->error("Failed to parse response: " + std::string(e.what()));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
retry++;
|
|
|
|
|
|
|
|
|
|
// wait for a while before retrying
|
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
|
|
|
|
|
|
|
|
|
logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
throw std::runtime_error("Failed to get response from LLM");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief 使用工具功能向LLM发送请求
|
|
|
|
|
* @param messages 对话消息列表
|
|
|
|
|
* @param system_msgs 可选的系统消息
|
|
|
|
|
* @param timeout 请求超时时间(秒)
|
|
|
|
|
* @param tools 工具列表
|
|
|
|
|
* @param tool_choice 工具选择策略
|
2025-03-16 22:56:03 +08:00
|
|
|
|
* @param max_retries 最大重试次数
|
2025-03-16 17:17:01 +08:00
|
|
|
|
* @return 生成的assistant message (content, tool_calls)
|
|
|
|
|
* @throws std::invalid_argument 如果工具、工具选择或消息无效
|
|
|
|
|
* @throws std::runtime_error 如果API调用失败
|
|
|
|
|
*/
|
|
|
|
|
json ask_tool(
|
|
|
|
|
const std::vector<Message>& messages,
|
|
|
|
|
const std::vector<Message>& system_msgs = {},
|
|
|
|
|
const std::vector<json> tools = {},
|
|
|
|
|
const std::string& tool_choice = "auto",
|
2025-03-16 22:56:03 +08:00
|
|
|
|
int timeout = 60,
|
|
|
|
|
int max_retries = 3
|
2025-03-16 17:17:01 +08:00
|
|
|
|
) {
|
|
|
|
|
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
|
|
|
|
|
throw std::invalid_argument("Invalid tool_choice: " + tool_choice);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<json> formatted_messages;
|
|
|
|
|
|
|
|
|
|
if (!system_msgs.empty()) {
|
|
|
|
|
auto system_formatted_messages = format_messages(system_msgs);
|
|
|
|
|
formatted_messages.insert(formatted_messages.end(), system_formatted_messages.begin(), system_formatted_messages.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto _formatted_messages = format_messages(messages);
|
|
|
|
|
formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end());
|
|
|
|
|
|
|
|
|
|
if (!tools.empty()) {
|
|
|
|
|
for (const json& tool : tools) {
|
|
|
|
|
if (!tool.contains("type")) {
|
|
|
|
|
throw std::invalid_argument("Tool must contain 'type' field");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
json body = {
|
2025-03-16 22:56:03 +08:00
|
|
|
|
{"model", llm_config_.model},
|
2025-03-16 17:17:01 +08:00
|
|
|
|
{"messages", formatted_messages},
|
2025-03-16 22:56:03 +08:00
|
|
|
|
{"temperature", llm_config_.temperature},
|
|
|
|
|
{"max_tokens", llm_config_.max_tokens},
|
2025-03-16 17:17:01 +08:00
|
|
|
|
{"tools", tools},
|
|
|
|
|
{"tool_choice", tool_choice}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
client_->set_read_timeout(timeout);
|
|
|
|
|
|
|
|
|
|
std::string body_str = body.dump();
|
|
|
|
|
|
|
|
|
|
int retry = 0;
|
|
|
|
|
|
|
|
|
|
while (retry <= max_retries) {
|
|
|
|
|
// send request
|
2025-03-16 22:56:03 +08:00
|
|
|
|
auto res = client_->Post(llm_config_.end_point, body_str, "application/json");
|
2025-03-16 17:17:01 +08:00
|
|
|
|
|
|
|
|
|
if (!res) {
|
|
|
|
|
logger->error("Failed to send request: " + httplib::to_string(res.error()));
|
|
|
|
|
} else if (res->status == 200) {
|
|
|
|
|
try {
|
|
|
|
|
json json_data = json::parse(res->body);
|
|
|
|
|
return json_data["choices"][0]["message"];
|
|
|
|
|
} catch (const std::exception & e) {
|
|
|
|
|
logger->error("Failed to parse response: " + std::string(e.what()));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
retry++;
|
|
|
|
|
|
|
|
|
|
// wait for a while before retrying
|
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
|
|
|
|
|
|
|
|
|
logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
throw std::runtime_error("Failed to get response from LLM");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace humanus
|
|
|
|
|
|
|
|
|
|
#endif // HUMANUS_LLM_H
|