mem0: a workable version
parent
4258a71d7a
commit
e9544ec16f
10
agent/base.h
10
agent/base.h
|
@ -123,9 +123,11 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
||||||
|
|
||||||
results.push_back("Step " + std::to_string(current_step) + ": " + step_result);
|
results.push_back("Step " + std::to_string(current_step) + ": " + step_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (current_step >= max_steps) {
|
if (current_step >= max_steps) {
|
||||||
results.push_back("Terminated: Reached max steps (" + std::to_string(max_steps) + ")");
|
results.push_back("Terminated: Reached max steps (" + std::to_string(max_steps) + ")");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (state != AgentState::FINISHED) {
|
if (state != AgentState::FINISHED) {
|
||||||
results.push_back("Terminated: Agent state is " + agent_state_map[state]);
|
results.push_back("Terminated: Agent state is " + agent_state_map[state]);
|
||||||
} else {
|
} else {
|
||||||
|
@ -149,7 +151,9 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
||||||
* Execute a single step in the agent's workflow.
|
* Execute a single step in the agent's workflow.
|
||||||
* Must be implemented by subclasses to define specific behavior.
|
* Must be implemented by subclasses to define specific behavior.
|
||||||
*/
|
*/
|
||||||
virtual std::string step() = 0;
|
virtual std::string step() {
|
||||||
|
return "No step executed";
|
||||||
|
}
|
||||||
|
|
||||||
// Handle stuck state by adding a prompt to change strategy
|
// Handle stuck state by adding a prompt to change strategy
|
||||||
void handle_stuck_state() {
|
void handle_stuck_state() {
|
||||||
|
@ -197,10 +201,6 @@ struct BaseAgent : std::enable_shared_from_this<BaseAgent> {
|
||||||
memory->clear();
|
memory->clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_messages(const std::vector<Message>& messages) {
|
|
||||||
memory->add_messages(messages);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace humanus
|
} // namespace humanus
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
#ifndef HUMANUS_AGENT_CHATBOT_H
|
||||||
|
#define HUMANUS_AGENT_CHATBOT_H
|
||||||
|
|
||||||
|
#include "base.h"
|
||||||
|
|
||||||
|
namespace humanus {
|
||||||
|
|
||||||
|
struct Chatbot : BaseAgent {
|
||||||
|
|
||||||
|
Chatbot(
|
||||||
|
const std::string& name = "chatbot",
|
||||||
|
const std::string& description = "A chatbot agent",
|
||||||
|
const std::string& system_prompt = "You are a helpful assistant.",
|
||||||
|
const std::shared_ptr<LLM>& llm = nullptr,
|
||||||
|
const std::shared_ptr<BaseMemory>& memory = nullptr
|
||||||
|
) : BaseAgent(
|
||||||
|
name,
|
||||||
|
description,
|
||||||
|
system_prompt,
|
||||||
|
"",
|
||||||
|
llm,
|
||||||
|
memory
|
||||||
|
) {}
|
||||||
|
|
||||||
|
std::string run(const std::string& request = "") override {
|
||||||
|
if (!request.empty()) {
|
||||||
|
update_memory("user", request);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get response with tool options
|
||||||
|
auto response = llm->ask(
|
||||||
|
memory->get_messages(request),
|
||||||
|
system_prompt
|
||||||
|
);
|
||||||
|
|
||||||
|
// Update memory with response
|
||||||
|
update_memory("assistant", response);
|
||||||
|
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // HUMANUS_AGENT_CHATBOT_H
|
|
@ -6,7 +6,7 @@ namespace humanus {
|
||||||
bool ToolCallAgent::think() {
|
bool ToolCallAgent::think() {
|
||||||
// Get response with tool options
|
// Get response with tool options
|
||||||
auto response = llm->ask_tool(
|
auto response = llm->ask_tool(
|
||||||
memory->get_messages(),
|
memory->get_messages(current_request),
|
||||||
system_prompt,
|
system_prompt,
|
||||||
next_step_prompt,
|
next_step_prompt,
|
||||||
available_tools.to_params(),
|
available_tools.to_params(),
|
||||||
|
@ -81,7 +81,7 @@ std::string ToolCallAgent::act() {
|
||||||
for (const auto& tool_call : tool_calls) {
|
for (const auto& tool_call : tool_calls) {
|
||||||
auto result = execute_tool(tool_call);
|
auto result = execute_tool(tool_call);
|
||||||
logger->info(
|
logger->info(
|
||||||
"🎯 Tool '" + tool_call.function.name + "' completed its mission! Result: " + result.substr(0, 500) + (result.size() > 500 ? "..." : "")
|
"🎯 Tool `" + tool_call.function.name + "` completed its mission! Result: " + result.substr(0, 500) + (result.size() > 500 ? "..." : "")
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add tool response to memory
|
// Add tool response to memory
|
||||||
|
@ -108,7 +108,7 @@ std::string ToolCallAgent::execute_tool(ToolCall tool_call) {
|
||||||
|
|
||||||
std::string name = tool_call.function.name;
|
std::string name = tool_call.function.name;
|
||||||
if (available_tools.tools_map.find(name) == available_tools.tools_map.end()) {
|
if (available_tools.tools_map.find(name) == available_tools.tools_map.end()) {
|
||||||
return "Error: Unknown tool '" + name + "'. Please use one of the following tools: " +
|
return "Error: Unknown tool `" + name + "`. Please use one of the following tools: " +
|
||||||
std::accumulate(available_tools.tools_map.begin(), available_tools.tools_map.end(), std::string(),
|
std::accumulate(available_tools.tools_map.begin(), available_tools.tools_map.end(), std::string(),
|
||||||
[](const std::string& a, const auto& b) {
|
[](const std::string& a, const auto& b) {
|
||||||
return a + (a.empty() ? "" : ", ") + b.first;
|
return a + (a.empty() ? "" : ", ") + b.first;
|
||||||
|
@ -124,7 +124,7 @@ std::string ToolCallAgent::execute_tool(ToolCall tool_call) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute the tool
|
// Execute the tool
|
||||||
logger->info("🔧 Activating tool: '" + name + "'...");
|
logger->info("🔧 Activating tool: `" + name + "`...");
|
||||||
ToolResult result = available_tools.execute(name, args);
|
ToolResult result = available_tools.execute(name, args);
|
||||||
|
|
||||||
// Format result for display
|
// Format result for display
|
||||||
|
@ -139,11 +139,11 @@ std::string ToolCallAgent::execute_tool(ToolCall tool_call) {
|
||||||
} catch (const json::exception& /* e */) {
|
} catch (const json::exception& /* e */) {
|
||||||
std::string error_msg = "Error parsing arguments for " + name + ": Invalid JSON format";
|
std::string error_msg = "Error parsing arguments for " + name + ": Invalid JSON format";
|
||||||
logger->error(
|
logger->error(
|
||||||
"📝 Oops! The arguments for '" + name + "' don't make sense - invalid JSON"
|
"📝 Oops! The arguments for `" + name + "` don't make sense - invalid JSON"
|
||||||
);
|
);
|
||||||
return "Error: " + error_msg;
|
return "Error: " + error_msg;
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
std::string error_msg = "⚠️ Tool '" + name + "' encountered a problem: " + std::string(e.what());
|
std::string error_msg = "⚠️ Tool `" + name + "` encountered a problem: " + std::string(e.what());
|
||||||
logger->error(error_msg);
|
logger->error(error_msg);
|
||||||
return "Error: " + error_msg;
|
return "Error: " + error_msg;
|
||||||
}
|
}
|
||||||
|
@ -156,7 +156,7 @@ void ToolCallAgent::_handle_special_tool(const std::string& name, const ToolResu
|
||||||
}
|
}
|
||||||
|
|
||||||
if (_should_finish_execution(name, result, kwargs)) {
|
if (_should_finish_execution(name, result, kwargs)) {
|
||||||
logger->info("🏁 Special tool '" + name + "' has completed the task!");
|
logger->info("🏁 Special tool `" + name + "` has completed the task!");
|
||||||
state = AgentState::FINISHED;
|
state = AgentState::FINISHED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
18
config.h
18
config.h
|
@ -39,9 +39,9 @@ struct LLMConfig {
|
||||||
const std::string& base_url = "https://api.deepseek.com",
|
const std::string& base_url = "https://api.deepseek.com",
|
||||||
const std::string& endpoint = "/v1/chat/completions",
|
const std::string& endpoint = "/v1/chat/completions",
|
||||||
const std::string& vision_details = "auto",
|
const std::string& vision_details = "auto",
|
||||||
int max_tokens = 4096,
|
int max_tokens = -1, // -1 for default
|
||||||
int timeout = 120,
|
int timeout = 120,
|
||||||
double temperature = 1.0,
|
double temperature = -1, // -1 for default
|
||||||
bool enable_vision = false,
|
bool enable_vision = false,
|
||||||
bool oai_tool_support = true
|
bool oai_tool_support = true
|
||||||
) : model(model), api_key(api_key), base_url(base_url), endpoint(endpoint), vision_details(vision_details),
|
) : model(model), api_key(api_key), base_url(base_url), endpoint(endpoint), vision_details(vision_details),
|
||||||
|
@ -191,9 +191,8 @@ namespace mem0 {
|
||||||
|
|
||||||
struct MemoryConfig {
|
struct MemoryConfig {
|
||||||
// Base config
|
// Base config
|
||||||
int max_messages = 5; // Short-term memory capacity
|
int max_messages = 16; // Short-term memory capacity
|
||||||
int limit = 5; // Number of results to retrive from long-term memory
|
int retrieval_limit = 8; // Number of results to retrive from long-term memory
|
||||||
std::string filters = ""; // Filters to apply to search results
|
|
||||||
|
|
||||||
// Prompt config
|
// Prompt config
|
||||||
std::string fact_extraction_prompt = prompt::mem0::FACT_EXTRACTION_PROMPT;
|
std::string fact_extraction_prompt = prompt::mem0::FACT_EXTRACTION_PROMPT;
|
||||||
|
@ -207,19 +206,12 @@ struct MemoryConfig {
|
||||||
|
|
||||||
// Vector store config
|
// Vector store config
|
||||||
std::shared_ptr<VectorStoreConfig> vector_store_config = nullptr;
|
std::shared_ptr<VectorStoreConfig> vector_store_config = nullptr;
|
||||||
|
FilterFunc filter = nullptr; // Filter to apply to search results
|
||||||
|
|
||||||
// Optional: LLM config
|
// Optional: LLM config
|
||||||
std::shared_ptr<LLMConfig> llm_config = nullptr;
|
std::shared_ptr<LLMConfig> llm_config = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MemoryItem {
|
|
||||||
size_t id; // The unique identifier for the text data
|
|
||||||
std::string memory; // The memory deduced from the text data
|
|
||||||
std::string hash; // The hash of the memory
|
|
||||||
json metadata; // Any additional metadata associated with the memory, like 'created_at' or 'updated_at'
|
|
||||||
float score; // The score associated with the text data, used for ranking and sorting
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace mem0
|
} // namespace mem0
|
||||||
|
|
||||||
struct AppConfig {
|
struct AppConfig {
|
||||||
|
|
|
@ -1,9 +1,5 @@
|
||||||
[default]
|
[default]
|
||||||
model = "deepseek-reasoner"
|
model = "qwen-max"
|
||||||
base_url = "https://api.deepseek.com"
|
base_url = "https://dashscope.aliyuncs.com"
|
||||||
endpoint = "/v1/chat/completions"
|
endpoint = "/compatible-mode/v1/chat/completions"
|
||||||
api_key = "sk-93c5bfcb920c4a8aa345791d429b8536"
|
api_key = "sk-cb1bb2a240d84182bb93f6dd0fe03600"
|
||||||
max_tokens = 8192
|
|
||||||
oai_tool_support = false
|
|
||||||
tool_start = "<tool_call>"
|
|
||||||
tool_end = "</tool_call>"
|
|
|
@ -25,4 +25,12 @@ base_url = "https://api.deepseek.com"
|
||||||
endpoint = "/v1/chat/completions"
|
endpoint = "/v1/chat/completions"
|
||||||
api_key = "sk-93c5bfcb920c4a8aa345791d429b8536"
|
api_key = "sk-93c5bfcb920c4a8aa345791d429b8536"
|
||||||
max_tokens = 8192
|
max_tokens = 8192
|
||||||
|
oai_tool_support = false
|
||||||
|
|
||||||
|
[llm]
|
||||||
|
model = "claude-3-5-sonnet-20241022"
|
||||||
|
base_url = "https://gpt.soruxgpt.com"
|
||||||
|
endpoint = "/api/api/v1/chat/completions"
|
||||||
|
api_key = "sk-o38PVgxNjzt8bYsfruSlKq9DqoPeiOwKytlOzN7fakJ4YRDF"
|
||||||
|
max_tokens = 8192
|
||||||
oai_tool_support = false
|
oai_tool_support = false
|
|
@ -0,0 +1,12 @@
|
||||||
|
set(target humanus_chat_mem0)
|
||||||
|
|
||||||
|
add_executable(${target} chat_mem0.cpp)
|
||||||
|
|
||||||
|
# 链接到核心库
|
||||||
|
target_link_libraries(${target} PRIVATE humanus)
|
||||||
|
|
||||||
|
# 设置输出目录
|
||||||
|
set_target_properties(${target}
|
||||||
|
PROPERTIES
|
||||||
|
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
|
||||||
|
)
|
|
@ -0,0 +1,115 @@
|
||||||
|
#include "agent/chatbot.h"
|
||||||
|
#include "logger.h"
|
||||||
|
#include "prompt.h"
|
||||||
|
#include "flow/flow_factory.h"
|
||||||
|
#include "memory/mem0/base.h"
|
||||||
|
|
||||||
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||||
|
#include <signal.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#elif defined (_WIN32)
|
||||||
|
#define WIN32_LEAN_AND_MEAN
|
||||||
|
#ifndef NOMINMAX
|
||||||
|
#define NOMINMAX
|
||||||
|
#endif
|
||||||
|
#include <windows.h>
|
||||||
|
#include <signal.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
using namespace humanus;
|
||||||
|
|
||||||
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||||
|
static void sigint_handler(int signo) {
|
||||||
|
if (signo == SIGINT) {
|
||||||
|
logger->info("Interrupted by user\n");
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
static bool readline_utf8(std::string & line, bool multiline_input) {
|
||||||
|
#if defined(_WIN32)
|
||||||
|
std::wstring wline;
|
||||||
|
if (!std::getline(std::wcin, wline)) {
|
||||||
|
// Input stream is bad or EOF received
|
||||||
|
line.clear();
|
||||||
|
GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
|
||||||
|
line.resize(size_needed);
|
||||||
|
WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
|
||||||
|
#else
|
||||||
|
if (!std::getline(std::cin, line)) {
|
||||||
|
// Input stream is bad or EOF received
|
||||||
|
line.clear();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
if (!line.empty()) {
|
||||||
|
char last = line.back();
|
||||||
|
if (last == '/') { // Always return control on '/' symbol
|
||||||
|
line.pop_back();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (last == '\\') { // '\\' changes the default action
|
||||||
|
line.pop_back();
|
||||||
|
multiline_input = !multiline_input;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
line += '\n';
|
||||||
|
|
||||||
|
// By default, continue input if multiline_input is set
|
||||||
|
return multiline_input;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
|
||||||
|
// ctrl+C handling
|
||||||
|
{
|
||||||
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||||
|
struct sigaction sigint_action;
|
||||||
|
sigint_action.sa_handler = sigint_handler;
|
||||||
|
sigemptyset (&sigint_action.sa_mask);
|
||||||
|
sigint_action.sa_flags = 0;
|
||||||
|
sigaction(SIGINT, &sigint_action, NULL);
|
||||||
|
#elif defined (_WIN32)
|
||||||
|
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||||
|
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
|
||||||
|
};
|
||||||
|
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||||
|
SetConsoleCP(CP_UTF8);
|
||||||
|
SetConsoleOutputCP(CP_UTF8);
|
||||||
|
_setmode(_fileno(stdin), _O_WTEXT); // wide character input mode
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
auto memory_config = mem0::MemoryConfig();
|
||||||
|
|
||||||
|
memory_config.max_messages = 1;
|
||||||
|
memory_config.retrieval_limit = 10;
|
||||||
|
|
||||||
|
auto memory = std::make_shared<mem0::Memory>(memory_config);
|
||||||
|
memory->current_request = "Chat with the user";
|
||||||
|
|
||||||
|
Chatbot chatbot{
|
||||||
|
"chat_mem0", // name
|
||||||
|
"A chatbot agent that uses memory to remember conversation history", // description
|
||||||
|
"You are a helpful assistant.", // system_prompt
|
||||||
|
nullptr, // llm
|
||||||
|
memory // memory
|
||||||
|
};
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
std::cout << "> ";
|
||||||
|
std::string prompt;
|
||||||
|
readline_utf8(prompt, false);
|
||||||
|
if (prompt == "exit" || prompt == "exit\n") {
|
||||||
|
logger->info("Goodbye!");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
auto response = chatbot.run(prompt);
|
||||||
|
std::cout << response << std::endl;
|
||||||
|
}
|
||||||
|
}
|
|
@ -85,6 +85,8 @@ int main() {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto memory = std::make_shared<mem0::Memory>(mem0::MemoryConfig());
|
||||||
|
|
||||||
std::shared_ptr<BaseAgent> agent_ptr = std::make_shared<Humanus>(
|
std::shared_ptr<BaseAgent> agent_ptr = std::make_shared<Humanus>(
|
||||||
ToolCollection( // Add general-purpose tools to the tool collection
|
ToolCollection( // Add general-purpose tools to the tool collection
|
||||||
{
|
{
|
||||||
|
@ -101,7 +103,7 @@ int main() {
|
||||||
prompt::humanus::SYSTEM_PROMPT,
|
prompt::humanus::SYSTEM_PROMPT,
|
||||||
prompt::humanus::NEXT_STEP_PROMPT,
|
prompt::humanus::NEXT_STEP_PROMPT,
|
||||||
nullptr,
|
nullptr,
|
||||||
std::make_shared<mem0::Memory>(mem0::MemoryConfig())
|
memory
|
||||||
);
|
);
|
||||||
|
|
||||||
std::map<std::string, std::shared_ptr<BaseAgent>> agents;
|
std::map<std::string, std::shared_ptr<BaseAgent>> agents;
|
||||||
|
@ -139,6 +141,7 @@ int main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::cout << "Processing your request..." << std::endl;
|
std::cout << "Processing your request..." << std::endl;
|
||||||
|
memory->current_request = prompt;
|
||||||
auto result = flow->execute(prompt);
|
auto result = flow->execute(prompt);
|
||||||
std::cout << result << std::endl;
|
std::cout << result << std::endl;
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,7 +62,7 @@ std::string PlanningFlow::execute(const std::string& input) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refactor memory
|
// Refactor memory
|
||||||
std::string prefix_sum = _summarize_plan(executor->memory->get_messages());
|
std::string prefix_sum = _summarize_plan(executor->memory->get_messages(step_result));
|
||||||
executor->reset(true); // TODO: More fine-grained memory reset?
|
executor->reset(true); // TODO: More fine-grained memory reset?
|
||||||
executor->update_memory("assistant", prefix_sum);
|
executor->update_memory("assistant", prefix_sum);
|
||||||
if (!input.empty()) {
|
if (!input.empty()) {
|
||||||
|
|
313
llm.cpp
313
llm.cpp
|
@ -1,6 +1,313 @@
|
||||||
#include "llm.h"
|
#include "llm.h"
|
||||||
|
|
||||||
namespace humanus {
|
namespace humanus {
|
||||||
// 定义静态成员变量
|
|
||||||
std::unordered_map<std::string, std::shared_ptr<LLM>> LLM::instances_;
|
std::unordered_map<std::string, std::shared_ptr<LLM>> LLM::instances_;
|
||||||
}
|
|
||||||
|
/**
|
||||||
|
* @brief Format the message list to the format that LLM can accept
|
||||||
|
* @param messages Message object message list
|
||||||
|
* @return The formatted message list
|
||||||
|
* @throws std::invalid_argument If the message format is invalid or missing necessary fields
|
||||||
|
* @throws std::runtime_error If the message type is not supported
|
||||||
|
*/
|
||||||
|
json LLM::format_messages(const std::vector<Message>& messages) {
|
||||||
|
json formatted_messages = json::array();
|
||||||
|
|
||||||
|
auto concat_content = [](const json& lhs, const json& rhs) -> json {
|
||||||
|
if (lhs.is_string() && rhs.is_string()) {
|
||||||
|
return lhs.get<std::string>() + "\n" + rhs.get<std::string>(); // Maybe other delimiter?
|
||||||
|
}
|
||||||
|
json res = json::array();
|
||||||
|
if (lhs.is_string()) {
|
||||||
|
res.push_back({
|
||||||
|
{"type", "text"},
|
||||||
|
{"text", lhs.get<std::string>()}
|
||||||
|
});
|
||||||
|
} else if (lhs.is_array()) {
|
||||||
|
res.insert(res.end(), lhs.begin(), lhs.end());
|
||||||
|
}
|
||||||
|
if (rhs.is_string()) {
|
||||||
|
res.push_back({
|
||||||
|
{"type", "text"},
|
||||||
|
{"text", rhs.get<std::string>()}
|
||||||
|
});
|
||||||
|
} else if (rhs.is_array()) {
|
||||||
|
res.insert(res.end(), rhs.begin(), rhs.end());
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (const auto& message : messages) {
|
||||||
|
if (message.content.empty() && message.tool_calls.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
formatted_messages.push_back(message.to_json());
|
||||||
|
if (!llm_config_->oai_tool_support) {
|
||||||
|
if (formatted_messages.back()["role"] == "tool") {
|
||||||
|
std::string tool_results_str = formatted_messages.back().dump(2);
|
||||||
|
formatted_messages.back() = {
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", tool_results_str}
|
||||||
|
};
|
||||||
|
} else if (!formatted_messages.back()["tool_calls"].empty()) {
|
||||||
|
if (formatted_messages.back()["content"].is_null()) {
|
||||||
|
formatted_messages.back()["content"] = "";
|
||||||
|
}
|
||||||
|
std::string tool_calls_str = tool_parser_->dump(formatted_messages.back()["tool_calls"]);
|
||||||
|
formatted_messages.back().erase("tool_calls");
|
||||||
|
formatted_messages.back()["content"] = concat_content(formatted_messages.back()["content"], tool_calls_str);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& message : formatted_messages) {
|
||||||
|
if (message["role"] != "user" && message["role"] != "assistant" && message["role"] != "system" && message["role"] != "tool") {
|
||||||
|
throw std::invalid_argument("Invalid role: " + message["role"].get<std::string>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t i = 0, j = -1;
|
||||||
|
for (; i < formatted_messages.size(); i++) {
|
||||||
|
if (i == 0 || formatted_messages[i]["role"] != formatted_messages[j]["role"]) {
|
||||||
|
formatted_messages[++j] = formatted_messages[i];
|
||||||
|
} else {
|
||||||
|
formatted_messages[j]["content"] = concat_content(formatted_messages[j]["content"], formatted_messages[i]["content"]);
|
||||||
|
if (!formatted_messages[i]["tool_calls"].empty()) {
|
||||||
|
formatted_messages[j]["tool_calls"] = concat_content(formatted_messages[j]["tool_calls"], formatted_messages[i]["tool_calls"]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
formatted_messages.erase(formatted_messages.begin() + j + 1, formatted_messages.end());
|
||||||
|
|
||||||
|
return formatted_messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string LLM::ask(
|
||||||
|
const std::vector<Message>& messages,
|
||||||
|
const std::string& system_prompt,
|
||||||
|
const std::string& next_step_prompt,
|
||||||
|
int max_retries
|
||||||
|
) {
|
||||||
|
json formatted_messages = json::array();
|
||||||
|
|
||||||
|
if (!system_prompt.empty()) {
|
||||||
|
formatted_messages.push_back({
|
||||||
|
{"role", "system"},
|
||||||
|
{"content", system_prompt}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
json _formatted_messages = format_messages(messages);
|
||||||
|
formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end());
|
||||||
|
|
||||||
|
if (!next_step_prompt.empty()) {
|
||||||
|
if (formatted_messages.empty() || formatted_messages.back()["role"] != "user") {
|
||||||
|
formatted_messages.push_back({
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", next_step_prompt}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
if (formatted_messages.back()["content"].is_string()) {
|
||||||
|
formatted_messages.back()["content"] = formatted_messages.back()["content"].get<std::string>() + "\n\n" + next_step_prompt;
|
||||||
|
} else if (formatted_messages.back()["content"].is_array()) {
|
||||||
|
formatted_messages.back()["content"].push_back({
|
||||||
|
{"type", "text"},
|
||||||
|
{"text", next_step_prompt}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
json body = {
|
||||||
|
{"model", llm_config_->model},
|
||||||
|
{"messages", formatted_messages}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (llm_config_->temperature > 0) {
|
||||||
|
body["temperature"] = llm_config_->temperature;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llm_config_->max_tokens > 0) {
|
||||||
|
body["max_tokens"] = llm_config_->max_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string body_str = body.dump();
|
||||||
|
|
||||||
|
int retry = 0;
|
||||||
|
|
||||||
|
while (retry <= max_retries) {
|
||||||
|
// send request
|
||||||
|
auto res = client_->Post(llm_config_->endpoint, body_str, "application/json");
|
||||||
|
|
||||||
|
if (!res) {
|
||||||
|
logger->error("Failed to send request: " + httplib::to_string(res.error()));
|
||||||
|
} else if (res->status == 200) {
|
||||||
|
try {
|
||||||
|
json json_data = json::parse(res->body);
|
||||||
|
return json_data["choices"][0]["message"]["content"].get<std::string>();
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
logger->error("Failed to parse response: " + std::string(e.what()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
|
||||||
|
}
|
||||||
|
|
||||||
|
retry++;
|
||||||
|
|
||||||
|
if (retry > max_retries) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for a while before retrying
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||||
|
|
||||||
|
logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries));
|
||||||
|
}
|
||||||
|
|
||||||
|
throw std::runtime_error("Failed to get response from LLM");
|
||||||
|
}
|
||||||
|
|
||||||
|
json LLM::ask_tool(
|
||||||
|
const std::vector<Message>& messages,
|
||||||
|
const std::string& system_prompt,
|
||||||
|
const std::string& next_step_prompt,
|
||||||
|
const json& tools,
|
||||||
|
const std::string& tool_choice,
|
||||||
|
int max_retries
|
||||||
|
) {
|
||||||
|
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
|
||||||
|
throw std::invalid_argument("Invalid tool_choice: " + tool_choice);
|
||||||
|
}
|
||||||
|
|
||||||
|
json formatted_messages = json::array();
|
||||||
|
|
||||||
|
if (!system_prompt.empty()) {
|
||||||
|
formatted_messages.push_back({
|
||||||
|
{"role", "system"},
|
||||||
|
{"content", system_prompt}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
json _formatted_messages = format_messages(messages);
|
||||||
|
formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end());
|
||||||
|
|
||||||
|
if (!next_step_prompt.empty()) {
|
||||||
|
if (formatted_messages.empty() || formatted_messages.back()["role"] != "user") {
|
||||||
|
formatted_messages.push_back({
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", next_step_prompt}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
if (formatted_messages.back()["content"].is_string()) {
|
||||||
|
formatted_messages.back()["content"] = formatted_messages.back()["content"].get<std::string>() + "\n\n" + next_step_prompt;
|
||||||
|
} else if (formatted_messages.back()["content"].is_array()) {
|
||||||
|
formatted_messages.back()["content"].push_back({
|
||||||
|
{"type", "text"},
|
||||||
|
{"text", next_step_prompt}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tools.empty()) {
|
||||||
|
for (const json& tool : tools) {
|
||||||
|
if (!tool.contains("type")) {
|
||||||
|
throw std::invalid_argument("Tool must contain 'type' field but got: " + tool.dump(2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tool_choice == "required" && tools.empty()) {
|
||||||
|
throw std::invalid_argument("No tool available for required tool choice");
|
||||||
|
}
|
||||||
|
if (!tools.is_array()) {
|
||||||
|
throw std::invalid_argument("Tools must be an array");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
json body = {
|
||||||
|
{"model", llm_config_->model},
|
||||||
|
{"messages", formatted_messages}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (llm_config_->temperature > 0) {
|
||||||
|
body["temperature"] = llm_config_->temperature;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llm_config_->max_tokens > 0) {
|
||||||
|
body["max_tokens"] = llm_config_->max_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llm_config_->oai_tool_support) {
|
||||||
|
body["tools"] = tools;
|
||||||
|
body["tool_choice"] = tool_choice;
|
||||||
|
} else {
|
||||||
|
if (body["messages"].empty() || body["messages"].back()["role"] != "user") {
|
||||||
|
body["messages"].push_back({
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", tool_parser_->hint(tools.dump(2))}
|
||||||
|
});
|
||||||
|
} else if (body["messages"].back()["content"].is_string()) {
|
||||||
|
body["messages"].back()["content"] = body["messages"].back()["content"].get<std::string>() + "\n\n" + tool_parser_->hint(tools.dump(2));
|
||||||
|
} else if (body["messages"].back()["content"].is_array()) {
|
||||||
|
body["messages"].back()["content"].push_back({
|
||||||
|
{"type", "text"},
|
||||||
|
{"text", tool_parser_->hint(tools.dump(2))}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string body_str = body.dump();
|
||||||
|
|
||||||
|
int retry = 0;
|
||||||
|
|
||||||
|
while (retry <= max_retries) {
|
||||||
|
// send request
|
||||||
|
auto res = client_->Post(llm_config_->endpoint, body_str, "application/json");
|
||||||
|
|
||||||
|
if (!res) {
|
||||||
|
logger->error("Failed to send request: " + httplib::to_string(res.error()));
|
||||||
|
} else if (res->status == 200) {
|
||||||
|
try {
|
||||||
|
json json_data = json::parse(res->body);
|
||||||
|
json message = json_data["choices"][0]["message"];
|
||||||
|
if (!llm_config_->oai_tool_support && message["content"].is_string()) {
|
||||||
|
message = tool_parser_->parse(message["content"].get<std::string>());
|
||||||
|
}
|
||||||
|
return message;
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
logger->error("Failed to parse response: " + std::string(e.what()));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
|
||||||
|
}
|
||||||
|
|
||||||
|
retry++;
|
||||||
|
|
||||||
|
if (retry > max_retries) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for a while before retrying
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||||
|
|
||||||
|
logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries));
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the logger has a file sink, log the request body
|
||||||
|
if (logger->sinks().size() > 1) {
|
||||||
|
auto file_sink = std::dynamic_pointer_cast<spdlog::sinks::basic_file_sink_mt>(logger->sinks()[1]);
|
||||||
|
if (file_sink) {
|
||||||
|
file_sink->log(spdlog::details::log_msg(
|
||||||
|
spdlog::source_loc{},
|
||||||
|
logger->name(),
|
||||||
|
spdlog::level::debug,
|
||||||
|
"Failed to get response from LLM. Full request body: " + body_str
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw std::runtime_error("Failed to get response from LLM");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace humanus
|
284
llm.h
284
llm.h
|
@ -55,6 +55,14 @@ public:
|
||||||
}
|
}
|
||||||
return instances_[config_name];
|
return instances_[config_name];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool enable_vision() const {
|
||||||
|
return llm_config_->enable_vision;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string vision_details() const {
|
||||||
|
return llm_config_->vision_details;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Format the message list to the format that LLM can accept
|
* @brief Format the message list to the format that LLM can accept
|
||||||
|
@ -63,78 +71,7 @@ public:
|
||||||
* @throws std::invalid_argument If the message format is invalid or missing necessary fields
|
* @throws std::invalid_argument If the message format is invalid or missing necessary fields
|
||||||
* @throws std::runtime_error If the message type is not supported
|
* @throws std::runtime_error If the message type is not supported
|
||||||
*/
|
*/
|
||||||
json format_messages(const std::vector<Message>& messages) {
|
json format_messages(const std::vector<Message>& messages);
|
||||||
json formatted_messages = json::array();
|
|
||||||
|
|
||||||
auto concat_content = [](const json& lhs, const json& rhs) -> json {
|
|
||||||
if (lhs.is_string() && rhs.is_string()) {
|
|
||||||
return lhs.get<std::string>() + "\n" + rhs.get<std::string>(); // Maybe other delimiter?
|
|
||||||
}
|
|
||||||
json res = json::array();
|
|
||||||
if (lhs.is_string()) {
|
|
||||||
res.push_back({
|
|
||||||
{"type", "text"},
|
|
||||||
{"text", lhs.get<std::string>()}
|
|
||||||
});
|
|
||||||
} else if (lhs.is_array()) {
|
|
||||||
res.insert(res.end(), lhs.begin(), lhs.end());
|
|
||||||
}
|
|
||||||
if (rhs.is_string()) {
|
|
||||||
res.push_back({
|
|
||||||
{"type", "text"},
|
|
||||||
{"text", rhs.get<std::string>()}
|
|
||||||
});
|
|
||||||
} else if (rhs.is_array()) {
|
|
||||||
res.insert(res.end(), rhs.begin(), rhs.end());
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
};
|
|
||||||
|
|
||||||
for (const auto& message : messages) {
|
|
||||||
if (message.content.empty() && message.tool_calls.empty()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
formatted_messages.push_back(message.to_json());
|
|
||||||
if (!llm_config_->oai_tool_support) {
|
|
||||||
if (formatted_messages.back()["role"] == "tool") {
|
|
||||||
std::string tool_results_str = formatted_messages.back().dump(2);
|
|
||||||
formatted_messages.back() = {
|
|
||||||
{"role", "user"},
|
|
||||||
{"content", tool_results_str}
|
|
||||||
};
|
|
||||||
} else if (!formatted_messages.back()["tool_calls"].empty()) {
|
|
||||||
if (formatted_messages.back()["content"].is_null()) {
|
|
||||||
formatted_messages.back()["content"] = "";
|
|
||||||
}
|
|
||||||
std::string tool_calls_str = tool_parser_->dump(formatted_messages.back()["tool_calls"]);
|
|
||||||
formatted_messages.back().erase("tool_calls");
|
|
||||||
formatted_messages.back()["content"] = concat_content(formatted_messages.back()["content"], tool_calls_str);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto& message : formatted_messages) {
|
|
||||||
if (message["role"] != "user" && message["role"] != "assistant" && message["role"] != "system" && message["role"] != "tool") {
|
|
||||||
throw std::invalid_argument("Invalid role: " + message["role"].get<std::string>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t i = 0, j = -1;
|
|
||||||
for (; i < formatted_messages.size(); i++) {
|
|
||||||
if (i == 0 || formatted_messages[i]["role"] != formatted_messages[j]["role"]) {
|
|
||||||
formatted_messages[++j] = formatted_messages[i];
|
|
||||||
} else {
|
|
||||||
formatted_messages[j]["content"] = concat_content(formatted_messages[j]["content"], formatted_messages[i]["content"]);
|
|
||||||
if (!formatted_messages[i]["tool_calls"].empty()) {
|
|
||||||
formatted_messages[j]["tool_calls"] = concat_content(formatted_messages[j]["tool_calls"], formatted_messages[i]["tool_calls"]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted_messages.erase(formatted_messages.begin() + j + 1, formatted_messages.end());
|
|
||||||
|
|
||||||
return formatted_messages;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Send a request to the LLM and get the reply
|
* @brief Send a request to the LLM and get the reply
|
||||||
|
@ -151,86 +88,13 @@ public:
|
||||||
const std::string& system_prompt = "",
|
const std::string& system_prompt = "",
|
||||||
const std::string& next_step_prompt = "",
|
const std::string& next_step_prompt = "",
|
||||||
int max_retries = 3
|
int max_retries = 3
|
||||||
) {
|
);
|
||||||
json formatted_messages = json::array();
|
|
||||||
|
|
||||||
if (!system_prompt.empty()) {
|
|
||||||
formatted_messages.push_back({
|
|
||||||
{"role", "system"},
|
|
||||||
{"content", system_prompt}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
json _formatted_messages = format_messages(messages);
|
|
||||||
formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end());
|
|
||||||
|
|
||||||
if (!next_step_prompt.empty()) {
|
|
||||||
if (formatted_messages.empty() || formatted_messages.back()["role"] != "user") {
|
|
||||||
formatted_messages.push_back({
|
|
||||||
{"role", "user"},
|
|
||||||
{"content", next_step_prompt}
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
if (formatted_messages.back()["content"].is_string()) {
|
|
||||||
formatted_messages.back()["content"] = formatted_messages.back()["content"].get<std::string>() + "\n\n" + next_step_prompt;
|
|
||||||
} else if (formatted_messages.back()["content"].is_array()) {
|
|
||||||
formatted_messages.back()["content"].push_back({
|
|
||||||
{"type", "text"},
|
|
||||||
{"text", next_step_prompt}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
json body = {
|
|
||||||
{"model", llm_config_->model},
|
|
||||||
{"messages", formatted_messages},
|
|
||||||
{"temperature", llm_config_->temperature},
|
|
||||||
{"max_tokens", llm_config_->max_tokens}
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string body_str = body.dump();
|
|
||||||
|
|
||||||
int retry = 0;
|
|
||||||
|
|
||||||
while (retry <= max_retries) {
|
|
||||||
// send request
|
|
||||||
auto res = client_->Post(llm_config_->endpoint, body_str, "application/json");
|
|
||||||
|
|
||||||
if (!res) {
|
|
||||||
logger->error("Failed to send request: " + httplib::to_string(res.error()));
|
|
||||||
} else if (res->status == 200) {
|
|
||||||
try {
|
|
||||||
json json_data = json::parse(res->body);
|
|
||||||
return json_data["choices"][0]["message"]["content"].get<std::string>();
|
|
||||||
} catch (const std::exception& e) {
|
|
||||||
logger->error("Failed to parse response: " + std::string(e.what()));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
|
|
||||||
}
|
|
||||||
|
|
||||||
retry++;
|
|
||||||
|
|
||||||
if (retry > max_retries) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait for a while before retrying
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
|
||||||
|
|
||||||
logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries));
|
|
||||||
}
|
|
||||||
|
|
||||||
throw std::runtime_error("Failed to get response from LLM");
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Send a request to the LLM with tool functions
|
* @brief Send a request to the LLM with tool functions
|
||||||
* @param messages The conversation message list
|
* @param messages The conversation message list
|
||||||
* @param system_prompt Optional system message
|
* @param system_prompt Optional system message
|
||||||
* @param next_step_prompt Optinonal prompt message for the next step
|
* @param next_step_prompt Optinonal prompt message for the next step
|
||||||
* @param timeout The request timeout (seconds)
|
|
||||||
* @param tools The tool list
|
* @param tools The tool list
|
||||||
* @param tool_choice The tool choice strategy
|
* @param tool_choice The tool choice strategy
|
||||||
* @param max_retries The maximum number of retries
|
* @param max_retries The maximum number of retries
|
||||||
|
@ -245,133 +109,7 @@ public:
|
||||||
const json& tools = {},
|
const json& tools = {},
|
||||||
const std::string& tool_choice = "auto",
|
const std::string& tool_choice = "auto",
|
||||||
int max_retries = 3
|
int max_retries = 3
|
||||||
) {
|
);
|
||||||
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
|
|
||||||
throw std::invalid_argument("Invalid tool_choice: " + tool_choice);
|
|
||||||
}
|
|
||||||
|
|
||||||
json formatted_messages = json::array();
|
|
||||||
|
|
||||||
if (!system_prompt.empty()) {
|
|
||||||
formatted_messages.push_back({
|
|
||||||
{"role", "system"},
|
|
||||||
{"content", system_prompt}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
json _formatted_messages = format_messages(messages);
|
|
||||||
formatted_messages.insert(formatted_messages.end(), _formatted_messages.begin(), _formatted_messages.end());
|
|
||||||
|
|
||||||
if (!next_step_prompt.empty()) {
|
|
||||||
if (formatted_messages.empty() || formatted_messages.back()["role"] != "user") {
|
|
||||||
formatted_messages.push_back({
|
|
||||||
{"role", "user"},
|
|
||||||
{"content", next_step_prompt}
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
if (formatted_messages.back()["content"].is_string()) {
|
|
||||||
formatted_messages.back()["content"] = formatted_messages.back()["content"].get<std::string>() + "\n\n" + next_step_prompt;
|
|
||||||
} else if (formatted_messages.back()["content"].is_array()) {
|
|
||||||
formatted_messages.back()["content"].push_back({
|
|
||||||
{"type", "text"},
|
|
||||||
{"text", next_step_prompt}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!tools.empty()) {
|
|
||||||
for (const json& tool : tools) {
|
|
||||||
if (!tool.contains("type")) {
|
|
||||||
throw std::invalid_argument("Tool must contain 'type' field but got: " + tool.dump(2));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (tool_choice == "required" && tools.empty()) {
|
|
||||||
throw std::invalid_argument("No tool available for required tool choice");
|
|
||||||
}
|
|
||||||
if (!tools.is_array()) {
|
|
||||||
throw std::invalid_argument("Tools must be an array");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
json body = {
|
|
||||||
{"model", llm_config_->model},
|
|
||||||
{"messages", formatted_messages},
|
|
||||||
{"temperature", llm_config_->temperature},
|
|
||||||
{"max_tokens", llm_config_->max_tokens},
|
|
||||||
{"tool_choice", tool_choice}
|
|
||||||
};
|
|
||||||
|
|
||||||
if (llm_config_->oai_tool_support) {
|
|
||||||
body["tools"] = tools;
|
|
||||||
} else {
|
|
||||||
if (body["messages"].empty() || body["messages"].back()["role"] != "user") {
|
|
||||||
body["messages"].push_back({
|
|
||||||
{"role", "user"},
|
|
||||||
{"content", tool_parser_->hint(tools.dump(2))}
|
|
||||||
});
|
|
||||||
} else if (body["messages"].back()["content"].is_string()) {
|
|
||||||
body["messages"].back()["content"] = body["messages"].back()["content"].get<std::string>() + "\n\n" + tool_parser_->hint(tools.dump(2));
|
|
||||||
} else if (body["messages"].back()["content"].is_array()) {
|
|
||||||
body["messages"].back()["content"].push_back({
|
|
||||||
{"type", "text"},
|
|
||||||
{"text", tool_parser_->hint(tools.dump(2))}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string body_str = body.dump();
|
|
||||||
|
|
||||||
int retry = 0;
|
|
||||||
|
|
||||||
while (retry <= max_retries) {
|
|
||||||
// send request
|
|
||||||
auto res = client_->Post(llm_config_->endpoint, body_str, "application/json");
|
|
||||||
|
|
||||||
if (!res) {
|
|
||||||
logger->error("Failed to send request: " + httplib::to_string(res.error()));
|
|
||||||
} else if (res->status == 200) {
|
|
||||||
try {
|
|
||||||
json json_data = json::parse(res->body);
|
|
||||||
json message = json_data["choices"][0]["message"];
|
|
||||||
if (!llm_config_->oai_tool_support && message["content"].is_string()) {
|
|
||||||
message = tool_parser_->parse(message["content"].get<std::string>());
|
|
||||||
}
|
|
||||||
return message;
|
|
||||||
} catch (const std::exception& e) {
|
|
||||||
logger->error("Failed to parse response: " + std::string(e.what()));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger->error("Failed to send request: status=" + std::to_string(res->status) + ", body=" + res->body);
|
|
||||||
}
|
|
||||||
|
|
||||||
retry++;
|
|
||||||
|
|
||||||
if (retry > max_retries) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// wait for a while before retrying
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
|
||||||
|
|
||||||
logger->info("Retrying " + std::to_string(retry) + "/" + std::to_string(max_retries));
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the logger has a file sink, log the request body
|
|
||||||
if (logger->sinks().size() > 1) {
|
|
||||||
auto file_sink = std::dynamic_pointer_cast<spdlog::sinks::basic_file_sink_mt>(logger->sinks()[1]);
|
|
||||||
if (file_sink) {
|
|
||||||
file_sink->log(spdlog::details::log_msg(
|
|
||||||
spdlog::source_loc{},
|
|
||||||
logger->name(),
|
|
||||||
spdlog::level::debug,
|
|
||||||
"Failed to get response from LLM. Full request body: " + body_str
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
throw std::runtime_error("Failed to get response from LLM");
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace humanus
|
} // namespace humanus
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
namespace humanus {
|
namespace humanus {
|
||||||
|
|
||||||
struct BaseMemory {
|
struct BaseMemory {
|
||||||
std::vector<Message> messages;
|
std::deque<Message> messages;
|
||||||
|
|
||||||
// Add a message to the memory
|
// Add a message to the memory
|
||||||
virtual void add_message(const Message& message) {
|
virtual void add_message(const Message& message) {
|
||||||
|
@ -21,18 +21,16 @@ struct BaseMemory {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear all messages
|
// Clear all messages
|
||||||
void clear() {
|
virtual void clear() {
|
||||||
messages.clear();
|
messages.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual std::vector<Message> get_messages(const std::string& query = "") const {
|
virtual std::vector<Message> get_messages(const std::string& query = "") const {
|
||||||
return messages;
|
std::vector<Message> result;
|
||||||
}
|
for (const auto& message : messages) {
|
||||||
|
result.push_back(message);
|
||||||
// Get the last n messages
|
}
|
||||||
virtual std::vector<Message> get_recent_messages(int n) const {
|
return result;
|
||||||
n = std::min(n, static_cast<int>(messages.size()));
|
|
||||||
return std::vector<Message>(messages.end() - n, messages.end());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert messages to list of dicts
|
// Convert messages to list of dicts
|
||||||
|
@ -48,13 +46,13 @@ struct BaseMemory {
|
||||||
struct Memory : BaseMemory {
|
struct Memory : BaseMemory {
|
||||||
int max_messages;
|
int max_messages;
|
||||||
|
|
||||||
Memory(int max_messages = 100) : max_messages(max_messages) {}
|
Memory(int max_messages = 30) : max_messages(max_messages) {}
|
||||||
|
|
||||||
void add_message(const Message& message) override {
|
void add_message(const Message& message) override {
|
||||||
BaseMemory::add_message(message);
|
messages.push_back(message);
|
||||||
while (!messages.empty() && (messages.size() > max_messages || messages.begin()->role == "assistant" || messages.begin()->role == "tool")) {
|
while (!messages.empty() && (messages.size() > max_messages || messages.front().role == "assistant" || messages.front().role == "tool")) {
|
||||||
// Ensure the first message is always a user or system message
|
// Ensure the first message is always a user or system message
|
||||||
messages.erase(messages.begin());
|
messages.pop_front();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -9,112 +9,196 @@
|
||||||
#include "httplib.h"
|
#include "httplib.h"
|
||||||
#include "llm.h"
|
#include "llm.h"
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
|
#include "tool/fact_extract.h"
|
||||||
|
|
||||||
namespace humanus::mem0 {
|
namespace humanus::mem0 {
|
||||||
|
|
||||||
struct Memory : BaseMemory {
|
struct Memory : BaseMemory {
|
||||||
MemoryConfig config;
|
MemoryConfig config;
|
||||||
|
|
||||||
|
std::string current_request;
|
||||||
|
|
||||||
std::string fact_extraction_prompt;
|
std::string fact_extraction_prompt;
|
||||||
std::string update_memory_prompt;
|
std::string update_memory_prompt;
|
||||||
int max_messages;
|
int max_messages;
|
||||||
int limit;
|
int retrieval_limit;
|
||||||
std::string filters;
|
FilterFunc filter;
|
||||||
|
|
||||||
std::shared_ptr<EmbeddingModel> embedding_model;
|
std::shared_ptr<EmbeddingModel> embedding_model;
|
||||||
std::shared_ptr<VectorStore> vector_store;
|
std::shared_ptr<VectorStore> vector_store;
|
||||||
std::shared_ptr<LLM> llm;
|
std::shared_ptr<LLM> llm;
|
||||||
// std::shared_ptr<SQLiteManager> db;
|
// std::shared_ptr<SQLiteManager> db;
|
||||||
|
|
||||||
|
std::shared_ptr<FactExtract> fact_extract_tool;
|
||||||
|
|
||||||
Memory(const MemoryConfig& config) : config(config) {
|
Memory(const MemoryConfig& config) : config(config) {
|
||||||
fact_extraction_prompt = config.fact_extraction_prompt;
|
fact_extraction_prompt = config.fact_extraction_prompt;
|
||||||
update_memory_prompt = config.update_memory_prompt;
|
update_memory_prompt = config.update_memory_prompt;
|
||||||
max_messages = config.max_messages;
|
max_messages = config.max_messages;
|
||||||
limit = config.limit;
|
retrieval_limit = config.retrieval_limit;
|
||||||
filters = config.filters;
|
filter = config.filter;
|
||||||
|
|
||||||
embedding_model = EmbeddingModel::get_instance("mem0_" + std::to_string(reinterpret_cast<uintptr_t>(this)), config.embedding_model_config);
|
size_t pos = fact_extraction_prompt.find("{current_date}");
|
||||||
vector_store = VectorStore::get_instance("mem0_" + std::to_string(reinterpret_cast<uintptr_t>(this)), config.vector_store_config);
|
if (pos != std::string::npos) {
|
||||||
|
// %Y-%d-%m
|
||||||
|
auto current_date = std::chrono::system_clock::now();
|
||||||
|
auto in_time_t = std::chrono::system_clock::to_time_t(current_date);
|
||||||
|
std::stringstream ss;
|
||||||
|
std::tm tm_info = *std::localtime(&in_time_t);
|
||||||
|
ss << std::put_time(&tm_info, "%Y-%m-%d");
|
||||||
|
std::string formatted_date = ss.str(); // YYYY-MM-DD
|
||||||
|
fact_extraction_prompt.replace(pos, 14, formatted_date);
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding_model = EmbeddingModel::get_instance("default", config.embedding_model_config);
|
||||||
|
vector_store = VectorStore::get_instance("default", config.vector_store_config);
|
||||||
|
|
||||||
llm = LLM::get_instance("mem0_" + std::to_string(reinterpret_cast<uintptr_t>(this)), config.llm_config);
|
llm = LLM::get_instance("default", config.llm_config);
|
||||||
// db = std::make_shared<SQLiteManager>(config.history_db_path);
|
// db = std::make_shared<SQLiteManager>(config.history_db_path);
|
||||||
|
|
||||||
|
fact_extract_tool = std::make_shared<FactExtract>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_message(const Message& message) override {
|
void add_message(const Message& message) override {
|
||||||
while (!messages.empty() && (messages.size() > max_messages || messages.begin()->role == "assistant" || messages.begin()->role == "tool")) {
|
|
||||||
// Ensure the first message is always a user or system message
|
|
||||||
Message front_message = *messages.begin();
|
|
||||||
messages.erase(messages.begin());
|
|
||||||
|
|
||||||
if (config.llm_config->enable_vision) {
|
|
||||||
front_message = parse_vision_message(front_message, llm, config.llm_config->vision_details);
|
|
||||||
} else {
|
|
||||||
front_message = parse_vision_message(front_message);
|
|
||||||
}
|
|
||||||
|
|
||||||
_add_to_vector_store(front_message);
|
|
||||||
}
|
|
||||||
|
|
||||||
messages.push_back(message);
|
messages.push_back(message);
|
||||||
|
// Message message_to_memory = message;
|
||||||
|
// if (llm->enable_vision()) {
|
||||||
|
// message_to_memory = parse_vision_message(message_to_memory, llm, llm->vision_details());
|
||||||
|
// } else {
|
||||||
|
// message_to_memory = parse_vision_message(message_to_memory);
|
||||||
|
// }
|
||||||
|
// _add_to_vector_store({message_to_memory});
|
||||||
|
// while (!messages.empty() && (messages.size() > max_messages || messages.front().role == "assistant" || messages.front().role == "tool")) {
|
||||||
|
// // Ensure the first message is always a user or system message
|
||||||
|
// messages.pop_front();
|
||||||
|
// }
|
||||||
|
std::vector<Message> messages_to_memory;
|
||||||
|
while (!messages.empty() && (messages.size() > max_messages || messages.front().role == "assistant" || messages.front().role == "tool")) {
|
||||||
|
// Ensure the first message is always a user or system message
|
||||||
|
messages_to_memory.push_back(messages.front());
|
||||||
|
messages.pop_front();
|
||||||
|
}
|
||||||
|
if (!messages_to_memory.empty()) {
|
||||||
|
if (llm->enable_vision()) {
|
||||||
|
for (auto& m : messages_to_memory) {
|
||||||
|
m = parse_vision_message(m, llm, llm->vision_details());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (auto& m : messages_to_memory) {
|
||||||
|
m = parse_vision_message(m);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_add_to_vector_store(messages_to_memory);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Message> get_messages(const std::string& query = "") const override {
|
std::vector<Message> get_messages(const std::string& query = "") const override {
|
||||||
auto embeddings = embedding_model->embed(query, EmbeddingType::SEARCH);
|
std::vector<Message> messages_with_memory;
|
||||||
std::vector<MemoryItem> memories;
|
|
||||||
|
|
||||||
// 检查vector_store是否已初始化
|
|
||||||
if (vector_store) {
|
|
||||||
memories = vector_store->search(embeddings, limit, filters);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string memory_prompt;
|
if (!query.empty()) {
|
||||||
for (const auto& memory_item : memories) {
|
auto embeddings = embedding_model->embed(query, EmbeddingType::SEARCH);
|
||||||
memory_prompt += "<memory>" + memory_item.memory + "</memory>";
|
std::vector<MemoryItem> memories;
|
||||||
}
|
|
||||||
|
// 检查vector_store是否已初始化
|
||||||
|
if (vector_store) {
|
||||||
|
memories = vector_store->search(embeddings, retrieval_limit, filter);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<Message> messages_with_memory{Message::user_message(memory_prompt)};
|
if (!memories.empty()) {
|
||||||
|
sort(memories.begin(), memories.end(), [](const MemoryItem& a, const MemoryItem& b) {
|
||||||
|
return a.updated_at < b.updated_at;
|
||||||
|
});
|
||||||
|
|
||||||
|
std::string memory_prompt;
|
||||||
|
for (const auto& memory_item : memories) {
|
||||||
|
memory_prompt += "<memory>" + memory_item.memory + "</memory>";
|
||||||
|
}
|
||||||
|
|
||||||
|
messages_with_memory.push_back(Message::user_message(memory_prompt));
|
||||||
|
|
||||||
|
logger->info("📤 Total retreived memories: " + std::to_string(memories.size()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
messages_with_memory.insert(messages_with_memory.end(), messages.begin(), messages.end());
|
messages_with_memory.insert(messages_with_memory.end(), messages.begin(), messages.end());
|
||||||
|
|
||||||
return messages_with_memory;
|
return messages_with_memory;
|
||||||
}
|
}
|
||||||
|
|
||||||
void _add_to_vector_store(const Message& message) {
|
void clear() override {
|
||||||
|
if (messages.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<Message> messages_to_memory(messages.begin(), messages.end());
|
||||||
|
_add_to_vector_store(messages_to_memory);
|
||||||
|
messages.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void _add_to_vector_store(const std::vector<Message>& messages) {
|
||||||
// 检查vector_store是否已初始化
|
// 检查vector_store是否已初始化
|
||||||
if (!vector_store) {
|
if (!vector_store) {
|
||||||
logger->warn("Vector store is not initialized, skipping memory operation");
|
logger->warn("Vector store is not initialized, skipping memory operation");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string parsed_message = message.role + ": " + (message.content.is_string() ? message.content.get<std::string>() : message.content.dump());
|
std::string parsed_message;
|
||||||
|
|
||||||
|
for (const auto& message : messages) {
|
||||||
|
parsed_message += message.role + ": " + (message.content.is_string() ? message.content.get<std::string>() : message.content.dump()) + "\n";
|
||||||
|
|
||||||
for (const auto& tool_call : message.tool_calls) {
|
for (const auto& tool_call : message.tool_calls) {
|
||||||
parsed_message += "<tool_call>" + tool_call.to_json().dump() + "</tool_call>";
|
parsed_message += "<tool_call>" + tool_call.to_json().dump() + "</tool_call>\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string system_prompt = fact_extraction_prompt;
|
||||||
|
|
||||||
|
size_t pos = system_prompt.find("{current_request}");
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
system_prompt = system_prompt.replace(pos, 17, current_request);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string system_prompt = fact_extraction_prompt;
|
|
||||||
std::string user_prompt = "Input:\n" + parsed_message;
|
std::string user_prompt = "Input:\n" + parsed_message;
|
||||||
|
|
||||||
Message user_message = Message::user_message(user_prompt);
|
Message user_message = Message::user_message(user_prompt);
|
||||||
|
|
||||||
std::string response = llm->ask(
|
json response = llm->ask_tool(
|
||||||
{user_message},
|
{user_message},
|
||||||
system_prompt
|
system_prompt,
|
||||||
|
"",
|
||||||
|
json::array({fact_extract_tool->to_param()}),
|
||||||
|
"required"
|
||||||
);
|
);
|
||||||
|
|
||||||
json new_retrieved_facts; // ["fact1", "fact2", "fact3"]
|
std::vector<std::string> new_facts; // ["fact1", "fact2", "fact3"]
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// response = remove_code_blocks(response);
|
auto tool_calls = ToolCall::from_json_list(response["tool_calls"]);
|
||||||
new_retrieved_facts = json::parse(response)["facts"];
|
for (const auto& tool_call : tool_calls) {
|
||||||
|
// Parse arguments
|
||||||
|
json args = tool_call.function.arguments;
|
||||||
|
|
||||||
|
if (args.is_string()) {
|
||||||
|
args = json::parse(args.get<std::string>());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto facts = fact_extract_tool->execute(args).output.get<std::vector<std::string>>();
|
||||||
|
new_facts.insert(new_facts.end(), facts.begin(), facts.end());
|
||||||
|
}
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
logger->error("Error in new_retrieved_facts: " + std::string(e.what()));
|
logger->error("Error in new_facts: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<json> retrieved_old_memory;
|
if (new_facts.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
logger->info("📫 New facts to remember: " + json(new_facts).dump());
|
||||||
|
|
||||||
|
std::vector<json> old_memories;
|
||||||
std::map<std::string, std::vector<float>> new_message_embeddings;
|
std::map<std::string, std::vector<float>> new_message_embeddings;
|
||||||
|
|
||||||
for (const auto& fact : new_retrieved_facts) {
|
for (const auto& fact : new_facts) {
|
||||||
auto message_embedding = embedding_model->embed(fact, EmbeddingType::ADD);
|
auto message_embedding = embedding_model->embed(fact, EmbeddingType::ADD);
|
||||||
new_message_embeddings[fact] = message_embedding;
|
new_message_embeddings[fact] = message_embedding;
|
||||||
auto existing_memories = vector_store->search(
|
auto existing_memories = vector_store->search(
|
||||||
|
@ -122,29 +206,29 @@ struct Memory : BaseMemory {
|
||||||
5
|
5
|
||||||
);
|
);
|
||||||
for (const auto& memory : existing_memories) {
|
for (const auto& memory : existing_memories) {
|
||||||
retrieved_old_memory.push_back({
|
old_memories.push_back({
|
||||||
{"id", memory.id},
|
{"id", memory.id},
|
||||||
{"text", memory.metadata["data"]}
|
{"text", memory.memory}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// sort and unique by id
|
// sort and unique by id
|
||||||
std::sort(retrieved_old_memory.begin(), retrieved_old_memory.end(), [](const json& a, const json& b) {
|
std::sort(old_memories.begin(), old_memories.end(), [](const json& a, const json& b) {
|
||||||
return a["id"] < b["id"];
|
return a["id"] < b["id"];
|
||||||
});
|
});
|
||||||
retrieved_old_memory.resize(std::unique(retrieved_old_memory.begin(), retrieved_old_memory.end(), [](const json& a, const json& b) {
|
old_memories.resize(std::unique(old_memories.begin(), old_memories.end(), [](const json& a, const json& b) {
|
||||||
return a["id"] == b["id"];
|
return a["id"] == b["id"];
|
||||||
}) - retrieved_old_memory.begin());
|
}) - old_memories.begin());
|
||||||
logger->info("Total existing memories: " + std::to_string(retrieved_old_memory.size()));
|
logger->info("🧠 Existing memories about new facts: " + std::to_string(old_memories.size()));
|
||||||
|
|
||||||
// mapping UUIDs with integers for handling UUID hallucinations
|
// mapping UUIDs with integers for handling ID hallucinations
|
||||||
std::vector<size_t> temp_uuid_mapping;
|
std::vector<size_t> temp_id_mapping;
|
||||||
for (size_t idx = 0; idx < retrieved_old_memory.size(); ++idx) {
|
for (size_t idx = 0; idx < old_memories.size(); ++idx) {
|
||||||
temp_uuid_mapping.push_back(retrieved_old_memory[idx]["id"].get<size_t>());
|
temp_id_mapping.push_back(old_memories[idx]["id"].get<size_t>());
|
||||||
retrieved_old_memory[idx]["id"] = idx;
|
old_memories[idx]["id"] = idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts, fact_extraction_prompt, update_memory_prompt);
|
std::string function_calling_prompt = get_update_memory_messages(old_memories, new_facts, update_memory_prompt);
|
||||||
|
|
||||||
std::string new_memories_with_actions_str;
|
std::string new_memories_with_actions_str;
|
||||||
json new_memories_with_actions = json::array();
|
json new_memories_with_actions = json::array();
|
||||||
|
@ -153,28 +237,36 @@ struct Memory : BaseMemory {
|
||||||
new_memories_with_actions_str = llm->ask(
|
new_memories_with_actions_str = llm->ask(
|
||||||
{Message::user_message(function_calling_prompt)}
|
{Message::user_message(function_calling_prompt)}
|
||||||
);
|
);
|
||||||
new_memories_with_actions = json::parse(new_memories_with_actions_str);
|
new_memories_with_actions_str = remove_code_blocks(new_memories_with_actions_str);
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
|
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// new_memories_with_actions_str = remove_code_blocks(new_memories_with_actions_str);
|
|
||||||
new_memories_with_actions = json::parse(new_memories_with_actions_str);
|
new_memories_with_actions = json::parse(new_memories_with_actions_str);
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
logger->error("Invalid JSON response: " + std::string(e.what()));
|
logger->error("Invalid JSON response: " + std::string(e.what()));
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
for (const auto& resp : new_memories_with_actions.value("memory", json::array())) {
|
for (const auto& resp : new_memories_with_actions["memory"]) {
|
||||||
logger->info("Processing memory: " + resp.dump(2));
|
logger->debug("Processing memory: " + resp.dump(2));
|
||||||
try {
|
try {
|
||||||
if (!resp.contains("text")) {
|
if (!resp.contains("text")) {
|
||||||
logger->info("Skipping memory entry because of empty `text` field.");
|
logger->warn("Skipping memory entry because of empty `text` field.");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::string event = resp.value("event", "NONE");
|
std::string event = resp.value("event", "NONE");
|
||||||
size_t memory_id = resp.contains("id") ? temp_uuid_mapping[resp["id"].get<size_t>()] : uuid();
|
size_t memory_id;
|
||||||
|
try {
|
||||||
|
if (event != "ADD") {
|
||||||
|
memory_id = temp_id_mapping.at(resp["id"].get<size_t>());
|
||||||
|
} else {
|
||||||
|
memory_id = get_uuid_64();
|
||||||
|
}
|
||||||
|
} catch (...) {
|
||||||
|
memory_id = get_uuid_64();
|
||||||
|
}
|
||||||
if (event == "ADD") {
|
if (event == "ADD") {
|
||||||
_create_memory(
|
_create_memory(
|
||||||
memory_id,
|
memory_id,
|
||||||
|
@ -189,8 +281,6 @@ struct Memory : BaseMemory {
|
||||||
);
|
);
|
||||||
} else if (event == "DELETE") {
|
} else if (event == "DELETE") {
|
||||||
_delete_memory(memory_id);
|
_delete_memory(memory_id);
|
||||||
} else if (event == "NONE") {
|
|
||||||
logger->info("NOOP for Memory.");
|
|
||||||
}
|
}
|
||||||
} catch (const std::exception& e) {
|
} catch (const std::exception& e) {
|
||||||
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
|
logger->error("Error in new_memories_with_actions: " + std::string(e.what()));
|
||||||
|
@ -214,11 +304,10 @@ struct Memory : BaseMemory {
|
||||||
embedding = embedding_model->embed(data, EmbeddingType::ADD);
|
embedding = embedding_model->embed(data, EmbeddingType::ADD);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto created_at = std::chrono::system_clock::now();
|
MemoryItem metadata{
|
||||||
json metadata = {
|
memory_id,
|
||||||
{"data", data},
|
data,
|
||||||
{"hash", httplib::detail::MD5(data)},
|
httplib::detail::MD5(data)
|
||||||
{"created_at", std::chrono::system_clock::now().time_since_epoch().count()}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
vector_store->insert(
|
vector_store->insert(
|
||||||
|
@ -252,15 +341,14 @@ struct Memory : BaseMemory {
|
||||||
embedding = embedding_model->embed(data, EmbeddingType::ADD);
|
embedding = embedding_model->embed(data, EmbeddingType::ADD);
|
||||||
}
|
}
|
||||||
|
|
||||||
json metadata = existing_memory.metadata;
|
existing_memory.memory = data;
|
||||||
metadata["data"] = data;
|
existing_memory.hash = httplib::detail::MD5(data);
|
||||||
metadata["hash"] = httplib::detail::MD5(data);
|
existing_memory.updated_at = std::chrono::system_clock::now().time_since_epoch().count();
|
||||||
metadata["updated_at"] = std::chrono::system_clock::now().time_since_epoch().count();
|
|
||||||
|
|
||||||
vector_store->update(
|
vector_store->update(
|
||||||
memory_id,
|
memory_id,
|
||||||
embedding,
|
embedding,
|
||||||
metadata
|
existing_memory
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,22 @@
|
||||||
|
|
||||||
namespace humanus::mem0 {
|
namespace humanus::mem0 {
|
||||||
|
|
||||||
static size_t uuid() {
|
// Removes enclosing code block markers ```[language] and ``` from a given string.
|
||||||
|
//
|
||||||
|
// Remarks:
|
||||||
|
// - The function uses a regex pattern to match code blocks that may start with ``` followed by an optional language tag (letters or numbers) and end with ```.
|
||||||
|
// - If a code block is detected, it returns only the inner content, stripping out the markers.
|
||||||
|
// - If no code block markers are found, the original content is returned as-is.
|
||||||
|
std::string remove_code_blocks(const std::string& text) {
|
||||||
|
static const std::regex pattern(R"(^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$)");
|
||||||
|
std::smatch match;
|
||||||
|
if (std::regex_search(text, match, pattern)) {
|
||||||
|
return match[1].str();
|
||||||
|
}
|
||||||
|
return text;
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t get_uuid_64() {
|
||||||
const std::string chars = "0123456789abcdef";
|
const std::string chars = "0123456789abcdef";
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
std::mt19937 gen(rd());
|
std::mt19937 gen(rd());
|
||||||
|
@ -29,33 +44,33 @@ static size_t uuid() {
|
||||||
return uuid_int;
|
return uuid_int;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_update_memory_messages(const json& retrieved_old_memory, const json& new_retrieved_facts, const std::string fact_extraction_prompt, const std::string& update_memory_prompt) {
|
std::string get_update_memory_messages(const json& old_memories, const json& new_facts, const std::string& update_memory_prompt) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << fact_extraction_prompt << "\n\n";
|
ss << update_memory_prompt << "\n\n";
|
||||||
ss << "Below is the current content of my memory which I have collected till now. You have to update it in the following format only:\n\n";
|
ss << "Below is the current content of my memory which I have collected till now. You have to update it in the following format only:\n\n";
|
||||||
ss << "```" + retrieved_old_memory.dump(2) + "```\n\n";
|
ss << old_memories.dump(2) + "\n\n";
|
||||||
ss << "The new retrieved facts are mentioned in the triple backticks. You have to analyze the new retrieved facts and determine whether these facts should be added, updated, or deleted in the memory.\n\n";
|
ss << "The new retrieved facts are mentioned below. You have to analyze the new retrieved facts and determine whether these facts should be added, updated, or deleted in the memory.\n\n";
|
||||||
ss << "```" + new_retrieved_facts.dump(2) + "```\n\n";
|
ss << new_facts.dump(2) + "\n\n";
|
||||||
ss << "You must return your response in the following JSON structure only:\n\n";
|
ss << "You must return your response in the following JSON structure only:\n\n";
|
||||||
ss << R"json({
|
ss << R"json({
|
||||||
"memory" : [
|
"memory" : [
|
||||||
{
|
{
|
||||||
"id" : "<ID of the memory>", # Use existing ID for updates/deletes, or new ID for additions
|
"id" : <interger ID of the memory>, # Use existing ID for updates/deletes, or new ID for additions
|
||||||
"text" : "<Content of the memory>", # Content of the memory
|
"text" : "<Content of the memory>", # Content of the memory
|
||||||
"event" : "<Operation to be performed>", # Must be "ADD", "UPDATE", "DELETE", or "NONE"
|
"event" : "<Operation to be performed>", # Must be "ADD", "UPDATE", "DELETE", or "NONE"
|
||||||
"old_memory" : "<Old memory content>" # Required only if the event is "UPDATE"
|
"old_memory" : "<Old memory content>" # Required only if the event is "UPDATE"
|
||||||
},
|
},
|
||||||
...
|
...
|
||||||
]
|
]
|
||||||
})json" << "\n\n";
|
})json" << "\n\n";
|
||||||
ss << "Follow the instruction mentioned below:\n"
|
ss << "Follow the instruction mentioned below:\n"
|
||||||
<< "- Do not return anything from the custom few shot prompts provided above.\n"
|
<< "- Do not return anything from the custom few shot prompts provided above.\n"
|
||||||
<< "- If the current memory is empty, then you have to add the new retrieved facts to the memory.\n"
|
<< "- If the current memory is empty, then you have to add the new retrieved facts to the memory.\n"
|
||||||
<< "- You should return the updated memory in only JSON format as shown below. The memory key should be the same if no changes are made.\n"
|
<< "- You should return the updated memory in only JSON format as shown below. The memory key should be the same if no changes are made.\n"
|
||||||
<< "- If there is an addition, generate a new key and add the new memory corresponding to it.\n"
|
<< "- If there is an addition, generate a new key and add the new memory corresponding to it.\n"
|
||||||
<< "- If there is a deletion, the memory key-value pair should be removed from the memory.\n"
|
<< "- If there is a deletion, the memory key-value pair should be removed from the memory.\n"
|
||||||
<< "- If there is an update, the ID key should remain the same and only the value needs to be updated.\n"
|
<< "- If there is an update, the ID key should remain the same and only the value needs to be updated.\n"
|
||||||
<< "\n";
|
<< "\n";
|
||||||
ss << "Do not return anything except the JSON format.\n";
|
ss << "Do not return anything except the JSON format.\n";
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,22 +30,22 @@ public:
|
||||||
* @brief 插入向量到集合中
|
* @brief 插入向量到集合中
|
||||||
* @param vector 向量数据
|
* @param vector 向量数据
|
||||||
* @param vector_id 向量ID
|
* @param vector_id 向量ID
|
||||||
* @param metadata 可选的元数据
|
* @param metadata 元数据
|
||||||
*/
|
*/
|
||||||
virtual void insert(const std::vector<float>& vector,
|
virtual void insert(const std::vector<float>& vector,
|
||||||
const size_t vector_id,
|
const size_t vector_id,
|
||||||
const json& metadata = json::object()) = 0;
|
const MemoryItem& metadata) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief 搜索相似向量
|
* @brief 搜索相似向量
|
||||||
* @param query 查询向量
|
* @param query 查询向量
|
||||||
* @param limit 返回结果数量限制
|
* @param limit 返回结果数量限制
|
||||||
* @param filters 可选的过滤条件
|
* @param filter 可选的过滤条件
|
||||||
* @return 相似向量的MemoryItem列表
|
* @return 相似向量的MemoryItem列表
|
||||||
*/
|
*/
|
||||||
virtual std::vector<MemoryItem> search(const std::vector<float>& query,
|
virtual std::vector<MemoryItem> search(const std::vector<float>& query,
|
||||||
int limit = 5,
|
size_t limit = 5,
|
||||||
const std::string& filters = "") = 0;
|
const FilterFunc& filter = nullptr) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief 通过ID删除向量
|
* @brief 通过ID删除向量
|
||||||
|
@ -56,12 +56,10 @@ public:
|
||||||
/**
|
/**
|
||||||
* @brief 更新向量及其负载
|
* @brief 更新向量及其负载
|
||||||
* @param vector_id 向量ID
|
* @param vector_id 向量ID
|
||||||
* @param vector 可选的新向量数据
|
* @param vector 新向量数据
|
||||||
* @param metadata 可选的新负载数据
|
* @param metadata 新负载数据
|
||||||
*/
|
*/
|
||||||
virtual void update(size_t vector_id,
|
virtual void update(size_t vector_id, const std::vector<float>& vector = std::vector<float>(), const MemoryItem& metadata = MemoryItem()) = 0;
|
||||||
const std::vector<float> vector = std::vector<float>(),
|
|
||||||
const json& metadata = json::object()) = 0;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief 通过ID获取向量
|
* @brief 通过ID获取向量
|
||||||
|
@ -72,11 +70,11 @@ public:
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief 列出所有记忆
|
* @brief 列出所有记忆
|
||||||
* @param filters 可选的过滤条件
|
|
||||||
* @param limit 可选的结果数量限制
|
* @param limit 可选的结果数量限制
|
||||||
|
* @param filter 可选的过滤条件(isIDAllowed)
|
||||||
* @return 记忆ID列表
|
* @return 记忆ID列表
|
||||||
*/
|
*/
|
||||||
virtual std::vector<MemoryItem> list(const std::string& filters = "", int limit = 0) = 0;
|
virtual std::vector<MemoryItem> list(size_t limit = 0, const FilterFunc& filter = nullptr) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace humanus::mem0
|
} // namespace humanus::mem0
|
||||||
|
|
|
@ -22,54 +22,39 @@ void HNSWLibVectorStore::reset() {
|
||||||
space = std::make_shared<hnswlib::InnerProductSpace>(config_->dim);
|
space = std::make_shared<hnswlib::InnerProductSpace>(config_->dim);
|
||||||
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>>(space.get(), config_->max_elements, config_->M, config_->ef_construction);
|
hnsw = std::make_shared<hnswlib::HierarchicalNSW<float>>(space.get(), config_->max_elements, config_->M, config_->ef_construction);
|
||||||
} else {
|
} else {
|
||||||
throw std::invalid_argument("Unsupported metric: " + std::to_string(static_cast<int>(config_->metric)));
|
throw std::invalid_argument("Unsupported metric: " + std::to_string(static_cast<size_t>(config_->metric)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void HNSWLibVectorStore::insert(const std::vector<float>& vector,
|
void HNSWLibVectorStore::insert(const std::vector<float>& vector, const size_t vector_id, const MemoryItem& metadata) {
|
||||||
const size_t vector_id,
|
|
||||||
const json& metadata) {
|
|
||||||
hnsw->addPoint(vector.data(), vector_id);
|
hnsw->addPoint(vector.data(), vector_id);
|
||||||
|
|
||||||
// 存储元数据
|
// 存储元数据
|
||||||
auto now = std::chrono::system_clock::now().time_since_epoch().count();
|
auto now = std::chrono::system_clock::now().time_since_epoch().count();
|
||||||
json _metadata = metadata;
|
MemoryItem _metadata = metadata;
|
||||||
if (!_metadata.contains("created_at")) {
|
if (_metadata.created_at < 0) {
|
||||||
_metadata["created_at"] = now;
|
_metadata.created_at = now;
|
||||||
}
|
}
|
||||||
if (!_metadata.contains("updated_at")) {
|
if (_metadata.updated_at < 0) {
|
||||||
_metadata["updated_at"] = now;
|
_metadata.updated_at = now;
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata_store[vector_id] = _metadata;
|
metadata_store[vector_id] = _metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MemoryItem> HNSWLibVectorStore::search(const std::vector<float>& query,
|
std::vector<MemoryItem> HNSWLibVectorStore::search(const std::vector<float>& query, size_t limit, const FilterFunc& filter) {
|
||||||
int limit = 5,
|
auto filte_wrapper = filter ? std::make_unique<HNSWLibFilterFunctorWrapper>(*this, filter) : nullptr;
|
||||||
const std::string& filters = "") {
|
auto results = hnsw->searchKnn(query.data(), limit, filte_wrapper.get());
|
||||||
auto results = hnsw->searchKnn(query.data(), limit);
|
|
||||||
std::vector<MemoryItem> memory_items;
|
std::vector<MemoryItem> memory_items;
|
||||||
|
|
||||||
while (!results.empty()) {
|
while (!results.empty()) {
|
||||||
const auto& [id, distance] = results.top();
|
const auto& [distance, id] = results.top();
|
||||||
|
|
||||||
results.pop();
|
results.pop();
|
||||||
|
|
||||||
if (metadata_store.find(id) != metadata_store.end()) {
|
if (metadata_store.find(id) != metadata_store.end()) {
|
||||||
MemoryItem item;
|
MemoryItem item = metadata_store[id];
|
||||||
item.id = id;
|
|
||||||
|
|
||||||
if (metadata_store[id].contains("data")) {
|
|
||||||
item.memory = metadata_store[id]["data"];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (metadata_store[id].contains("hash")) {
|
|
||||||
item.hash = metadata_store[id]["hash"];
|
|
||||||
}
|
|
||||||
|
|
||||||
item.metadata = metadata_store[id];
|
|
||||||
item.score = distance;
|
item.score = distance;
|
||||||
|
|
||||||
memory_items.push_back(item);
|
memory_items.push_back(item);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -82,79 +67,50 @@ void HNSWLibVectorStore::delete_vector(size_t vector_id) {
|
||||||
metadata_store.erase(vector_id);
|
metadata_store.erase(vector_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void HNSWLibVectorStore::update(size_t vector_id,
|
void HNSWLibVectorStore::update(size_t vector_id, const std::vector<float>& vector, const MemoryItem& metadata) {
|
||||||
const std::vector<float> vector = std::vector<float>(),
|
|
||||||
const json& metadata = json::object()) {
|
|
||||||
// 检查向量是否需要更新
|
// 检查向量是否需要更新
|
||||||
if (!vector.empty()) {
|
if (!vector.empty()) {
|
||||||
hnsw->markDelete(vector_id);
|
hnsw->markDelete(vector_id);
|
||||||
hnsw->addPoint(vector.data(), vector_id);
|
hnsw->addPoint(vector.data(), vector_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新元数据
|
if (!metadata.empty()) {
|
||||||
if (metadata_store.find(vector_id) != metadata_store.end()) {
|
MemoryItem new_metadata = metadata;
|
||||||
|
new_metadata.id = vector_id; // Make sure the id is the same as the vector id
|
||||||
auto now = std::chrono::system_clock::now().time_since_epoch().count();
|
auto now = std::chrono::system_clock::now().time_since_epoch().count();
|
||||||
|
if (metadata_store.find(vector_id) != metadata_store.end()) {
|
||||||
// 合并现有元数据和新元数据
|
MemoryItem old_metadata = metadata_store[vector_id];
|
||||||
for (auto& [key, value] : metadata.items()) {
|
if (new_metadata.hash == old_metadata.hash) {
|
||||||
metadata_store[vector_id][key] = value;
|
new_metadata.created_at = old_metadata.created_at;
|
||||||
|
} else {
|
||||||
|
new_metadata.created_at = now;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if (new_metadata.created_at < 0) {
|
||||||
// 更新时间戳
|
new_metadata.created_at = now;
|
||||||
metadata_store[vector_id]["updated_at"] = now;
|
|
||||||
} else if (!metadata.empty()) {
|
|
||||||
// 如果元数据不存在但提供了新的元数据,则创建新条目
|
|
||||||
auto now = std::chrono::system_clock::now().time_since_epoch().count();
|
|
||||||
json new_metadata = metadata;
|
|
||||||
if (!new_metadata.contains("created_at")) {
|
|
||||||
new_metadata["created_at"] = now;
|
|
||||||
}
|
}
|
||||||
new_metadata["updated_at"] = now;
|
new_metadata.updated_at = now;
|
||||||
metadata_store[vector_id] = new_metadata;
|
metadata_store[vector_id] = new_metadata;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MemoryItem HNSWLibVectorStore::get(size_t vector_id) {
|
MemoryItem HNSWLibVectorStore::get(size_t vector_id) {
|
||||||
MemoryItem item;
|
return metadata_store.at(vector_id);
|
||||||
item.id = vector_id;
|
|
||||||
|
|
||||||
// 获取向量数据
|
|
||||||
std::vector<float> vector_data = hnsw->getDataByLabel<float>(vector_id);
|
|
||||||
|
|
||||||
// 获取元数据
|
|
||||||
if (metadata_store.find(vector_id) != metadata_store.end()) {
|
|
||||||
if (metadata_store[vector_id].contains("data")) {
|
|
||||||
item.memory = metadata_store[vector_id]["data"];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (metadata_store[vector_id].contains("hash")) {
|
|
||||||
item.hash = metadata_store[vector_id]["hash"];
|
|
||||||
}
|
|
||||||
|
|
||||||
item.metadata = metadata_store[vector_id];
|
|
||||||
}
|
|
||||||
|
|
||||||
return item;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MemoryItem> HNSWLibVectorStore::list(const std::string& filters = "", int limit = 0) {
|
std::vector<MemoryItem> HNSWLibVectorStore::list(size_t limit, const FilterFunc& filter) {
|
||||||
std::vector<MemoryItem> result;
|
std::vector<MemoryItem> result;
|
||||||
size_t count = hnsw->cur_element_count;
|
size_t count = hnsw->cur_element_count;
|
||||||
|
|
||||||
for (size_t i = 0; i < count; i++) {
|
for (size_t i = 0; i < count; i++) {
|
||||||
if (!hnsw->isMarkedDeleted(i)) {
|
if (!hnsw->isMarkedDeleted(i)) {
|
||||||
// 如果有过滤条件,检查元数据是否匹配
|
// 如果有过滤条件,检查元数据是否匹配
|
||||||
if (!filters.empty() && metadata_store.find(i) != metadata_store.end()) {
|
auto memory_item = get(i);
|
||||||
// 简单的字符串匹配过滤,可以根据需要扩展
|
if (filter && !filter(memory_item)) {
|
||||||
json metadata_json = metadata_store[i];
|
continue;
|
||||||
std::string metadata_str = metadata_json.dump();
|
|
||||||
if (metadata_str.find(filters) == std::string::npos) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
result.emplace_back(memory_item);
|
||||||
result.emplace_back(get(i));
|
if (limit > 0 && result.size() >= limit) {
|
||||||
if (limit > 0 && result.size() >= static_cast<size_t>(limit)) {
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ class HNSWLibVectorStore : public VectorStore {
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<hnswlib::HierarchicalNSW<float>> hnsw;
|
std::shared_ptr<hnswlib::HierarchicalNSW<float>> hnsw;
|
||||||
std::shared_ptr<hnswlib::SpaceInterface<float>> space; // 保持space对象的引用以确保其生命周期
|
std::shared_ptr<hnswlib::SpaceInterface<float>> space; // 保持space对象的引用以确保其生命周期
|
||||||
std::unordered_map<size_t, json> metadata_store; // 存储向量的元数据
|
std::unordered_map<size_t, MemoryItem> metadata_store; // 存储向量的元数据
|
||||||
|
|
||||||
public:
|
public:
|
||||||
HNSWLibVectorStore(const std::shared_ptr<VectorStoreConfig>& config) : VectorStore(config) {
|
HNSWLibVectorStore(const std::shared_ptr<VectorStoreConfig>& config) : VectorStore(config) {
|
||||||
|
@ -19,19 +19,41 @@ public:
|
||||||
|
|
||||||
void reset() override;
|
void reset() override;
|
||||||
|
|
||||||
void insert(const std::vector<float>& vector, const size_t vector_id, const json& metadata) override;
|
void insert(const std::vector<float>& vector, const size_t vector_id, const MemoryItem& metadata) override;
|
||||||
|
|
||||||
std::vector<MemoryItem> search(const std::vector<float>& query, int limit, const std::string& filters) override;
|
std::vector<MemoryItem> search(const std::vector<float>& query, size_t limit, const FilterFunc& filter = nullptr) override;
|
||||||
|
|
||||||
void delete_vector(size_t vector_id) override;
|
void delete_vector(size_t vector_id) override;
|
||||||
|
|
||||||
void update(size_t vector_id, const std::vector<float> vector, const json& metadata) override;
|
void update(size_t vector_id, const std::vector<float>& vector = std::vector<float>(), const MemoryItem& metadata = MemoryItem()) override;
|
||||||
|
|
||||||
MemoryItem get(size_t vector_id) override;
|
MemoryItem get(size_t vector_id) override;
|
||||||
|
|
||||||
std::vector<MemoryItem> list(const std::string& filters, int limit) override;
|
std::vector<MemoryItem> list(size_t limit, const FilterFunc& filter = nullptr) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class HNSWLibFilterFunctorWrapper : public hnswlib::BaseFilterFunctor {
|
||||||
|
private:
|
||||||
|
HNSWLibVectorStore& vector_store;
|
||||||
|
FilterFunc filter_func;
|
||||||
|
|
||||||
|
public:
|
||||||
|
HNSWLibFilterFunctorWrapper(HNSWLibVectorStore& store, const FilterFunc& filter_func)
|
||||||
|
: vector_store(store), filter_func(filter_func) {}
|
||||||
|
|
||||||
|
bool operator()(hnswlib::labeltype id) override {
|
||||||
|
if (filter_func == nullptr) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
return filter_func(vector_store.get(id));
|
||||||
|
} catch (...) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H
|
#endif // HUMANUS_MEMORY_MEM0_VECTOR_STORE_HNSWLIB_H
|
113
prompt.cpp
113
prompt.cpp
|
@ -18,7 +18,11 @@ const char* NEXT_STEP_PROMPT = R"(You can interact with the computer using pytho
|
||||||
|
|
||||||
Based on user needs, proactively select the most appropriate tool or combination of tools. For complex tasks, you can break down the problem and use different tools step by step to solve it.
|
Based on user needs, proactively select the most appropriate tool or combination of tools. For complex tasks, you can break down the problem and use different tools step by step to solve it.
|
||||||
|
|
||||||
After using each tool, clearly explain the execution results and suggest the next steps. If you finish the current step, call `terminate` to switch to next step.)";
|
After using each tool, clearly explain the execution results and suggest the next steps. If you finish the current step, call `terminate` to switch to next step.
|
||||||
|
|
||||||
|
Unless required by user, you should always at most use one tool at a time, observe the result and then choose the next tool or action.
|
||||||
|
|
||||||
|
Detect the language of the user input and respond in the same language for thoughts.)";
|
||||||
} // namespace humanus
|
} // namespace humanus
|
||||||
|
|
||||||
namespace planning {
|
namespace planning {
|
||||||
|
@ -83,44 +87,23 @@ Types of Information to Remember:
|
||||||
|
|
||||||
1. Store Personal Preferences: Keep track of likes, dislikes, and specific preferences in various categories such as food, products, activities, and entertainment.
|
1. Store Personal Preferences: Keep track of likes, dislikes, and specific preferences in various categories such as food, products, activities, and entertainment.
|
||||||
2. Maintain Important Personal Details: Remember significant personal information like names, relationships, and important dates.
|
2. Maintain Important Personal Details: Remember significant personal information like names, relationships, and important dates.
|
||||||
3. Track Plans and Intentions: Note upcoming events, trips, goals, and any plans the user has shared.
|
3. Track Plans and Intentions: Note upcoming events, trips, goals, and any plans the user has shared or assistant has generated.
|
||||||
4. Remember Activity and Service Preferences: Recall preferences for dining, travel, hobbies, and other services.
|
4. Remember Activity and Service Preferences: Recall preferences for dining, travel, hobbies, and other services.
|
||||||
5. Monitor Health and Wellness Preferences: Keep a record of dietary restrictions, fitness routines, and other wellness-related information.
|
5. Monitor Health and Wellness Preferences: Keep a record of dietary restrictions, fitness routines, and other wellness-related information.
|
||||||
6. Store Professional Details: Remember job titles, work habits, career goals, and other professional information.
|
6. Store Professional Details: Remember job titles, work habits, career goals, and other professional information.
|
||||||
7. Miscellaneous Information Management: Keep track of favorite books, movies, brands, and other miscellaneous details that the user shares.
|
7. Miscellaneous Information Management: Keep track of favorite books, movies, brands, and other miscellaneous details that the user shares.
|
||||||
|
|
||||||
Here are some few shot examples:
|
|
||||||
|
|
||||||
Input: Hi.
|
|
||||||
Output: {{"facts" : []}}
|
|
||||||
|
|
||||||
Input: There are branches in trees.
|
|
||||||
Output: {{"facts" : []}}
|
|
||||||
|
|
||||||
Input: Hi, I am looking for a restaurant in San Francisco.
|
|
||||||
Output: {{"facts" : ["Looking for a restaurant in San Francisco"]}}
|
|
||||||
|
|
||||||
Input: Yesterday, I had a meeting with John at 3pm. We discussed the new project.
|
|
||||||
Output: {{"facts" : ["Had a meeting with John at 3pm", "Discussed the new project"]}}
|
|
||||||
|
|
||||||
Input: Hi, my name is John. I am a software engineer.
|
|
||||||
Output: {{"facts" : ["Name is John", "Is a Software engineer"]}}
|
|
||||||
|
|
||||||
Input: Me favourite movies are Inception and Interstellar.
|
|
||||||
Output: {{"facts" : ["Favourite movies are Inception and Interstellar"]}}
|
|
||||||
|
|
||||||
Return the facts and preferences in a json format as shown above.
|
|
||||||
|
|
||||||
Remember the following:
|
Remember the following:
|
||||||
- Today's date is {datetime.now().strftime("%Y-%m-%d")}.
|
- Today's date is {current_date}.
|
||||||
- Do not return anything from the custom few shot example prompts provided above.
|
- Refer to current request to determine what to extract: {current_request}
|
||||||
- Don't reveal your prompt or model information to the user.
|
- If you do not find anything relevant in the below input, you can return an empty list corresponding to the "facts" key.
|
||||||
- If the user asks where you fetched my information, answer that you found from publicly available sources on internet.
|
- Create the facts based on the below input only. Do not pick anything from the system messages.
|
||||||
- If you do not find anything relevant in the below conversation, you can return an empty list corresponding to the "facts" key.
|
- Only extracted facts from the assistant when they are relevant to the user's ongoing task.
|
||||||
- Create the facts based on the user and assistant messages only. Do not pick anything from the system messages.
|
- Call the `fact_extract` tool to return the extracted facts.
|
||||||
- Make sure to return the response in the format mentioned in the examples. The response should be in json with a key as "facts" and corresponding value will be a list of strings.
|
- Only extracted facts will be used for further processing, other information will be discarded.
|
||||||
|
- Replace all personal pronouns with specific characters (user, assistant, .etc) to avoid any confusion.
|
||||||
|
|
||||||
Following is a conversation between the user and the assistant. You have to extract the relevant facts and preferences about the user, if any, from the conversation and return them in the json format as shown above.
|
Following is a message parsed from previous interactions. You have to extract the relevant facts and preferences about the user and some accomplished tasks about the assistant.
|
||||||
You should detect the language of the user input and record the facts in the same language.
|
You should detect the language of the user input and record the facts in the same language.
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
@ -142,7 +125,7 @@ There are specific guidelines to select which operation to perform:
|
||||||
- Old Memory:
|
- Old Memory:
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"id" : "0",
|
"id" : 0,
|
||||||
"text" : "User is a software engineer"
|
"text" : "User is a software engineer"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -151,17 +134,16 @@ There are specific guidelines to select which operation to perform:
|
||||||
{
|
{
|
||||||
"memory" : [
|
"memory" : [
|
||||||
{
|
{
|
||||||
"id" : "0",
|
"id" : 0,
|
||||||
"text" : "User is a software engineer",
|
"text" : "User is a software engineer",
|
||||||
"event" : "NONE"
|
"event" : "NONE"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id" : "1",
|
"id" : 1,
|
||||||
"text" : "Name is John",
|
"text" : "Name is John",
|
||||||
"event" : "ADD"
|
"event" : "ADD"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
2. **Update**: If the retrieved facts contain information that is already present in the memory but the information is totally different, then you have to update it.
|
2. **Update**: If the retrieved facts contain information that is already present in the memory but the information is totally different, then you have to update it.
|
||||||
|
@ -175,72 +157,71 @@ Please note to return the IDs in the output from the input IDs only and do not g
|
||||||
- Old Memory:
|
- Old Memory:
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"id" : "0",
|
"id" : 0,
|
||||||
"text" : "I really like cheese pizza"
|
"text" : "I really like cheese pizza"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id" : "1",
|
"id" : 1,
|
||||||
"text" : "User is a software engineer"
|
"text" : "User is a software engineer"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id" : "2",
|
"id" : 2,
|
||||||
"text" : "User likes to play cricket"
|
"text" : "User likes to play cricket"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
- Retrieved facts: ["Loves chicken pizza", "Loves to play cricket with friends"]
|
- Retrieved facts: ["Loves chicken pizza", "Loves to play cricket with friends"]
|
||||||
- New Memory:
|
- New Memory:
|
||||||
{
|
{
|
||||||
"memory" : [
|
"memory" : [
|
||||||
{
|
{
|
||||||
"id" : "0",
|
"id" : 0,
|
||||||
"text" : "Loves cheese and chicken pizza",
|
"text" : "User loves cheese and chicken pizza",
|
||||||
"event" : "UPDATE",
|
"event" : "UPDATE",
|
||||||
"old_memory" : "I really like cheese pizza"
|
"old_memory" : "I really like cheese pizza"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id" : "1",
|
"id" : 1,
|
||||||
"text" : "User is a software engineer",
|
"text" : "User is a software engineer",
|
||||||
"event" : "NONE"
|
"event" : "NONE"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id" : "2",
|
"id" : 2,
|
||||||
"text" : "Loves to play cricket with friends",
|
"text" : "User loves to play cricket with friends",
|
||||||
"event" : "UPDATE",
|
"event" : "UPDATE",
|
||||||
"old_memory" : "User likes to play cricket"
|
"old_memory" : "User likes to play cricket"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
3. **Delete**: If the retrieved facts contain information that contradicts the information present in the memory, then you have to delete it. Or if the direction is to delete the memory, then you have to delete it.
|
3. **Delete**: If the retrieved facts contain information that contradicts the information present in the memory, then you have to delete it. Or if the direction is to delete the memory, then you have to delete it.
|
||||||
Please note to return the IDs in the output from the input IDs only and do not generate any new ID.
|
Please note to return the IDs in the output from the input IDs only and do not generate any new ID.
|
||||||
- **Example**:
|
- **Example**:
|
||||||
- Old Memory:
|
- Old Memory:
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"id" : "0",
|
"id" : 0,
|
||||||
"text" : "Name is John"
|
"text" : "User's name is John"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id" : "1",
|
"id" : 1,
|
||||||
"text" : "Loves cheese pizza"
|
"text" : "User loves cheese pizza"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
- Retrieved facts: ["Dislikes cheese pizza"]
|
- Retrieved facts: ["Dislikes cheese pizza"]
|
||||||
- New Memory:
|
- New Memory:
|
||||||
{
|
{
|
||||||
"memory" : [
|
"memory" : [
|
||||||
{
|
{
|
||||||
"id" : "0",
|
"id" : 0,
|
||||||
"text" : "Name is John",
|
"text" : "User's name is John",
|
||||||
"event" : "NONE"
|
"event" : "NONE"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id" : "1",
|
"id" : 1,
|
||||||
"text" : "Loves cheese pizza",
|
"text" : "User loves cheese pizza",
|
||||||
"event" : "DELETE"
|
"event" : "DELETE"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
4. **No Change**: If the retrieved facts contain information that is already present in the memory, then you do not need to make any changes.
|
4. **No Change**: If the retrieved facts contain information that is already present in the memory, then you do not need to make any changes.
|
||||||
|
@ -248,26 +229,26 @@ Please note to return the IDs in the output from the input IDs only and do not g
|
||||||
- Old Memory:
|
- Old Memory:
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"id" : "0",
|
"id" : 0,
|
||||||
"text" : "Name is John"
|
"text" : "User's name is John"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id" : "1",
|
"id" : 1,
|
||||||
"text" : "Loves cheese pizza"
|
"text" : "User loves cheese pizza"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
- Retrieved facts: ["Name is John"]
|
- Retrieved facts: ["User's name is John"]
|
||||||
- New Memory:
|
- New Memory:
|
||||||
{
|
{
|
||||||
"memory" : [
|
"memory" : [
|
||||||
{
|
{
|
||||||
"id" : "0",
|
"id" : 0,
|
||||||
"text" : "Name is John",
|
"text" : "User's name is John",
|
||||||
"event" : "NONE"
|
"event" : "NONE"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id" : "1",
|
"id" : 1,
|
||||||
"text" : "Loves cheese pizza",
|
"text" : "User loves cheese pizza",
|
||||||
"event" : "NONE"
|
"event" : "NONE"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
27
schema.h
27
schema.h
|
@ -156,6 +156,33 @@ struct Message {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace mem0 {
|
||||||
|
|
||||||
|
struct MemoryItem {
|
||||||
|
size_t id; // The unique identifier for the text data
|
||||||
|
std::string memory; // The memory deduced from the text data
|
||||||
|
std::string hash; // The hash of the memory
|
||||||
|
long long created_at; // The creation time of the memory,
|
||||||
|
long long updated_at; // The last update time of the memory
|
||||||
|
float score; // The score associated with the text data, used for ranking and sorting
|
||||||
|
|
||||||
|
MemoryItem(size_t id = -1, const std::string& memory = "", const std::string& hash = "")
|
||||||
|
: id(id), memory(memory), hash(hash) {
|
||||||
|
auto now = std::chrono::system_clock::now().time_since_epoch().count();
|
||||||
|
created_at = now;
|
||||||
|
updated_at = now;
|
||||||
|
score = -1.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
return memory.empty();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef std::function<bool(const MemoryItem&)> FilterFunc;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace humanus
|
} // namespace humanus
|
||||||
|
|
||||||
#endif // HUMANUS_SCHEMA_H
|
#endif // HUMANUS_SCHEMA_H
|
||||||
|
|
|
@ -0,0 +1,103 @@
|
||||||
|
import pygame
|
||||||
|
import sys
|
||||||
|
import random
|
||||||
|
|
||||||
|
# Initialize Pygame
|
||||||
|
pygame.init()
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
SCREEN_WIDTH, SCREEN_HEIGHT = 640, 480
|
||||||
|
GRID_SIZE = 20
|
||||||
|
FPS = 10
|
||||||
|
|
||||||
|
# Colors
|
||||||
|
WHITE = (255, 255, 255)
|
||||||
|
BLACK = (0, 0, 0)
|
||||||
|
GREEN = (0, 255, 0)
|
||||||
|
RED = (255, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class Game:
|
||||||
|
def __init__(self):
|
||||||
|
# Set up the display
|
||||||
|
self.screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
|
||||||
|
self.clock = pygame.time.Clock()
|
||||||
|
self.snake = Snake([(SCREEN_WIDTH // 2, SCREEN_HEIGHT // 2)], (0, -1), 1)
|
||||||
|
self.food = Food(self.generate_food_position())
|
||||||
|
|
||||||
|
def generate_food_position(self):
|
||||||
|
return (random.randint(0, (SCREEN_WIDTH // GRID_SIZE) - 1) * GRID_SIZE,
|
||||||
|
random.randint(0, (SCREEN_HEIGHT // GRID_SIZE) - 1) * GRID_SIZE)
|
||||||
|
|
||||||
|
def main_loop(self):
|
||||||
|
while True:
|
||||||
|
for event in pygame.event.get():
|
||||||
|
if event.type == pygame.QUIT:
|
||||||
|
pygame.quit()
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
# Handle user input and update game state
|
||||||
|
self.handle_input()
|
||||||
|
self.update_game_state()
|
||||||
|
|
||||||
|
# Draw everything
|
||||||
|
self.screen.fill(BLACK)
|
||||||
|
self.draw_snake()
|
||||||
|
self.draw_food()
|
||||||
|
pygame.display.flip()
|
||||||
|
|
||||||
|
# Cap the frame rate
|
||||||
|
self.clock.tick(FPS)
|
||||||
|
|
||||||
|
def handle_input(self):
|
||||||
|
keys = pygame.key.get_pressed()
|
||||||
|
if keys[pygame.K_UP] and self.snake.direction != (0, 1):
|
||||||
|
self.snake.change_direction((0, -1))
|
||||||
|
elif keys[pygame.K_DOWN] and self.snake.direction != (0, -1):
|
||||||
|
self.snake.change_direction((0, 1))
|
||||||
|
elif keys[pygame.K_LEFT] and self.snake.direction != (1, 0):
|
||||||
|
self.snake.change_direction((-1, 0))
|
||||||
|
elif keys[pygame.K_RIGHT] and self.snake.direction != (-1, 0):
|
||||||
|
self.snake.change_direction((1, 0))
|
||||||
|
|
||||||
|
def update_game_state(self):
|
||||||
|
self.snake.move()
|
||||||
|
if self.snake.position[0] == self.food.position:
|
||||||
|
self.snake.grow()
|
||||||
|
self.food.position = self.generate_food_position()
|
||||||
|
|
||||||
|
def draw_snake(self):
|
||||||
|
for segment in self.snake.position:
|
||||||
|
pygame.draw.rect(self.screen, GREEN, (segment[0], segment[1], GRID_SIZE, GRID_SIZE))
|
||||||
|
|
||||||
|
def draw_food(self):
|
||||||
|
pygame.draw.rect(self.screen, RED, (self.food.position[0], self.food.position[1], GRID_SIZE, GRID_SIZE))
|
||||||
|
|
||||||
|
|
||||||
|
class Snake:
|
||||||
|
def __init__(self, position, direction, length):
|
||||||
|
self.position = position
|
||||||
|
self.direction = direction
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def move(self):
|
||||||
|
head_x, head_y = self.position[0]
|
||||||
|
new_dir_x, new_dir_y = self.direction
|
||||||
|
new_head = (head_x + new_dir_x * GRID_SIZE, head_y + new_dir_y * GRID_SIZE)
|
||||||
|
self.position.insert(0, new_head)
|
||||||
|
if len(self.position) > self.length:
|
||||||
|
self.position.pop()
|
||||||
|
|
||||||
|
def change_direction(self, new_direction):
|
||||||
|
self.direction = new_direction
|
||||||
|
|
||||||
|
def grow(self):
|
||||||
|
self.length += 1
|
||||||
|
|
||||||
|
|
||||||
|
class Food:
|
||||||
|
def __init__(self, position):
|
||||||
|
self.position = position
|
||||||
|
|
||||||
|
def spawn(self, board_size):
|
||||||
|
pass # This method is not used in this implementation
|
|
@ -0,0 +1,61 @@
|
||||||
|
import pygame
|
||||||
|
import sys
|
||||||
|
from pygame.locals import *
|
||||||
|
|
||||||
|
class Snake:
|
||||||
|
def __init__(self, start_position, length=3):
|
||||||
|
self.body = [start_position] * length
|
||||||
|
self.direction = (0, 1) # Initially, the snake is moving to the right
|
||||||
|
|
||||||
|
def move(self):
|
||||||
|
head_x, head_y = self.body[0]
|
||||||
|
dir_x, dir_y = self.direction
|
||||||
|
new_head = (head_x + dir_x, head_y + dir_y)
|
||||||
|
self.body.insert(0, new_head)
|
||||||
|
self.body.pop()
|
||||||
|
|
||||||
|
def change_direction(self, new_direction):
|
||||||
|
if (new_direction[0] == -self.direction[0] and new_direction[1] == -self.direction[1]):
|
||||||
|
return
|
||||||
|
self.direction = new_direction
|
||||||
|
|
||||||
|
def grow(self):
|
||||||
|
self.body.append(self.body[-1])
|
||||||
|
|
||||||
|
# Initialize Pygame
|
||||||
|
pygame.init()
|
||||||
|
window_size = (400, 400)
|
||||||
|
screen = pygame.display.set_mode(window_size)
|
||||||
|
clock = pygame.time.Clock()
|
||||||
|
|
||||||
|
# Create a snake at position (20, 20)
|
||||||
|
snake = Snake((20, 20))
|
||||||
|
|
||||||
|
# Main game loop
|
||||||
|
while True:
|
||||||
|
for event in pygame.event.get():
|
||||||
|
if event.type == QUIT:
|
||||||
|
pygame.quit()
|
||||||
|
sys.exit()
|
||||||
|
elif event.type == KEYDOWN:
|
||||||
|
if event.key == K_UP:
|
||||||
|
snake.change_direction((0, -1))
|
||||||
|
elif event.key == K_DOWN:
|
||||||
|
snake.change_direction((0, 1))
|
||||||
|
elif event.key == K_LEFT:
|
||||||
|
snake.change_direction((-1, 0))
|
||||||
|
elif event.key == K_RIGHT:
|
||||||
|
snake.change_direction((1, 0))
|
||||||
|
|
||||||
|
# Move the snake
|
||||||
|
snake.move()
|
||||||
|
|
||||||
|
# Clear screen
|
||||||
|
screen.fill((0, 0, 0))
|
||||||
|
# Draw the snake (for simplicity, we will just print its position)
|
||||||
|
for part in snake.body:
|
||||||
|
print(part) # In a real game, you would draw the snake on the screen here
|
||||||
|
|
||||||
|
# Update display
|
||||||
|
pygame.display.update()
|
||||||
|
clock.tick(10) # Limit the frame rate to 10 FPS
|
|
@ -0,0 +1,34 @@
|
||||||
|
Snake Game Architecture
|
||||||
|
======================
|
||||||
|
|
||||||
|
- Game Class
|
||||||
|
- Responsibilities:
|
||||||
|
- Initialize game objects (snake, food, and board)
|
||||||
|
- Main game loop
|
||||||
|
- Handle user input for snake direction
|
||||||
|
- Update the game state (move snake, check for collisions, etc.)
|
||||||
|
- Keep track of the score
|
||||||
|
- Display the game on the screen
|
||||||
|
|
||||||
|
- Snake Class
|
||||||
|
- Attributes:
|
||||||
|
- Position (list of tuples representing the coordinates of each segment)
|
||||||
|
- Direction (tuple representing the current direction of movement)
|
||||||
|
- Length (integer representing the length of the snake)
|
||||||
|
- Methods:
|
||||||
|
- Move (update the position of each segment)
|
||||||
|
- Change direction (based on user input)
|
||||||
|
- Grow (increase the length of the snake when it eats food)
|
||||||
|
|
||||||
|
- Food Class
|
||||||
|
- Attributes:
|
||||||
|
- Position (tuple representing the coordinates of the food)
|
||||||
|
- Methods:
|
||||||
|
- Spawn (randomly generate a new position for the food)
|
||||||
|
|
||||||
|
- Board Class (Optional)
|
||||||
|
- Attributes:
|
||||||
|
- Size (width and height of the game board)
|
||||||
|
- Methods:
|
||||||
|
- Draw (draw the boundaries of the game board)
|
||||||
|
- Check collision (check if the snake has collided with the walls or itself)
|
|
@ -168,7 +168,7 @@ struct ToolError : ToolResult {
|
||||||
|
|
||||||
// Execute the tool with given parameters.
|
// Execute the tool with given parameters.
|
||||||
struct BaseTool {
|
struct BaseTool {
|
||||||
inline static std::set<std::string> special_tool_name = {"terminate", "planning"};
|
inline static std::set<std::string> special_tool_name = {"terminate", "planning", "fact_extract", "memory_update"};
|
||||||
|
|
||||||
std::string name;
|
std::string name;
|
||||||
std::string description;
|
std::string description;
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
#ifndef HUMANUS_TOOL_FACT_EXTRACT_H
|
||||||
|
#define HUMANUS_TOOL_FACT_EXTRACT_H
|
||||||
|
|
||||||
|
namespace humanus {
|
||||||
|
|
||||||
|
struct FactExtract : BaseTool {
|
||||||
|
inline static const std::string name_ = "fact_extract";
|
||||||
|
inline static const std::string description_ = "Extract facts and store them in a long-term memory.";
|
||||||
|
inline static const json parameters_ = json::parse(R"json({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"facts": {
|
||||||
|
"description": "List of facts to extract and store.",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["facts"],
|
||||||
|
"additionalProperties": false
|
||||||
|
})json");
|
||||||
|
|
||||||
|
FactExtract() : BaseTool(name_, description_, parameters_) {}
|
||||||
|
|
||||||
|
ToolResult execute(const json& arguments) override {
|
||||||
|
return ToolResult(arguments["facts"]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // HUMANUS_TOOL_FACT_EXTRACT_H
|
|
@ -0,0 +1,59 @@
|
||||||
|
#ifndef HUMANUS_TOOL_MEMORY_UPDATE_H
|
||||||
|
#define HUMANUS_TOOL_MEMORY_UPDATE_H
|
||||||
|
|
||||||
|
namespace humanus {
|
||||||
|
|
||||||
|
struct MemoryUpdate : BaseTool {
|
||||||
|
inline static const std::string name_ = "memory_update";
|
||||||
|
inline static const std::string description_ = "Compare newly retrieved facts with the existing memory. For each new fact, decide whether to:\n- ADD: Add it to the memory as a new element\n- UPDATE: Update an existing memory element\n- DELETE: Delete an existing memory element\n- NONE: Make no change (if the fact is already present or irrelevant)";
|
||||||
|
inline static const json parameters = json::parse(R"json({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"memory": {
|
||||||
|
"description": "List of memory operations.",
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {
|
||||||
|
"description": "Unique integer ID of the memory item, required by event UPDATE and DELETE",
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"text": {
|
||||||
|
"description": "Plain text fact to ADD, UPDATE or DELETE",
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"event": {
|
||||||
|
"description": "The type of the operation",
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"ADD",
|
||||||
|
"UPDATE",
|
||||||
|
"DELETE",
|
||||||
|
"NONE"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"text",
|
||||||
|
"event"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"memory"
|
||||||
|
],
|
||||||
|
"additionalProperties": false
|
||||||
|
})json");
|
||||||
|
|
||||||
|
MemoryUpdate() : BaseTool(name_, description_, parameters) {}
|
||||||
|
|
||||||
|
ToolResult execute(const json& arguments) override {
|
||||||
|
return ToolResult(arguments["memory"]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace humanus
|
||||||
|
|
||||||
|
#endif // HUMANUS_TOOL_MEMORY_UPDATE_H
|
Loading…
Reference in New Issue