openmanus.cpp/include/tool_call_agent.h

179 lines
5.1 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#ifndef OPENMANUS_TOOL_CALL_AGENT_H
#define OPENMANUS_TOOL_CALL_AGENT_H
#include <string>
#include <vector>
#include <memory>
#include "agent_base.h"
#include "tool_call.h"
#include "config.h"
#include "mcp/common/httplib.h"
namespace openmanus {
// 消息类型枚举
enum class MessageRole {
SYSTEM,
USER,
ASSISTANT,
TOOL
};
// 消息结构体
struct Message {
MessageRole role;
std::string content;
std::string name; // 用于工具消息的工具名称
std::string tool_call_id; // 用于工具消息的工具调用ID
// 创建系统消息
static Message systemMessage(const std::string& content) {
return {MessageRole::SYSTEM, content, "", ""};
}
// 创建用户消息
static Message userMessage(const std::string& content) {
return {MessageRole::USER, content, "", ""};
}
// 创建助手消息
static Message assistantMessage(const std::string& content) {
return {MessageRole::ASSISTANT, content, "", ""};
}
// 创建工具消息
static Message toolMessage(const std::string& content, const std::string& name, const std::string& tool_call_id) {
return {MessageRole::TOOL, content, name, tool_call_id};
}
};
/**
* @class ToolCallAgent
* @brief 工具调用代理,能够执行工具调用的代理
*/
class ToolCallAgent : public AgentBase {
public:
ToolCallAgent(const std::string& name, const std::string& description, const std::string& config_file = "config.toml");
virtual ~ToolCallAgent() = default;
/**
* @brief 运行代理,处理用户输入的提示
* @param prompt 用户输入的提示
* @return 处理结果
*/
virtual std::string run(const std::string& prompt) override;
/**
* @brief 设置LLM API参数
* @param host LLM API主机
* @param port LLM API端口
* @param endpoint LLM API端点
*/
void setLLMAPIParams(const std::string& host, int port, const std::string& endpoint);
/**
* @brief 设置云服务提供商的LLM API参数
* @param provider 提供商名称(如"deepseek"
* @param base_url 基础URL
* @param endpoint API端点
*/
void setCloudLLMAPIParams(const std::string& provider, const std::string& base_url, const std::string& endpoint);
/**
* @brief 设置LLM API密钥
* @param api_key LLM API密钥
*/
void setLLMAPIKey(const std::string& api_key);
/**
* @brief 清空消息历史
*/
void clearMessages();
/**
* @brief 添加消息到历史
* @param message 消息
*/
void addMessage(const Message& message);
/**
* @brief 获取消息历史
* @return 消息历史
*/
const std::vector<Message>& getMessages() const { return messages_; }
protected:
/**
* @brief 思考下一步行动
* @return 是否继续执行
*/
virtual bool think();
/**
* @brief 执行行动
* @return 执行结果
*/
virtual std::string act();
/**
* @brief 执行工具
* @param tool_call 工具调用
* @return 执行结果
*/
virtual std::string executeToolCall(const ToolCall& tool_call);
/**
* @brief 处理特殊工具
* @param name 工具名称
* @param result 执行结果
* @return 是否是特殊工具
*/
virtual bool handleSpecialTool(const std::string& name, const std::string& result);
/**
* @brief 判断是否应该结束执行
* @return 是否结束执行
*/
virtual bool shouldFinishExecution();
/**
* @brief 判断是否是特殊工具
* @param name 工具名称
* @return 是否是特殊工具
*/
virtual bool isSpecialTool(const std::string& name);
/**
* @brief 调用LLM API获取下一步行动
* @param messages 消息历史
* @return 响应JSON
*/
virtual mcp::json callLLMAPI(const std::vector<Message>& messages);
protected:
std::vector<Message> messages_; // 消息历史
std::vector<ToolCall> tool_calls_; // 工具调用列表
int max_steps_; // 最大步骤数
int current_step_; // 当前步骤
bool should_clear_history_; // 是否应该清空历史
int max_consecutive_tool_calls_; // 最大连续工具调用次数
int consecutive_tool_calls_; // 当前连续工具调用次数
// 配置
Config config_; // 配置
// LLM API相关
std::string llm_api_host_; // LLM API主机
int llm_api_port_; // LLM API端口
std::string llm_api_base_url_; // LLM API基础URL
std::string llm_api_endpoint_; // LLM API端点
std::string llm_api_key_; // LLM API密钥
std::string llm_model_; // LLM模型名称
int llm_max_tokens_; // LLM最大生成token数
std::string llm_provider_; // LLM提供商
std::unique_ptr<httplib::Client> http_client_; // HTTP客户端
};
} // namespace openmanus
#endif // OPENMANUS_TOOL_CALL_AGENT_H