humanus.cpp/llm.h

264 lines
9.6 KiB
C
Raw Normal View History

2025-03-16 17:17:01 +08:00
#ifndef HUMANUS_LLM_H
#define HUMANUS_LLM_H
#include "config.h"
#include "logger.h"
#include "schema.h"
#include "mcp/common/httplib.h"
#include <map>
#include <string>
#include <memory>
#include <vector>
#include <functional>
#include <stdexcept>
#include <future>
namespace humanus {
class LLM {
private:
static std::map<std::string, std::shared_ptr<LLM>> _instances;
2025-03-16 22:56:03 +08:00
std::unique_ptr<httplib::Client> client_;
2025-03-16 17:17:01 +08:00
2025-03-17 01:58:37 +08:00
std::shared_ptr<LLMSettings> llm_config_;
2025-03-16 17:17:01 +08:00
2025-03-17 01:58:37 +08:00
public:
// 构造函数
LLM(const std::string& config_name, const std::shared_ptr<LLMSettings>& llm_config = nullptr) : llm_config_(llm_config) {
if (!llm_config_) {
if (Config::get_instance().llm().find(config_name) == Config::get_instance().llm().end()) {
throw std::invalid_argument("Config not found: " + config_name);
}
llm_config_ = std::make_shared<LLMSettings>(Config::get_instance().llm().at(config_name));
}
client_ = std::make_unique<httplib::Client>(llm_config_->base_url);
2025-03-16 22:56:03 +08:00
client_->set_default_headers({
2025-03-17 01:58:37 +08:00
{"Authorization", "Bearer " + llm_config_->api_key}
2025-03-16 22:56:03 +08:00
});
2025-03-16 17:17:01 +08:00
}
2025-03-17 01:58:37 +08:00
2025-03-16 17:17:01 +08:00
// 单例模式获取实例
2025-03-17 01:58:37 +08:00
static std::shared_ptr<LLM> get_instance(const std::string& config_name = "default", const std::shared_ptr<LLMSettings>& llm_config = nullptr) {
2025-03-16 17:17:01 +08:00
if (_instances.find(config_name) == _instances.end()) {
_instances[config_name] = std::make_shared<LLM>(config_name, llm_config);
}
return _instances[config_name];
}
/**
* @brief LLM
* @param messages Message
* @return
* @throws std::invalid_argument
* @throws std::runtime_error
*/
static std::vector<json> format_messages(const std::vector<Message>& messages) {
std::vector<json> formatted_messages;
for (const auto& message : messages) {
formatted_messages.push_back(message.to_json());
}
for (const auto& message : formatted_messages) {
if (message["role"] != "user" && message["role"] != "assistant" && message["role"] != "system" && message["role"] != "tool") {
2025-03-17 01:58:37 +08:00
throw std::invalid_argument("Invalid role: " + message["role"].get<std::string>());
2025-03-16 17:17:01 +08:00
}
2025-03-16 22:56:03 +08:00
if (message["content"].empty() && message["tool_calls"].empty()) {
2025-03-16 17:17:01 +08:00
throw std::invalid_argument("Message must contain either 'content' or 'tool_calls'");
}
}
return formatted_messages;
}
/**
* @brief LLM
* @param messages json
* @return
* @throws std::invalid_argument
* @throws std::runtime_error
*/
static std::vector<json> format_messages(const std::vector<json>& messages) {
std::vector<json> formatted_messages;
for (const auto& message : messages) {
if (!message.contains("role")) {
throw std::invalid_argument("消息缺少必要字段: role");
}
formatted_messages.push_back(message);
}
for (const auto& message : formatted_messages) {
if (message["role"] != "user" && message["role"] != "assistant" && message["role"] != "system" && message["role"] != "tool") {
2025-03-17 01:58:37 +08:00
throw std::invalid_argument("Invalid role: " + message["role"].get<std::string>());
2025-03-16 17:17:01 +08:00
}
2025-03-16 22:56:03 +08:00
if (message["content"].empty() && message["tool_calls"].empty()) {
2025-03-16 17:17:01 +08:00
throw std::invalid_argument("Message must contain either 'content' or 'tool_calls'");
}
}
return formatted_messages;
}
/**
* @brief LLM
* @param messages
* @param system_msgs
* @param max_retries
* @return assitant content
* @throws std::invalid_argument
* @throws std::runtime_error API
*/
std::string ask(
const std::vector<Message>& messages,
const std::vector<Message>& system_msgs = {},
int max_retries = 3
) {
std::vector<json> formatted_messages;
if (!system_msgs.empty()) {
auto system_formatted_messages = format_messages(system_msgs);
formatted_messages.insert(formatted_messages.end(), system_formatted_messages.begin(), system_formatted_messages.end());
}
auto _formatted_messages = format_messages(messages);
formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end());
json body = {
2025-03-17 01:58:37 +08:00
{"model", llm_config_->model},
2025-03-16 17:17:01 +08:00
{"messages", formatted_messages},
2025-03-17 01:58:37 +08:00
{"temperature", llm_config_->temperature},
{"max_tokens", llm_config_->max_tokens}
2025-03-16 17:17:01 +08:00
};
std::string body_str = body.dump();
int retry = 0;
while (retry <= max_retries) {
// send request
2025-03-17 01:58:37 +08:00
auto res = client_->Post(llm_config_->end_point, body_str, "application/json");
2025-03-16 17:17:01 +08:00
if (!res) {
logger->error("Failed to send request: " + httplib::to_string(res.error()));
} else if (res->status == 200) {
try {
json json_data = json::parse(res->body);
return json_data["choices"][0]["message"]["content"].get<std::string>();
} catch (const std::exception & e) {
logger->error("Failed to parse response: " + std::string(e.what()));
}
} else {
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
}
retry++;
2025-03-17 01:58:37 +08:00
if (retry > max_retries) {
break;
}
2025-03-16 17:17:01 +08:00
// wait for a while before retrying
std::this_thread::sleep_for(std::chrono::milliseconds(500));
logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries));
}
throw std::runtime_error("Failed to get response from LLM");
}
/**
* @brief 使LLM
* @param messages
* @param system_msgs
* @param timeout
* @param tools
* @param tool_choice
2025-03-16 22:56:03 +08:00
* @param max_retries
2025-03-16 17:17:01 +08:00
* @return assistant message (content, tool_calls)
* @throws std::invalid_argument
* @throws std::runtime_error API
*/
json ask_tool(
const std::vector<Message>& messages,
const std::vector<Message>& system_msgs = {},
const std::vector<json> tools = {},
const std::string& tool_choice = "auto",
2025-03-16 22:56:03 +08:00
int timeout = 60,
int max_retries = 3
2025-03-16 17:17:01 +08:00
) {
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
throw std::invalid_argument("Invalid tool_choice: " + tool_choice);
}
std::vector<json> formatted_messages;
if (!system_msgs.empty()) {
auto system_formatted_messages = format_messages(system_msgs);
formatted_messages.insert(formatted_messages.end(), system_formatted_messages.begin(), system_formatted_messages.end());
}
auto _formatted_messages = format_messages(messages);
formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end());
if (!tools.empty()) {
for (const json& tool : tools) {
if (!tool.contains("type")) {
throw std::invalid_argument("Tool must contain 'type' field");
}
}
}
json body = {
2025-03-17 01:58:37 +08:00
{"model", llm_config_->model},
2025-03-16 17:17:01 +08:00
{"messages", formatted_messages},
2025-03-17 01:58:37 +08:00
{"temperature", llm_config_->temperature},
{"max_tokens", llm_config_->max_tokens},
2025-03-16 17:17:01 +08:00
{"tools", tools},
{"tool_choice", tool_choice}
};
client_->set_read_timeout(timeout);
std::string body_str = body.dump();
int retry = 0;
while (retry <= max_retries) {
// send request
2025-03-17 01:58:37 +08:00
auto res = client_->Post(llm_config_->end_point, body_str, "application/json");
2025-03-16 17:17:01 +08:00
if (!res) {
logger->error("Failed to send request: " + httplib::to_string(res.error()));
} else if (res->status == 200) {
try {
json json_data = json::parse(res->body);
return json_data["choices"][0]["message"];
} catch (const std::exception & e) {
logger->error("Failed to parse response: " + std::string(e.what()));
}
} else {
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
}
retry++;
2025-03-17 01:58:37 +08:00
if (retry > max_retries) {
break;
}
2025-03-16 17:17:01 +08:00
// wait for a while before retrying
std::this_thread::sleep_for(std::chrono::milliseconds(500));
logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries));
}
throw std::runtime_error("Failed to get response from LLM");
}
};
} // namespace humanus
#endif // HUMANUS_LLM_H