-
Notifications
You must be signed in to change notification settings - Fork 226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add HF input pipeline #592
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,19 +155,30 @@ ici_autoregressive_parallelism: 1 | |
# you should set compile_toplogy_num_slices, which will in turn set this value. For non-TPU environments this is set to 1. | ||
num_slices: -1 | ||
|
||
# Tokenizer | ||
vocab_size: 32_000 # powers of 2 for sharding | ||
# When using HF pipeline, set tokenizer_path to a tokenizer in HF hub, e.g. "google-t5/t5-large", or a local folder contains HF tokenizer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hf pipeline isn't compatible with all of our tokenizers? Why not? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current tokenizers in sp model format are not supported by HF tokenizer loader, AutoTokenizer. HF uses tokenizer in json format. |
||
tokenizer_path: "assets/tokenizer.llama2" | ||
# provide access token if you are using a huggingface tokenizer from a gated model | ||
hf_access_token: '' | ||
|
||
# Dataset | ||
# Replace with your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/" | ||
# For TFDS pipeline, Replace with your path given as argument in download_dataset.sh, e.g. "gs://my-maxtext-dataset/" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm very concerned by the level of complexity here -- it feels very un-MaxTexty and needlessly confusing for our reference implementation. It is critical that we take action to make this simple. |
||
# For Grain pipeline, set dataset_path to match the MOUNT_PATH you used in setup_gcsfuse.sh | ||
# For HF pipeline, this is the data_files field ("gs://" path supported) in load_dataset (https://huggingface.co/docs/datasets/en/loading) | ||
dataset_path: "" | ||
vocab_size: 32_000 # powers of 2 for sharding | ||
tokenizer_path: "assets/tokenizer.llama2" | ||
dataset_name: 'c4/en:3.0.1' | ||
# For TFDS pipeline, dataset_name is the name of TFDS dataset | ||
# For Grain pipeline, dataset_name is the subfolder path under dataset_path | ||
# For HF pipeline, dataset_name corresponds to the path field in load_dataset (https://huggingface.co/docs/datasets/en/loading) | ||
dataset_name: 'c4/en:3.0.1' # e.g. 'allenai/c4' for HF pipeline | ||
dataset_dir_hf: '' # e.g. 'en', used by HF pipeline only, corresponds to the data_dir field in load_dataset (https://huggingface.co/docs/datasets/en/loading) | ||
eval_dataset_name: 'c4/en:3.0.1' | ||
eval_split: 'validation' | ||
per_device_batch_size: 12.0 | ||
expansion_factor_real_data: -1 # if -1 then all hosts will load real data, else total_hosts//expansion_factor_real_data will pull data from GCS. | ||
eval_per_device_batch_size: 0 | ||
max_corpus_chars: 10_000_000 | ||
dataset_type: c4 # must be c4 or synthetic | ||
dataset_type: c4 # must be c4, c4-array_record, c4_mlperf, hf or synthetic | ||
|
||
# Setting for grain | ||
grain_worker_count: 4 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
""" | ||
Copyright 2023 Google LLC | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
https://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
"""Input pipeline using Huggingface datasets.""" | ||
|
||
from typing import Optional, Union | ||
|
||
import ml_collections | ||
import jax | ||
|
||
from datasets import load_dataset | ||
from datasets.distributed import split_dataset_by_node | ||
from transformers import AutoTokenizer | ||
from torchdata.datapipes.iter import IterableWrapper | ||
|
||
from input_pipeline import _hf_operations | ||
import multihost_dataloading | ||
|
||
def get_datasets( | ||
config: ml_collections.ConfigDict | ||
): | ||
"""Load huggingface dataset""" | ||
train_ds = load_dataset(config.dataset_name, | ||
data_dir=config.dataset_dir_hf, | ||
data_files=config.dataset_path, | ||
split="train", | ||
streaming=True, | ||
token=config.hf_access_token) | ||
return train_ds, None | ||
|
||
def preprocess_dataset(config: ml_collections.ConfigDict, | ||
global_mesh, | ||
train_ds, | ||
add_bos = True, | ||
add_eos = True, | ||
): | ||
"""preprocess dataset""" | ||
# Set global batch size. | ||
global_batch_size_to_load = config.global_batch_size_to_load | ||
|
||
train_iter = preprocessing_pipeline( | ||
dataset=train_ds, | ||
tokenizer_path=config.tokenizer_path, | ||
add_bos=add_bos, | ||
add_eos=add_eos, | ||
batch_size=global_batch_size_to_load, | ||
global_mesh=global_mesh, | ||
shuffle=config.enable_data_shuffling, | ||
num_epochs=1, | ||
pack_examples=True, | ||
max_length=config.max_target_length, | ||
data_shuffle_seed=config.data_shuffle_seed, | ||
access_token=config.hf_access_token,) | ||
|
||
return train_iter, None, None | ||
|
||
def preprocessing_pipeline( | ||
dataset, | ||
tokenizer_path, | ||
add_bos: bool, | ||
add_eos: bool, | ||
batch_size: int, | ||
global_mesh, | ||
shuffle: bool, | ||
num_epochs: Optional[int] = 1, # only support num_epoch=1 for now | ||
pack_examples: bool = True, | ||
max_length: int = 512, | ||
shift: bool = True, | ||
drop_remainder: bool = True, # does not support drop_remainder | ||
data_shuffle_seed = 0, | ||
access_token: Union[str | None] = None, | ||
prefetch_buffer_size: int = 100, | ||
): | ||
"""pipeline for preprocessing""" | ||
assert ( | ||
batch_size % global_mesh.size == 0 | ||
), 'Batch size should be divisible number of global devices.' | ||
|
||
dataset = split_dataset_by_node(dataset, world_size=jax.process_count(), rank=jax.process_index()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @khatwanimohit make sure to sync with Mohit here, this is no longer right. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implemention does not support "expansion_factor_real_data" yet, as I noted in the doc input_pipeline.md. I will add support for expansion_factor_real_data in later PR. (Currently "expansion_factor_real_data" is ignored when dataset_type=hf). This should work for now, @khatwanimohit please also take a look. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see a reason why it can't be done in this PR. Its a small change. You'll just have to pass in
to |
||
|
||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, | ||
add_bos_token=add_bos, | ||
add_eos_token=add_eos, | ||
model_max_length=max_length, | ||
token=access_token) | ||
|
||
dataset = dataset.map(_hf_operations.tokenization, batched=True, | ||
fn_kwargs={"tokenizer": tokenizer, "max_length": max_length}) | ||
|
||
dataset = dataset.map(_hf_operations.normalize_features, batched=True, | ||
fn_kwargs={"key":"input_ids"}) | ||
|
||
dataset = dataset.select_columns(['inputs', 'targets']) | ||
dataset = dataset.with_format("np") | ||
|
||
if shuffle: | ||
dataset = dataset.shuffle(seed=data_shuffle_seed) | ||
|
||
if pack_examples: | ||
pack_op = _hf_operations.PackAndBatchOperation( | ||
batch_size=batch_size // jax.process_count(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably this is also no longer right, talk to mohit! |
||
length_struct={"inputs": max_length, "targets":max_length}, | ||
shift_inputs=shift, | ||
) | ||
dataset = _hf_operations.TransformedDataset(pack_op, dataset) | ||
|
||
dataset = IterableWrapper(iter(dataset)) | ||
dataset = dataset.prefetch(prefetch_buffer_size) | ||
|
||
multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataset, global_mesh) | ||
|
||
# Return multi-host jax.Array prep iterator | ||
return multihost_gen |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this command. Is this loading a single parquet file or all parquent files in that directory?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is only using a single file. Because using the whole dataset adds an overhead to resolve the whole dataset, which is too much for unit tests.
I've added more details in this doc (go/maxtext-input-pipeline) - Issues during implementation and solutions - HF hub or local dataset?