humanus.cpp/tokenizer/bpe.h

246 lines
8.6 KiB
C
Raw Normal View History

#ifndef HUMANUS_TOKENIZER_BPE_H
#define HUMANUS_TOKENIZER_BPE_H
#include "base.h"
#include "../mcp/common/base64.hpp"
#include <unordered_map>
#include <queue>
#include <fstream>
#include <sstream>
#include <iostream>
#include <algorithm>
#include <memory>
#include <utility>
#include <functional>
#include <limits>
namespace humanus {
/**
* @brief BPE (Byte Pair Encoding) Tokenizer
*
* BPE (Byte Pair Encoding)
* 使tiktoken
* 使BPE
*/
class BPETokenizer : public BaseTokenizer {
private:
// 辅助结构用于哈希pair
struct PairHash {
template <class T1, class T2>
std::size_t operator() (const std::pair<T1, T2>& pair) const {
return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
}
};
// 字节对到对应token ID的映射
std::unordered_map<std::string, size_t> encoder;
// token ID到字节对的映射
std::unordered_map<size_t, std::string> decoder;
// 合并优先级映射,优先级越小越优先合并
std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash> merge_ranks;
/**
* @brief base64
* @param encoded base64
* @return
*/
std::string base64_decode(const std::string& encoded) const {
if (encoded.empty()) return "";
return base64::decode(encoded);
}
/**
* @brief
*
* merge_rankstoken
* rank
*/
struct MergeComparator {
const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& ranks;
MergeComparator(const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& r)
: ranks(r) {}
bool operator()(const std::pair<std::pair<std::string, std::string>, size_t>& a,
const std::pair<std::pair<std::string, std::string>, size_t>& b) const {
// 首先按照merge_ranks比较如果不存在则使用最大值
size_t rank_a = ranks.count(a.first) ? ranks.at(a.first) : std::numeric_limits<size_t>::max();
size_t rank_b = ranks.count(b.first) ? ranks.at(b.first) : std::numeric_limits<size_t>::max();
// 优先队列是最大堆所以我们需要反向比较较小的rank优先级更高
return rank_a > rank_b;
}
};
public:
/**
* @brief tiktokenBPE tokenizer
* @param tokenizer_path tiktoken
*
* base64tokentoken ID
* : "IQ== 0""IQ=="base64token0ID
*/
BPETokenizer(const std::string& tokenizer_path) {
std::ifstream file(tokenizer_path);
if (!file.is_open()) {
throw std::runtime_error("无法打开tokenizer文件: " + tokenizer_path);
}
std::string line;
while (std::getline(file, line)) {
std::istringstream iss(line);
std::string token_base64;
size_t rank;
if (iss >> token_base64 >> rank) {
std::string token = base64_decode(token_base64);
// 存储token和其ID的映射关系
encoder[token] = rank;
decoder[rank] = token;
}
}
// 构建merge_ranks
build_merge_ranks();
}
/**
* @brief
*
* tokens
* 1tokens
*
*/
void build_merge_ranks() {
// 对于tiktoken格式我们可以利用编码中长度>1的token来构建合并规则
for (const auto& [token, id] : encoder) {
if (token.length() <= 1) continue;
// 尝试所有可能的分割点
for (size_t i = 1; i < token.length(); ++i) {
std::string first = token.substr(0, i);
std::string second = token.substr(i);
// 如果两个部分都在词汇表中,假设这是一个有效的合并规则
if (encoder.count(first) && encoder.count(second)) {
// 使用ID作为优先级 - 较小的ID表示更高的优先级
merge_ranks[{first, second}] = id;
}
}
}
}
/**
* @brief
* @param ranks
*/
void set_merge_ranks(const std::unordered_map<std::pair<std::string, std::string>, size_t, PairHash>& ranks) {
merge_ranks = ranks;
}
/**
* @brief BPE
* @param text
* @return token IDs
*
* 使BPE
* 1.
* 2. 使merge_rankstoken
* 3. tokensIDs
*/
std::vector<size_t> encode(const std::string& text) const override {
if (text.empty()) {
return {};
}
// 将文本分解为单个字符tokens
std::vector<std::string> tokens;
for (unsigned char c : text) {
tokens.push_back(std::string(1, c));
}
// 使用优先队列执行BPE合并
while (tokens.size() > 1) {
// 构建优先队列,用于选择最高优先级的合并对
using MergePair = std::pair<std::pair<std::string, std::string>, size_t>;
MergeComparator comparator(merge_ranks);
std::priority_queue<MergePair, std::vector<MergePair>, MergeComparator> merge_candidates(comparator);
// 查找所有可能的合并对
for (size_t i = 0; i < tokens.size() - 1; ++i) {
std::pair<std::string, std::string> pair = {tokens[i], tokens[i+1]};
if (merge_ranks.count(pair)) {
merge_candidates.push({pair, i});
}
}
// 如果没有可合并的对,退出循环
if (merge_candidates.empty()) {
break;
}
// 执行优先级最高的合并(优先队列中排在最前面的)
auto top_merge = merge_candidates.top();
auto pair = top_merge.first; // 要合并的token对
size_t pos = top_merge.second; // 要合并的位置
// 合并token
std::string merged_token = pair.first + pair.second;
tokens[pos] = merged_token;
tokens.erase(tokens.begin() + pos + 1);
}
// 将tokens转换为IDs
std::vector<size_t> ids;
ids.reserve(tokens.size());
for (const auto& token : tokens) {
auto it = encoder.find(token);
if (it != encoder.end()) {
ids.push_back(it->second);
}
// 未知token将被跳过
}
return ids;
}
/**
* @brief BPE tokens
* @param tokens token IDs
* @return
*
* token IDs
* IDtoken
*/
std::string decode(const std::vector<size_t>& tokens) const override {
std::string result;
result.reserve(tokens.size() * 2); // 预估大小,避免频繁重新分配
for (size_t id : tokens) {
auto it = decoder.find(id);
if (it != decoder.end()) {
result += it->second;
}
// 未知ID将被跳过
}
return result;
}
/**
* @brief tiktokenBPE
* @param file_path tiktoken
* @return BPE tokenizer
*/
static std::shared_ptr<BPETokenizer> load_from_tiktoken(const std::string& file_path) {
return std::make_shared<BPETokenizer>(file_path);
}
};
} // namespace humanus
#endif // HUMANUS_TOKENIZER_BPE_H