Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tokenization is extremely slow--am I doing something wrong? #22

Open
andersonbcdefg opened this issue Dec 28, 2023 · 5 comments
Open

Tokenization is extremely slow--am I doing something wrong? #22

andersonbcdefg opened this issue Dec 28, 2023 · 5 comments

Comments

@andersonbcdefg
Copy link
Contributor

Under what circumstances is MLX supposed to provide a speedup over sentencepiece? In a naive test with the same SPM .model file, I'm able to tokenize 1000 batches in 13 seconds with sentencepiece, and it takes over 5 minutes with MLX. Hardware is M2 Macbook Pro with 64GB unified memory. Is the CharTrie tokenization only useful when paired with key_transform? Are there plans to add a "tokenize_batch" with better parallelization/concurrency?

Code for reference:

class Tokenizer:
    def __init__(
        self, 
        model_path: str, 
        use_mlx: bool = False,
        mlx_tokenize_shortest: bool = False,
    ):
        assert Path(model_path).exists(), model_path
        self.use_mlx = use_mlx
        self.mlx_tokenize_shortest = mlx_tokenize_shortest
        self.spm_model = SentencePieceProcessor(model_file=model_path)
        assert self.spm_model.vocab_size() == self.spm_model.get_piece_size()
        if self.use_mlx:
            try:
                from mlx.data.core import Tokenizer as MLXTokenizer
                from mlx.data.tokenizer_helpers import read_trie_from_spm
            except ImportError:
                raise ImportError("Please install MLX to use MLX Tokenizer")
            trie, weights = read_trie_from_spm(model_path)
            try:
                self.mlx_model = MLXTokenizer(trie, trie_key_scores=weights)
                print(f"Loaded trie with {trie.num_keys()} keys.")
                assert self.spm_model.vocab_size() == trie.num_keys()
            except Exception as e:
                print(f"Unable to load trie into MLX Tokenizer: {e}. Using SentencePiece instead.")
                self.use_mlx = False

    @property
    def vocab_size(self) -> int:
        return self.spm_model.vocab_size()
    
    @property
    def bos_id(self) -> int:
        return self.spm_model.bos_id()
    
    @property
    def eos_id(self) -> int:
        return self.spm_model.eos_id()

    @property
    def pad_id(self) -> int:
        return self.spm_model.pad_id()
    
    @property
    def mask_id(self) -> int:
        # unknown kind of spiritually makes sense as mask
        return self.spm_model.unk_id()
    
    def __call__(
        self, 
        t: Union[str, list[str]],
        max_length: int,
        pack: bool = False,
        use_bos: bool = False,
        use_eos: bool = True,
        format: Literal["np", "mx"]= "mx",
        tokens_only: bool = False
    ) -> np.ndarray:
        if isinstance(t, str):
            t = [t]
        return self.encode_batch(t, max_length, pack, use_bos, use_eos, format)

    def _pad(self, tokens: ak.Array, max_length: int):
        tokens = ak.pad_none(tokens, target=max_length, axis=-1, clip=True)
        tokens = ak.fill_none(tokens, self.pad_id)
        return tokens
    
    def _unpad(self, tokens: ak.Array):
        tokens = ak.from_regular(tokens)
        is_pad = tokens == self.pad_id
        return tokens[~is_pad]

    def encode_single_mlx(
        self,
        text: str,
    ):
        if self.mlx_tokenize_shortest:
            tokens = self.mlx_model.tokenize_shortest(text)
        else:
            tokens = self.mlx_model.tokenize_rand(text)
        return tokens
    
    def encode_batch(
        self, 
        batch: list[str],
        max_length: int,
        pack: bool = False,
        use_bos: bool = False,
        use_eos: bool = True,
        tokens_only: bool = False
    ) -> np.ndarray:
        if self.use_mlx:
            tokens = ak.Array([self.encode_single_mlx(text) for text in batch])
        else:
            tokens = ak.Array(self.spm_model.encode(batch))
        if use_bos:
            tokens = ak.concatenate([ak.full_like(tokens[:, :1], self.bos_id), tokens], axis=1)
        if use_eos:
            tokens = ak.concatenate([tokens, ak.full_like(tokens[:, :1], self.eos_id)], axis=1)
        if pack:
            # pack into batches of max_length
            sequence_ids = ak.zeros_like(tokens) + range(len(tokens))
            tokens = ak.flatten(tokens, axis=None)
            sequence_ids = ak.flatten(sequence_ids, axis=None)
            n_to_truncate = len(tokens) % max_length
            tokens = np.asarray(tokens)[:-n_to_truncate].reshape(-1, max_length)
            sequence_ids = np.asarray(sequence_ids)[:-n_to_truncate].reshape(-1, max_length)
            if tokens_only:
                return tokens
            return {
                "input_ids": tokens,
                "sequence_ids": sequence_ids, # these are for block diagonal attention
                "attention_mask": np.zeros_like(tokens)
            }

            
        else:
            # keep sequences separate; pad or truncate to max_length
            tokens = np.array(self._pad(tokens, max_length))
            mask = np.where(tokens != self.pad_id, 0, float("-inf"))
            if tokens_only:
                return tokens
            return {
                "input_ids": tokens,
                "sequence_ids": np.zeros_like(tokens),
                "attention_mask": mask
            }
    

