humanus.cpp/tokenizer/bpe.h

246 lines
8.6 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#ifndef HUMANUS_TOKENIZER_BPE_H
#define HUMANUS_TOKENIZER_BPE_H
#include "base.h"
#include "../mcp/common/base64.hpp"
#include <unordered_map>
#include <queue>
#include <fstream>
#include <sstream>
#include <iostream>
#include <algorithm>
#include <memory>
#include <utility>
#include <functional>
#include <limits>
namespace humanus {
/**
* @brief BPE (Byte Pair Encoding) Tokenizer实现
*
* 这个类实现了基于BPE (Byte Pair Encoding)算法的分词器,
* 使用tiktoken格式的词汇表和合并规则文件。
* 使用优先队列进行高效的BPE合并操作。
*/
class BPETokenizer : public BaseTokenizer {
private:
// 辅助结构用于哈希pair
struct PairHash {
template <class T1, class T2>
std::size_t operator() (const std::pair<T1, T2>& pair) const {
return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
}
};
// 字节对到对应token ID的映射
std::unordered_map<std::string, size_t> encoder;
// token ID到字节对的映射
std::unordered_map<size_t, std::string> decoder;
// 合并优先级映射,优先级越小越优先合并
std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash> merge_ranks;
/**
* @brief 解码base64编码的字符串
* @param encoded base64编码的字符串
* @return 解码后的字符串
*/
std::string base64_decode(const std::string& encoded) const {
if (encoded.empty()) return "";
return base64::decode(encoded);
}
/**
* @brief 用于优先队列的比较函数
*
* 根据merge_ranks中的优先级比较两个token对。
* 优先级较低的值即较小的rank值具有更高的优先级。
*/
struct MergeComparator {
const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& ranks;
MergeComparator(const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& r)
: ranks(r) {}
bool operator()(const std::pair<std::pair<std::string, std::string>, size_t>& a,
const std::pair<std::pair<std::string, std::string>, size_t>& b) const {
// 首先按照merge_ranks比较如果不存在则使用最大值
size_t rank_a = ranks.count(a.first) ? ranks.at(a.first) : std::numeric_limits<size_t>::max();
size_t rank_b = ranks.count(b.first) ? ranks.at(b.first) : std::numeric_limits<size_t>::max();
// 优先队列是最大堆所以我们需要反向比较较小的rank优先级更高
return rank_a > rank_b;
}
};
public:
/**
* @brief 从tiktoken格式文件构造BPE tokenizer
* @param tokenizer_path tiktoken格式词汇表文件的路径
*
* 文件格式每行包含一个base64编码的token和对应的token ID
* 例如: "IQ== 0",其中"IQ=="是base64编码的token0是对应的ID
*/
BPETokenizer(const std::string& tokenizer_path) {
std::ifstream file(tokenizer_path);
if (!file.is_open()) {
throw std::runtime_error("无法打开tokenizer文件: " + tokenizer_path);
}
std::string line;
while (std::getline(file, line)) {
std::istringstream iss(line);
std::string token_base64;
size_t rank;
if (iss >> token_base64 >> rank) {
std::string token = base64_decode(token_base64);
// 存储token和其ID的映射关系
encoder[token] = rank;
decoder[rank] = token;
}
}
// 构建merge_ranks
build_merge_ranks();
}
/**
* @brief 构建合并优先级映射
*
* 利用词汇表中的tokens推断可能的合并规则。
* 对于长度大于1的tokens尝试所有可能的分割如果分割后的两个部分也在词汇表中
* 则假设这是一个有效的合并规则。
*/
void build_merge_ranks() {
// 对于tiktoken格式我们可以利用编码中长度>1的token来构建合并规则
for (const auto& [token, id] : encoder) {
if (token.length() <= 1) continue;
// 尝试所有可能的分割点
for (size_t i = 1; i < token.length(); ++i) {
std::string first = token.substr(0, i);
std::string second = token.substr(i);
// 如果两个部分都在词汇表中,假设这是一个有效的合并规则
if (encoder.count(first) && encoder.count(second)) {
// 使用ID作为优先级 - 较小的ID表示更高的优先级
merge_ranks[{first, second}] = id;
}
}
}
}
/**
* @brief 设置合并优先级
* @param ranks 新的合并优先级映射
*/
void set_merge_ranks(const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& ranks) {
merge_ranks = ranks;
}
/**
* @brief 执行BPE编码
* @param text 要编码的文本
* @return 编码后的token IDs
*
* 该方法使用BPE算法对输入文本进行编码。
* 1. 首先将文本分解为单个字节
* 2. 使用优先队列根据merge_ranks中的优先级对相邻token进行合并
* 3. 将最终的tokens转换为对应的IDs
*/
std::vector<size_t> encode(const std::string& text) const override {
if (text.empty()) {
return {};
}
// 将文本分解为单个字符tokens
std::vector<std::string> tokens;
for (unsigned char c : text) {
tokens.push_back(std::string(1, c));
}
// 使用优先队列执行BPE合并
while (tokens.size() > 1) {
// 构建优先队列,用于选择最高优先级的合并对
using MergePair = std::pair<std::pair<std::string, std::string>, size_t>;
MergeComparator comparator(merge_ranks);
std::priority_queue<MergePair, std::vector<MergePair>, MergeComparator> merge_candidates(comparator);
// 查找所有可能的合并对
for (size_t i = 0; i < tokens.size() - 1; ++i) {
std::pair<std::string, std::string> pair = {tokens[i], tokens[i+1]};
if (merge_ranks.count(pair)) {
merge_candidates.push({pair, i});
}
}
// 如果没有可合并的对,退出循环
if (merge_candidates.empty()) {
break;
}
// 执行优先级最高的合并(优先队列中排在最前面的)
auto top_merge = merge_candidates.top();
auto pair = top_merge.first; // 要合并的token对
size_t pos = top_merge.second; // 要合并的位置
// 合并token
std::string merged_token = pair.first + pair.second;
tokens[pos] = merged_token;
tokens.erase(tokens.begin() + pos + 1);
}
// 将tokens转换为IDs
std::vector<size_t> ids;
ids.reserve(tokens.size());
for (const auto& token : tokens) {
auto it = encoder.find(token);
if (it != encoder.end()) {
ids.push_back(it->second);
}
// 未知token将被跳过
}
return ids;
}
/**
* @brief 解码BPE tokens
* @param tokens 要解码的token IDs
* @return 解码后的文本
*
* 该方法将编码后的token IDs转换回原始文本。
* 简单地将每个ID对应的token字符串连接起来。
*/
std::string decode(const std::vector<size_t>& tokens) const override {
std::string result;
result.reserve(tokens.size() * 2); // 预估大小,避免频繁重新分配
for (size_t id : tokens) {
auto it = decoder.find(id);
if (it != decoder.end()) {
result += it->second;
}
// 未知ID将被跳过
}
return result;
}
/**
* @brief 从tiktoken文件加载BPE词汇和合并规则
* @param file_path tiktoken格式文件的路径
* @return 共享指针指向创建的BPE tokenizer
*/
static std::shared_ptr<BPETokenizer> load_from_tiktoken(const std::string& file_path) {
return std::make_shared<BPETokenizer>(file_path);
}
};
} // namespace humanus
#endif // HUMANUS_TOKENIZER_BPE_H