From d64fc0359a92687230e569d7bb40a85d344e7a58 Mon Sep 17 00:00:00 2001 From: hkr04 Date: Thu, 20 Mar 2025 16:03:26 +0800 Subject: [PATCH] started to refactor memory --- CMakeLists.txt | 5 + agent/base.h | 8 +- agent/humanus.h | 2 +- agent/planning.h | 2 +- agent/react.h | 2 +- agent/swe.h | 2 +- agent/toolcall.h | 2 +- config.cpp | 123 +++++++++--------- config.h | 14 +- config/{config.toml => config_llm.toml} | 6 +- .../{config.toml.bak => config_llm.toml.bak} | 10 +- llm.h | 8 +- memory/base.h | 46 +++++++ memory/mem0.h | 34 +++++ memory/simple.h | 24 ++++ schema.h | 43 ------ 16 files changed, 203 insertions(+), 128 deletions(-) rename config/{config.toml => config_llm.toml} (65%) rename config/{config.toml.bak => config_llm.toml.bak} (72%) create mode 100644 memory/base.h create mode 100644 memory/mem0.h create mode 100644 memory/simple.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b41c88..e20fc04 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -81,6 +81,11 @@ file(GLOB FLOW_SOURCES "flow/*.cc" ) +file(GLOB MEMORY_SOURCES + "memory/*.cpp" + "memory/*.cc" +) + add_executable(humanus_cli main.cpp config.cpp diff --git a/agent/base.h b/agent/base.h index 814db7a..cdd7cb4 100644 --- a/agent/base.h +++ b/agent/base.h @@ -4,6 +4,8 @@ #include "llm.h" #include "schema.h" #include "logger.h" +#include "memory/base.h" +#include "memory/simple.h" #include #include #include @@ -28,7 +30,7 @@ struct BaseAgent : std::enable_shared_from_this { // Dependencies std::shared_ptr llm; // Language model instance - std::shared_ptr memory; // Agent's memory store + std::shared_ptr memory; // Agent's memory store AgentState state; // Current state of the agent // Execution control @@ -45,7 +47,7 @@ struct BaseAgent : std::enable_shared_from_this { const std::string& system_prompt, const std::string& next_step_prompt, const std::shared_ptr& llm = nullptr, - const std::shared_ptr& memory = nullptr, + const std::shared_ptr& memory = nullptr, AgentState state = AgentState::IDLE, int max_steps = 10, int current_step = 0, @@ -69,7 +71,7 @@ struct BaseAgent : std::enable_shared_from_this { llm = LLM::get_instance("default"); } if (!memory) { - memory = std::make_shared(max_steps); + memory = std::make_shared(max_steps); } } diff --git a/agent/humanus.h b/agent/humanus.h index 970257e..f94fa50 100644 --- a/agent/humanus.h +++ b/agent/humanus.h @@ -38,7 +38,7 @@ struct Humanus : ToolCallAgent { const std::string& system_prompt = prompt::humanus::SYSTEM_PROMPT, const std::string& next_step_prompt = prompt::humanus::NEXT_STEP_PROMPT, const std::shared_ptr& llm = nullptr, - const std::shared_ptr& memory = nullptr, + const std::shared_ptr& memory = nullptr, AgentState state = AgentState::IDLE, int max_steps = 30, int current_step = 0, diff --git a/agent/planning.h b/agent/planning.h index 32479e3..2546b32 100644 --- a/agent/planning.h +++ b/agent/planning.h @@ -34,7 +34,7 @@ struct PlanningAgent : ToolCallAgent { const std::string& system_prompt = prompt::planning::PLANNING_SYSTEM_PROMPT, const std::string& next_step_prompt = prompt::planning::NEXT_STEP_PROMPT, const std::shared_ptr& llm = nullptr, - const std::shared_ptr& memory = nullptr, + const std::shared_ptr& memory = nullptr, AgentState state = AgentState::IDLE, int max_steps = 20, int current_step = 0, diff --git a/agent/react.h b/agent/react.h index 06960f7..a160895 100644 --- a/agent/react.h +++ b/agent/react.h @@ -12,7 +12,7 @@ struct ReActAgent : BaseAgent { const std::string& system_prompt, const std::string& next_step_prompt, const std::shared_ptr& llm = nullptr, - const std::shared_ptr& memory = nullptr, + const std::shared_ptr& memory = nullptr, AgentState state = AgentState::IDLE, int max_steps = 10, int current_step = 0, diff --git a/agent/swe.h b/agent/swe.h index a342fc1..3e908f8 100644 --- a/agent/swe.h +++ b/agent/swe.h @@ -30,7 +30,7 @@ struct SweAgent : ToolCallAgent { const std::string& system_prompt = prompt::swe::SYSTEM_PROMPT, const std::string& next_step_prompt = prompt::swe::NEXT_STEP_TEMPLATE, const std::shared_ptr& llm = nullptr, - const std::shared_ptr& memory = nullptr, + const std::shared_ptr& memory = nullptr, AgentState state = AgentState::IDLE, int max_steps = 100, int current_step = 0, diff --git a/agent/toolcall.h b/agent/toolcall.h index 894d276..a9f6280 100644 --- a/agent/toolcall.h +++ b/agent/toolcall.h @@ -30,7 +30,7 @@ struct ToolCallAgent : ReActAgent { const std::string& system_prompt = prompt::toolcall::SYSTEM_PROMPT, const std::string& next_step_prompt = prompt::toolcall::NEXT_STEP_PROMPT, const std::shared_ptr& llm = nullptr, - const std::shared_ptr& memory = nullptr, + const std::shared_ptr& memory = nullptr, AgentState state = AgentState::IDLE, int max_steps = 30, int current_step = 0, diff --git a/config.cpp b/config.cpp index cb21aef..95638a8 100644 --- a/config.cpp +++ b/config.cpp @@ -18,78 +18,79 @@ void Config::_load_initial_config() { const auto& data = toml::parse_file(config_path.string()); // Load LLM configuration + for (const auto& [key, value] : data) { + const auto& llm_table = *value.as_table(); - // Check if llm configuration exists - if (!data.contains("llm") || !data["llm"].is_table()) { - throw std::runtime_error("Config file does not contain `llm` table"); - } - - const auto& llm_table = *data["llm"].as_table(); + LLMConfig llm_config; - LLMSettings llm_settings; - - if (llm_table.contains("model") && llm_table["model"].is_string()) { - llm_settings.model = llm_table["model"].as_string()->get(); - } else { - throw std::runtime_error("Invalid `model` configuration"); - } - - if (llm_table.contains("api_key") && llm_table["api_key"].is_string()) { - llm_settings.api_key = llm_table["api_key"].as_string()->get(); - } else { - throw std::runtime_error("Invalid `api_key` configuration"); - } - - if (llm_table.contains("base_url") && llm_table["base_url"].is_string()) { - llm_settings.base_url = llm_table["base_url"].as_string()->get(); - } else { - throw std::runtime_error("Invalid `base_url` configuration"); - } - - if (llm_table.contains("end_point") && llm_table["end_point"].is_string()) { - llm_settings.end_point = llm_table["end_point"].as_string()->get(); - } - - if (llm_table.contains("max_tokens") && llm_table["max_tokens"].is_integer()) { - llm_settings.max_tokens = llm_table["max_tokens"].as_integer()->get(); - } - - if (llm_table.contains("timeout") && llm_table["timeout"].is_integer()) { - llm_settings.timeout = llm_table["timeout"].as_integer()->get(); - } - - if (llm_table.contains("temperature") && llm_table["temperature"].is_floating_point()) { - llm_settings.temperature = llm_table["temperature"].as_floating_point()->get(); - } - - if (llm_table.contains("oai_tool_support") && llm_table["oai_tool_support"].is_boolean()) { - llm_settings.oai_tool_support = llm_table["oai_tool_support"].as_boolean()->get(); - } - - _config.llm["default"] = llm_settings; - - // Load tool helper configurations - ToolHelper tool_helper; - if (data.contains("tool_helper") && data["tool_helper"].is_table()) { - const auto& tool_helper_table = *data["tool_helper"].as_table(); - - if (tool_helper_table.contains("tool_start")) { - tool_helper.tool_start = tool_helper_table["tool_start"].as_string()->get(); + if (llm_table.contains("model") && llm_table["model"].is_string()) { + llm_config.model = llm_table["model"].as_string()->get(); } - if (tool_helper_table.contains("tool_end")) { - tool_helper.tool_end = tool_helper_table["tool_end"].as_string()->get(); + if (llm_table.contains("api_key") && llm_table["api_key"].is_string()) { + llm_config.api_key = llm_table["api_key"].as_string()->get(); } - if (tool_helper_table.contains("tool_hint_template")) { - tool_helper.tool_hint_template = tool_helper_table["tool_hint_template"].as_string()->get(); + if (llm_table.contains("base_url") && llm_table["base_url"].is_string()) { + llm_config.base_url = llm_table["base_url"].as_string()->get(); + } + + if (llm_table.contains("end_point") && llm_table["end_point"].is_string()) { + llm_config.end_point = llm_table["end_point"].as_string()->get(); + } + + if (llm_table.contains("max_tokens") && llm_table["max_tokens"].is_integer()) { + llm_config.max_tokens = llm_table["max_tokens"].as_integer()->get(); + } + + if (llm_table.contains("timeout") && llm_table["timeout"].is_integer()) { + llm_config.timeout = llm_table["timeout"].as_integer()->get(); + } + + if (llm_table.contains("temperature") && llm_table["temperature"].is_floating_point()) { + llm_config.temperature = llm_table["temperature"].as_floating_point()->get(); + } + + if (llm_table.contains("oai_tool_support") && llm_table["oai_tool_support"].is_boolean()) { + llm_config.oai_tool_support = llm_table["oai_tool_support"].as_boolean()->get(); + } + + _config.llm[std::string(key.str())] = llm_config; + + if (!llm_config.oai_tool_support) { + // Load tool helper configuration + ToolHelper tool_helper; + if (llm_table.contains("tool_helper") && llm_table["tool_helper"].is_table()) { + const auto& tool_helper_table = *llm_table["tool_helper"].as_table(); + if (tool_helper_table.contains("tool_start")) { + tool_helper.tool_start = tool_helper_table["tool_start"].as_string()->get(); + } + + if (tool_helper_table.contains("tool_end")) { + tool_helper.tool_end = tool_helper_table["tool_end"].as_string()->get(); + } + + if (tool_helper_table.contains("tool_hint_template")) { + tool_helper.tool_hint_template = tool_helper_table["tool_hint_template"].as_string()->get(); + } + } + _config.tool_helper[std::string(key.str())] = tool_helper; } } - _config.tool_helper["default"] = tool_helper; + + if (_config.llm.empty()) { + throw std::runtime_error("No LLM configuration found"); + } else if (_config.llm.find("default") == _config.llm.end()) { + _config.llm["default"] = _config.llm.begin()->second; + } + + if (_config.tool_helper.find("default") == _config.tool_helper.end()) { + _config.tool_helper["default"] = ToolHelper(); + } } catch (const std::exception& e) { std::cerr << "Loading config file failed: " << e.what() << std::endl; // Set default configuration - _config.llm["default"] = LLMSettings(); + _config.llm["default"] = LLMConfig(); _config.tool_helper["default"] = ToolHelper(); } } diff --git a/config.h b/config.h index 8b1f067..56f669c 100644 --- a/config.h +++ b/config.h @@ -21,7 +21,7 @@ static std::filesystem::path get_project_root() { static const std::filesystem::path PROJECT_ROOT = get_project_root(); -struct LLMSettings { +struct LLMConfig { std::string model; std::string api_key; std::string base_url; @@ -31,7 +31,7 @@ struct LLMSettings { double temperature; bool oai_tool_support; - LLMSettings( + LLMConfig( const std::string& model = "deepseek-chat", const std::string& api_key = "sk-", const std::string& base_url = "https://api.deepseek.com", @@ -154,7 +154,7 @@ struct ToolHelper { }; struct AppConfig { - std::map llm; + std::map llm; std::map tool_helper; }; @@ -179,14 +179,10 @@ private: */ static std::filesystem::path _get_config_path() { auto root = PROJECT_ROOT; - auto config_path = root / "config" / "config.toml"; + auto config_path = root / "config" / "config_llm.toml"; if (std::filesystem::exists(config_path)) { return config_path; } - auto example_path = root / "config" / "config.example.toml"; - if (std::filesystem::exists(example_path)) { - return example_path; - } throw std::runtime_error("Config file not found"); } @@ -214,7 +210,7 @@ public: * @brief Get the LLM settings * @return The LLM settings map */ - const std::map& llm() const { + const std::map& llm() const { return _config.llm; } diff --git a/config/config.toml b/config/config_llm.toml similarity index 65% rename from config/config.toml rename to config/config_llm.toml index 9144f6b..0a6b838 100644 --- a/config/config.toml +++ b/config/config_llm.toml @@ -1,7 +1,9 @@ -[llm] +[default] model = "deepseek-reasoner" base_url = "https://api.deepseek.com" end_point = "/v1/chat/completions" api_key = "sk-93c5bfcb920c4a8aa345791d429b8536" max_tokens = 8192 -oai_tool_support = false \ No newline at end of file +oai_tool_support = false +tool_start = "" +tool_end = "" \ No newline at end of file diff --git a/config/config.toml.bak b/config/config_llm.toml.bak similarity index 72% rename from config/config.toml.bak rename to config/config_llm.toml.bak index 8a12cb0..7e495c2 100644 --- a/config/config.toml.bak +++ b/config/config_llm.toml.bak @@ -17,4 +17,12 @@ model = "qwen-max" base_url = "https://dashscope.aliyuncs.com" end_point = "/compatible-mode/v1/chat/completions" api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600" -max_tokens = 8192 \ No newline at end of file +max_tokens = 8192 + +[llm] +model = "deepseek-reasoner" +base_url = "https://api.deepseek.com" +end_point = "/v1/chat/completions" +api_key = "sk-93c5bfcb920c4a8aa345791d429b8536" +max_tokens = 8192 +oai_tool_support = false \ No newline at end of file diff --git a/llm.h b/llm.h index 473f878..f4c2d47 100644 --- a/llm.h +++ b/llm.h @@ -21,18 +21,18 @@ private: std::unique_ptr client_; - std::shared_ptr llm_config_; + std::shared_ptr llm_config_; std::shared_ptr tool_helper_; public: // Constructor - LLM(const std::string& config_name, const std::shared_ptr& llm_config = nullptr, const std::shared_ptr& tool_helper = nullptr) : llm_config_(llm_config), tool_helper_(tool_helper) { + LLM(const std::string& config_name, const std::shared_ptr& llm_config = nullptr, const std::shared_ptr& tool_helper = nullptr) : llm_config_(llm_config), tool_helper_(tool_helper) { if (!llm_config_) { if (Config::get_instance().llm().find(config_name) == Config::get_instance().llm().end()) { throw std::invalid_argument("LLM config not found: " + config_name); } - llm_config_ = std::make_shared(Config::get_instance().llm().at(config_name)); + llm_config_ = std::make_shared(Config::get_instance().llm().at(config_name)); } if (!llm_config_->oai_tool_support && !tool_helper_) { if (Config::get_instance().tool_helper().find(config_name) == Config::get_instance().tool_helper().end()) { @@ -48,7 +48,7 @@ public: } // Get the singleton instance - static std::shared_ptr get_instance(const std::string& config_name = "default", const std::shared_ptr& llm_config = nullptr) { + static std::shared_ptr get_instance(const std::string& config_name = "default", const std::shared_ptr& llm_config = nullptr) { if (_instances.find(config_name) == _instances.end()) { _instances[config_name] = std::make_shared(config_name, llm_config); } diff --git a/memory/base.h b/memory/base.h new file mode 100644 index 0000000..8afe16f --- /dev/null +++ b/memory/base.h @@ -0,0 +1,46 @@ +#ifndef HUMANUS_MEMORY_BASE_H +#define HUMANUS_MEMORY_BASE_H + +#include "schema.h" + +namespace humanus { + +struct MemoryBase { + std::vector messages; + + // Add a message to the memory + virtual void add_message(const Message& message) { + messages.push_back(message); + } + + // Add multiple messages to the memory + void add_messages(const std::vector& messages) { + for (const auto& message : messages) { + add_message(message); + } + } + + // Clear all messages + void clear() { + messages.clear(); + } + + // Get the last n messages + virtual std::vector get_recent_messages(int n) const { + n = std::min(n, static_cast(messages.size())); + return std::vector(messages.end() - n, messages.end()); + } + + // Convert messages to list of dicts + json to_json_list() const { + json memory = json::array(); + for (const auto& message : messages) { + memory.push_back(message.to_json()); + } + return memory; + } +}; + +} + +#endif // HUMANUS_MEMORY_BASE_H \ No newline at end of file diff --git a/memory/mem0.h b/memory/mem0.h new file mode 100644 index 0000000..183d3e3 --- /dev/null +++ b/memory/mem0.h @@ -0,0 +1,34 @@ +#ifndef HUMANUS_MEMORY_MEM0_H +#define HUMANUS_MEMORY_MEM0_H + +#include "base.h" + +namespace humanus { + +struct MemoryConfig { + // Database config + std::string history_db_path = ":memory:"; + + // Embedder config + struct { + std::string provider = "llama_cpp"; + EmbedderConfig config; + } embedder; + + // Vector store config + struct { + std::string provider = "hnswlib"; + VectorStoreConfig config; + } vector_store; + + // Optional: LLM config + struct { + std::string provider = "openai"; + LLMConfig config; + } llm; +}; + + +} + +#endif // HUMANUS_MEMORY_MEM0_H \ No newline at end of file diff --git a/memory/simple.h b/memory/simple.h new file mode 100644 index 0000000..5f00238 --- /dev/null +++ b/memory/simple.h @@ -0,0 +1,24 @@ +#ifndef HUMANUS_MEMORY_SIMPLE_H +#define HUMANUS_MEMORY_SIMPLE_H + +#include "base.h" + +namespace humanus { + +struct MemorySimple : MemoryBase { + int max_messages; + + MemorySimple(int max_messages = 100) : max_messages(max_messages) {} + + void add_message(const Message& message) override { + MemoryBase::add_message(message); + while (!messages.empty() && (messages.size() > max_messages || messages.begin()->role == "assistant" || messages.begin()->role == "tool")) { + // Ensure the first message is always a user or system message + messages.erase(messages.begin()); + } + } +}; + +} // namespace humanus + +#endif // HUMANUS_MEMORY_SIMPLE_H \ No newline at end of file diff --git a/schema.h b/schema.h index 3d7dd72..b7fd6d1 100644 --- a/schema.h +++ b/schema.h @@ -156,49 +156,6 @@ struct Message { } }; -struct Memory { - std::vector messages; - int max_messages; - - Memory(int max_messages = 100) : max_messages(max_messages) {} - - // Add a message to the memory - void add_message(const Message& message) { - messages.push_back(message); - while (!messages.empty() && (messages.size() > max_messages || messages.begin()->role == "assistant" || messages.begin()->role == "tool")) { - // Ensure the first message is always a user or system message - messages.erase(messages.begin()); - } - } - - // Add multiple messages to the memory - void add_messages(const std::vector& messages) { - for (const auto& message : messages) { - add_message(message); - } - } - - // Clear all messages - void clear() { - messages.clear(); - } - - // Get the last n messages - std::vector get_recent_messages(int n) const { - n = std::min(n, static_cast(messages.size())); - return std::vector(messages.end() - n, messages.end()); - } - - // Convert messages to list of dicts - json to_json_list() const { - json memory = json::array(); - for (const auto& message : messages) { - memory.push_back(message.to_json()); - } - return memory; - } -}; - } // namespace humanus #endif // HUMANUS_SCHEMA_H