started to refactor memory

main
hkr04 2025-03-20 16:03:26 +08:00
parent 1d87a2f4a6
commit d64fc0359a
16 changed files with 203 additions and 128 deletions

View File

@ -81,6 +81,11 @@ file(GLOB FLOW_SOURCES
"flow/*.cc"
)
file(GLOB MEMORY_SOURCES
"memory/*.cpp"
"memory/*.cc"
)
add_executable(humanus_cli
main.cpp
config.cpp

View File

@ -4,6 +4,8 @@
#include "llm.h"
#include "schema.h"
#include "logger.h"
#include "memory/base.h"
#include "memory/simple.h"
#include <memory>
#include <stdexcept>
#include <string>
@ -28,7 +30,7 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
// Dependencies
std::shared_ptr<LLM> llm; // Language model instance
std::shared_ptr<Memory> memory; // Agent's memory store
std::shared_ptr<MemoryBase> memory; // Agent's memory store
AgentState state; // Current state of the agent
// Execution control
@ -45,7 +47,7 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
const std::string& system_prompt,
const std::string& next_step_prompt,
const std::shared_ptr<LLM>& llm = nullptr,
const std::shared_ptr<Memory>& memory = nullptr,
const std::shared_ptr<MemoryBase>& memory = nullptr,
AgentState state = AgentState::IDLE,
int max_steps = 10,
int current_step = 0,
@ -69,7 +71,7 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
llm = LLM::get_instance("default");
}
if (!memory) {
memory = std::make_shared<Memory>(max_steps);
memory = std::make_shared<MemorySimple>(max_steps);
}
}

View File

@ -38,7 +38,7 @@ struct Humanus : ToolCallAgent {
const std::string& system_prompt = prompt::humanus::SYSTEM_PROMPT,
const std::string& next_step_prompt = prompt::humanus::NEXT_STEP_PROMPT,
const std::shared_ptr<LLM>& llm = nullptr,
const std::shared_ptr<Memory>& memory = nullptr,
const std::shared_ptr<MemoryBase>& memory = nullptr,
AgentState state = AgentState::IDLE,
int max_steps = 30,
int current_step = 0,

View File

@ -34,7 +34,7 @@ struct PlanningAgent : ToolCallAgent {
const std::string& system_prompt = prompt::planning::PLANNING_SYSTEM_PROMPT,
const std::string& next_step_prompt = prompt::planning::NEXT_STEP_PROMPT,
const std::shared_ptr<LLM>& llm = nullptr,
const std::shared_ptr<Memory>& memory = nullptr,
const std::shared_ptr<MemoryBase>& memory = nullptr,
AgentState state = AgentState::IDLE,
int max_steps = 20,
int current_step = 0,

View File

@ -12,7 +12,7 @@ struct ReActAgent : BaseAgent {
const std::string& system_prompt,
const std::string& next_step_prompt,
const std::shared_ptr<LLM>& llm = nullptr,
const std::shared_ptr<Memory>& memory = nullptr,
const std::shared_ptr<MemoryBase>& memory = nullptr,
AgentState state = AgentState::IDLE,
int max_steps = 10,
int current_step = 0,

View File

@ -30,7 +30,7 @@ struct SweAgent : ToolCallAgent {
const std::string& system_prompt = prompt::swe::SYSTEM_PROMPT,
const std::string& next_step_prompt = prompt::swe::NEXT_STEP_TEMPLATE,
const std::shared_ptr<LLM>& llm = nullptr,
const std::shared_ptr<Memory>& memory = nullptr,
const std::shared_ptr<MemoryBase>& memory = nullptr,
AgentState state = AgentState::IDLE,
int max_steps = 100,
int current_step = 0,

View File

@ -30,7 +30,7 @@ struct ToolCallAgent : ReActAgent {
const std::string& system_prompt = prompt::toolcall::SYSTEM_PROMPT,
const std::string& next_step_prompt = prompt::toolcall::NEXT_STEP_PROMPT,
const std::shared_ptr<LLM>& llm = nullptr,
const std::shared_ptr<Memory>& memory = nullptr,
const std::shared_ptr<MemoryBase>& memory = nullptr,
AgentState state = AgentState::IDLE,
int max_steps = 30,
int current_step = 0,

View File

@ -18,78 +18,79 @@ void Config::_load_initial_config() {
const auto& data = toml::parse_file(config_path.string());
// Load LLM configuration
for (const auto& [key, value] : data) {
const auto& llm_table = *value.as_table();
// Check if llm configuration exists
if (!data.contains("llm") || !data["llm"].is_table()) {
throw std::runtime_error("Config file does not contain `llm` table");
}
const auto& llm_table = *data["llm"].as_table();
LLMConfig llm_config;
LLMSettings llm_settings;
if (llm_table.contains("model") && llm_table["model"].is_string()) {
llm_settings.model = llm_table["model"].as_string()->get();
} else {
throw std::runtime_error("Invalid `model` configuration");
}
if (llm_table.contains("api_key") && llm_table["api_key"].is_string()) {
llm_settings.api_key = llm_table["api_key"].as_string()->get();
} else {
throw std::runtime_error("Invalid `api_key` configuration");
}
if (llm_table.contains("base_url") && llm_table["base_url"].is_string()) {
llm_settings.base_url = llm_table["base_url"].as_string()->get();
} else {
throw std::runtime_error("Invalid `base_url` configuration");
}
if (llm_table.contains("end_point") && llm_table["end_point"].is_string()) {
llm_settings.end_point = llm_table["end_point"].as_string()->get();
}
if (llm_table.contains("max_tokens") && llm_table["max_tokens"].is_integer()) {
llm_settings.max_tokens = llm_table["max_tokens"].as_integer()->get();
}
if (llm_table.contains("timeout") && llm_table["timeout"].is_integer()) {
llm_settings.timeout = llm_table["timeout"].as_integer()->get();
}
if (llm_table.contains("temperature") && llm_table["temperature"].is_floating_point()) {
llm_settings.temperature = llm_table["temperature"].as_floating_point()->get();
}
if (llm_table.contains("oai_tool_support") && llm_table["oai_tool_support"].is_boolean()) {
llm_settings.oai_tool_support = llm_table["oai_tool_support"].as_boolean()->get();
}
_config.llm["default"] = llm_settings;
// Load tool helper configurations
ToolHelper tool_helper;
if (data.contains("tool_helper") && data["tool_helper"].is_table()) {
const auto& tool_helper_table = *data["tool_helper"].as_table();
if (tool_helper_table.contains("tool_start")) {
tool_helper.tool_start = tool_helper_table["tool_start"].as_string()->get();
if (llm_table.contains("model") && llm_table["model"].is_string()) {
llm_config.model = llm_table["model"].as_string()->get();
}
if (tool_helper_table.contains("tool_end")) {
tool_helper.tool_end = tool_helper_table["tool_end"].as_string()->get();
if (llm_table.contains("api_key") && llm_table["api_key"].is_string()) {
llm_config.api_key = llm_table["api_key"].as_string()->get();
}
if (tool_helper_table.contains("tool_hint_template")) {
tool_helper.tool_hint_template = tool_helper_table["tool_hint_template"].as_string()->get();
if (llm_table.contains("base_url") && llm_table["base_url"].is_string()) {
llm_config.base_url = llm_table["base_url"].as_string()->get();
}
if (llm_table.contains("end_point") && llm_table["end_point"].is_string()) {
llm_config.end_point = llm_table["end_point"].as_string()->get();
}
if (llm_table.contains("max_tokens") && llm_table["max_tokens"].is_integer()) {
llm_config.max_tokens = llm_table["max_tokens"].as_integer()->get();
}
if (llm_table.contains("timeout") && llm_table["timeout"].is_integer()) {
llm_config.timeout = llm_table["timeout"].as_integer()->get();
}
if (llm_table.contains("temperature") && llm_table["temperature"].is_floating_point()) {
llm_config.temperature = llm_table["temperature"].as_floating_point()->get();
}
if (llm_table.contains("oai_tool_support") && llm_table["oai_tool_support"].is_boolean()) {
llm_config.oai_tool_support = llm_table["oai_tool_support"].as_boolean()->get();
}
_config.llm[std::string(key.str())] = llm_config;
if (!llm_config.oai_tool_support) {
// Load tool helper configuration
ToolHelper tool_helper;
if (llm_table.contains("tool_helper") && llm_table["tool_helper"].is_table()) {
const auto& tool_helper_table = *llm_table["tool_helper"].as_table();
if (tool_helper_table.contains("tool_start")) {
tool_helper.tool_start = tool_helper_table["tool_start"].as_string()->get();
}
if (tool_helper_table.contains("tool_end")) {
tool_helper.tool_end = tool_helper_table["tool_end"].as_string()->get();
}
if (tool_helper_table.contains("tool_hint_template")) {
tool_helper.tool_hint_template = tool_helper_table["tool_hint_template"].as_string()->get();
}
}
_config.tool_helper[std::string(key.str())] = tool_helper;
}
}
_config.tool_helper["default"] = tool_helper;
if (_config.llm.empty()) {
throw std::runtime_error("No LLM configuration found");
} else if (_config.llm.find("default") == _config.llm.end()) {
_config.llm["default"] = _config.llm.begin()->second;
}
if (_config.tool_helper.find("default") == _config.tool_helper.end()) {
_config.tool_helper["default"] = ToolHelper();
}
} catch (const std::exception& e) {
std::cerr << "Loading config file failed: " << e.what() << std::endl;
// Set default configuration
_config.llm["default"] = LLMSettings();
_config.llm["default"] = LLMConfig();
_config.tool_helper["default"] = ToolHelper();
}
}

View File

@ -21,7 +21,7 @@ static std::filesystem::path get_project_root() {
static const std::filesystem::path PROJECT_ROOT = get_project_root();
struct LLMSettings {
struct LLMConfig {
std::string model;
std::string api_key;
std::string base_url;
@ -31,7 +31,7 @@ struct LLMSettings {
double temperature;
bool oai_tool_support;
LLMSettings(
LLMConfig(
const std::string& model = "deepseek-chat",
const std::string& api_key = "sk-",
const std::string& base_url = "https://api.deepseek.com",
@ -154,7 +154,7 @@ struct ToolHelper {
};
struct AppConfig {
std::map<std::string, LLMSettings> llm;
std::map<std::string, LLMConfig> llm;
std::map<std::string, ToolHelper> tool_helper;
};
@ -179,14 +179,10 @@ private:
*/
static std::filesystem::path _get_config_path() {
auto root = PROJECT_ROOT;
auto config_path = root / "config" / "config.toml";
auto config_path = root / "config" / "config_llm.toml";
if (std::filesystem::exists(config_path)) {
return config_path;
}
auto example_path = root / "config" / "config.example.toml";
if (std::filesystem::exists(example_path)) {
return example_path;
}
throw std::runtime_error("Config file not found");
}
@ -214,7 +210,7 @@ public:
* @brief Get the LLM settings
* @return The LLM settings map
*/
const std::map<std::string, LLMSettings>& llm() const {
const std::map<std::string, LLMConfig>& llm() const {
return _config.llm;
}

View File

@ -1,7 +1,9 @@
[llm]
[default]
model = "deepseek-reasoner"
base_url = "https://api.deepseek.com"
end_point = "/v1/chat/completions"
api_key = "sk-93c5bfcb920c4a8aa345791d429b8536"
max_tokens = 8192
oai_tool_support = false
oai_tool_support = false
tool_start = "<tool_call>"
tool_end = "</tool_call>"

View File

@ -17,4 +17,12 @@ model = "qwen-max"
base_url = "https://dashscope.aliyuncs.com"
end_point = "/compatible-mode/v1/chat/completions"
api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600"
max_tokens = 8192
max_tokens = 8192
[llm]
model = "deepseek-reasoner"
base_url = "https://api.deepseek.com"
end_point = "/v1/chat/completions"
api_key = "sk-93c5bfcb920c4a8aa345791d429b8536"
max_tokens = 8192
oai_tool_support = false

8
llm.h
View File

@ -21,18 +21,18 @@ private:
std::unique_ptr<httplib::Client> client_;
std::shared_ptr<LLMSettings> llm_config_;
std::shared_ptr<LLMConfig> llm_config_;
std::shared_ptr<ToolHelper> tool_helper_;
public:
// Constructor
LLM(const std::string& config_name, const std::shared_ptr<LLMSettings>& llm_config = nullptr, const std::shared_ptr<ToolHelper>& tool_helper = nullptr) : llm_config_(llm_config), tool_helper_(tool_helper) {
LLM(const std::string& config_name, const std::shared_ptr<LLMConfig>& llm_config = nullptr, const std::shared_ptr<ToolHelper>& tool_helper = nullptr) : llm_config_(llm_config), tool_helper_(tool_helper) {
if (!llm_config_) {
if (Config::get_instance().llm().find(config_name) == Config::get_instance().llm().end()) {
throw std::invalid_argument("LLM config not found: " + config_name);
}
llm_config_ = std::make_shared<LLMSettings>(Config::get_instance().llm().at(config_name));
llm_config_ = std::make_shared<LLMConfig>(Config::get_instance().llm().at(config_name));
}
if (!llm_config_->oai_tool_support && !tool_helper_) {
if (Config::get_instance().tool_helper().find(config_name) == Config::get_instance().tool_helper().end()) {
@ -48,7 +48,7 @@ public:
}
// Get the singleton instance
static std::shared_ptr<LLM> get_instance(const std::string& config_name = "default", const std::shared_ptr<LLMSettings>& llm_config = nullptr) {
static std::shared_ptr<LLM> get_instance(const std::string& config_name = "default", const std::shared_ptr<LLMConfig>& llm_config = nullptr) {
if (_instances.find(config_name) == _instances.end()) {
_instances[config_name] = std::make_shared<LLM>(config_name, llm_config);
}

46
memory/base.h 100644
View File

@ -0,0 +1,46 @@
#ifndef HUMANUS_MEMORY_BASE_H
#define HUMANUS_MEMORY_BASE_H
#include "schema.h"
namespace humanus {
struct MemoryBase {
std::vector<Message> messages;
// Add a message to the memory
virtual void add_message(const Message& message) {
messages.push_back(message);
}
// Add multiple messages to the memory
void add_messages(const std::vector<Message>& messages) {
for (const auto& message : messages) {
add_message(message);
}
}
// Clear all messages
void clear() {
messages.clear();
}
// Get the last n messages
virtual std::vector<Message> get_recent_messages(int n) const {
n = std::min(n, static_cast<int>(messages.size()));
return std::vector<Message>(messages.end() - n, messages.end());
}
// Convert messages to list of dicts
json to_json_list() const {
json memory = json::array();
for (const auto& message : messages) {
memory.push_back(message.to_json());
}
return memory;
}
};
}
#endif // HUMANUS_MEMORY_BASE_H

34
memory/mem0.h 100644
View File

@ -0,0 +1,34 @@
#ifndef HUMANUS_MEMORY_MEM0_H
#define HUMANUS_MEMORY_MEM0_H
#include "base.h"
namespace humanus {
struct MemoryConfig {
// Database config
std::string history_db_path = ":memory:";
// Embedder config
struct {
std::string provider = "llama_cpp";
EmbedderConfig config;
} embedder;
// Vector store config
struct {
std::string provider = "hnswlib";
VectorStoreConfig config;
} vector_store;
// Optional: LLM config
struct {
std::string provider = "openai";
LLMConfig config;
} llm;
};
}
#endif // HUMANUS_MEMORY_MEM0_H

24
memory/simple.h 100644
View File

@ -0,0 +1,24 @@
#ifndef HUMANUS_MEMORY_SIMPLE_H
#define HUMANUS_MEMORY_SIMPLE_H
#include "base.h"
namespace humanus {
struct MemorySimple : MemoryBase {
int max_messages;
MemorySimple(int max_messages = 100) : max_messages(max_messages) {}
void add_message(const Message& message) override {
MemoryBase::add_message(message);
while (!messages.empty() && (messages.size() > max_messages || messages.begin()->role == "assistant" || messages.begin()->role == "tool")) {
// Ensure the first message is always a user or system message
messages.erase(messages.begin());
}
}
};
} // namespace humanus
#endif // HUMANUS_MEMORY_SIMPLE_H

View File

@ -156,49 +156,6 @@ struct Message {
}
};
struct Memory {
std::vector<Message> messages;
int max_messages;
Memory(int max_messages = 100) : max_messages(max_messages) {}
// Add a message to the memory
void add_message(const Message& message) {
messages.push_back(message);
while (!messages.empty() && (messages.size() > max_messages || messages.begin()->role == "assistant" || messages.begin()->role == "tool")) {
// Ensure the first message is always a user or system message
messages.erase(messages.begin());
}
}
// Add multiple messages to the memory
void add_messages(const std::vector<Message>& messages) {
for (const auto& message : messages) {
add_message(message);
}
}
// Clear all messages
void clear() {
messages.clear();
}
// Get the last n messages
std::vector<Message> get_recent_messages(int n) const {
n = std::min(n, static_cast<int>(messages.size()));
return std::vector<Message>(messages.end() - n, messages.end());
}
// Convert messages to list of dicts
json to_json_list() const {
json memory = json::array();
for (const auto& message : messages) {
memory.push_back(message.to_json());
}
return memory;
}
};
} // namespace humanus
#endif // HUMANUS_SCHEMA_H