How can I make this run faster?

The script trains tokenizer (minbpe algorithm of LLM tokenization) on the input text and saves the vocab to disk for visualization. How can I make this script run faster? I am new to Numba.

    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256

        # split the text up into text chunks
        text_chunks = re.findall(self.compiled_pattern, text)

        # input text preprocessing
        ids = [list(ch.encode("utf-8")) for ch in text_chunks]

        # iteratively merge the most common pairs to create new tokens
        merges = {} # (int, int) -> int
        vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
        for i in range(num_merges):
            # count the number of times every consecutive pair appears
            stats = {}
            for chunk_ids in ids:
                # passing in stats will update it in place, adding up counts
                get_stats(chunk_ids, stats)
            # find the pair with the highest count
            pair = max(stats, key=stats.get)
            # mint a new token: assign it the next available id
            idx = 256 + i
            # replace all occurrences of pair in ids with idx
            ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
            # save the merge
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            # prints
            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")

        # save class variables
        self.merges = merges # used in encode()
        self.vocab = vocab   # used in decode()

Numba may not be the right tool for this problem. Numba excels at numeric computation and doesn’t really target text processing.

Yeah, as Nelson said, and as far as I know, numba unicode implementation is possibly slower than pure cpython in most cases, if the code heavily depends on unicode operations.

then what’s the other way to make this script run faster?

I can also confirm that you should not normally use Numba for string operations heavy tasks. However, if you know your craft well, you can use Numba to work with the bytes, which can be very efficient. I once wrote a parser with Numba and I was more than happy with its performance. Unfortunately, you only posted a non self-contained part of your code, so it’s hard to say if that’s an option for you. If you’re really constrained by performance and you think moving away from native CPython can solve that, looking into PyPy may be an option. I vaguely remember them really shining at some point in history for regex problems, for example. But not sure if PyPy works with the popular auto-differentiation engines you may be using, since you asked for an LLM tokenizer …