Byte Pair Encoding and Data Structures

Tokenization of input strings into sequences of words or sub-tokens is a central concept for modern Natural Language Processing techniques (NLP). This article focuses on a classic tokenization algorithm: Byte Pair Encoding (BPE) [1]. While resources describing the working principle of the algorithm are widely available, this article focuses on its implementation, illustrating how the choice of data structures impact the performance of a real-world NLP component. This article is complemented by a working Rust-based implementation, available on GitHub.

Introduction

Tokenization consists in splitting an input sequence (e.g., a sentence) into a sequence of tokens with a finite cardinality (for example words or individual characters). These tokens can then be encoded and processed further by machine learning components such as neural networks.

Halfway between word-based tokenization and character input, sub-word tokenization aims at combining the benefits of both approaches:

Sub-word tokenization works by splitting rare words into sub-word components instead of ignoring them or having a seldom-used dedicated representation for them. Ideally, these sub-word representations are semantically or syntactically meaningful, for example by splitting pre- or suffixes from words (unexpected may become [un, ##expected]). The resulting vocabulary size can be an order of magnitude smaller than for word-based segmentation while keeping common words as single entries - allowing to generate semantically-rich word embeddings for the downstream model.

These models usually have two algorithmic components:

While proper algorithmic design of the training phase is important to allow scaling to larger corpora, this is a one-off cost. This article focuses on the design and implementation of the prediction phase which is critical to limit the latency and operation cost of models when they are deployed and generate production. Following a previous article illustrating the SentencePiece unigram tokenization model, this article will focus on another ubiquitous algorithm: Byte Pair Encoding (BPE).

1. The Byte Pair Encoding (BPE) tokenizer

BPE is a morphological tokenizer that merges adjacent byte pairs based on their frequency in a training corpus. Based on a compression algorithm with the same name, BPE has been adapted to sub-word tokenization and can be thought of as a clustering algorithm [2]. A starting sequence of individual characters will be aggregated bottom-up based on frequencies learned during a training phase. When no further aggregation is possible, end the process and return the sub-words generated.

No unknown tokens can occur as an output of this process as long as all individual characters are present in the vocabulary (in the worst case, the sequence of individual characters will be returned). With a sufficiently large vocabulary, the most common words will be fully aggregated and the entire sequence of input characters will be returned as a single word token. Rare words will be returned as a list of sub-tokens that can no longer be aggregated.

Prediction phase (tokenization)

Even though the BPE model needs to be trained before it can be used, we will first describe how it is used to tokenize an input sequence at prediction time in this article. This will help build an intuition on how the algorithm decomposes words into sub-tokens.

The tokenization algorithm is as follows:

merges is a learned mapping of symbol pairs to score that is learned during a training phase (see next section). A Symbol is a sub-token of a given input, which may be a character, multiple consecutive characters that have been merged or the entire word. The following example illustrates how a merges vocabulary (mapping of symbols that may be aggregated with their frequency) is used to tokenize 2 words. The symbol pairs with higher frequencies are merged first:

{
  "e r": 25,
  "h e": 12,
  "l l": 11,
  "l o": 10,
  "he ll": 4,
  "lo w": 3,
  "hell o": 2 
}

Training phase

For computational efficiency, BPE training relies on a pre-tokenizer (e.g. whitespace tokenizer) in order to generate a dictionary with word frequencies (the original BPE algorithm does not allow cross-word merges). This word counter is used to initialize a counter of symbol pairs of adjacent characters that may be merged. After the count completes, the most frequent “symbol pair” is merged to create a new symbol in the dictionary, the symbol pairs counts in the corpus are updated and the process repeats until the target vocabulary size is reached. As such, BPE is an algorithm that grows its vocabulary at each iteration (in contrast with SentencePiece unigram’s model that prunes a large vocabulary at each iteration).

The entire training algorithm is rather straight-forward and given below (reproduced from [1]).

The algorithm above can be improved as described in [1] by updating data structures instead of re-computing the “merges” from scratch at each iteration. This article will however focus on the tokenization procedure called at prediction time.

2. Rust implementation(s) of the BPE algorithm

The rest of this article consists of a walk-through of a number of working implementations of the BPE tokenization algorithm in Rust. Starting from a “naive” implementation approach, improvements will be made to highlight pros and cons of some common data structures within the scope of BPE tokenization.

4 algorithms for the BPE tokenization will be presented:

a. Naive implementation

Let’s begin with a direct implementation of [algorithm 1], which consists in 2 main procedures within a loop:

  1. Find the best merge
  2. Apply the best merge

Let’s examine the algorithm complexity, where N represents the number of characters of an input text. Assuming a lookup in the merges mapping can be done in constant time (a fair assumption for a standard hashmap implementation), the FindBestPair procedure has a $O(N)$ complexity (N-1 pair candidates need to be checked for the first iteration). The MergeBestPair has a $O(N)$ complexity if the Merge operation can be done in constant time if we allow for multiple merges per iteration, otherwise $O(1)$. The main Tokenize procedure will iterate until no further merge can be found, which will result in N-1 merges in the “worst” case (if the entire word exists in the vocabulary). This results in a combined complexity of $O(N^{2})$. This motivates the pre-tokenization (e.g. whitespace splitting) step that usually precedes BPE tokenization, so that N is the typical length of a word rather than a full sentence/document.

Let’s look at a Rust implementation of this algorithm. The following code has been edited to focus on the critical parts of the algorithm. The full working version can be found at [3]. Let’s start by defining Symbols, the data structure representing a sub-token:

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Symbol {
    pub start_byte: usize,
    pub end_byte: usize,
}

A symbol contains information related to the start and end byte positions of the sub-token. This is lighter than working with string slices (and significantly faster than string clones). It is also convenient for the operations we expect on Symbols (merges and lookups). We will not manipulate Symbols on their own, but rather work on a collection of Symbols. The naive implementation, looping over the list of symbols indicates a Rust Vec may be appropriate:

pub struct SymbolArray {
    pub symbols: Vec<Symbol>,
}

The symbols are initialized from the characters of the input text (see algorithm 1, line 21). We will create a method from_text that populates the initial SymbolArray from a string slice. Note that we look up character byte indices so that we can handle characters that span over multiple UTF-8 bytes correctly.

We also implement a method for finding the best pair to merge (given a tokenizer/merge dictionary) that will return an optional position of the best pair in the SymbolArray. When this method return None, we know that the array can no longer be merged. Finally, we implement directly a method to merge a pair of symbols mutating the SymbolArray inplace. This method will take the best pair position as an input to directly insert the merged pair and remove the two parents.

impl SymbolArray {
    pub fn from_text(input_text: &str) -> Self {
        let mut symbols = Vec::new();
        for (character_start, character) in input_text.char_indices() {
            symbols.push(Symbol {
                start_byte: character_start,
                end_byte: character_start + character.len_utf8(),
            });
        }
        Self { symbols }
    }

    pub fn find_best_merge<T>(&self, input_text: &str, tokenizer: &T) -> Option<usize>
    where
        T: BpeTokenizer,
    {
        self.symbols
            .iter()
            .tuple_windows::<(&Symbol, &Symbol)>()
            .enumerate()
            .filter_map(|(pos, (first, second))| {
                tokenizer
                    .get_merge_score(first, second, input_text)
                    .map(|rank| (pos, rank))
            })
            .min_by_key(|(_, rank)| *rank)
            .map(|(pos, _)| pos)
    }

    pub fn merge_symbols(&mut self, best_pair_index: usize) -> Symbol {
        let new_symbol = Symbol {
            start_byte: self.symbols[best_pair_index].start_byte,
            end_byte: self.symbols[best_pair_index + 1].end_byte,
        };
        self.symbols.remove(best_pair_index + 1);
        self.symbols.remove(best_pair_index);
        self.symbols.insert(best_pair_index, new_symbol);
        new_symbol
    }
}

We still need to implement the actual tokenizer and tokenize procedure. This will pre-process the input text (replaces whitespaces by the corresponding encoding symbol in the pre-trained vocabulary), pre-populate a SymbolArray and identify/merge best symbol pairs until no further merge is possible. It then returns string slices with the sub-tokens. At no point the tokenizer creates a copy of the string input: it solely relies on byte positions and string slices.

pub struct NaiveBpeTokenizer {
    merges_vocab: MergesVocab,
}

impl BpeTokenizer for NaiveBpeTokenizer {

    fn tokenize<'a>(&self, input_text: &'a str) -> Vec<&'a str> {
        let (text, byte_mapping) = self.pre_process_text(input_text, '▁');

        let mut symbols = SymbolArray::from_text(text.as_str());
        while let Some(best_pair_index) = symbols.find_best_merge(text.as_str(), self) {
            symbols.merge_symbols(best_pair_index);
        }
        let mut output = Vec::new();
        for symbol in symbols {
            output.push(
                &input_text[byte_mapping[&symbol.start_byte]..byte_mapping[&symbol.end_byte]],
            );
        }
        output
    }
}

As mentioned previously, this algorithm has a $O(N^{2})$ complexity where N represents the number of characters in the input. This means that this algorithm will be unable to process long input text of the size of a sentence, paragraph or document. A common work-around consists in limiting the size of the input passed to the BPE tokenizer, for example by applying a whitespace splitting pre-tokenizer.

b. Pre-splitting naive implementation

The previous implementation can easily be extended to include a pre-tokenization step. This is actually the standard BPE implementation in several widely used packages, such as subword-nmt (Python) [4] or fastBPE (C++) [5]. By pre-tokenizing the sequence (for example whitespace splitting), one can effectively limit the average size of the inputs passed for BPE tokenization. A whitespace splitting will pass single words for processing. For most languages, the expected size of a word M is much smaller than the number of characters N in an average sentence or document (although after living in Germany for 10 years, the author realizes this hypothesis may be optimistic at times). The complexity of the tokenization for a word is $O(M^{2})$. Splitting the sequence into words using a simple whitespace/punctuation rule has a $O(N)$ complexity, resulting in a number of words that is at most N, meaning the algorithm complexity is $O(N) + O(N*M^{2}) = O(N)$ if $M \ll N$.

To implement this tokenizer, the Tokenizer has an additional method for whitespace/punctuation tokenization. Punctuation marks are returned as a single character word, and other words contain the leading whitespace character if applicable. The method returns a vector of string slices:

impl NaivePreSplitBpeTokenizer {
    fn split_whitespace_punctuation<'a>(
        &self,
        input_string: &'a str,
        whitespace_token: char,
    ) -> Vec<&'a str> {
        let mut output: Vec<&str> = Vec::new();
        let mut start: usize = 0;

        for (c_pos, c) in input_string.char_indices() {
            if c == whitespace_token {
                if start < c_pos {
                    output.push(&input_string[start..c_pos]);
                }
                start = c_pos;
            } else if c.is_ascii_punctuation() {
                if start < c_pos {
                    output.push(&input_string[start..c_pos]);
                }
                output.push(&input_string[c_pos..c_pos + c.len_utf8()]);
                start = c_pos + c.len_utf8();
            }
        }
        if start < input_string.len() {
            output.push(&input_string[start..]);
        }
        output
    }
}

The previous tokenize method can be modified to include the pre-tokenization step. This is essentially identical to the previous method, with the addition of the whitespace splitting and additional logic to keep track of the correct character offsets to return the sub-tokens:

impl BpeTokenizer for NaivePreSplitBpeTokenizer {
    fn get_merges_vocab(&self) -> &MergesVocab {
        &self.merges_vocab
    }

    fn tokenize<'a>(&self, input_text: &'a str) -> Vec<&'a str> {
        let whitespace_token = '▁';

        let (text, byte_mapping) = self.pre_process_text(input_text, whitespace_token);
        let split_texts = self.split_whitespace_punctuation(text.as_str(), whitespace_token);

        let mut output = Vec::new();
        let mut offset = 0;
        for split_text in split_texts {
            let mut symbols = SymbolArray::from_text(split_text);
            while let Some(best_pair_index) = symbols.find_best_merge(split_text, self) {
                symbols.merge_symbols(best_pair_index);
            }
            for symbol in symbols.symbols {
                output.push(
                    &input_text[byte_mapping[&(offset + symbol.start_byte)]
                        ..byte_mapping[&(offset + symbol.end_byte)]],
                );
            }
            offset += split_text.len();
        }
        output
    }
}

This implementation brings the expected complexity from $O(N^{2})$ to $O(N)$ which is the best that can be done since the input needs to be scanned at least once to perform tokenization. This implementation, however, has the following drawbacks:

The following sections will present two additional implementation that will aim at reducing the computational cost of the naive algorithm without relying on a pre-tokenization step.

c. Priority Queue + Binary Search Tree implementation

The nave algorithm suffers from a worst case $O(N^{2})$ complexity. This is due to up to N iterations of FindBestPair and MergeBestPair. Generally, an input may be made from up to N-1 valid pairs that will be merged, and the outer loop complexity can therefore not be optimized.

The MergeBestPair complexity can be reduced to the time required to select the symbols to merge based on the best pair information if a single merge is performed at every iteration. The symbols can be compared and ordered based on their position - and a Binary Search Tree implementation of the Symbols sequence allows finding elements in $O(\log(N))$ time instead of linear time for an array.

The FindBestPair can be optimized by noting that the entire symbol pair information does not need to be re-computed entirely at each iteration, but rather updated after every merge. Merging two symbols will have a local impact on the merge information: only the symbols preceding and succeeding the pair to be merged are impacted: we need to check if the preceding symbol and newly created merged symbol form a valid pair. Similarly, we need to check if the newly created merged symbol and next symbol form a valid pair. We still need a data structure to store the symbol pair information:

These requirements effectively describe a priority queue which will be used to build and maintain the set of SymbolPairs. Using a Min Heap allows performing both extract_min and insert operations in $O(\log(N))$ time.

By implementing the following data structures:

A worst-case complexity of $O(N(\log(N) + \log(N)))$ = $O(N\log(N))$ can be achieved, significantly better than the initial $O(N^{2})$. The maintenance of a priority queue for the SymbolPairs to process is mentioned in the SentencePiece article [6] and is used in the optimized BPE implementation of the C++ SentencePiece library [7]. Algorithm 1 (BPE tokenization) becomes:

Lines 17 to 21 initialize both the Symbols and SymbolPairs data structures. At each iteration, the best SymbolPair is popped from the priority queue (line 23). The check on line 27 is required as instead of manually removing invalid merges following a merge (if the first 2 elements of a triplets get merged, the pair information for the last 2 elements is no longer valid), we pop pairs from the SymbolPairs and then check their validity. If the pair is still valid, we proceed to merge and check if new pairs should be added (lines 31 and 32).

Similarly to the SymbolsArray, the Rust implementation for the Symbols Binary Search Tree contains a method to mutate itself and merge two symbols:

pub struct SymbolBTree {
    pub symbols: BTreeSet<Symbol>,
}

impl SymbolBTree {
    pub fn merge_symbols(&mut self, symbol_1: &Symbol, symbol_2: &Symbol) -> Symbol {
        self.symbols.remove(symbol_1);
        self.symbols.remove(symbol_2);
        let new_symbol = Symbol {
            start_byte: symbol_1.start_byte,
            end_byte: symbol_2.end_byte,
        };
        self.symbols.insert(new_symbol);
        new_symbol
    }
}

The Tokenizer contains methods to build and maintain a priority queue of SymbolPairs called “agenda”:

impl PriorityQueueBpeTokenizer {
    fn maybe_add_pair(
        &self,
        left_symbol: &Symbol,
        right_symbol: &Symbol,
        input_text: &str,
        agenda: &mut BinaryHeap<SymbolPair>,
    ) {
        let merged_text = &input_text[left_symbol.start_byte..right_symbol.end_byte];
        if let Some(&score) = self.merges_vocab.get(merged_text) {
            agenda.push(SymbolPair {
                left: *left_symbol,
                right: *right_symbol,
                score,
            })
        }
    }
}

impl BpeTokenizer for PriorityQueueBpeTokenizer {
    fn tokenize<'a>(&self, input_text: &'a str) -> Vec<&'a str> {
        let (text, byte_mapping) = self.pre_process_text(input_text, '▁');

        let mut symbols = SymbolBTree::from_text(text.as_str());
        let mut agenda: BinaryHeap<SymbolPair> = BinaryHeap::new();

        for (left_symbol, right_symbol) in symbols.iter().tuple_windows::<(&Symbol, &Symbol)>() {
            self.maybe_add_pair(left_symbol, right_symbol, text.as_str(), &mut agenda);
        }
        
        while let Some(symbol_pair) = agenda.pop() {
            let left_symbol = symbols.get(&symbol_pair.left).cloned();
            let right_symbol = symbols.get(&symbol_pair.right).cloned();

            if let (Some(left_symbol), Some(right_symbol)) = (left_symbol, right_symbol) {
                let new_symbol = symbols.merge_symbols(&left_symbol, &right_symbol);
                if let Some(next) = symbols.symbols.range(new_symbol..).nth(1) {
                    self.maybe_add_pair(&new_symbol, next, text.as_str(), &mut agenda);
                }
                if let Some(prev) = symbols.symbols.range(..new_symbol).next_back() {
                    self.maybe_add_pair(prev, &new_symbol, text.as_str(), &mut agenda);
                }
            }
        }

        let mut output = Vec::new();
        for symbol in symbols {
            output.push(
                &input_text[byte_mapping[&symbol.start_byte]..byte_mapping[&symbol.end_byte]],
            );
        }
        output
    }
}

While this is a significant improvement over the naive implementation, accessing the left and right symbols of the best pair for merge (and their predecessor/successor) could be optimized further: a total of 2 get operations on the binary search tree and the construction of 2 “throwaway” iterators to find the predecessor and successor likely impact the constant factor of the algorithm. The following section proposes a further optimization keeping the same asymptotic complexity but reducing the number of operations performed at each iteration.

d. Priority Queue + Linked List implementation

The previous implementation solves the problem to identifying the best symbols pair to merge without recalculating the entire list at each iteration, but suffers from inefficiencies to select the symbols to execute the merge. This involves the following operations:

  1. Select the left symbol
  2. Select the right symbol
  3. Create a new symbol, combining left and right
  4. Remove the left symbol
  5. Remove the right symbol
  6. Insert the new symbol
  7. Select the left symbol’s predecessor, check if it forms a valid pair with the new token
  8. Select the right symbol’s successor, check if it forms a valid pair with the new token

These operations seem like a natural fit for a linked list: one can easily access predecessors and successors and replace a sequence of arbitrary nodes by a new element. If the SymbolPair element contains pointer to its left and right node, one can even skip scanning the linked list to find the nodes, providing $O(1)$ complexity for all of the operations above.

Linked lists are however deceptively simple, and are notoriously difficult to implement while both satisfying Rust’s borrow checker and offering a high level of performance. An excellent article [8] highlights the challenges of implementing linked lists in Rust. Instead, Rust’s Vec growable array data structure has been thoroughly optimized and is recommended over linked lists for better use of CPU cache in the official Rust documentation [9].

The following implementation will implement a LinkedList behaviour, storing the Symbol nodes in a Rust Vec. This is feasible because while new nodes will be inserted following a merge, they will replace 2 local nodes and the linked list will only shrink following its construction. Our Symbol data structure is modified to contain pointers to the previous and next nodes (which are effectively just a position in the “linked list/vec”). This position pointer is a isize, where a -1 value indicates that a given Symbol has no predecessor/successor (first or last node in the list).

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct Symbol {
    pub start_byte: usize,
    pub end_byte: usize,
    pub prev: isize,
    pub next: isize,
    pub size: usize,
}

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct SymbolPair {
    pub left: isize,
    pub right: isize,
    pub score: i64,
    pub pair_size: usize,
}

We have also added a size field to both the Symbol and the SymbolPair. The reason is that following a merge, we will delete the right Symbol, and update the left Symbol in-place to represent the combined Symbol (we don’t want to grow the linked list). The pointers of the predecessor and successor Symbols are updated to reflect these changes. However, there is no way to find all SymbolPairs that contain a given Symbol (the lookup only works in the other direction). This means that following a merge, some SymbolPair (that were pointing to the left element of the previous SymbolPair) are no longer valid: the Symbol at this position has changed (its size increased as a result of combining two symbols). We use the size information in the Symbol and SymbolPair as a validation step when popping a new SymbolPair from the agenda: if the size of the SymbolPair is smaller than the sum of the size of its left and right Symbols, this SymbolPair is no longer valid and we ignore it (pop the next SymbolPair). The size for all Symbols is initialized as 1 when the linked list is constructed from the character list of the text to tokenize.

Linked list symbol merge

The data structure holding the Symbols is now a list, implemented using a Rust Vec. It implements the method for merging two symbols (mutates itself) from two symbol positions and a size validation (the merge is executed only if the size validation described above succeeds).

pub struct SymbolList {
    symbols: Vec<Option<Symbol>>,
}

impl SymbolList {
    pub fn merge_symbols(
        &mut self,
        symbol_1_index: usize,
        symbol_2_index: usize,
        size_validation: usize,
    ) -> Option<Symbol> {
        if let (Some(left_symbol), Some(right_symbol)) =
        (self[symbol_1_index], self[symbol_2_index])
        {
            if left_symbol.size + right_symbol.size != size_validation {
                return None;
            }
            if right_symbol.next != -1 {
                if let Some(next_next) = self.symbols.get_mut(right_symbol.next as usize).unwrap() {
                    next_next.prev = symbol_1_index as isize;
                }
            }
            let new_symbol = Symbol {
                start_byte: left_symbol.start_byte,
                end_byte: right_symbol.end_byte,
                prev: left_symbol.prev,
                next: right_symbol.next,
                size: left_symbol.size + right_symbol.size,
            };
            self.symbols[symbol_2_index] = None;
            self.symbols[symbol_1_index] = Some(new_symbol);
            Some(new_symbol)
        } else {
            None
        }
    }
}