def test():
    tokenizer_file = "tokenizer.model"
    spm_tokenizer = Tokenizer(tokenizer_file, use_mlx=False)
    mlx_tokenizer = Tokenizer(tokenizer_file, use_mlx=True)

    random_texts = [
        """Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book. It has survived not only five centuries, but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum."""
    ] * 500
    import time
    start = time.time()
    for i in tqdm.tqdm(range(1_000)):
        spm_tokenizer(random_texts, 512, pack=True, use_bos=False, use_eos=True)
    print(f"spm: {time.time() - start:.3f}s")

    start = time.time()
    for i in tqdm.tqdm(range(1_000)):
        mlx_tokenizer(random_texts, 512, pack=True, use_bos=False, use_eos=True)
    print(f"mlx: {time.time() - start:.3f}s")


test()
@angeloskath
Copy link
Member

Hi @andersonbcdefg, sorry for the extremely late reply.

So the tokenizer in MLX Data is actually quite fast when it comes to smallish documents. It optimizes over the whole passed document so it is quite slower when passed such a huge text like the one above (while it obviously doesn't make sense to check the whole graph).

For example the wikitext benchmark (https://github.com/ml-explore/mlx-data/blob/c1204bce12ce495add1ed68338543cb4b5c5a595/benchmarks/comparative/wikitext/mlx_data.py) on my Mac tokenizes a few millions of tokens per second which should be more than enough for any use case.

@andersonbcdefg
Copy link
Contributor Author

Hmm, well the document in my example is only a few hundred characters. It's a batch of 500 of the same doc, but the doc is short so I'm not sure that optimizing over a large graph would explain the disparity in speed.

@angeloskath
Copy link
Member

Oh sorry I kinda misunderstood the code snippet.

Having said that, I wouldn't say it is significantly slower than SPM. Running your benchmark with varying document size on my M2 air laptop I get the following comparison table with SPM

Doc length (chars) | Batch Tokens | MLX time / SPM time | MLX Tokens/s
-------------------+--------------+---------------------+-------------
              57   |      14336   |              5.40   |  1098130.11
             114   |      28672   |              8.25   |  1125880.12
             172   |      43520   |              9.30   |  1168599.07
             229   |      58368   |              9.89   |  1205666.60
             287   |      72704   |              9.63   |  1182750.15
             344   |      85504   |              11.0   |  1103767.26
             401   |     100864   |              10.9   |  1125994.54
             459   |     113664   |              10.7   |  1130119.24
             516   |     125952   |              10.9   |  1074331.83
             574   |     139264   |              10.5   |  1072074.781

Keep in mind that this is single core. So >1M tok/s per core I think is pretty reasonable for almost all use cases. We would of course appreciate PRs that improve that to reach the speed of SPM which is probably somewhere around 2M-3M tok/s per core on my machine.

@andersonbcdefg
Copy link
Contributor Author

Yeah I hope that it's able to be sped up! A 10x difference in speed makes a big difference esp. for offline data processing type workflows (I understand 1M tok/s is fine if you're feeding an LLM in real time, but tokenization is also important for batch processing!)

@angeloskath
Copy link
Member

Sure I understand, and we should work on it however, this is still single core. When using the following pipeline on my M2 air it is 3x slower than SPM

dset = (
          dx.stream_python_iterable(lambda: ({"doc": s.encode()} for s in random_texts))
          .tokenize("doc", trie)
          .prefetch(20, 4)
          .sliding_window("doc", 512, 512)
          .shape("doc", "length", 0)
          .batch(128)
          .prefetch(2, 1)
      )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants