started to refactor memory
parent
1d87a2f4a6
commit
d64fc0359a
|
@ -81,6 +81,11 @@ file(GLOB FLOW_SOURCES
|
|||
"flow/*.cc"
|
||||
)
|
||||
|
||||
file(GLOB MEMORY_SOURCES
|
||||
"memory/*.cpp"
|
||||
"memory/*.cc"
|
||||
)
|
||||
|
||||
add_executable(humanus_cli
|
||||
main.cpp
|
||||
config.cpp
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
#include "llm.h"
|
||||
#include "schema.h"
|
||||
#include "logger.h"
|
||||
#include "memory/base.h"
|
||||
#include "memory/simple.h"
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
@ -28,7 +30,7 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
|||
|
||||
// Dependencies
|
||||
std::shared_ptr<LLM> llm; // Language model instance
|
||||
std::shared_ptr<Memory> memory; // Agent's memory store
|
||||
std::shared_ptr<MemoryBase> memory; // Agent's memory store
|
||||
AgentState state; // Current state of the agent
|
||||
|
||||
// Execution control
|
||||
|
@ -45,7 +47,7 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
|||
const std::string& system_prompt,
|
||||
const std::string& next_step_prompt,
|
||||
const std::shared_ptr<LLM>& llm = nullptr,
|
||||
const std::shared_ptr<Memory>& memory = nullptr,
|
||||
const std::shared_ptr<MemoryBase>& memory = nullptr,
|
||||
AgentState state = AgentState::IDLE,
|
||||
int max_steps = 10,
|
||||
int current_step = 0,
|
||||
|
@ -69,7 +71,7 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
|||
llm = LLM::get_instance("default");
|
||||
}
|
||||
if (!memory) {
|
||||
memory = std::make_shared<Memory>(max_steps);
|
||||
memory = std::make_shared<MemorySimple>(max_steps);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ struct Humanus : ToolCallAgent {
|
|||
const std::string& system_prompt = prompt::humanus::SYSTEM_PROMPT,
|
||||
const std::string& next_step_prompt = prompt::humanus::NEXT_STEP_PROMPT,
|
||||
const std::shared_ptr<LLM>& llm = nullptr,
|
||||
const std::shared_ptr<Memory>& memory = nullptr,
|
||||
const std::shared_ptr<MemoryBase>& memory = nullptr,
|
||||
AgentState state = AgentState::IDLE,
|
||||
int max_steps = 30,
|
||||
int current_step = 0,
|
||||
|
|
|
@ -34,7 +34,7 @@ struct PlanningAgent : ToolCallAgent {
|
|||
const std::string& system_prompt = prompt::planning::PLANNING_SYSTEM_PROMPT,
|
||||
const std::string& next_step_prompt = prompt::planning::NEXT_STEP_PROMPT,
|
||||
const std::shared_ptr<LLM>& llm = nullptr,
|
||||
const std::shared_ptr<Memory>& memory = nullptr,
|
||||
const std::shared_ptr<MemoryBase>& memory = nullptr,
|
||||
AgentState state = AgentState::IDLE,
|
||||
int max_steps = 20,
|
||||
int current_step = 0,
|
||||
|
|
|
@ -12,7 +12,7 @@ struct ReActAgent : BaseAgent {
|
|||
const std::string& system_prompt,
|
||||
const std::string& next_step_prompt,
|
||||
const std::shared_ptr<LLM>& llm = nullptr,
|
||||
const std::shared_ptr<Memory>& memory = nullptr,
|
||||
const std::shared_ptr<MemoryBase>& memory = nullptr,
|
||||
AgentState state = AgentState::IDLE,
|
||||
int max_steps = 10,
|
||||
int current_step = 0,
|
||||
|
|
|
@ -30,7 +30,7 @@ struct SweAgent : ToolCallAgent {
|
|||
const std::string& system_prompt = prompt::swe::SYSTEM_PROMPT,
|
||||
const std::string& next_step_prompt = prompt::swe::NEXT_STEP_TEMPLATE,
|
||||
const std::shared_ptr<LLM>& llm = nullptr,
|
||||
const std::shared_ptr<Memory>& memory = nullptr,
|
||||
const std::shared_ptr<MemoryBase>& memory = nullptr,
|
||||
AgentState state = AgentState::IDLE,
|
||||
int max_steps = 100,
|
||||
int current_step = 0,
|
||||
|
|
|
@ -30,7 +30,7 @@ struct ToolCallAgent : ReActAgent {
|
|||
const std::string& system_prompt = prompt::toolcall::SYSTEM_PROMPT,
|
||||
const std::string& next_step_prompt = prompt::toolcall::NEXT_STEP_PROMPT,
|
||||
const std::shared_ptr<LLM>& llm = nullptr,
|
||||
const std::shared_ptr<Memory>& memory = nullptr,
|
||||
const std::shared_ptr<MemoryBase>& memory = nullptr,
|
||||
AgentState state = AgentState::IDLE,
|
||||
int max_steps = 30,
|
||||
int current_step = 0,
|
||||
|
|
59
config.cpp
59
config.cpp
|
@ -18,61 +18,50 @@ void Config::_load_initial_config() {
|
|||
const auto& data = toml::parse_file(config_path.string());
|
||||
|
||||
// Load LLM configuration
|
||||
for (const auto& [key, value] : data) {
|
||||
const auto& llm_table = *value.as_table();
|
||||
|
||||
// Check if llm configuration exists
|
||||
if (!data.contains("llm") || !data["llm"].is_table()) {
|
||||
throw std::runtime_error("Config file does not contain `llm` table");
|
||||
}
|
||||
|
||||
const auto& llm_table = *data["llm"].as_table();
|
||||
|
||||
LLMSettings llm_settings;
|
||||
LLMConfig llm_config;
|
||||
|
||||
if (llm_table.contains("model") && llm_table["model"].is_string()) {
|
||||
llm_settings.model = llm_table["model"].as_string()->get();
|
||||
} else {
|
||||
throw std::runtime_error("Invalid `model` configuration");
|
||||
llm_config.model = llm_table["model"].as_string()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("api_key") && llm_table["api_key"].is_string()) {
|
||||
llm_settings.api_key = llm_table["api_key"].as_string()->get();
|
||||
} else {
|
||||
throw std::runtime_error("Invalid `api_key` configuration");
|
||||
llm_config.api_key = llm_table["api_key"].as_string()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("base_url") && llm_table["base_url"].is_string()) {
|
||||
llm_settings.base_url = llm_table["base_url"].as_string()->get();
|
||||
} else {
|
||||
throw std::runtime_error("Invalid `base_url` configuration");
|
||||
llm_config.base_url = llm_table["base_url"].as_string()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("end_point") && llm_table["end_point"].is_string()) {
|
||||
llm_settings.end_point = llm_table["end_point"].as_string()->get();
|
||||
llm_config.end_point = llm_table["end_point"].as_string()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("max_tokens") && llm_table["max_tokens"].is_integer()) {
|
||||
llm_settings.max_tokens = llm_table["max_tokens"].as_integer()->get();
|
||||
llm_config.max_tokens = llm_table["max_tokens"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("timeout") && llm_table["timeout"].is_integer()) {
|
||||
llm_settings.timeout = llm_table["timeout"].as_integer()->get();
|
||||
llm_config.timeout = llm_table["timeout"].as_integer()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("temperature") && llm_table["temperature"].is_floating_point()) {
|
||||
llm_settings.temperature = llm_table["temperature"].as_floating_point()->get();
|
||||
llm_config.temperature = llm_table["temperature"].as_floating_point()->get();
|
||||
}
|
||||
|
||||
if (llm_table.contains("oai_tool_support") && llm_table["oai_tool_support"].is_boolean()) {
|
||||
llm_settings.oai_tool_support = llm_table["oai_tool_support"].as_boolean()->get();
|
||||
llm_config.oai_tool_support = llm_table["oai_tool_support"].as_boolean()->get();
|
||||
}
|
||||
|
||||
_config.llm["default"] = llm_settings;
|
||||
_config.llm[std::string(key.str())] = llm_config;
|
||||
|
||||
// Load tool helper configurations
|
||||
if (!llm_config.oai_tool_support) {
|
||||
// Load tool helper configuration
|
||||
ToolHelper tool_helper;
|
||||
if (data.contains("tool_helper") && data["tool_helper"].is_table()) {
|
||||
const auto& tool_helper_table = *data["tool_helper"].as_table();
|
||||
|
||||
if (llm_table.contains("tool_helper") && llm_table["tool_helper"].is_table()) {
|
||||
const auto& tool_helper_table = *llm_table["tool_helper"].as_table();
|
||||
if (tool_helper_table.contains("tool_start")) {
|
||||
tool_helper.tool_start = tool_helper_table["tool_start"].as_string()->get();
|
||||
}
|
||||
|
@ -85,11 +74,23 @@ void Config::_load_initial_config() {
|
|||
tool_helper.tool_hint_template = tool_helper_table["tool_hint_template"].as_string()->get();
|
||||
}
|
||||
}
|
||||
_config.tool_helper["default"] = tool_helper;
|
||||
_config.tool_helper[std::string(key.str())] = tool_helper;
|
||||
}
|
||||
}
|
||||
|
||||
if (_config.llm.empty()) {
|
||||
throw std::runtime_error("No LLM configuration found");
|
||||
} else if (_config.llm.find("default") == _config.llm.end()) {
|
||||
_config.llm["default"] = _config.llm.begin()->second;
|
||||
}
|
||||
|
||||
if (_config.tool_helper.find("default") == _config.tool_helper.end()) {
|
||||
_config.tool_helper["default"] = ToolHelper();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Loading config file failed: " << e.what() << std::endl;
|
||||
// Set default configuration
|
||||
_config.llm["default"] = LLMSettings();
|
||||
_config.llm["default"] = LLMConfig();
|
||||
_config.tool_helper["default"] = ToolHelper();
|
||||
}
|
||||
}
|
||||
|
|
14
config.h
14
config.h
|
@ -21,7 +21,7 @@ static std::filesystem::path get_project_root() {
|
|||
|
||||
static const std::filesystem::path PROJECT_ROOT = get_project_root();
|
||||
|
||||
struct LLMSettings {
|
||||
struct LLMConfig {
|
||||
std::string model;
|
||||
std::string api_key;
|
||||
std::string base_url;
|
||||
|
@ -31,7 +31,7 @@ struct LLMSettings {
|
|||
double temperature;
|
||||
bool oai_tool_support;
|
||||
|
||||
LLMSettings(
|
||||
LLMConfig(
|
||||
const std::string& model = "deepseek-chat",
|
||||
const std::string& api_key = "sk-",
|
||||
const std::string& base_url = "https://api.deepseek.com",
|
||||
|
@ -154,7 +154,7 @@ struct ToolHelper {
|
|||
};
|
||||
|
||||
struct AppConfig {
|
||||
std::map<std::string, LLMSettings> llm;
|
||||
std::map<std::string, LLMConfig> llm;
|
||||
std::map<std::string, ToolHelper> tool_helper;
|
||||
};
|
||||
|
||||
|
@ -179,14 +179,10 @@ private:
|
|||
*/
|
||||
static std::filesystem::path _get_config_path() {
|
||||
auto root = PROJECT_ROOT;
|
||||
auto config_path = root / "config" / "config.toml";
|
||||
auto config_path = root / "config" / "config_llm.toml";
|
||||
if (std::filesystem::exists(config_path)) {
|
||||
return config_path;
|
||||
}
|
||||
auto example_path = root / "config" / "config.example.toml";
|
||||
if (std::filesystem::exists(example_path)) {
|
||||
return example_path;
|
||||
}
|
||||
throw std::runtime_error("Config file not found");
|
||||
}
|
||||
|
||||
|
@ -214,7 +210,7 @@ public:
|
|||
* @brief Get the LLM settings
|
||||
* @return The LLM settings map
|
||||
*/
|
||||
const std::map<std::string, LLMSettings>& llm() const {
|
||||
const std::map<std::string, LLMConfig>& llm() const {
|
||||
return _config.llm;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
[llm]
|
||||
[default]
|
||||
model = "deepseek-reasoner"
|
||||
base_url = "https://api.deepseek.com"
|
||||
end_point = "/v1/chat/completions"
|
||||
api_key = "sk-93c5bfcb920c4a8aa345791d429b8536"
|
||||
max_tokens = 8192
|
||||
oai_tool_support = false
|
||||
tool_start = "<tool_call>"
|
||||
tool_end = "</tool_call>"
|
|
@ -18,3 +18,11 @@ base_url = "https://dashscope.aliyuncs.com"
|
|||
end_point = "/compatible-mode/v1/chat/completions"
|
||||
api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600"
|
||||
max_tokens = 8192
|
||||
|
||||
[llm]
|
||||
model = "deepseek-reasoner"
|
||||
base_url = "https://api.deepseek.com"
|
||||
end_point = "/v1/chat/completions"
|
||||
api_key = "sk-93c5bfcb920c4a8aa345791d429b8536"
|
||||
max_tokens = 8192
|
||||
oai_tool_support = false
|
8
llm.h
8
llm.h
|
@ -21,18 +21,18 @@ private:
|
|||
|
||||
std::unique_ptr<httplib::Client> client_;
|
||||
|
||||
std::shared_ptr<LLMSettings> llm_config_;
|
||||
std::shared_ptr<LLMConfig> llm_config_;
|
||||
|
||||
std::shared_ptr<ToolHelper> tool_helper_;
|
||||
|
||||
public:
|
||||
// Constructor
|
||||
LLM(const std::string& config_name, const std::shared_ptr<LLMSettings>& llm_config = nullptr, const std::shared_ptr<ToolHelper>& tool_helper = nullptr) : llm_config_(llm_config), tool_helper_(tool_helper) {
|
||||
LLM(const std::string& config_name, const std::shared_ptr<LLMConfig>& llm_config = nullptr, const std::shared_ptr<ToolHelper>& tool_helper = nullptr) : llm_config_(llm_config), tool_helper_(tool_helper) {
|
||||
if (!llm_config_) {
|
||||
if (Config::get_instance().llm().find(config_name) == Config::get_instance().llm().end()) {
|
||||
throw std::invalid_argument("LLM config not found: " + config_name);
|
||||
}
|
||||
llm_config_ = std::make_shared<LLMSettings>(Config::get_instance().llm().at(config_name));
|
||||
llm_config_ = std::make_shared<LLMConfig>(Config::get_instance().llm().at(config_name));
|
||||
}
|
||||
if (!llm_config_->oai_tool_support && !tool_helper_) {
|
||||
if (Config::get_instance().tool_helper().find(config_name) == Config::get_instance().tool_helper().end()) {
|
||||
|
@ -48,7 +48,7 @@ public:
|
|||
}
|
||||
|
||||
// Get the singleton instance
|
||||
static std::shared_ptr<LLM> get_instance(const std::string& config_name = "default", const std::shared_ptr<LLMSettings>& llm_config = nullptr) {
|
||||
static std::shared_ptr<LLM> get_instance(const std::string& config_name = "default", const std::shared_ptr<LLMConfig>& llm_config = nullptr) {
|
||||
if (_instances.find(config_name) == _instances.end()) {
|
||||
_instances[config_name] = std::make_shared<LLM>(config_name, llm_config);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
#ifndef HUMANUS_MEMORY_BASE_H
|
||||
#define HUMANUS_MEMORY_BASE_H
|
||||
|
||||
#include "schema.h"
|
||||
|
||||
namespace humanus {
|
||||
|
||||
struct MemoryBase {
|
||||
std::vector<Message> messages;
|
||||
|
||||
// Add a message to the memory
|
||||
virtual void add_message(const Message& message) {
|
||||
messages.push_back(message);
|
||||
}
|
||||
|
||||
// Add multiple messages to the memory
|
||||
void add_messages(const std::vector<Message>& messages) {
|
||||
for (const auto& message : messages) {
|
||||
add_message(message);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear all messages
|
||||
void clear() {
|
||||
messages.clear();
|
||||
}
|
||||
|
||||
// Get the last n messages
|
||||
virtual std::vector<Message> get_recent_messages(int n) const {
|
||||
n = std::min(n, static_cast<int>(messages.size()));
|
||||
return std::vector<Message>(messages.end() - n, messages.end());
|
||||
}
|
||||
|
||||
// Convert messages to list of dicts
|
||||
json to_json_list() const {
|
||||
json memory = json::array();
|
||||
for (const auto& message : messages) {
|
||||
memory.push_back(message.to_json());
|
||||
}
|
||||
return memory;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif // HUMANUS_MEMORY_BASE_H
|
|
@ -0,0 +1,34 @@
|
|||
#ifndef HUMANUS_MEMORY_MEM0_H
|
||||
#define HUMANUS_MEMORY_MEM0_H
|
||||
|
||||
#include "base.h"
|
||||
|
||||
namespace humanus {
|
||||
|
||||
struct MemoryConfig {
|
||||
// Database config
|
||||
std::string history_db_path = ":memory:";
|
||||
|
||||
// Embedder config
|
||||
struct {
|
||||
std::string provider = "llama_cpp";
|
||||
EmbedderConfig config;
|
||||
} embedder;
|
||||
|
||||
// Vector store config
|
||||
struct {
|
||||
std::string provider = "hnswlib";
|
||||
VectorStoreConfig config;
|
||||
} vector_store;
|
||||
|
||||
// Optional: LLM config
|
||||
struct {
|
||||
std::string provider = "openai";
|
||||
LLMConfig config;
|
||||
} llm;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
#endif // HUMANUS_MEMORY_MEM0_H
|
|
@ -0,0 +1,24 @@
|
|||
#ifndef HUMANUS_MEMORY_SIMPLE_H
|
||||
#define HUMANUS_MEMORY_SIMPLE_H
|
||||
|
||||
#include "base.h"
|
||||
|
||||
namespace humanus {
|
||||
|
||||
struct MemorySimple : MemoryBase {
|
||||
int max_messages;
|
||||
|
||||
MemorySimple(int max_messages = 100) : max_messages(max_messages) {}
|
||||
|
||||
void add_message(const Message& message) override {
|
||||
MemoryBase::add_message(message);
|
||||
while (!messages.empty() && (messages.size() > max_messages || messages.begin()->role == "assistant" || messages.begin()->role == "tool")) {
|
||||
// Ensure the first message is always a user or system message
|
||||
messages.erase(messages.begin());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace humanus
|
||||
|
||||
#endif // HUMANUS_MEMORY_SIMPLE_H
|
43
schema.h
43
schema.h
|
@ -156,49 +156,6 @@ struct Message {
|
|||
}
|
||||
};
|
||||
|
||||
struct Memory {
|
||||
std::vector<Message> messages;
|
||||
int max_messages;
|
||||
|
||||
Memory(int max_messages = 100) : max_messages(max_messages) {}
|
||||
|
||||
// Add a message to the memory
|
||||
void add_message(const Message& message) {
|
||||
messages.push_back(message);
|
||||
while (!messages.empty() && (messages.size() > max_messages || messages.begin()->role == "assistant" || messages.begin()->role == "tool")) {
|
||||
// Ensure the first message is always a user or system message
|
||||
messages.erase(messages.begin());
|
||||
}
|
||||
}
|
||||
|
||||
// Add multiple messages to the memory
|
||||
void add_messages(const std::vector<Message>& messages) {
|
||||
for (const auto& message : messages) {
|
||||
add_message(message);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear all messages
|
||||
void clear() {
|
||||
messages.clear();
|
||||
}
|
||||
|
||||
// Get the last n messages
|
||||
std::vector<Message> get_recent_messages(int n) const {
|
||||
n = std::min(n, static_cast<int>(messages.size()));
|
||||
return std::vector<Message>(messages.end() - n, messages.end());
|
||||
}
|
||||
|
||||
// Convert messages to list of dicts
|
||||
json to_json_list() const {
|
||||
json memory = json::array();
|
||||
for (const auto& message : messages) {
|
||||
memory.push_back(message.to_json());
|
||||
}
|
||||
return memory;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace humanus
|
||||
|
||||
#endif // HUMANUS_SCHEMA_H
|
||||
|
|
Loading…
Reference in New Issue