The script trains tokenizer (minbpe algorithm of LLM tokenization) on the input text and saves the vocab to disk for visualization. How can I make this script run faster? I am new to Numba.
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
num_merges = vocab_size - 256
# split the text up into text chunks
text_chunks = re.findall(self.compiled_pattern, text)
# input text preprocessing
ids = [list(ch.encode("utf-8")) for ch in text_chunks]
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
for i in range(num_merges):
# count the number of times every consecutive pair appears
stats = {}
for chunk_ids in ids:
# passing in stats will update it in place, adding up counts
get_stats(chunk_ids, stats)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = 256 + i
# replace all occurrences of pair in ids with idx
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()