humanus.cpp/tokenizer/utils.cpp

99 lines
3.7 KiB
C++

#include "utils.h"
namespace humanus {
int num_tokens_from_messages(const BaseTokenizer& tokenizer, const json& messages) {
// TODO: configure the magic number
static const int tokens_per_message = 3;
static const int tokens_per_name = 1;
static const int tokens_per_image = 1024;
int num_tokens = 0;
json messages_;
if (messages.is_object()) {
messages_ = json::array({messages});
} else {
messages_ = messages;
}
for (const auto& message : messages_) {
num_tokens += tokens_per_message;
for (const auto& [key, value] : message.items()) {
if (value.is_string()) {
num_tokens += tokenizer.encode(value.get<std::string>()).size();
} else if (value.is_array()) {
for (const auto& item : value) {
if (item.contains("text")) {
num_tokens += tokenizer.encode(item.at("text").get<std::string>()).size();
} else if (item.contains("image_url")) {
num_tokens += tokens_per_image;
}
}
}
if (key == "name") {
num_tokens += tokens_per_name;
}
}
}
num_tokens += 3; // every reply is primed with <|start|>assistant<|message|>
return num_tokens;
}
int num_tokens_for_tools(const BaseTokenizer& tokenizer, const json& tools, const json& messages) {
// TODO: configure the magic number
static const int tool_init = 10;
static const int prop_init = 3;
static const int prop_key = 3;
static const int enum_init = 3;
static const int enum_item = 3;
static const int tool_end = 12;
int tool_token_count = 0;
if (!tools.empty()) {
for (const auto& tool : tools) {
tool_token_count += tool_init; // Add tokens for start of each tool
auto function = tool["function"];
auto f_name = function["name"].get<std::string>();
auto f_desc = function["description"].get<std::string>();
if (f_desc.back() == '.') {
f_desc.pop_back();
}
auto line = f_name + ":" + f_desc;
tool_token_count += tokenizer.encode(line).size();
if (function["parameters"].contains("properties")) {
tool_token_count += prop_init; // Add tokens for start of each property
for (const auto& [key, value] : function["parameters"]["properties"].items()) {
tool_token_count += prop_key; // Add tokens for each set property
auto p_name = key;
auto p_type = value["type"].get<std::string>();
auto p_desc = value["description"].get<std::string>();
if (value.contains("enum")) {
tool_token_count += enum_init; // Add tokens if property has enum list
for (const auto& item : value["enum"]) {
tool_token_count += enum_item;
tool_token_count += tokenizer.encode(item.get<std::string>()).size();
}
}
if (p_desc.back() == '.') {
p_desc.pop_back();
}
auto line = p_name + ":" + p_type + ":" + p_desc;
tool_token_count += tokenizer.encode(line).size();
}
}
}
tool_token_count += tool_end;
}
auto messages_token_count = num_tokens_from_messages(tokenizer, messages);
auto total_token_count = tool_token_count + messages_token_count;
return total_token_count;
}
}