Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supported features #571

Open
peregilk opened this issue Mar 30, 2024 · 4 comments
Open

Supported features #571

peregilk opened this issue Mar 30, 2024 · 4 comments

Comments

@peregilk
Copy link

Mainly wanted to start with thanking you for making MaxText available. I have been using it for a few days, and the first impression is fantastic. Getting started was really easy, it seemed very stable, and the performance was fantastic. It seems to scale very nicely.

A few things that I have not been able to figure out yet, it might be because of lack of documentation, or simply because it is not implemented.

  • Are there any support for Flash attention, or any plans for implementing this? This has been a major area where GPUs have been ahead of TPUs. I have noticed that there now is at least an experimental implementation from the Jax-team: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py.

  • Training directly from tfds seemed straight forward. However, I was a bit confused about how to implement more advanced data loader features, for instance probability sampling like explained here. This can be somewhat tricky to do efficiently on multiple tpus. What is the sensible approach here? Manually sampling into a tfds dataset does not seem very efficient. Are there external libraries here that are compatible with maxtext?

  • Are there plans for implementing DPO/RLHF?

I also shamelessly wanted to point you to my own repo: https://github.com/peregilk/ttconnect. It is a very simple bash script that ideally should be run on a VM in the same zone. It automatically opens up synchronised tmux windows to all the VMs in the pod, and allows you to type the same command into all the VMs. This makes it even easier to go from one tpu to pods.

@rwitten
Copy link
Collaborator

rwitten commented Mar 31, 2024

Thank you for the comments!

(1) Fused attention is on by default for training! We use "splash attention" which is a custom and faster version! (And we're working on accelerated inference attentions.)
(2) We don't implement more advanced data loaders though I think they can be implemented in TFDS. It is also easy to plug in your own data loader. Is there a specific data loading solution you'd like us to use?
(3) Yes, DPO is underway!

ttconnect is super cool, thanks for sending!

@peregilk
Copy link
Author

peregilk commented Apr 1, 2024

Thanks for the answer. Looking forward to the DPO support.

It would of course be fantastic if the HuggingFace datasets could natively be supported. I have never really been able to run large non-streaming datasets from HF on the TPUs (disk-size issues on the VMs), but we have been able to wrap the HF datasets in torch.split_dataset_by_node, to stream on multiple TPUs. Im not sure if I am able to implement something like this into MaxText though. Not really sure on what level it should be implemented.

Any chance you support HF datasets in the future?

But any way of preprocessing the data before it is split to the TPUs would be extremely useful for running experiments on dataset building. Thats both for sampling or filtering based on a field in the dataset.

@A9isha
Copy link
Collaborator

A9isha commented May 6, 2024

Yes support for HF datasets in MaxText is on the way
@aireenmei

@aireenmei
Copy link
Collaborator

Thank you for tagging me on this. Yes, supporting HuggingFace dataset is in our plan. We have some implementations and are undergoing some perf evaluations to understand it better. I will update here when we have it out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants