Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add HF input pipeline #592

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ jobs:
- name: Test with pytest
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c 'cd MaxText;python3 -m pytest'
- name: Test train.py with c4
- name: Test train.py with HF c4
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet dataset_name=parquet dataset_type=hf steps=2 tokenizer_path=google-t5/t5-large enable_checkpointing=false'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this command. Is this loading a single parquet file or all parquent files in that directory?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is only using a single file. Because using the whole dataset adds an overhead to resolve the whole dataset, which is too much for unit tests.
I've added more details in this doc (go/maxtext-input-pipeline) - Issues during implementation and solutions - HF hub or local dataset?

- name: Test train.py with TFDS c4
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false'
Expand Down
21 changes: 16 additions & 5 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,30 @@ ici_autoregressive_parallelism: 1
# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1.
num_slices: -1

# Tokenizer
vocab_size: 32_000 # powers of 2 for sharding
# When using HF pipeline, set tokenizer_path to a tokenizer in HF hub, e.g. "google-t5/t5-large", or a local folder contains HF tokenizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hf pipeline isn't compatible with all of our tokenizers? Why not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current tokenizers in sp model format are not supported by HF tokenizer loader, AutoTokenizer. HF uses tokenizer in json format.
I've added more details in this doc (go/maxtext-input-pipeline) - Issues during implementation and solutions - tokenizer

tokenizer_path: "assets/tokenizer.llama2"
# provide access token if you are using a huggingface tokenizer from a gated model
hf_access_token: ''

# Dataset
# Replace with your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/"
# For TFDS pipeline, Replace with your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm very concerned by the level of complexity here -- it feels very un-MaxTexty and needlessly confusing for our reference implementation. It is critical that we take action to make this simple.

# For Grain pipeline, set dataset_path to match the MOUNT_PATH you used in setup_gcsfuse.sh
# For HF pipeline, this is the data_files field ("gs://" path supported) in load_dataset (https://huggingface.co/docs/datasets/en/loading)
dataset_path: ""
vocab_size: 32_000 # powers of 2 for sharding
tokenizer_path: "assets/tokenizer.llama2"
dataset_name: 'c4/en:3.0.1'
# For TFDS pipeline, dataset_name is the name of TFDS dataset
# For Grain pipeline, dataset_name is the subfolder path under dataset_path
# For HF pipeline, dataset_name corresponds to the path field in load_dataset (https://huggingface.co/docs/datasets/en/loading)
dataset_name: 'c4/en:3.0.1' # e.g. 'allenai/c4' for HF pipeline
dataset_dir_hf: '' # e.g. 'en', used by HF pipeline only, corresponds to the data_dir field in load_dataset (https://huggingface.co/docs/datasets/en/loading)
eval_dataset_name: 'c4/en:3.0.1'
eval_split: 'validation'
per_device_batch_size: 12.0
expansion_factor_real_data: -1 # if -1 then all hosts will load real data, else total_hosts//expansion_factor_real_data will pull data from GCS.
eval_per_device_batch_size: 0
max_corpus_chars: 10_000_000
dataset_type: c4 # must be c4 or synthetic
dataset_type: c4 # must be c4, c4-array_record, c4_mlperf, hf or synthetic

# Setting for grain
grain_worker_count: 4
Expand Down
126 changes: 126 additions & 0 deletions MaxText/input_pipeline/_hf_data_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
Copyright 2023 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""Input pipeline using Huggingface datasets."""

from typing import Optional, Union

import ml_collections
import jax

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from transformers import AutoTokenizer
from torchdata.datapipes.iter import IterableWrapper

from input_pipeline import _hf_operations
import multihost_dataloading

def get_datasets(
config: ml_collections.ConfigDict
):
"""Load huggingface dataset"""
train_ds = load_dataset(config.dataset_name,
data_dir=config.dataset_dir_hf,
data_files=config.dataset_path,
split="train",
streaming=True,
token=config.hf_access_token)
return train_ds, None

def preprocess_dataset(config: ml_collections.ConfigDict,
global_mesh,
train_ds,
add_bos = True,
add_eos = True,
):
"""preprocess dataset"""
# Set global batch size.
global_batch_size_to_load = config.global_batch_size_to_load

train_iter = preprocessing_pipeline(
dataset=train_ds,
tokenizer_path=config.tokenizer_path,
add_bos=add_bos,
add_eos=add_eos,
batch_size=global_batch_size_to_load,
global_mesh=global_mesh,
shuffle=config.enable_data_shuffling,
num_epochs=1,
pack_examples=True,
max_length=config.max_target_length,
data_shuffle_seed=config.data_shuffle_seed,
access_token=config.hf_access_token,)

return train_iter, None, None

def preprocessing_pipeline(
dataset,
tokenizer_path,
add_bos: bool,
add_eos: bool,
batch_size: int,
global_mesh,
shuffle: bool,
num_epochs: Optional[int] = 1, # only support num_epoch=1 for now
pack_examples: bool = True,
max_length: int = 512,
shift: bool = True,
drop_remainder: bool = True, # does not support drop_remainder
data_shuffle_seed = 0,
access_token: Union[str | None] = None,
prefetch_buffer_size: int = 100,
):
"""pipeline for preprocessing"""
assert (
batch_size % global_mesh.size == 0
), 'Batch size should be divisible number of global devices.'

dataset = split_dataset_by_node(dataset, world_size=jax.process_count(), rank=jax.process_index())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@khatwanimohit make sure to sync with Mohit here, this is no longer right.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implemention does not support "expansion_factor_real_data" yet, as I noted in the doc input_pipeline.md. I will add support for expansion_factor_real_data in later PR. (Currently "expansion_factor_real_data" is ignored when dataset_type=hf). This should work for now, @khatwanimohit please also take a look.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a reason why it can't be done in this PR. Its a small change. You'll just have to pass in

dataloading_host_index = process_indices.index(jax.process_index()),
dataloading_host_count = len(process_indices)

to _hf_data_processing.preprocess_dataset


tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
add_bos_token=add_bos,
add_eos_token=add_eos,
model_max_length=max_length,
token=access_token)

dataset = dataset.map(_hf_operations.tokenization, batched=True,
fn_kwargs={"tokenizer": tokenizer, "max_length": max_length})

dataset = dataset.map(_hf_operations.normalize_features, batched=True,
fn_kwargs={"key":"input_ids"})

dataset = dataset.select_columns(['inputs', 'targets'])
dataset = dataset.with_format("np")

if shuffle:
dataset = dataset.shuffle(seed=data_shuffle_seed)

if pack_examples:
pack_op = _hf_operations.PackAndBatchOperation(
batch_size=batch_size // jax.process_count(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably this is also no longer right, talk to mohit!

length_struct={"inputs": max_length, "targets":max_length},
shift_inputs=shift,
)
dataset = _hf_operations.TransformedDataset(pack_op, dataset)

dataset = IterableWrapper(iter(dataset))
dataset = dataset.prefetch(prefetch_buffer_size)

multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataset, global_mesh)

# Return multi-host jax.Array prep iterator
return multihost_gen