Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support S3 data loading #729

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Support S3 data loading #729

wants to merge 3 commits into from

Conversation

jrocmar
Copy link

@jrocmar jrocmar commented Mar 11, 2024

This PR introduces the S3IndexedDataset, which supports loading a dataset stored in S3 in the same format as the MMapIndexedDataset S3 data loading to IndexedDataset. In particular, the .idx file is downloaded to a local directory at initialization so that we can memory map it and the .bin file is streamed into memory block-by-block.

@jrocmar
Copy link
Author

jrocmar commented Mar 11, 2024

Draft implementation for #698

@jrocmar jrocmar marked this pull request as draft March 11, 2024 19:22
@jrocmar jrocmar marked this pull request as ready for review March 11, 2024 23:34
Copy link

@jkamalu jkamalu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments. Thanks a bunch for the MR! Would like to see some consolidation in S3IndexedDataset before being merged.

try:
response = s3_client.head_object(Bucket=parsed_s3_path.bucket, Key=parsed_s3_path.key)
return True
except ClientError:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A ClientError will be thrown in the case of any failure correct, even in the object exists? For example, we should differentiate between bad permissions and a missing object.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point- I changed it to raise the exception if the error code is anything besides a 404.

Comment on lines 52 to 54
assert path.startswith('s3://')
path = path[len('s3://') :]
self._bucket, self._key = path.split('/', 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use globals

_S3_PREFIX

_parse_s3_path

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is now moved into _S3BinReader. I changed it to use the parse_s3_path method.

return False


class _S3Agent:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class, though private, should have arg and return typing for all functions

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is moved to _S3BinReader. Added arg and return typing for all functions.

self._client.close()


def _download_file(s3_client, s3_path, local_path):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs typing and doc string

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

s3_client.download_file(parsed_s3_path.bucket, parsed_s3_path.key, local_path)


def _maybe_download_file(s3_path, local_path):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs typing and doc string, in particular a description of when maybe means no and when it means yes

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 145 to 146
if not os.path.exists(local_path):
_download_file(s3_client, s3_path, local_path)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this redownload on all local rank 0s?

Copy link
Author

@jrocmar jrocmar Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because local_path will already exist on local rank 0s.

assert os.path.exists(local_path)


class S3IndexedDataset(torch.utils.data.Dataset):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this class inherit from MMapIndexedDataset, keeping the length, get item, and properties, but overloading the s3 specific functions?

setstate and getstate could even be inherited if state were presumed defined in init

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some consideration and after seeing this refactor, it seemed like the best way to integrate S3 data loading into IndexedDataset was to introduce an abstract _BinReader class and concrete _MMapBinReader, _FileBinReader and _S3BinReader subclasses, but let me know your thoughts on the design. I'm open to feedback.

def mock_s3_client(*args, **kwargs):
return MOCK_S3_CLIENT

monkeypatch.setattr("boto3.client", mock_s3_client)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would the following not work?

monkeypatch.setattr("boto3.client", MockS3Client)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't work, because it creates a new MockS3Client each time boto3.client is called. Instead, I want the state (i.e., self._data) to be maintained across calls to boto3.client, because it reflects the state of the mock S3.

s3_client.upload_file(path, bucket_name, path[1:])
assert path_to_data.startswith("/")
prefix_for_path_prefix = "s3://" + bucket_name
IndexedDataset = S3IndexedDataset
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What purpose does this line serve?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed after the refactor.

@@ -67,7 +71,7 @@ def merge_datasets(idir):
merge_main()


def do_test_preprocess_data(temp_dir, extra_args=[]):
def do_test_preprocess_data(temp_dir, extra_args=[], indexed_dataset_type=MMapIndexedDataset):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for a default, let's make it a positional arg

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This argument is now a boolean called s3 that defaults to False. Do you still want to make the new argument a positional one?

Jake Marcus added 3 commits March 20, 2024 09:11
This commit introduces the S3IndexedDataset, which supports
loading a dataset stored in S3 in the same format as the
MMapIndexedDataset. In particular, the .idx file is downloaded
to a local directory at initialization so that we can memory map
it and the .bin file is streamed into memory block-by-block.
This commit removes the S3IndexedDataset class and integrates
its functionality into the IndexedDataset class. It does so by
introducing an abstract _BinReader class. Instead of
IndexedDataset switching between file pointers and memory
mapping using if/else statements, the class switches between
file pointers, memory mapping and S3 data loading based on the
specific _BinReader used (i.e., _FileBinReader, _MMapBinReader,
and _S3BinReader).
@jrocmar
Copy link
Author

jrocmar commented Mar 20, 2024

Left some comments. Thanks a bunch for the MR! Would like to see some consolidation in S3IndexedDataset before being merged.

@jkamalu Thank you for the comments! Based on those comments and after seeing this refactor, I refactored S3 data loading to integrate it into IndexedDataset and address your feedback.

@jrocmar
Copy link
Author

jrocmar commented Mar 20, 2024

@jkamalu The _S3BinReader will have poor performance when using a global random shuffle over samples (which is what GPTDataset currently does). I need to either implement "block shuffling" in GPTDataset as described in the "Example" section here (that section also describes why _S3BinReader will have poor performance) or I need to add an option to disable shuffling in GPTDataset (the user then has to be responsible for preshuffling their data). I'm inclined to just add the option to disable shuffling to start, because it's simpler. What do you think?

@jrocmar
Copy link
Author

jrocmar commented Apr 1, 2024

@jkamalu The _S3BinReader will have poor performance when using a global random shuffle over samples (which is what GPTDataset currently does). I need to either implement "block shuffling" in GPTDataset as described in the "Example" section here (that section also describes why _S3BinReader will have poor performance) or I need to add an option to disable shuffling in GPTDataset (the user then has to be responsible for preshuffling their data). I'm inclined to just add the option to disable shuffling to start, because it's simpler. What do you think?

Moving the "Example" section from the old NeMo PR into this comment.

In NeMo, a sample consists of seq_length tokens. For simplicity, suppose each token is 1 byte and seq_length is 100.

Each sample then takes 100 bytes.

Suppose we have a dataset with 12 samples.

Sample index 0 is stored in bytes [0, 100), sample index 1 is stored in bytes [100, 200), ..., and sample index 11 is stored in bytes [1100, 1200).

Currently, NeMo takes the list of sample indices:

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

And produces a shuffle_idx, which is just a permutation of those sample indices like:

[11, 3, 0, 6, 1, 9, 10, 2, 5, 7, 4, 8]

The shuffle_idx determines the order in which NeMo processes samples.

We could have the IndexedDataset just grab the bytes for a sample at a time. The first request would be for the bytes [1100, 1200), the second request would be for the bytes [300, 400), the third request would be for the bytes [0, 100) and so on in the order determined by shuffle_idx. That works, but it's slow, because you're making one request for each sample.

Let's try to introduce an in-memory cache. In particular, suppose the IndexedDataset does this:

  • If the requested bytes range [start, end) is in the cache, then extract the requested bytes range from the cache.
  • Otherwise, first refresh the cache by downloading the bytes range [start, start + cache_nbytes) and then extract the requested bytes range from the cache.

Suppose the cache_nbytes is 400. The first request would be for the bytes [1100, 1200). The cache is initially empty, so we refresh the cache by downloading the bytes [1100, 1500) and then extract the requested bytes range from the cache. The second request would be for the bytes [300, 400). Those bytes are not in the cache, so we refresh the cache by downloading the bytes [300, 700) and then extract the requested bytes range from that cache. And so on.

We actually made the problem worse. For most samples, we have to refresh the cache, so we have not reduced the number of requests much. We've just made the requests have to download a larger number of bytes. The issue is that the bytes needed for a sample index are probably not next to the bytes needed for the previous sample index.

To use the cache effectively, we have to introduce some correlation in the shuffle. In particular, we divide the original list of sample indices into blocks like:

  • [0, 1, 2, 3]
  • [4, 5, 6, 7]
  • [8, 9, 10, 11]

We then shuffle within the blocks like:

  • [3, 0, 2, 1]
  • [4, 6, 5, 7]
  • [11, 10, 8, 9]

We then shuffle the order of the blocks like:

  • [11, 10, 8, 9]
  • [4, 6, 5, 7]
  • [3, 0, 2, 1]

And we construct the block-shuffled shuffle_idx like:

[11, 10, 8, 9, 4, 6, 5, 7, 3, 0, 2, 1]

We also have to change which bytes we download on a cache miss. In particular, we download the bytes [cache_start, cache_start + cache_nbytes), where cache_start is (start//cache_nbytes) * cache_nbytes.

The first request would be for the bytes [1100, 1200). The cache is initially empty, so we refresh the cache by downloading the bytes [800, 1200) and then extract the requested bytes range from that cache. The second request would be for the bytes [1000, 1100). We extract those bytes from the cache. The third request would be for the bytes [800, 1200). We extract those bytes from the cache. And so on. In this way, we only have to refresh cache at the start of each new block.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants