openmanus.cpp/include/tool_call_agent.h

179 lines
5.1 KiB
C
Raw Permalink Normal View History

2025-03-10 02:38:39 +08:00
#ifndef OPENMANUS_TOOL_CALL_AGENT_H
#define OPENMANUS_TOOL_CALL_AGENT_H
#include <string>
#include <vector>
#include <memory>
#include "agent_base.h"
#include "tool_call.h"
#include "config.h"
#include "mcp/common/httplib.h"
namespace openmanus {
// 消息类型枚举
enum class MessageRole {
SYSTEM,
USER,
ASSISTANT,
TOOL
};
// 消息结构体
struct Message {
MessageRole role;
std::string content;
std::string name; // 用于工具消息的工具名称
std::string tool_call_id; // 用于工具消息的工具调用ID
// 创建系统消息
static Message systemMessage(const std::string& content) {
return {MessageRole::SYSTEM, content, "", ""};
}
// 创建用户消息
static Message userMessage(const std::string& content) {
return {MessageRole::USER, content, "", ""};
}
// 创建助手消息
static Message assistantMessage(const std::string& content) {
return {MessageRole::ASSISTANT, content, "", ""};
}
// 创建工具消息
static Message toolMessage(const std::string& content, const std::string& name, const std::string& tool_call_id) {
return {MessageRole::TOOL, content, name, tool_call_id};
}
};
/**
* @class ToolCallAgent
* @brief
*/
class ToolCallAgent : public AgentBase {
public:
ToolCallAgent(const std::string& name, const std::string& description, const std::string& config_file = "config.toml");
virtual ~ToolCallAgent() = default;
/**
* @brief
* @param prompt
* @return
*/
virtual std::string run(const std::string& prompt) override;
/**
* @brief LLM API
* @param host LLM API
* @param port LLM API
* @param endpoint LLM API
*/
void setLLMAPIParams(const std::string& host, int port, const std::string& endpoint);
/**
* @brief LLM API
* @param provider "deepseek"
* @param base_url URL
* @param endpoint API
*/
void setCloudLLMAPIParams(const std::string& provider, const std::string& base_url, const std::string& endpoint);
/**
* @brief LLM API
* @param api_key LLM API
*/
void setLLMAPIKey(const std::string& api_key);
/**
* @brief
*/
void clearMessages();
/**
* @brief
* @param message
*/
void addMessage(const Message& message);
/**
* @brief
* @return
*/
const std::vector<Message>& getMessages() const { return messages_; }
protected:
/**
* @brief
* @return
*/
virtual bool think();
/**
* @brief
* @return
*/
virtual std::string act();
/**
* @brief
* @param tool_call
* @return
*/
virtual std::string executeToolCall(const ToolCall& tool_call);
/**
* @brief
* @param name
* @param result
* @return
*/
virtual bool handleSpecialTool(const std::string& name, const std::string& result);
/**
* @brief
* @return
*/
virtual bool shouldFinishExecution();
/**
* @brief
* @param name
* @return
*/
virtual bool isSpecialTool(const std::string& name);
/**
* @brief LLM API
* @param messages
* @return JSON
*/
virtual mcp::json callLLMAPI(const std::vector<Message>& messages);
protected:
std::vector<Message> messages_; // 消息历史
std::vector<ToolCall> tool_calls_; // 工具调用列表
int max_steps_; // 最大步骤数
int current_step_; // 当前步骤
bool should_clear_history_; // 是否应该清空历史
int max_consecutive_tool_calls_; // 最大连续工具调用次数
int consecutive_tool_calls_; // 当前连续工具调用次数
// 配置
Config config_; // 配置
// LLM API相关
std::string llm_api_host_; // LLM API主机
int llm_api_port_; // LLM API端口
std::string llm_api_base_url_; // LLM API基础URL
std::string llm_api_endpoint_; // LLM API端点
std::string llm_api_key_; // LLM API密钥
std::string llm_model_; // LLM模型名称
int llm_max_tokens_; // LLM最大生成token数
std::string llm_provider_; // LLM提供商
std::unique_ptr<httplib::Client> http_client_; // HTTP客户端
};
} // namespace openmanus
#endif // OPENMANUS_TOOL_CALL_AGENT_H