failed_discord_bot/llama_tokenizer_lite.py

225 lines
644 KiB
Python
Raw Normal View History

import base64
import struct
import heapq
import codecs
import time
from queue import PriorityQueue
class LlamaTokenizerLite:
def __init__(self):
self.vocab_by_id = {}
self.vocab_by_string = {}
self.merges = {}
vocab_base64 = "PHVuaz4KPHM+Cjwvcz4KPDB4MDA+CjwweDAxPgo8MHgwMj4KPDB4MDM+CjwweDA0Pgo8MHgwNT4KPDB4MDY+CjwweDA3Pgo8MHgwOD4KPDB4MDk+CjwweDBBPgo8MHgwQj4KPDB4MEM+CjwweDBEPgo8MHgwRT4KPDB4MEY+CjwweDEwPgo8MHgxMT4KPDB4MTI+CjwweDEzPgo8MHgxND4KPDB4MTU+CjwweDE2Pgo8MHgxNz4KPDB4MTg+CjwweDE5Pgo8MHgxQT4KPDB4MUI+CjwweDFDPgo8MHgxRD4KPDB4MUU+CjwweDFGPgo8MHgyMD4KPDB4MjE+CjwweDIyPgo8MHgyMz4KPDB4MjQ+CjwweDI1Pgo8MHgyNj4KPDB4Mjc+CjwweDI4Pgo8MHgyOT4KPDB4MkE+CjwweDJCPgo8MHgyQz4KPDB4MkQ+CjwweDJFPgo8MHgyRj4KPDB4MzA+CjwweDMxPgo8MHgzMj4KPDB4MzM+CjwweDM0Pgo8MHgzNT4KPDB4MzY+CjwweDM3Pgo8MHgzOD4KPDB4Mzk+CjwweDNBPgo8MHgzQj4KPDB4M0M+CjwweDNEPgo8MHgzRT4KPDB4M0Y+CjwweDQwPgo8MHg0MT4KPDB4NDI+CjwweDQzPgo8MHg0ND4KPDB4NDU+CjwweDQ2Pgo8MHg0Nz4KPDB4NDg+CjwweDQ5Pgo8MHg0QT4KPDB4NEI+CjwweDRDPgo8MHg0RD4KPDB4NEU+CjwweDRGPgo8MHg1MD4KPDB4NTE+CjwweDUyPgo8MHg1Mz4KPDB4NTQ+CjwweDU1Pgo8MHg1Nj4KPDB4NTc+CjwweDU4Pgo8MHg1OT4KPDB4NUE+CjwweDVCPgo8MHg1Qz4KPDB4NUQ+CjwweDVFPgo8MHg1Rj4KPDB4NjA+CjwweDYxPgo8MHg2Mj4KPDB4NjM+CjwweDY0Pgo8MHg2NT4KPDB4NjY+CjwweDY3Pgo8MHg2OD4KPDB4Njk+CjwweDZBPgo8MHg2Qj4KPDB4NkM+CjwweDZEPgo8MHg2RT4KPDB4NkY+CjwweDcwPgo8MHg3MT4KPDB4NzI+CjwweDczPgo8MHg3ND4KPDB4NzU+CjwweDc2Pgo8MHg3Nz4KPDB4Nzg+CjwweDc5Pgo8MHg3QT4KPDB4N0I+CjwweDdDPgo8MHg3RD4KPDB4N0U+CjwweDdGPgo8MHg4MD4KPDB4ODE+CjwweDgyPgo8MHg4Mz4KPDB4ODQ+CjwweDg1Pgo8MHg4Nj4KPDB4ODc+CjwweDg4Pgo8MHg4OT4KPDB4OEE+CjwweDhCPgo8MHg4Qz4KPDB4OEQ+CjwweDhFPgo8MHg4Rj4KPDB4OTA+CjwweDkxPgo8MHg5Mj4KPDB4OTM+CjwweDk0Pgo8MHg5NT4KPDB4OTY+CjwweDk3Pgo8MHg5OD4KPDB4OTk+CjwweDlBPgo8MHg5Qj4KPDB4OUM+CjwweDlEPgo8MHg5RT4KPDB4OUY+CjwweEEwPgo8MHhBMT4KPDB4QTI+CjwweEEzPgo8MHhBND4KPDB4QTU+CjwweEE2Pgo8MHhBNz4KPDB4QTg+CjwweEE5Pgo8MHhBQT4KPDB4QUI+CjwweEFDPgo8MHhBRD4KPDB4QUU+CjwweEFGPgo8MHhCMD4KPDB4QjE+CjwweEIyPgo8MHhCMz4KPDB4QjQ+CjwweEI1Pgo8MHhCNj4KPDB4Qjc+CjwweEI4Pgo8MHhCOT4KPDB4QkE+CjwweEJCPgo8MHhCQz4KPDB4QkQ+CjwweEJFPgo8MHhCRj4KPDB4QzA+CjwweEMxPgo8MHhDMj4KPDB4QzM+CjwweEM0Pgo8MHhDNT4KPDB4QzY+CjwweEM3Pgo8MHhDOD4KPDB4Qzk+CjwweENBPgo8MHhDQj4KPDB4Q0M+CjwweENEPgo8MHhDRT4KPDB4Q0Y+CjwweEQwPgo8MHhEMT4KPDB4RDI+CjwweEQzPgo8MHhEND4KPDB4RDU+CjwweEQ2Pgo8MHhENz4KPDB4RDg+CjwweEQ5Pgo8MHhEQT4KPDB4REI+CjwweERDPgo8MHhERD4KPDB4REU+CjwweERGPgo8MHhFMD4KPDB4RTE+CjwweEUyPgo8MHhFMz4KPDB4RTQ+CjwweEU1Pgo8MHhFNj4KPDB4RTc+CjwweEU4Pgo8MHhFOT4KPDB4RUE+CjwweEVCPgo8MHhFQz4KPDB4RUQ+CjwweEVFPgo8MHhFRj4KPDB4RjA+CjwweEYxPgo8MHhGMj4KPDB4RjM+CjwweEY0Pgo8MHhGNT4KPDB4RjY+CjwweEY3Pgo8MHhGOD4KPDB4Rjk+CjwweEZBPgo8MHhGQj4KPDB4RkM+CjwweEZEPgo8MHhGRT4KPDB4RkY+CuKWgeKWgQriloF0CmVyCmluCuKWgWEKZW4Kb24K4paBdGgKZXMK4paB4paB4paB4paBCuKWgXMK4paBZAphdApvcgphbgriloFjCmlzCnJlCml0CuKWgXRoZQphcgpsZQriloF3CuKWgXAKb3UKYWwK4paBZgriloFtCmVkCuKWgW8K4paBYgpvbQppb24KaW5nCmljCmFzCmVsCmVudAriloFpbgriloFoCm5kCmV0CuKWgWwK4paBbgpzdAriloF0bwpjaAriloFJCnJvCuKWgeKWgeKWgeKWgeKWgeKWgeKWgeKWgQppbAriloFvZgpkZQpjdAriloEoCmFtCuKWgUMK4paBZGUK4paBUwriloF1CuKWgUEK4paBXAriloFlCuKWgWFuZAriloFUCm9sCuKWgXYKaW0Kb3QKYWQKdXQK4paBZwplbQp1cgppZAriloEqCmlnCnJhCuKWgXJlCuKWgWlzCnF1Cm93CuKWgU0KZXN0CuKWgXkKc2UKdmUKY2UKaWUKdW4K4paBUAriloFCCmFnCnVsCuKWgT0KaGUKZW5kCm9kZQp0ZXIKbWVudApvcwriloFECmlmCmF0aW9uCuKWgWZvcgriloFyCuKWgUwK4paBeW91CuKWgWJlCmx5CnZlcgphYgp0ZQriloFpdAriloFvbgpyaQp1cwriloEiCuKWgXdoCuKWgWNvbgriloFICuKWgXN0CmlyCuKWgUUK4paBRgpjawriloFhbgp0aAplZwpheQppdGgK4paBUgppc3QKYW5kCuKWgXRoYXQK4paBYWwK4paBJAriloEjCm9kCnVtCuKWgVcKaHQKY29kZQriloFHCmF0ZQplc3MK4paBTgplcmUKcHAK4paBYXMK4paBc2UK4paBcHJvCuKWgXdpdGgKcGUK4paBawplcnMKcHQKKTsKbG8K4paB4paB4paB4paB4paBCuKWgWNvbQphbWUK4paBYAriloFDb20KaWEKYW50CuKWgWxhCuKWgXsK4paBZW4KY3Rpb24K4paBZXgKbGQKdWIK4paBagpsYQp1ZQriloFKCmljaAriloFkbwriloFPCuKWgXF1Cml2Cm9ydAphcnQK4paBdW4K4paBIyMK4paBdGhpcwprZQriloFoYQriloEtCm91dAriloFUaGUK4paBbm90CuKWgW5lCmlsbAriloFsZQpjaQpyb20KaW5lCi8vCm9wCmVnaW4K4paBQ29tbWVudAriloHiloHiloHiloHiloHiloHiloHiloHiloHiloHiloHiloHiloHiloHiloHiloEKYmVnaW4K0YHRggphc3MKaXoKKS4Kb2cK4paB0L8K4paBb3IK4paBd2FzCuKWgWF0Cm91cgriloFpCmFpbgriloFLCtC90LAK4paBVgpnZQriloFzdQphcAphZ2UKb3VsZApuZQphdgp4dApvcmUKaWxlCi0tCuKWgdCyCuKWgWJ5CmxpCmF0aArRgNCwCmJlcgphY2gKYWxsCuKWgVRoCnVsdAriloF9CuKWgVUK4paBdXMK4paBegp1c3QK4paBaGF2ZQpsaWMK0L3QuAriloFjYW4KdHIKY29tCiksCuKWgUluCmluZAplbGwK4paBZnJvbQrQvtCyCnRvCuKWgVsKYWJ
merges_binary = "r3SxdLB0tnSzdLR0r3SydLB0tHS1dLR0BAG6dK90ggGwdLd0r3S3dK90uXSydLF0tXS2dLJ0tHSvdLt0s3S3dLZ0sHSzdLF0BAFiAQoBsHSvdBAGsnS2dLh0sHSvdMV0r3S+dLV0vHSydLh0r3TAdK90vXSwdLl0r3S1dK90wnS1dL10WQK0dLN0CQEGAb90s3RhA7N0u3SydLd0sHS4dAgBsXSwdFEC2gG0dK90BgGvdLp0tHS5dLB0sXSvdLh0r3S0dLd0sXQEAbV0r3QFArt0unSvdM50tnS1dLN0uHQgAcB0r3TOA7l0sHS7dLF0r3TKdLJ0vXSvdNN0DgGwdK90NwGvdM90r3S8dK901XSvdNF0r3SwdAcBKwGBAbl0r3SIAa9013S1dLh0r3TGdLN0vXS1dLF0snS5dLx0sXSvdL90sHS9dLx0tnSzdLl0r3TqdLN0v3S2dLJ0bAGwdK90FAHaAbd0r3QTAfN0vHS1dMV0r3TkdAsBsXSwdC8Br3TDdLd0sHTGdLB0u3SwdLN0sHS8dLR0r3TldK907XSydL90vHS4dK904nS6dLB0CAG5dLB0KwGNAbB0tXQ3AXMBtnSxdAUB7ANRAvcFsXS9dCgBtXS3dK906HSzdMB0DwEjAasICQFTUbR0snTVMR0BEAGlBrZ0r3SuBa90tnSvdOt0VwEbAeBgvHSvdJUYIQGwdK90kwO4dMN0WQG2dMZ0BQGydMJ0sXSwdNoBsXSvdBUBIAG0dK90CQG2dLN0vHS3dK901HQZAbp0r3Q0BRIBCQEWBbR0r3QXAq90/nQNAbF0r3QvAbN0tnSvdON0r3T3dLt0x3QHAbR0r3QRAbF0unSwdL90snTDdBUBunSzdIIBr3TudBMBsXSzdC8BEQG5dLJ0KwEEAZcJCgEPAa90cxYHAbh0r3QcAa908nSvdPF0tXS5dLx0vXSvdAR1unSxdFcENwFUS7B0u3RkAa90AnUPAbB0snRzAQsBt3SwdH0Dr3T0dAUBsHSwdBQBvnS+dAcBt3SvdCYBDQGwdK90WAEaATMBIAK1dK90AwMZAYUBFTC6dK907Qm+dLB0r3THdAUBt3SwdPAIvnSxdMl08HS4dLV0EgEiARYFvXSvdP4BOgGwdLJ07AOvdAB1OwEiAXUMvXSvdPMFs3SydBEBsXSydFECLQGydK90sQGvdNh0QQG0dK90CAE4ASMBu3TVMUEB3HSvdN8CuHS5dLx0wnSvdOZ0uHSydLx0sHSvdA91JQG6dLN0MQEOAbV0r3RLB690/XQPD7x0r3RTAbN0xnQQAbF0tXThCBcBsXSydOEIPgG0dK90XAGMAfF0r3TlCAQBYFkKARMBr3RWBcd0sHQqAbJ0r3QuCa90y3QbAbF0tXRJAUMBYgHyAbB0r3QoBi4BRwG2ArF0r3Q1BS4BsHSvdOQBNQG4dLN0hQItAbB0r3QYAbt0s3QzAb10tnQiAQYBsHSzdOQB2nTadLV0vnSDAQYBsHTvFKYBZgFUBygBr3QhTpMD7xScbQYBwnTMAe907HQmAbd0snR9A7N04HTJdMF0tXS/dK90DHUgAbZ0r3QQARkBJgE8LLd0r3ReKwcBsXSvdA8BGwG2dLV0TAGvdLN0Age0dLJ0BgGvdBB14XTWdK90C3W/dLB0DQG8dK90YgiydL50XwGwdLJ03wEbAa4BwxO5dLV0UkG0dLB0snTGdNx0sXQQAbB0tXQUATUBsHSzdBgBy3TLdK909nQhAcN0r3RJBrh0s3QPAbp0snSCAed01nSTA7Z0wnQFATICunSydDEBHAG4dLJ0hQJDAbp0r3RFBWABsXS8dGgHr3TZdK90DXU+Abd0r3R3Aa904HR3AbF0vHQvASoBRwW/AVkBI1GwdK90fETsAbt0uHQlAeF03XQSAREBkRa0dK90/AuxdLZ0VwS9dLt0IgHJdMR0MgG0dK90HQMGAbl0s3QrAScBuHSwdIUCHQHIAZAFIgFPOb10r3ReDNB09nSxdLV0r3QJdXIBGAG+MLB0snQ5AmcBsXS1dC8BEgG6dK90MQGvArF0sHQ4AU8BkAF1BLF0BgGxdLN0UQKvdNt0BwEUATQCsHSvdFYC2gG9dK90RgENAbp0r3RNA690DnU/AbR0r3S4Cq9073QPAbJ0snStA30BsHSzdBQBBAG2dK90/QFXBLR0u3QJARABuXS1dHsVFQHDdLN0+QMXAbl0snR7FSoBsHSvdGIBIQFJASkFsXSvdFsQtXS7dOJ01HQaAbZ0r3QuAkwBsHS8dBQBnAG2dL50BQEyAsd0snSAARABx3QJAb90tXRhAxEBt3SydL8H/3TQdDwCsHS+dBgBDgELATwBt3SvdN8KtXTHdBABvXS1dN4GwQK2dMV0BQGydMd0vnS2dCYBsHSydFgBQQG4dK90JwG+dLp0snS7dD4BKwG7Abl0r3RmAwcBtnSvdBcB2gHAdK90aQG8dLl0vnS3dBUBsHSzdHMBnAiwdMJ0GAHhdNB0vQa2dMB0BQG+dLh0uAGwdLN0WQERAb90snRhAwgBt3SwdL8H53TQdA0BtXSvdEICt3S1dCYBsXSydC8BynTJdNYKBQG3dCwCtnS8dFsBt3SzdAsBr3TNdLJ0vHS1dMZ053TfdBF10HQOAQUBPAG2dK90oAIeAcN0r3RNBRkBsHSvdMECHgGwdK907AO0dLF0BwG5dK90SAFMAbR0vHSTa1cB2QFuAbZ04GBMAa90gh/jINp0zXTKARcBsHSydBQBBwGFAooBuHSvdPEBwHTAdLN0tXQLAdUxVgEjAQEtCQFGAbB0s3TsA0EBtnSvdAUBsQF9AwUNt3S4dNEBr3TddHkBtAGvdGwSIgGwdLV07AMJAbF0tXRRAhoBFwHkDLZ0r3ReAx4BsnSvdI8Cr3QbddR0xHSvdNB0wHSxdKcBuHSzdBwBu3S7dBsBKwEdBbl0tXRmAy0Bs3SvdOwBUQG3dGwBCwGvdLICLAG6dLB0ggEsBTgB5nQKAgcBlwEXDL50r3SkAz0BsXSvdFcDJQGwdLN0WgEHAb10r3Q6ATICsXSydDgBDgEnATwBuHSvdAAYv3S2dA8BHwGTAbl0snTmJFsBtnSzdAUBBwHCdK90cgFBAbF0r3QsARwBcAHxAcN0wXTBdGkE4QizB7F0vnS5AbN0x3QaAQUB1AS2dK90IgISAWECegGxdBYFUQKvdAUF53TddP901nRYAbZ0t3QFAfh03XS4dLh0WwHFdLN0KQNPAbR0s3QSFN502HSiDrF0vnRJAQkBsHS1dOQBXAGsATUS1TGBRiMBDgGzdK90oxcXAcN0snTPAhUBIwHMEgkBs3TVMb10snTfdOF03wGxdL90LAEtAbV0r3ShAUUBHAGDC7h0r3QXA690IHVQAbR0tnQRAa90BXUIAVoBiQ+wdBkBJALdGMd0r3QHBTgb1nSvdN0Bs3S+dBUBSwE4Ar10s3RMBcYFsHTDdJwBr3QjdSoBEwFTHLd0r3RgWT4BWAH2AbB0r3TlBTcBtnS5dAUBEgJFAnVVLAKvdBZYGQHFAYRuuHSvdNs20gGwdLN0jwPsdNZ0oQHFdLh0VAE7Abp0r3SrBUoBLAHJBrF0r3SRAk0BsHSzdDcBGwG3dLV0dwEGAcd0nwEjAb501TH4dNZ0ySOTa7F0UwJcAb90vHRhA7B0u3S8dL90nQTeBq4FvXTAdCsCFAG3dLZ0CwGQAbF0unS6AhsBv3S1dLAC+HQadS4BtXSvdMEEu3S4dGwBtXSvdDMBIAHkAXUBsHSvdIoCsXSxdHMcs3S7dHYBuXS8dD4BvnSvdBID7HTQdMp01HQgAcJ0r3THAsV0sHQQAcN0tXTPAkEBLwE/A7F0r3RWAQUBw3SwdM8CWwG4dLN0JwEvAbZ0t3T9AbV0wnS3AbB0Dw+yAa90IgOnAbR0s3QRASABSQFLCLF0r3TBARoBuHSvdDwCLgEpA8QBxXSvdMoF/3TddK90FnW2dMN0RwG6dLV0ggGCAQUBEAa2dLF05AhFARcBgwu2dK90qwYZAeMBryJSQQ0BBQGZAbZ0r3SDAm
self.initialize_llama_tokenizer(vocab_base64, merges_binary)
@staticmethod
def base64decode(encoded_string):
return base64.b64decode(encoded_string)
def get_merge_identifier_string(self, first_token_id, second_token_id):
return self.vocab_by_id[first_token_id] + " " + self.vocab_by_id[second_token_id]
def decompress_merges(self, merges_binary):
byte_array_string = self.base64decode(merges_binary)
byte_array = bytearray(byte_array_string)
token_ids = []
for i in range(0, len(byte_array), 2):
token_id = struct.unpack('<H', byte_array[i:i+2])[0]
token_ids.append(token_id)
merges = {}
for i in range(0, len(token_ids), 2):
id1 = token_ids[i]
id2 = token_ids[i+1]
merge_identifier_string = self.get_merge_identifier_string(id1, id2)
merges[merge_identifier_string] = i+1
return merges
@staticmethod
def decode_vocabulary(vocab_base64):
byte_array = base64.b64decode(vocab_base64)
return byte_array.decode('utf-8').split("\n")
@staticmethod
def utf8_byte_to_hex(c):
hex_value = format(c, '02X')
return f"<0x{hex_value}>"
@staticmethod
def hex_to_utf8_byte(hex):
stripped_hex = hex.replace('<0x', '').replace('>', '')
return int(stripped_hex, 16)
class PriorityQueue:
def __init__(self, comparator=lambda a, b: a > b):
self._heap = []
self._comparator = comparator
def size(self):
return len(self._heap)
def is_empty(self):
return self.size() == 0
def peek(self):
return self._heap[0] if len(self._heap) > 0 else None
def push(self, *values):
for value in values:
heapq.heappush(self._heap, value)
def pop(self):
return heapq.heappop(self._heap) if len(self._heap) > 0 else None
def replace(self, value):
replaced_value = self.peek()
if replaced_value is not None:
self._heap[0] = value
heapq.heapify(self._heap)
return replaced_value
def map_characters_to_token_ids(self, prompt, add_bos_token, add_preceding_space):
token_ids = []
if add_bos_token:
token_ids.append(1)
if add_preceding_space:
prompt = " " + prompt
prompt_altered = prompt.replace(" ", self.vocab_by_id[29871])
char_array = list(prompt_altered)
for i in range(len(char_array)):
c = char_array[i]
if c in self.vocab_by_string:
token_ids.append(self.vocab_by_string[c])
else:
bytes = c.encode('utf-8')
for j in range(len(bytes)):
hex = self.vocab_by_string.get(self.utf8_byte_to_hex(bytes[j]))
token_ids.append(hex)
if not hex >= 0:
print('Encountered unknown character ' + c + " (partial UTF-8 byte " + str(bytes[j]) + " + hex + " + self.utf8_byte_to_hex(bytes[j]) + ")")
token_ids[-1] = 0
return token_ids
def encode(self, prompt, add_bos_token=True, add_preceding_space=True, log_performance=False):
start_time = None
if log_performance:
start_time = time.time()
if not self.vocab_by_id or not self.vocab_by_string or not self.merges:
print('Tokenizer not initialized properly!')
return
if len(prompt) == 0:
return []
token_ids = self.map_characters_to_token_ids(prompt, add_bos_token, add_preceding_space)
merge_queue = PriorityQueue()
def add_to_merge_queue(left_node):
merge_identifier_string = self.get_merge_identifier_string(left_node['token_id'], left_node['next']['token_id'])
merge_prio = self.merges.get(merge_identifier_string)
if merge_prio is not None:
merge_prio += left_node['orig_pos'] / len(prompt)
left_node['merge_prio'] = merge_prio
left_node['merge_to_string'] = merge_identifier_string.replace(" ", "")
merge_queue.put((merge_prio, left_node))
first_token_node = {
'orig_pos': 0,
'token_id': token_ids[0],
'prev': None,
'next': None,
}
prev_token_node = first_token_node
for i in range(1, len(token_ids)):
curr_token_node = {
'orig_pos': i,
'token_id': token_ids[i],
'prev': prev_token_node,
'next': None
}
prev_token_node['next'] = curr_token_node
add_to_merge_queue(prev_token_node)
prev_token_node = curr_token_node
while not merge_queue.empty():
_, left_of_merge = merge_queue.get()
if left_of_merge.get('deleted'):
continue
if not left_of_merge.get('next'):
continue
if left_of_merge.get('next').get('deleted'):
continue
left_of_merge['deleted'] = True
left_of_merge['next']['deleted'] = True
if left_of_merge.get('prev'):
old_prev = left_of_merge['prev']
old_prev['deleted'] = True
new_prev = {
'orig_pos': old_prev['orig_pos'],
'token_id': old_prev['token_id'],
'prev': old_prev['prev'],
'next': old_prev['next']
}
left_of_merge['prev'] = new_prev
if new_prev.get('prev'):
new_prev['prev']['next'] = new_prev
else:
first_token_node = new_prev
result_of_merge = {
'orig_pos': left_of_merge['orig_pos'],
'token_id': self.vocab_by_string.get(left_of_merge['merge_to_string']),
'prev': left_of_merge['prev'],
'next': left_of_merge['next']['next']
}
if result_of_merge.get('prev'):
result_of_merge['prev']['next'] = result_of_merge
add_to_merge_queue(result_of_merge['prev'])
else:
first_token_node = result_of_merge
if result_of_merge.get('next'):
result_of_merge['next']['prev'] = result_of_merge
add_to_merge_queue(result_of_merge)
merged_token_ids = []
curr_token_node = first_token_node
while curr_token_node is not None:
merged_token_ids.append(curr_token_node['token_id'])
curr_token_node = curr_token_node['next']
if log_performance:
end_time = time.time()
print('Tokenizer running time: ' + str(end_time - start_time) + " seconds")
return merged_token_ids
def decode(self, token_ids, add_bos_token=True, add_preceding_space=True):
utf8byte_vals = []
start_index = 1 if add_bos_token else 0
for i in range(start_index, len(token_ids)):
token_id = token_ids[i]
token_string = self.vocab_by_id[token_id]
if token_string.startswith("<0x") and token_string.endswith(">"):
utf8byte_vals.append(self.hex_to_utf8_byte(token_string))
else:
utf8bytes = token_string.encode('utf-8')
utf8byte_vals.extend(utf8bytes)
byte_array = bytearray(utf8byte_vals)
decoded_string = byte_array.decode('utf-8')
spaces_fixed = decoded_string.replace(self.vocab_by_id[29871], " ")
return spaces_fixed[1:] if add_preceding_space else spaces_fixed
def initialize_llama_tokenizer(self, vocab_base64, merges_binary):
self.vocab_base64 = vocab_base64
self.merges_binary = merges_binary
self.vocab_by_id = self.decode_vocabulary(self.vocab_base64)
self.vocab_by_string = {v: k for k, v in enumerate(self.vocab_by_id)}
self.merges = self.decompress_merges(self.merges_binary)