#ifndef HUMANUS_TOKENIZER_BPE_H #define HUMANUS_TOKENIZER_BPE_H #include "base.h" #include "../mcp/common/base64.hpp" #include #include #include #include #include #include #include #include #include #include namespace humanus { /** * @brief BPE (Byte Pair Encoding) Tokenizer实现 * * 这个类实现了基于BPE (Byte Pair Encoding)算法的分词器, * 使用tiktoken格式的词汇表和合并规则文件。 * 使用优先队列进行高效的BPE合并操作。 */ class BPETokenizer : public BaseTokenizer { private: // 辅助结构,用于哈希pair struct PairHash { template std::size_t operator() (const std::pair& pair) const { return std::hash()(pair.first) ^ std::hash()(pair.second); } }; // 字节对到对应token ID的映射 std::unordered_map encoder; // token ID到字节对的映射 std::unordered_map decoder; // 合并优先级映射,优先级越小越优先合并 std::unordered_map, size_t, PairHash> merge_ranks; /** * @brief 解码base64编码的字符串 * @param encoded base64编码的字符串 * @return 解码后的字符串 */ std::string base64_decode(const std::string& encoded) const { if (encoded.empty()) return ""; return base64::decode(encoded); } /** * @brief 用于优先队列的比较函数 * * 根据merge_ranks中的优先级比较两个token对。 * 优先级较低的值(即较小的rank值)具有更高的优先级。 */ struct MergeComparator { const std::unordered_map, size_t, PairHash>& ranks; MergeComparator(const std::unordered_map, size_t, PairHash>& r) : ranks(r) {} bool operator()(const std::pair, size_t>& a, const std::pair, size_t>& b) const { // 首先按照merge_ranks比较,如果不存在则使用最大值 size_t rank_a = ranks.count(a.first) ? ranks.at(a.first) : std::numeric_limits::max(); size_t rank_b = ranks.count(b.first) ? ranks.at(b.first) : std::numeric_limits::max(); // 优先队列是最大堆,所以我们需要反向比较(较小的rank优先级更高) return rank_a > rank_b; } }; public: /** * @brief 从tiktoken格式文件构造BPE tokenizer * @param tokenizer_path tiktoken格式词汇表文件的路径 * * 文件格式:每行包含一个base64编码的token和对应的token ID * 例如: "IQ== 0",其中"IQ=="是base64编码的token,0是对应的ID */ BPETokenizer(const std::string& tokenizer_path) { std::ifstream file(tokenizer_path); if (!file.is_open()) { throw std::runtime_error("无法打开tokenizer文件: " + tokenizer_path); } std::string line; while (std::getline(file, line)) { std::istringstream iss(line); std::string token_base64; size_t rank; if (iss >> token_base64 >> rank) { std::string token = base64_decode(token_base64); // 存储token和其ID的映射关系 encoder[token] = rank; decoder[rank] = token; } } // 构建merge_ranks build_merge_ranks(); } /** * @brief 构建合并优先级映射 * * 利用词汇表中的tokens推断可能的合并规则。 * 对于长度大于1的tokens,尝试所有可能的分割,如果分割后的两个部分也在词汇表中, * 则假设这是一个有效的合并规则。 */ void build_merge_ranks() { // 对于tiktoken格式,我们可以利用编码中长度>1的token来构建合并规则 for (const auto& [token, id] : encoder) { if (token.length() <= 1) continue; // 尝试所有可能的分割点 for (size_t i = 1; i < token.length(); ++i) { std::string first = token.substr(0, i); std::string second = token.substr(i); // 如果两个部分都在词汇表中,假设这是一个有效的合并规则 if (encoder.count(first) && encoder.count(second)) { // 使用ID作为优先级 - 较小的ID表示更高的优先级 merge_ranks[{first, second}] = id; } } } } /** * @brief 设置合并优先级 * @param ranks 新的合并优先级映射 */ void set_merge_ranks(const std::unordered_map, size_t, PairHash>& ranks) { merge_ranks = ranks; } /** * @brief 执行BPE编码 * @param text 要编码的文本 * @return 编码后的token IDs * * 该方法使用BPE算法对输入文本进行编码。 * 1. 首先将文本分解为单个字节 * 2. 使用优先队列根据merge_ranks中的优先级对相邻token进行合并 * 3. 将最终的tokens转换为对应的IDs */ std::vector encode(const std::string& text) const override { if (text.empty()) { return {}; } // 将文本分解为单个字符tokens std::vector tokens; for (unsigned char c : text) { tokens.push_back(std::string(1, c)); } // 使用优先队列执行BPE合并 while (tokens.size() > 1) { // 构建优先队列,用于选择最高优先级的合并对 using MergePair = std::pair, size_t>; MergeComparator comparator(merge_ranks); std::priority_queue, MergeComparator> merge_candidates(comparator); // 查找所有可能的合并对 for (size_t i = 0; i < tokens.size() - 1; ++i) { std::pair pair = {tokens[i], tokens[i+1]}; if (merge_ranks.count(pair)) { merge_candidates.push({pair, i}); } } // 如果没有可合并的对,退出循环 if (merge_candidates.empty()) { break; } // 执行优先级最高的合并(优先队列中排在最前面的) auto top_merge = merge_candidates.top(); auto pair = top_merge.first; // 要合并的token对 size_t pos = top_merge.second; // 要合并的位置 // 合并token std::string merged_token = pair.first + pair.second; tokens[pos] = merged_token; tokens.erase(tokens.begin() + pos + 1); } // 将tokens转换为IDs std::vector ids; ids.reserve(tokens.size()); for (const auto& token : tokens) { auto it = encoder.find(token); if (it != encoder.end()) { ids.push_back(it->second); } // 未知token将被跳过 } return ids; } /** * @brief 解码BPE tokens * @param tokens 要解码的token IDs * @return 解码后的文本 * * 该方法将编码后的token IDs转换回原始文本。 * 简单地将每个ID对应的token字符串连接起来。 */ std::string decode(const std::vector& tokens) const override { std::string result; result.reserve(tokens.size() * 2); // 预估大小,避免频繁重新分配 for (size_t id : tokens) { auto it = decoder.find(id); if (it != decoder.end()) { result += it->second; } // 未知ID将被跳过 } return result; } /** * @brief 从tiktoken文件加载BPE词汇和合并规则 * @param file_path tiktoken格式文件的路径 * @return 共享指针指向创建的BPE tokenizer */ static std::shared_ptr load_from_tiktoken(const std::string& file_path) { return std::make_shared(file_path); } }; } // namespace humanus #endif // HUMANUS_TOKENIZER_BPE_H