179 lines
5.1 KiB
C++
179 lines
5.1 KiB
C++
#ifndef OPENMANUS_TOOL_CALL_AGENT_H
|
||
#define OPENMANUS_TOOL_CALL_AGENT_H
|
||
|
||
#include <string>
|
||
#include <vector>
|
||
#include <memory>
|
||
#include "agent_base.h"
|
||
#include "tool_call.h"
|
||
#include "config.h"
|
||
#include "mcp/common/httplib.h"
|
||
|
||
namespace openmanus {
|
||
|
||
// 消息类型枚举
|
||
enum class MessageRole {
|
||
SYSTEM,
|
||
USER,
|
||
ASSISTANT,
|
||
TOOL
|
||
};
|
||
|
||
// 消息结构体
|
||
struct Message {
|
||
MessageRole role;
|
||
std::string content;
|
||
std::string name; // 用于工具消息的工具名称
|
||
std::string tool_call_id; // 用于工具消息的工具调用ID
|
||
|
||
// 创建系统消息
|
||
static Message systemMessage(const std::string& content) {
|
||
return {MessageRole::SYSTEM, content, "", ""};
|
||
}
|
||
|
||
// 创建用户消息
|
||
static Message userMessage(const std::string& content) {
|
||
return {MessageRole::USER, content, "", ""};
|
||
}
|
||
|
||
// 创建助手消息
|
||
static Message assistantMessage(const std::string& content) {
|
||
return {MessageRole::ASSISTANT, content, "", ""};
|
||
}
|
||
|
||
// 创建工具消息
|
||
static Message toolMessage(const std::string& content, const std::string& name, const std::string& tool_call_id) {
|
||
return {MessageRole::TOOL, content, name, tool_call_id};
|
||
}
|
||
};
|
||
|
||
/**
|
||
* @class ToolCallAgent
|
||
* @brief 工具调用代理,能够执行工具调用的代理
|
||
*/
|
||
class ToolCallAgent : public AgentBase {
|
||
public:
|
||
ToolCallAgent(const std::string& name, const std::string& description, const std::string& config_file = "config.toml");
|
||
virtual ~ToolCallAgent() = default;
|
||
|
||
/**
|
||
* @brief 运行代理,处理用户输入的提示
|
||
* @param prompt 用户输入的提示
|
||
* @return 处理结果
|
||
*/
|
||
virtual std::string run(const std::string& prompt) override;
|
||
|
||
/**
|
||
* @brief 设置LLM API参数
|
||
* @param host LLM API主机
|
||
* @param port LLM API端口
|
||
* @param endpoint LLM API端点
|
||
*/
|
||
void setLLMAPIParams(const std::string& host, int port, const std::string& endpoint);
|
||
|
||
/**
|
||
* @brief 设置云服务提供商的LLM API参数
|
||
* @param provider 提供商名称(如"deepseek")
|
||
* @param base_url 基础URL
|
||
* @param endpoint API端点
|
||
*/
|
||
void setCloudLLMAPIParams(const std::string& provider, const std::string& base_url, const std::string& endpoint);
|
||
|
||
/**
|
||
* @brief 设置LLM API密钥
|
||
* @param api_key LLM API密钥
|
||
*/
|
||
void setLLMAPIKey(const std::string& api_key);
|
||
|
||
/**
|
||
* @brief 清空消息历史
|
||
*/
|
||
void clearMessages();
|
||
|
||
/**
|
||
* @brief 添加消息到历史
|
||
* @param message 消息
|
||
*/
|
||
void addMessage(const Message& message);
|
||
|
||
/**
|
||
* @brief 获取消息历史
|
||
* @return 消息历史
|
||
*/
|
||
const std::vector<Message>& getMessages() const { return messages_; }
|
||
|
||
protected:
|
||
/**
|
||
* @brief 思考下一步行动
|
||
* @return 是否继续执行
|
||
*/
|
||
virtual bool think();
|
||
|
||
/**
|
||
* @brief 执行行动
|
||
* @return 执行结果
|
||
*/
|
||
virtual std::string act();
|
||
|
||
/**
|
||
* @brief 执行工具
|
||
* @param tool_call 工具调用
|
||
* @return 执行结果
|
||
*/
|
||
virtual std::string executeToolCall(const ToolCall& tool_call);
|
||
|
||
/**
|
||
* @brief 处理特殊工具
|
||
* @param name 工具名称
|
||
* @param result 执行结果
|
||
* @return 是否是特殊工具
|
||
*/
|
||
virtual bool handleSpecialTool(const std::string& name, const std::string& result);
|
||
|
||
/**
|
||
* @brief 判断是否应该结束执行
|
||
* @return 是否结束执行
|
||
*/
|
||
virtual bool shouldFinishExecution();
|
||
|
||
/**
|
||
* @brief 判断是否是特殊工具
|
||
* @param name 工具名称
|
||
* @return 是否是特殊工具
|
||
*/
|
||
virtual bool isSpecialTool(const std::string& name);
|
||
|
||
/**
|
||
* @brief 调用LLM API获取下一步行动
|
||
* @param messages 消息历史
|
||
* @return 响应JSON
|
||
*/
|
||
virtual mcp::json callLLMAPI(const std::vector<Message>& messages);
|
||
|
||
protected:
|
||
std::vector<Message> messages_; // 消息历史
|
||
std::vector<ToolCall> tool_calls_; // 工具调用列表
|
||
int max_steps_; // 最大步骤数
|
||
int current_step_; // 当前步骤
|
||
bool should_clear_history_; // 是否应该清空历史
|
||
int max_consecutive_tool_calls_; // 最大连续工具调用次数
|
||
int consecutive_tool_calls_; // 当前连续工具调用次数
|
||
|
||
// 配置
|
||
Config config_; // 配置
|
||
|
||
// LLM API相关
|
||
std::string llm_api_host_; // LLM API主机
|
||
int llm_api_port_; // LLM API端口
|
||
std::string llm_api_base_url_; // LLM API基础URL
|
||
std::string llm_api_endpoint_; // LLM API端点
|
||
std::string llm_api_key_; // LLM API密钥
|
||
std::string llm_model_; // LLM模型名称
|
||
int llm_max_tokens_; // LLM最大生成token数
|
||
std::string llm_provider_; // LLM提供商
|
||
std::unique_ptr<httplib::Client> http_client_; // HTTP客户端
|
||
};
|
||
|
||
} // namespace openmanus
|
||
|
||
#endif // OPENMANUS_TOOL_CALL_AGENT_H
|