Skip to content

Commit

Permalink
add HF input pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
aireenmei committed Apr 12, 2024
1 parent 28a3279 commit 390b9ae
Show file tree
Hide file tree
Showing 13 changed files with 545 additions and 20 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ jobs:
- name: Test with pytest
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c 'cd MaxText;python3 -m pytest'
- name: Test train.py with c4
- name: Test train.py with HF c4
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet dataset_name=parquet dataset_type=hf steps=2 tokenizer_path=google-t5/t5-large enable_checkpointing=false'
- name: Test train.py with TFDS c4
run: |
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false'
Expand Down
21 changes: 16 additions & 5 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,30 @@ ici_autoregressive_parallelism: 1
# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1.
num_slices: -1

# Tokenizer
vocab_size: 32_000 # powers of 2 for sharding
# When using HF pipeline, set tokenizer_path to a tokenizer in HF hub, e.g. "google-t5/t5-large", or a local folder contains HF tokenizer
tokenizer_path: "assets/tokenizer.llama2"
# provide access token if you are using a huggingface tokenizer from a gated model
hf_access_token: ''

# Dataset
# Replace with your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/"
# For TFDS pipeline, Replace with your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/"
# For Grain pipeline, set dataset_path to match the MOUNT_PATH you used in setup_gcsfuse.sh
# For HF pipeline, this is the data_files field ("gs://" path supported) in load_dataset (https://huggingface.co/docs/datasets/en/loading)
dataset_path: ""
vocab_size: 32_000 # powers of 2 for sharding
tokenizer_path: "assets/tokenizer.llama2"
dataset_name: 'c4/en:3.0.1'
# For TFDS pipeline, dataset_name is the name of TFDS dataset
# For Grain pipeline, dataset_name is the subfolder path under dataset_path
# For HF pipeline, dataset_name corresponds to the path field in load_dataset (https://huggingface.co/docs/datasets/en/loading)
dataset_name: 'c4/en:3.0.1' # e.g. 'allenai/c4' for HF pipeline
dataset_dir_hf: '' # e.g. 'en', used by HF pipeline only, corresponds to the data_dir field in load_dataset (https://huggingface.co/docs/datasets/en/loading)
eval_dataset_name: 'c4/en:3.0.1'
eval_split: 'validation'
per_device_batch_size: 12.0
expansion_factor_real_data: -1 # if -1 then all hosts will load real data, else total_hosts//expansion_factor_real_data will pull data from GCS.
eval_per_device_batch_size: 0
max_corpus_chars: 10_000_000
dataset_type: c4 # must be c4 or synthetic
dataset_type: c4 # must be c4, c4-array_record, c4_mlperf, hf or synthetic

# Setting for grain
grain_worker_count: 4
Expand Down
126 changes: 126 additions & 0 deletions MaxText/input_pipeline/_hf_data_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""Input pipeline using Huggingface datasets."""

from typing import Optional, Union

import ml_collections
import jax

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from transformers import AutoTokenizer
from torchdata.datapipes.iter import IterableWrapper

from input_pipeline import _hf_operations
import multihost_dataloading

def get_datasets(
config: ml_collections.ConfigDict
):
"""Load huggingface dataset"""
train_ds = load_dataset(config.dataset_name,
data_dir=config.dataset_dir_hf,
data_files=config.dataset_path,
split="train",
streaming=True,
token=config.hf_access_token)
return train_ds, None

def preprocess_dataset(config: ml_collections.ConfigDict,
global_mesh,
train_ds,
add_bos = True,
add_eos = True,
):
"""preprocess dataset"""
# Set global batch size.
global_batch_size_to_load = config.global_batch_size_to_load

train_iter = preprocessing_pipeline(
dataset=train_ds,
tokenizer_path=config.tokenizer_path,
add_bos=add_bos,
add_eos=add_eos,
batch_size=global_batch_size_to_load,
global_mesh=global_mesh,
shuffle=config.enable_data_shuffling,
num_epochs=1,
pack_examples=True,
max_length=config.max_target_length,
data_shuffle_seed=config.data_shuffle_seed,
access_token=config.hf_access_token,)

return train_iter, None, None

def preprocessing_pipeline(
dataset,
tokenizer_path,
add_bos: bool,
add_eos: bool,
batch_size: int,
global_mesh,
shuffle: bool,
num_epochs: Optional[int] = 1, # only support num_epoch=1 for now
pack_examples: bool = True,
max_length: int = 512,
shift: bool = True,
drop_remainder: bool = True, # does not support drop_remainder
data_shuffle_seed = 0,
access_token: Union[str | None] = None,
prefetch_buffer_size: int = 100,
):
"""pipeline for preprocessing"""
assert (
batch_size % global_mesh.size == 0
), 'Batch size should be divisible number of global devices.'

dataset = split_dataset_by_node(dataset, world_size=jax.process_count(), rank=jax.process_index())

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,
add_bos_token=add_bos,
add_eos_token=add_eos,
model_max_length=max_length,
token=access_token)

dataset = dataset.map(_hf_operations.tokenization, batched=True,
fn_kwargs={"tokenizer": tokenizer, "max_length": max_length})

dataset = dataset.map(_hf_operations.normalize_features, batched=True,
fn_kwargs={"key":"input_ids"})

dataset = dataset.select_columns(['inputs', 'targets'])
dataset = dataset.with_format("np")

if shuffle:
dataset = dataset.shuffle(seed=data_shuffle_seed)

if pack_examples:
pack_op = _hf_operations.PackAndBatchOperation(
batch_size=batch_size // jax.process_count(),
length_struct={"inputs": max_length, "targets":max_length},
shift_inputs=shift,
)
dataset = _hf_operations.TransformedDataset(pack_op, dataset)

dataset = IterableWrapper(iter(dataset))
dataset = dataset.prefetch(prefetch_buffer_size)

multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataset, global_mesh)

# Return multi-host jax.Array prep iterator
return multihost_gen

0 comments on commit 390b9ae

Please sign in to comment.