The tokenizer method looks very similar to the previous implementations. The maybe_add_pair method only changes to calculate the size of a SymbolPair from the sum of its Symbols sizes and is skipped below (see the full code at [3]). Accessing the left and right token of a SymbolPair, along with their predecessor and successor is now done by directly looking up the fields left, right of the SymbolPair and prev, next of the Symbols.

impl BpeTokenizer for PriorityQueueBpeLLTokenizer {
    fn tokenize<'a>(&self, input_text: &'a str) -> Vec<&'a str> {
        let (text, byte_mapping) = self.pre_process_text(input_text, '▁');

        let mut symbols = SymbolList::from_text(text.as_str());
        let mut agenda: BinaryHeap<SymbolPair> = BinaryHeap::new();

        for symbol_index in 1..symbols.len() {
            self.maybe_add_pair(
                symbol_index as isize - 1,
                symbol_index as isize,
                text.as_str(),
                &symbols,
                &mut agenda,
            );
        }

        while let Some(symbol_pair) = agenda.pop() {
            let left_symbol_index = symbol_pair.left;
            let right_symbol_index = symbol_pair.right;
            if left_symbol_index != -1 && right_symbol_index != -1 {
                let new_symbol = symbols.merge_symbols(
                    left_symbol_index as usize,
                    right_symbol_index as usize,
                    symbol_pair.pair_size,
                );
                if let Some(new_symbol) = new_symbol {
                    self.maybe_add_pair(
                        new_symbol.prev,
                        left_symbol_index,
                        text.as_str(),
                        &symbols,
                        &mut agenda,
                    );
                    self.maybe_add_pair(
                        left_symbol_index,
                        new_symbol.next,
                        text.as_str(),
                        &symbols,
                        &mut agenda,
                    );
                }
            }
        }

        let mut output = Vec::new();
        for symbol in symbols {
            if let Some(symbol) = symbol {
                output.push(
                    &input_text[byte_mapping[&symbol.start_byte]..byte_mapping[&symbol.end_byte]],
                );
            }
        }
        output
    }
}

3. Benchmarks

So far we have compared the implementations on a theoretical complexity level. In reality, the constants hidden in the asymptotic behaviour can have a significant impact. This section reports experimental results taking samples of varying size of Shakespeare’s Hamlet [10]. The time taken to tokenize the first 1, 10, 100 or 1000 lines of the play is recorded for the 4 implementations presented previously:

Input size Naive Naive
(pre-split)
Priority Queue +
Binary Search Tree
Priority Queue + Linked List
1 27.1 $\mu$s 9.1 $\mu$s 14.7 $\mu$s 8.6 $\mu$s
10 162 $\mu$s 26.5 $\mu$s 54.4 $\mu$s 24.7 $\mu$s
100 85 ms 0.57 ms 1.76 ms 0.68 ms
1000 18.6 s 8.6 ms 33.9 ms 12.8 ms

This confirms the expectations derived from the algorithms earlier: the naive approach becomes unpractical for inputs lengths exceeding 100 lines. Meanwhile, all other approaches stay in the order of a millisecond even for inputs that are 1000 lines long, confirming the asymptotic benefits of the data structures investigated. We also see that the two exact solutions leveraging priority queues provide execution times that are in line with the pre-split approximation. Even though they have the same asymptotic complexity, we also note that the linked-list implementation for the Symbols outperforms the binary search tree version.

These results, along with asymptotic trend-lines, can be seen in the figure below:

BPE implementations benchmark

Conclusion

Byte pair Encoding is a tokenization method that is in essence very simple and effective as a pre-processing step for modern machine learning pipelines. Widely used in multiple productive libraries, its actual implementation can vary significantly from one source to another. This article gives an overview of some key implementations of the algorithm that the reader may encounter and provides a high-level intuition behind their design. It illustrates the impact that the choice of a data structure can have on a real NLP application that is used everyday by thousands of data scientists and machine learning engineers. The priority-queue / linked-list implementation of Byte pair Encoding has been implemented in the rust-tokenizers library [11], along with other modern tokenization algorithms.

References