246 lines
8.6 KiB
C++
246 lines
8.6 KiB
C++
#ifndef HUMANUS_TOKENIZER_BPE_H
|
||
#define HUMANUS_TOKENIZER_BPE_H
|
||
|
||
#include "base.h"
|
||
#include "../mcp/common/base64.hpp"
|
||
#include <unordered_map>
|
||
#include <queue>
|
||
#include <fstream>
|
||
#include <sstream>
|
||
#include <iostream>
|
||
#include <algorithm>
|
||
#include <memory>
|
||
#include <utility>
|
||
#include <functional>
|
||
#include <limits>
|
||
|
||
namespace humanus {
|
||
|
||
/**
|
||
* @brief BPE (Byte Pair Encoding) Tokenizer实现
|
||
*
|
||
* 这个类实现了基于BPE (Byte Pair Encoding)算法的分词器,
|
||
* 使用tiktoken格式的词汇表和合并规则文件。
|
||
* 使用优先队列进行高效的BPE合并操作。
|
||
*/
|
||
class BPETokenizer : public BaseTokenizer {
|
||
private:
|
||
// 辅助结构,用于哈希pair
|
||
struct PairHash {
|
||
template <class T1, class T2>
|
||
std::size_t operator() (const std::pair<T1, T2>& pair) const {
|
||
return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
|
||
}
|
||
};
|
||
|
||
// 字节对到对应token ID的映射
|
||
std::unordered_map<std::string, size_t> encoder;
|
||
// token ID到字节对的映射
|
||
std::unordered_map<size_t, std::string> decoder;
|
||
|
||
// 合并优先级映射,优先级越小越优先合并
|
||
std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash> merge_ranks;
|
||
|
||
/**
|
||
* @brief 解码base64编码的字符串
|
||
* @param encoded base64编码的字符串
|
||
* @return 解码后的字符串
|
||
*/
|
||
std::string base64_decode(const std::string& encoded) const {
|
||
if (encoded.empty()) return "";
|
||
return base64::decode(encoded);
|
||
}
|
||
|
||
/**
|
||
* @brief 用于优先队列的比较函数
|
||
*
|
||
* 根据merge_ranks中的优先级比较两个token对。
|
||
* 优先级较低的值(即较小的rank值)具有更高的优先级。
|
||
*/
|
||
struct MergeComparator {
|
||
const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& ranks;
|
||
|
||
MergeComparator(const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& r)
|
||
: ranks(r) {}
|
||
|
||
bool operator()(const std::pair<std::pair<std::string, std::string>, size_t>& a,
|
||
const std::pair<std::pair<std::string, std::string>, size_t>& b) const {
|
||
// 首先按照merge_ranks比较,如果不存在则使用最大值
|
||
size_t rank_a = ranks.count(a.first) ? ranks.at(a.first) : std::numeric_limits<size_t>::max();
|
||
size_t rank_b = ranks.count(b.first) ? ranks.at(b.first) : std::numeric_limits<size_t>::max();
|
||
|
||
// 优先队列是最大堆,所以我们需要反向比较(较小的rank优先级更高)
|
||
return rank_a > rank_b;
|
||
}
|
||
};
|
||
|
||
public:
|
||
/**
|
||
* @brief 从tiktoken格式文件构造BPE tokenizer
|
||
* @param tokenizer_path tiktoken格式词汇表文件的路径
|
||
*
|
||
* 文件格式:每行包含一个base64编码的token和对应的token ID
|
||
* 例如: "IQ== 0",其中"IQ=="是base64编码的token,0是对应的ID
|
||
*/
|
||
BPETokenizer(const std::string& tokenizer_path) {
|
||
std::ifstream file(tokenizer_path);
|
||
if (!file.is_open()) {
|
||
throw std::runtime_error("无法打开tokenizer文件: " + tokenizer_path);
|
||
}
|
||
|
||
std::string line;
|
||
while (std::getline(file, line)) {
|
||
std::istringstream iss(line);
|
||
std::string token_base64;
|
||
size_t rank;
|
||
|
||
if (iss >> token_base64 >> rank) {
|
||
std::string token = base64_decode(token_base64);
|
||
|
||
// 存储token和其ID的映射关系
|
||
encoder[token] = rank;
|
||
decoder[rank] = token;
|
||
}
|
||
}
|
||
|
||
// 构建merge_ranks
|
||
build_merge_ranks();
|
||
}
|
||
|
||
/**
|
||
* @brief 构建合并优先级映射
|
||
*
|
||
* 利用词汇表中的tokens推断可能的合并规则。
|
||
* 对于长度大于1的tokens,尝试所有可能的分割,如果分割后的两个部分也在词汇表中,
|
||
* 则假设这是一个有效的合并规则。
|
||
*/
|
||
void build_merge_ranks() {
|
||
// 对于tiktoken格式,我们可以利用编码中长度>1的token来构建合并规则
|
||
for (const auto& [token, id] : encoder) {
|
||
if (token.length() <= 1) continue;
|
||
|
||
// 尝试所有可能的分割点
|
||
for (size_t i = 1; i < token.length(); ++i) {
|
||
std::string first = token.substr(0, i);
|
||
std::string second = token.substr(i);
|
||
|
||
// 如果两个部分都在词汇表中,假设这是一个有效的合并规则
|
||
if (encoder.count(first) && encoder.count(second)) {
|
||
// 使用ID作为优先级 - 较小的ID表示更高的优先级
|
||
merge_ranks[{first, second}] = id;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/**
|
||
* @brief 设置合并优先级
|
||
* @param ranks 新的合并优先级映射
|
||
*/
|
||
void set_merge_ranks(const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& ranks) {
|
||
merge_ranks = ranks;
|
||
}
|
||
|
||
/**
|
||
* @brief 执行BPE编码
|
||
* @param text 要编码的文本
|
||
* @return 编码后的token IDs
|
||
*
|
||
* 该方法使用BPE算法对输入文本进行编码。
|
||
* 1. 首先将文本分解为单个字节
|
||
* 2. 使用优先队列根据merge_ranks中的优先级对相邻token进行合并
|
||
* 3. 将最终的tokens转换为对应的IDs
|
||
*/
|
||
std::vector<size_t> encode(const std::string& text) const override {
|
||
if (text.empty()) {
|
||
return {};
|
||
}
|
||
|
||
// 将文本分解为单个字符tokens
|
||
std::vector<std::string> tokens;
|
||
for (unsigned char c : text) {
|
||
tokens.push_back(std::string(1, c));
|
||
}
|
||
|
||
// 使用优先队列执行BPE合并
|
||
while (tokens.size() > 1) {
|
||
// 构建优先队列,用于选择最高优先级的合并对
|
||
using MergePair = std::pair<std::pair<std::string, std::string>, size_t>;
|
||
MergeComparator comparator(merge_ranks);
|
||
std::priority_queue<MergePair, std::vector<MergePair>, MergeComparator> merge_candidates(comparator);
|
||
|
||
// 查找所有可能的合并对
|
||
for (size_t i = 0; i < tokens.size() - 1; ++i) {
|
||
std::pair<std::string, std::string> pair = {tokens[i], tokens[i+1]};
|
||
if (merge_ranks.count(pair)) {
|
||
merge_candidates.push({pair, i});
|
||
}
|
||
}
|
||
|
||
// 如果没有可合并的对,退出循环
|
||
if (merge_candidates.empty()) {
|
||
break;
|
||
}
|
||
|
||
// 执行优先级最高的合并(优先队列中排在最前面的)
|
||
auto top_merge = merge_candidates.top();
|
||
auto pair = top_merge.first; // 要合并的token对
|
||
size_t pos = top_merge.second; // 要合并的位置
|
||
|
||
// 合并token
|
||
std::string merged_token = pair.first + pair.second;
|
||
tokens[pos] = merged_token;
|
||
tokens.erase(tokens.begin() + pos + 1);
|
||
}
|
||
|
||
// 将tokens转换为IDs
|
||
std::vector<size_t> ids;
|
||
ids.reserve(tokens.size());
|
||
for (const auto& token : tokens) {
|
||
auto it = encoder.find(token);
|
||
if (it != encoder.end()) {
|
||
ids.push_back(it->second);
|
||
}
|
||
// 未知token将被跳过
|
||
}
|
||
|
||
return ids;
|
||
}
|
||
|
||
/**
|
||
* @brief 解码BPE tokens
|
||
* @param tokens 要解码的token IDs
|
||
* @return 解码后的文本
|
||
*
|
||
* 该方法将编码后的token IDs转换回原始文本。
|
||
* 简单地将每个ID对应的token字符串连接起来。
|
||
*/
|
||
std::string decode(const std::vector<size_t>& tokens) const override {
|
||
std::string result;
|
||
result.reserve(tokens.size() * 2); // 预估大小,避免频繁重新分配
|
||
|
||
for (size_t id : tokens) {
|
||
auto it = decoder.find(id);
|
||
if (it != decoder.end()) {
|
||
result += it->second;
|
||
}
|
||
// 未知ID将被跳过
|
||
}
|
||
|
||
return result;
|
||
}
|
||
|
||
/**
|
||
* @brief 从tiktoken文件加载BPE词汇和合并规则
|
||
* @param file_path tiktoken格式文件的路径
|
||
* @return 共享指针指向创建的BPE tokenizer
|
||
*/
|
||
static std::shared_ptr<BPETokenizer> load_from_tiktoken(const std::string& file_path) {
|
||
return std::make_shared<BPETokenizer>(file_path);
|
||
}
|
||
};
|
||
|
||
} // namespace humanus
|
||
|
||
#endif // HUMANUS_TOKENIZER_BPE_H
|