humanus.cpp/llm.h

343 lines
14 KiB
C++

#ifndef HUMANUS_LLM_H
#define HUMANUS_LLM_H
#include "config.h"
#include "logger.h"
#include "schema.h"
#include "mcp/common/httplib.h"
#include <map>
#include <string>
#include <memory>
#include <vector>
#include <functional>
#include <stdexcept>
#include <future>
namespace humanus {
class LLM {
private:
static std::map<std::string, std::shared_ptr<LLM>> _instances;
std::unique_ptr<httplib::Client> client_;
std::shared_ptr<LLMSettings> llm_config_;
std::shared_ptr<ToolHelper> tool_helper_;
public:
// Constructor
LLM(const std::string& config_name, const std::shared_ptr<LLMSettings>& llm_config = nullptr, const std::shared_ptr<ToolHelper>& tool_helper = nullptr) : llm_config_(llm_config), tool_helper_(tool_helper) {
if (!llm_config_) {
if (Config::get_instance().llm().find(config_name) == Config::get_instance().llm().end()) {
throw std::invalid_argument("LLM config not found: " + config_name);
}
llm_config_ = std::make_shared<LLMSettings>(Config::get_instance().llm().at(config_name));
}
if (!llm_config_->oai_tool_support && !tool_helper_) {
if (Config::get_instance().tool_helper().find(config_name) == Config::get_instance().tool_helper().end()) {
throw std::invalid_argument("Tool helper config not found: " + config_name);
}
tool_helper_ = std::make_shared<ToolHelper>(Config::get_instance().tool_helper().at(config_name));
}
client_ = std::make_unique<httplib::Client>(llm_config_->base_url);
client_->set_default_headers({
{"Authorization", "Bearer " + llm_config_->api_key}
});
client_->set_read_timeout(llm_config_->timeout);
}
// Get the singleton instance
static std::shared_ptr<LLM> get_instance(const std::string& config_name = "default", const std::shared_ptr<LLMSettings>& llm_config = nullptr) {
if (_instances.find(config_name) == _instances.end()) {
_instances[config_name] = std::make_shared<LLM>(config_name, llm_config);
}
return _instances[config_name];
}
/**
* @brief Format the message list to the format that LLM can accept
* @param messages Message object message list
* @return The formatted message list
* @throws std::invalid_argument If the message format is invalid or missing necessary fields
* @throws std::runtime_error If the message type is not supported
*/
json format_messages(const std::vector<Message>& messages) {
json formatted_messages = json::array();
auto concat_content = [](const json& lhs, const json& rhs) -> json {
if (lhs.is_string() && rhs.is_string()) {
return lhs.get<std::string>() + "\n" + rhs.get<std::string>(); // Maybe other delimiter?
}
json res = json::array();
if (lhs.is_string()) {
res.push_back({
{"type", "text"},
{"text", lhs.get<std::string>()}
});
} else if (lhs.is_array()) {
res.insert(res.end(), lhs.begin(), lhs.end());
}
if (rhs.is_string()) {
res.push_back({
{"type", "text"},
{"text", rhs.get<std::string>()}
});
} else if (rhs.is_array()) {
res.insert(res.end(), rhs.begin(), rhs.end());
}
return res;
};
for (const auto& message : messages) {
if (message.content.empty() && message.tool_calls.empty()) {
continue;
}
formatted_messages.push_back(message.to_json());
if (!llm_config_->oai_tool_support) {
if (formatted_messages.back()["role"] == "tool") {
std::string tool_results_str = formatted_messages.back().dump(2);
formatted_messages.back() = {
{"role", "user"},
{"content", tool_results_str}
};
} else if (!formatted_messages.back()["tool_calls"].empty()) {
if (formatted_messages.back()["content"].is_null()) {
formatted_messages.back()["content"] = "";
}
std::string tool_calls_str = tool_helper_->dump(formatted_messages.back()["tool_calls"]);
formatted_messages.back().erase("tool_calls");
formatted_messages.back()["content"] = concat_content(formatted_messages.back()["content"], tool_calls_str);
}
}
}
for (const auto& message : formatted_messages) {
if (message["role"] != "user" && message["role"] != "assistant" && message["role"] != "system" && message["role"] != "tool") {
throw std::invalid_argument("Invalid role: " + message["role"].get<std::string>());
}
}
size_t i = 0, j = -1;
for (; i < formatted_messages.size(); i++) {
if (i == 0 || formatted_messages[i]["role"] != formatted_messages[j]["role"]) {
formatted_messages[++j] = formatted_messages[i];
} else {
formatted_messages[j]["content"] = concat_content(formatted_messages[j]["content"], formatted_messages[i]["content"]);
}
}
formatted_messages.erase(formatted_messages.begin() + j + 1, formatted_messages.end());
return formatted_messages;
}
/**
* @brief Send a request to the LLM and get the reply
* @param messages The conversation message list
* @param system_prompt Optional system message
* @param max_retries The maximum number of retries
* @return The generated assistant content
* @throws std::invalid_argument If the message is invalid or the reply is empty
* @throws std::runtime_error If the API call fails
*/
std::string ask(
const std::vector<Message>& messages,
const std::string& system_prompt = "",
int max_retries = 3
) {
json formatted_messages = json::array();
if (!system_prompt.empty()) {
formatted_messages.push_back({
{"role", "system"},
{"content", system_prompt}
});
}
json _formatted_messages = format_messages(messages);
formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end());
json body = {
{"model", llm_config_->model},
{"messages", formatted_messages},
{"temperature", llm_config_->temperature},
{"max_tokens", llm_config_->max_tokens}
};
std::string body_str = body.dump();
int retry = 0;
while (retry <= max_retries) {
// send request
auto res = client_->Post(llm_config_->end_point, body_str, "application/json");
if (!res) {
logger->error("Failed to send request: " + httplib::to_string(res.error()));
} else if (res->status == 200) {
try {
json json_data = json::parse(res->body);
return json_data["choices"][0]["message"]["content"].get<std::string>();
} catch (const std::exception& e) {
logger->error("Failed to parse response: " + std::string(e.what()));
}
} else {
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
}
retry++;
if (retry > max_retries) {
break;
}
// wait for a while before retrying
std::this_thread::sleep_for(std::chrono::milliseconds(500));
logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries));
}
throw std::runtime_error("Failed to get response from LLM");
}
/**
* @brief Send a request to the LLM with tool functions
* @param messages The conversation message list
* @param system_prompt Optional system message
* @param next_step_prompt Optinonal prompt message for the next step
* @param timeout The request timeout (seconds)
* @param tools The tool list
* @param tool_choice The tool choice strategy
* @param max_retries The maximum number of retries
* @return The generated assistant message (content, tool_calls)
* @throws std::invalid_argument If the tool, tool choice or message is invalid
* @throws std::runtime_error If the API call fails
*/
json ask_tool(
const std::vector<Message>& messages,
const std::string& system_prompt = "",
const std::string& next_step_prompt = "",
const json& tools = {},
const std::string& tool_choice = "auto",
int max_retries = 3
) {
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
throw std::invalid_argument("Invalid tool_choice: " + tool_choice);
}
json formatted_messages = json::array();
if (!system_prompt.empty()) {
formatted_messages.push_back({
{"role", "system"},
{"content", system_prompt}
});
}
json _formatted_messages = format_messages(messages);
formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end());
if (!next_step_prompt.empty()) {
if (formatted_messages.empty() || formatted_messages.back()["role"] != "user") {
formatted_messages.push_back({
{"role", "user"},
{"content", next_step_prompt}
});
} else {
if (formatted_messages.back()["content"].is_string()) {
formatted_messages.back()["content"] = formatted_messages.back()["content"].get<std::string>() + "\n\n" + next_step_prompt;
} else if (formatted_messages.back()["content"].is_array()) {
formatted_messages.back()["content"].push_back({
{"type", "text"},
{"text", next_step_prompt}
});
}
}
}
if (!tools.empty()) {
for (const json& tool : tools) {
if (!tool.contains("type")) {
throw std::invalid_argument("Tool must contain 'type' field but got: " + tool.dump(2));
}
}
if (tool_choice == "required" && tools.empty()) {
throw std::invalid_argument("No tool available for required tool choice");
}
if (!tools.is_array()) {
throw std::invalid_argument("Tools must be an array");
}
}
json body = {
{"model", llm_config_->model},
{"messages", formatted_messages},
{"temperature", llm_config_->temperature},
{"max_tokens", llm_config_->max_tokens},
{"tool_choice", tool_choice}
};
if (llm_config_->oai_tool_support) {
body["tools"] = tools;
} else {
if (body["messages"].empty() || body["messages"].back()["role"] != "user") {
body["messages"].push_back({
{"role", "user"},
{"content", tool_helper_->hint(tools.dump(2))}
});
} else if (body["messages"].back()["content"].is_string()) {
body["messages"].back()["content"] = body["messages"].back()["content"].get<std::string>() + "\n\n" + tool_helper_->hint(tools.dump(2));
} else if (body["messages"].back()["content"].is_array()) {
body["messages"].back()["content"].push_back({
{"type", "text"},
{"text", tool_helper_->hint(tools.dump(2))}
});
}
}
std::string body_str = body.dump();
int retry = 0;
while (retry <= max_retries) {
// send request
auto res = client_->Post(llm_config_->end_point, body_str, "application/json");
if (!res) {
logger->error("Failed to send request: " + httplib::to_string(res.error()));
} else if (res->status == 200) {
try {
json json_data = json::parse(res->body);
json message = json_data["choices"][0]["message"];
if (!llm_config_->oai_tool_support && message["content"].is_string()) {
message = tool_helper_->parse(message["content"].get<std::string>());
}
return message;
} catch (const std::exception& e) {
logger->error("Failed to parse response: " + std::string(e.what()));
}
} else {
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
}
retry++;
if (retry > max_retries) {
break;
}
// wait for a while before retrying
std::this_thread::sleep_for(std::chrono::milliseconds(500));
logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries));
}
throw std::runtime_error("Failed to get response from LLM");
}
};
} // namespace humanus
#endif // HUMANUS_LLM_H