diff --git a/CMakeLists.txt b/CMakeLists.txt index c89fe40..793bd0a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -78,12 +78,12 @@ if(Python3_FOUND) endif() # 添加简单版本的可执行文件 -add_executable(humanus_simple main_simple.cpp logger.cpp schema.cpp) -target_link_libraries(humanus_simple PRIVATE Threads::Threads ${OPENSSL_LIBRARIES}) -if(Python3_FOUND) - target_link_libraries(humanus_simple PRIVATE ${Python3_LIBRARIES}) -endif() +# add_executable(humanus_simple main_simple.cpp logger.cpp schema.cpp) +# target_link_libraries(humanus_simple PRIVATE Threads::Threads ${OPENSSL_LIBRARIES}) +# if(Python3_FOUND) +# target_link_libraries(humanus_simple PRIVATE ${Python3_LIBRARIES}) +# endif() # 安装目标 install(TARGETS humanus_cpp DESTINATION bin) -install(TARGETS humanus_simple DESTINATION bin) \ No newline at end of file +# install(TARGETS humanus_simple DESTINATION bin) \ No newline at end of file diff --git a/agent/base.h b/agent/base.h index 509a3c3..5f521b5 100644 --- a/agent/base.h +++ b/agent/base.h @@ -85,7 +85,7 @@ struct BaseAgent : std::enable_shared_from_this { // Execute the agent's main loop asynchronously virtual std::string run(const std::string& request = "") { if (state != AgentState::IDLE) { - throw std::runtime_error("Cannot run agent from state" + agent_state_map[state]); + throw std::runtime_error("Cannot run agent from state " + agent_state_map[state]); } if (!request.empty()) { @@ -97,7 +97,15 @@ struct BaseAgent : std::enable_shared_from_this { while (current_step < max_steps && state == AgentState::RUNNING) { current_step++; logger->info("Executing step " + std::to_string(current_step) + "/" + std::to_string(max_steps)); - std::string step_result = step(); + std::string step_result; + + try { + step_result = step(); + } catch (const std::exception& e) { + logger->error("Error executing step " + std::to_string(current_step) + ": " + std::string(e.what())); + state = AgentState::ERROR; + break; + } if (is_stuck()) { this->handle_stuck_state(); @@ -108,7 +116,11 @@ struct BaseAgent : std::enable_shared_from_this { if (current_step >= max_steps) { results.push_back("Terminated: Reached max steps (" + std::to_string(max_steps) + ")"); } - state = AgentState::IDLE; // RUNNING -> IDLE + if (state == AgentState::ERROR) { + results.push_back("Terminated: Agent state is " + agent_state_map[state]); + } else { + state = AgentState::IDLE; // FINISHED -> IDLE + } std::string result_str = ""; diff --git a/agent/planning.cpp b/agent/planning.cpp index c85569b..1e8a6c2 100644 --- a/agent/planning.cpp +++ b/agent/planning.cpp @@ -116,7 +116,7 @@ void PlanningAgent::update_plan_status(const std::string& tool_call_id) { try { // Mark the step as completed - available_tools.execute( + ToolResult result = available_tools.execute( "planning", { {"command", "mark_step"}, @@ -127,6 +127,7 @@ void PlanningAgent::update_plan_status(const std::string& tool_call_id) { ); logger->info( "Marked step " + std::to_string(step_index) + " as completed in plan " + active_plan_id + + "\n\n" + result.to_string() + "\n\n" ); } catch (const std::exception& e) { logger->warn("Failed to update plan status: " + std::string(e.what())); @@ -195,7 +196,7 @@ int PlanningAgent::_get_current_step_index() { // Create an initial plan based on the request. void PlanningAgent::create_initial_plan(const std::string& request) { - logger->info("Creating initial plan with ID: " + request); + logger->info("Creating initial plan with ID: " + active_plan_id); std::vector messages = { Message::user_message( diff --git a/agent/toolcall.cpp b/agent/toolcall.cpp index e143d0f..27e60b1 100644 --- a/agent/toolcall.cpp +++ b/agent/toolcall.cpp @@ -20,9 +20,9 @@ bool ToolCallAgent::think() { tool_calls = ToolCall::from_json_list(response["tool_calls"]); // Log response info - logger->info("✨ " + name + "'s thoughts:" + response["content"].dump()); + logger->info("✨ " + name + "'s thoughts: " + response["content"].dump()); logger->info( - "🛠️ " + name + " selected " + std::to_string(tool_calls.size()) + " tool(s) to use" + "🛠️ " + name + " selected " + std::to_string(tool_calls.size()) + " tool(s) to use" ); if (tool_calls.size() > 0) { std::string tools_str; @@ -114,7 +114,11 @@ std::string ToolCallAgent::execute_tool(ToolCall tool_call) { std::string name = tool_call.function.name; if (available_tools.tools_map.find(name) == available_tools.tools_map.end()) { - return "Error: Unknown tool '" + name + "'"; + return "Error: Unknown tool '" + name + "'. Please use one of the following tools: " + + std::accumulate(available_tools.tools_map.begin(), available_tools.tools_map.end(), std::string(), + [](const std::string& a, const auto& b) { + return a + (a.empty() ? "" : ", ") + b.first; + }); } try { diff --git a/config/config.toml b/config/config.toml index f110f20..febf9bb 100644 --- a/config/config.toml +++ b/config/config.toml @@ -1,6 +1,6 @@ [llm] -model = "deepseek-chat" -base_url = "https://api.deepseek.com" -end_point = "/v1/chat/completions" -api_key = "sk-93c5bfcb920c4a8aa345791d429b8536" +model = "anthropic/claude-3.7-sonnet" +base_url = "https://openrouter.ai" +end_point = "/api/v1/chat/completions" +api_key = "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" max_tokens = 8192 \ No newline at end of file diff --git a/config/config.toml.bak b/config/config.toml.bak index 08a5e6f..87abe58 100644 --- a/config/config.toml.bak +++ b/config/config.toml.bak @@ -3,4 +3,11 @@ model = "anthropic/claude-3.7-sonnet" base_url = "https://openrouter.ai" end_point = "/api/v1/chat/completions" api_key = "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" -max_tokens = 8196 \ No newline at end of file +max_tokens = 8192 + +[llm] +model = "deepseek-chat" +base_url = "https://api.deepseek.com" +end_point = "/v1/chat/completions" +api_key = "sk-93c5bfcb920c4a8aa345791d429b8536" +max_tokens = 8192 \ No newline at end of file diff --git a/flow/base.h b/flow/base.h index e0192d5..4208a9e 100644 --- a/flow/base.h +++ b/flow/base.h @@ -48,6 +48,8 @@ struct BaseFlow { } } + virtual ~BaseFlow() = default; + // Get the primary agent for the flow std::shared_ptr primary_agent() const { return agents.at(primary_agent_key); diff --git a/flow/planning.cpp b/flow/planning.cpp index 1a18051..4299364 100644 --- a/flow/planning.cpp +++ b/flow/planning.cpp @@ -58,7 +58,7 @@ std::string PlanningFlow::execute(const std::string& input) { result += step_result + "\n"; // Check if agent wants to terminate - if (executor->state == AgentState::FINISHED) { + if (executor->state == AgentState::FINISHED || executor->state == AgentState::ERROR) { break; } } @@ -197,7 +197,6 @@ void PlanningFlow::_get_current_step_info(int& current_step_index, json& step_in } current_step_index = i; - step_info = step; return; } current_step_index = -1; @@ -245,7 +244,7 @@ void PlanningFlow::_mark_step_completed() { try { // Mark the step as completed - planning_tool->execute({ + ToolResult result = planning_tool->execute({ {"command", "mark_step"}, {"plan_id", active_plan_id}, {"step_index", current_step_index}, @@ -253,6 +252,7 @@ void PlanningFlow::_mark_step_completed() { }); logger->info( "Marked step " + std::to_string(current_step_index) + " as completed in plan " + active_plan_id + + "\n\n" + result.to_string() + "\n\n" ); } catch (const std::exception& e) { logger->warn("Failed to update plan status: " + std::string(e.what())); diff --git a/flow/planning.h b/flow/planning.h index 2fbc358..353a668 100644 --- a/flow/planning.h +++ b/flow/planning.h @@ -25,7 +25,7 @@ struct PlanningFlow : public BaseFlow { const std::vector& executor_keys = {}, const std::string& active_plan_id = "", const std::map>& agents = {}, - const std::vector>& tools = {}, + const std::vector>& tools = {}, const std::string& primary_agent_key = "" ) : BaseFlow(agents, tools, primary_agent_key), llm(llm), @@ -33,11 +33,14 @@ struct PlanningFlow : public BaseFlow { executor_keys(executor_keys), active_plan_id(active_plan_id) { if (!llm) { - this->llm = LLM::get_instance(); + this->llm = LLM::get_instance("default"); } if (!planning_tool) { this->planning_tool = std::make_shared(); } + if (active_plan_id.empty()) { + this->active_plan_id = "plan_" + std::to_string(std::chrono::system_clock::now().time_since_epoch().count()); + } if (executor_keys.empty()) { for (const auto& [key, agent] : agents) { this->executor_keys.push_back(key); diff --git a/main.cpp b/main.cpp index 77dcf3a..9b1c544 100644 --- a/main.cpp +++ b/main.cpp @@ -1,6 +1,7 @@ #include "agent/manus.h" #include "logger.h" #include "prompt.h" +#include "flow/flow_factory.h" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include @@ -42,16 +43,52 @@ int main() { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif } - Manus agent = Manus(); + // Manus agent = Manus(); + // while (true) { + // std::string prompt; + // std::cout << "Enter your prompt (or 'exit' to quit): "; + // std::getline(std::cin, prompt); + // if (prompt == "exit") { + // logger->info("Goodbye!"); + // break; + // } + // logger->info("Processing your request..."); + // agent.run(prompt); + // } + + std::shared_ptr agent_ptr = std::make_shared(); + std::map> agents; + agents["default"] = agent_ptr; + + auto flow = FlowFactory::create_flow( + FlowType::PLANNING, + nullptr, // llm + nullptr, // planning_tool + std::vector{}, // executor_keys + "", // active_plan_id + agents, // agents + std::vector>{}, // tools + "default" // primary_agent_key + ); + while (true) { + if (agent_ptr->current_step == agent_ptr->max_steps) { + std::cout << "Program automatically paused after " << agent_ptr->current_step << " steps." << std::endl; + std::cout << "Enter your prompt (like 'continue' to continue or 'exit' to quit): "; + agent_ptr->current_step = 0; + } else { + std::cout << "Enter your prompt (or 'exit' to quit): "; + } + std::string prompt; - std::cout << "Enter your prompt (or 'exit' to quit): "; std::getline(std::cin, prompt); if (prompt == "exit") { logger->info("Goodbye!"); break; } - logger->info("Processing your request..."); - agent.run(prompt); + + std::cout << "Processing your request..." << std::endl; + auto result = flow->execute(prompt); + std::cout << result << std::endl; } } \ No newline at end of file diff --git a/server/python_execute.cpp b/server/python_execute.cpp index e4bb820..216b982 100644 --- a/server/python_execute.cpp +++ b/server/python_execute.cpp @@ -13,6 +13,7 @@ #include #include #include +#include // 检查是否找到Python #ifdef PYTHON_FOUND @@ -24,13 +25,31 @@ * @brief Python解释器类,用于执行Python代码 */ class python_interpreter { +private: + // 互斥锁,确保Python解释器的线程安全 + mutable std::mutex py_mutex; + bool is_initialized; + public: /** * @brief 构造函数,初始化Python解释器 */ - python_interpreter() { + python_interpreter() : is_initialized(false) { #ifdef PYTHON_FOUND - Py_Initialize(); + try { + Py_Initialize(); + if (Py_IsInitialized()) { + is_initialized = true; + // 初始化线程支持 + PyEval_InitThreads(); + // 释放GIL,允许其他线程获取 + PyThreadState *_save = PyEval_SaveThread(); + } else { + std::cerr << "Python解释器初始化失败" << std::endl; + } + } catch (const std::exception& e) { + std::cerr << "Python解释器初始化异常: " << e.what() << std::endl; + } #endif } @@ -39,7 +58,11 @@ public: */ ~python_interpreter() { #ifdef PYTHON_FOUND - Py_Finalize(); + if (is_initialized) { + std::lock_guard lock(py_mutex); + Py_Finalize(); + is_initialized = false; + } #endif } @@ -50,77 +73,176 @@ public: */ mcp::json forward(const mcp::json& input) const { #ifdef PYTHON_FOUND - if (input.contains("code") && input["code"].is_string()) { - std::string code = input["code"].get(); - - // Create scope to manage Python objects automatically - PyObject *main_module = PyImport_AddModule("__main__"); - PyObject *main_dict = PyModule_GetDict(main_module); - PyObject *sys_module = PyImport_ImportModule("sys"); - PyObject *io_module = PyImport_ImportModule("io"); - PyObject *string_io = PyObject_GetAttrString(io_module, "StringIO"); - PyObject *sys_stdout = PyObject_CallObject(string_io, nullptr); - PyObject *sys_stderr = PyObject_CallObject(string_io, nullptr); - - // Replace sys.stdout and sys.stderr with our StringIO objects - PySys_SetObject("stdout", sys_stdout); - PySys_SetObject("stderr", sys_stderr); - - // Execute the Python code - PyObject *result = PyRun_String(code.c_str(), Py_file_input, main_dict, main_dict); - if (!result) { - PyErr_Print(); - } - Py_XDECREF(result); - - // Fetch the output and error from the StringIO object - PyObject *out_value = PyObject_CallMethod(sys_stdout, "getvalue", nullptr); - PyObject *err_value = PyObject_CallMethod(sys_stderr, "getvalue", nullptr); - - // Convert Python string to C++ string - std::string output = PyUnicode_AsUTF8(out_value); - std::string error = PyUnicode_AsUTF8(err_value); - - // Restore the original sys.stdout and sys.stderr - PySys_SetObject("stdout", PySys_GetObject("stdout")); - PySys_SetObject("stderr", PySys_GetObject("stderr")); - - // Clean up - Py_DECREF(sys_stdout); - Py_DECREF(sys_stderr); - Py_DECREF(string_io); - Py_DECREF(io_module); - Py_DECREF(sys_module); - - // Prepare JSON output - mcp::json result_json; - if (!output.empty()) { - result_json["output"] = output; - } - if (!error.empty()) { - result_json["error"] = error; - } - - if (result_json.empty()) { - std::string last_line; - std::istringstream code_stream(code); - while (std::getline(code_stream, last_line, '\n')) {} - size_t pos = last_line.find_last_of(';') + 1; - pos = last_line.find("=") + 1; - while (pos < last_line.size() && isblank(last_line[pos])) { - pos++; - } - if (pos != std::string::npos) { - last_line = last_line.substr(pos); - } - - return mcp::json{{"warning", "No output. Maybe try with print(" + last_line + ")?"}}; - } - - return result_json; - } else { - return mcp::json{{"error", "Invalid parameters or code not provided"}}; + if (!is_initialized) { + return mcp::json{{"error", "Python解释器未正确初始化"}}; } + + // 获取GIL锁 + std::lock_guard lock(py_mutex); + PyGILState_STATE gstate = PyGILState_Ensure(); + + mcp::json result_json; + + try { + if (input.contains("code") && input["code"].is_string()) { + std::string code = input["code"].get(); + + // 获取主模块和字典 + PyObject *main_module = PyImport_AddModule("__main__"); + if (!main_module) { + PyGILState_Release(gstate); + return mcp::json{{"error", "无法获取Python主模块"}}; + } + + PyObject *main_dict = PyModule_GetDict(main_module); + if (!main_dict) { + PyGILState_Release(gstate); + return mcp::json{{"error", "无法获取Python主模块字典"}}; + } + + // 导入sys和io模块 + PyObject *sys_module = PyImport_ImportModule("sys"); + if (!sys_module) { + PyErr_Print(); + PyGILState_Release(gstate); + return mcp::json{{"error", "无法导入sys模块"}}; + } + + PyObject *io_module = PyImport_ImportModule("io"); + if (!io_module) { + Py_DECREF(sys_module); + PyErr_Print(); + PyGILState_Release(gstate); + return mcp::json{{"error", "无法导入io模块"}}; + } + + // 获取StringIO类 + PyObject *string_io = PyObject_GetAttrString(io_module, "StringIO"); + if (!string_io) { + Py_DECREF(io_module); + Py_DECREF(sys_module); + PyErr_Print(); + PyGILState_Release(gstate); + return mcp::json{{"error", "无法获取StringIO类"}}; + } + + // 创建StringIO对象 + PyObject *sys_stdout = PyObject_CallObject(string_io, nullptr); + if (!sys_stdout) { + Py_DECREF(string_io); + Py_DECREF(io_module); + Py_DECREF(sys_module); + PyErr_Print(); + PyGILState_Release(gstate); + return mcp::json{{"error", "无法创建stdout StringIO对象"}}; + } + + PyObject *sys_stderr = PyObject_CallObject(string_io, nullptr); + if (!sys_stderr) { + Py_DECREF(sys_stdout); + Py_DECREF(string_io); + Py_DECREF(io_module); + Py_DECREF(sys_module); + PyErr_Print(); + PyGILState_Release(gstate); + return mcp::json{{"error", "无法创建stderr StringIO对象"}}; + } + + // 保存原始的stdout和stderr + PyObject *old_stdout = PySys_GetObject("stdout"); + PyObject *old_stderr = PySys_GetObject("stderr"); + + if (old_stdout) Py_INCREF(old_stdout); + if (old_stderr) Py_INCREF(old_stderr); + + // 替换sys.stdout和sys.stderr + if (PySys_SetObject("stdout", sys_stdout) != 0 || + PySys_SetObject("stderr", sys_stderr) != 0) { + Py_DECREF(sys_stderr); + Py_DECREF(sys_stdout); + Py_DECREF(string_io); + Py_DECREF(io_module); + Py_DECREF(sys_module); + PyErr_Print(); + PyGILState_Release(gstate); + return mcp::json{{"error", "无法设置stdout/stderr重定向"}}; + } + + // 执行Python代码 + PyObject *result = PyRun_String(code.c_str(), Py_file_input, main_dict, main_dict); + if (!result) { + PyErr_Print(); + } + Py_XDECREF(result); + + // 获取输出和错误 + PyObject *out_value = PyObject_CallMethod(sys_stdout, "getvalue", nullptr); + PyObject *err_value = PyObject_CallMethod(sys_stderr, "getvalue", nullptr); + + std::string output, error; + + // 安全地转换Python字符串到C++字符串 + if (out_value && PyUnicode_Check(out_value)) { + output = PyUnicode_AsUTF8(out_value); + } + + if (err_value && PyUnicode_Check(err_value)) { + error = PyUnicode_AsUTF8(err_value); + } + + // 恢复原始的stdout和stderr + if (old_stdout) { + PySys_SetObject("stdout", old_stdout); + Py_DECREF(old_stdout); + } + + if (old_stderr) { + PySys_SetObject("stderr", old_stderr); + Py_DECREF(old_stderr); + } + + // 清理 + Py_XDECREF(out_value); + Py_XDECREF(err_value); + Py_DECREF(sys_stdout); + Py_DECREF(sys_stderr); + Py_DECREF(string_io); + Py_DECREF(io_module); + Py_DECREF(sys_module); + + // 准备JSON输出 + if (!output.empty()) { + result_json["output"] = output; + } + if (!error.empty()) { + result_json["error"] = error; + } + + if (result_json.empty()) { + std::string last_line; + std::istringstream code_stream(code); + while (std::getline(code_stream, last_line, '\n')) {} + size_t pos = last_line.find_last_of(';') + 1; + pos = last_line.find("=") + 1; + while (pos < last_line.size() && isblank(last_line[pos])) { + pos++; + } + if (pos != std::string::npos) { + last_line = last_line.substr(pos); + } + + result_json["warning"] = "No output. Maybe try with print(" + last_line + ")?"; + } + } else { + result_json["error"] = "Invalid parameters or code not provided"; + } + } catch (const std::exception& e) { + result_json["error"] = std::string("Python执行异常: ") + e.what(); + } + + // 释放GIL + PyGILState_Release(gstate); + return result_json; #else return mcp::json{{"error", "Python interpreter not available"}}; #endif @@ -133,7 +255,6 @@ static python_interpreter interpreter; // Python执行工具处理函数 mcp::json python_execute_handler(const mcp::json& args) { if (!args.contains("code")) { - std::cout << args.dump() << std::endl; throw mcp::mcp_exception(mcp::error_code::invalid_params, "缺少'code'参数"); } @@ -141,25 +262,9 @@ mcp::json python_execute_handler(const mcp::json& args) { // 使用Python解释器执行代码 mcp::json result = interpreter.forward(args); - // 格式化输出结果 - std::string output; - if (result.contains("output")) { - output += result["output"].get(); - } - if (result.contains("error")) { - output += result["error"].get(); - } - if (result.contains("warning")) { - output += result["warning"].get(); - } - - if (output.empty()) { - output = "Executed successfully but w/o output. Maybe you should print more information."; - } - return {{ {"type", "text"}, - {"text", output} + {"text", result.dump(2)} }}; } catch (const std::exception& e) { throw mcp::mcp_exception(mcp::error_code::internal_error, @@ -167,12 +272,11 @@ mcp::json python_execute_handler(const mcp::json& args) { } } -// 注册Python执行工具的函数 +// Register the PythonExecute tool void register_python_execute_tool(mcp::server& server) { - // 注册PythonExecute工具 - mcp::tool python_tool = mcp::tool_builder("PythonExecute") - .with_description("执行Python代码并返回结果") - .with_string_param("code", "要执行的Python代码", true) + mcp::tool python_tool = mcp::tool_builder("python_execute") + .with_description("Execute Python code and return the result") + .with_string_param("code", "The Python code to execute", true) .build(); server.register_tool(python_tool, python_execute_handler); diff --git a/tool/base.h b/tool/base.h index 34a166f..71dca05 100644 --- a/tool/base.h +++ b/tool/base.h @@ -149,8 +149,16 @@ struct ToolResult { }; } + static std::string parse_json_content(const json& content) { + if (content.is_string()) { + return content.get(); + } else { + return content.dump(2); + } + } + std::string to_string() const { - return !error.empty() ? "Error: " + error.dump() : output.dump(); + return !error.empty() ? "Error: " + parse_json_content(error) : parse_json_content(output); } }; @@ -161,7 +169,7 @@ struct ToolError : ToolResult { // Execute the tool with given parameters. struct BaseTool { - inline static std::set special_tool_name = {"terminate"}; + inline static std::set special_tool_name = {"terminate", "planning"}; std::string name; std::string description; diff --git a/tool/filesystem.h b/tool/filesystem.h index 77ef840..51fa49b 100644 --- a/tool/filesystem.h +++ b/tool/filesystem.h @@ -62,6 +62,19 @@ struct FileSystem : BaseTool { "required": ["tool"] })json"); + inline static std::set allowed_tools = { + "read_file", + "read_multiple_files", + "write_file", + "edit_file", + "create_directory", + "list_directory", + "move_file", + "search_files", + "get_file_info", + "list_allowed_directories" + }; + FileSystem() : BaseTool(name_, description_, parameters_) {} ToolResult execute(const json& args) override { @@ -82,6 +95,14 @@ struct FileSystem : BaseTool { return ToolError("Tool is required"); } + if (allowed_tools.find(tool) == allowed_tools.end()) { + return ToolError("Unknown tool '" + tool + "'. Please use one of the following tools: " + + std::accumulate(allowed_tools.begin(), allowed_tools.end(), std::string(), + [](const std::string& a, const std::string& b) { + return a + (a.empty() ? "" : ", ") + b; + })); + } + json result = _client->call_tool(tool, args); bool is_error = result.value("isError", false); diff --git a/tool/planning.cpp b/tool/planning.cpp index 661e44d..b063def 100644 --- a/tool/planning.cpp +++ b/tool/planning.cpp @@ -254,7 +254,7 @@ ToolResult PlanningTool::_delete_plan(const std::string& plan_id) { // Format a plan for display. std::string PlanningTool::_format_plan(const json& plan) { std::stringstream output_ss; - output_ss << "Plan ID: " << plan["plan_id"].get() << "\n"; + output_ss << "Plan: " << plan.value("title", "Unknown Plan") << " (ID: " << plan["plan_id"].get() << ")\n"; int current_length = output_ss.str().length(); for (int i = 0; i < current_length; i++) { diff --git a/tool/puppeteer.h b/tool/puppeteer.h index 1038bbf..9cdd26c 100644 --- a/tool/puppeteer.h +++ b/tool/puppeteer.h @@ -58,6 +58,16 @@ struct Puppeteer : BaseTool { "required": ["tool"] })json"); + inline static std::set allowed_tools = { + "navigate", + "screenshot", + "click", + "hover", + "fill", + "select", + "evaluate" + }; + Puppeteer() : BaseTool(name_, description_, parameters_) {} ToolResult execute(const json& args) override { @@ -78,6 +88,14 @@ struct Puppeteer : BaseTool { return ToolError("Tool is required"); } + if (allowed_tools.find(tool) == allowed_tools.end()) { + return ToolError("Unknown tool '" + tool + "'. Please use one of the following tools: " + + std::accumulate(allowed_tools.begin(), allowed_tools.end(), std::string(), + [](const std::string& a, const std::string& b) { + return a + (a.empty() ? "" : ", ") + b; + })); + } + json result = _client->call_tool("puppeteer_" + tool, args); bool is_error = result.value("isError", false);