New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support S3 data loading #729
base: main
Are you sure you want to change the base?
Conversation
Draft implementation for #698 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments. Thanks a bunch for the MR! Would like to see some consolidation in S3IndexedDataset before being merged.
try: | ||
response = s3_client.head_object(Bucket=parsed_s3_path.bucket, Key=parsed_s3_path.key) | ||
return True | ||
except ClientError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A ClientError will be thrown in the case of any failure correct, even in the object exists? For example, we should differentiate between bad permissions and a missing object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point- I changed it to raise the exception if the error code is anything besides a 404.
assert path.startswith('s3://') | ||
path = path[len('s3://') :] | ||
self._bucket, self._key = path.split('/', 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use globals
_S3_PREFIX
_parse_s3_path
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is now moved into _S3BinReader. I changed it to use the parse_s3_path method.
return False | ||
|
||
|
||
class _S3Agent: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class, though private, should have arg and return typing for all functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is moved to _S3BinReader. Added arg and return typing for all functions.
self._client.close() | ||
|
||
|
||
def _download_file(s3_client, s3_path, local_path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs typing and doc string
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
s3_client.download_file(parsed_s3_path.bucket, parsed_s3_path.key, local_path) | ||
|
||
|
||
def _maybe_download_file(s3_path, local_path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs typing and doc string, in particular a description of when maybe means no and when it means yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if not os.path.exists(local_path): | ||
_download_file(s3_client, s3_path, local_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this redownload on all local rank 0s?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, because local_path
will already exist on local rank 0s.
assert os.path.exists(local_path) | ||
|
||
|
||
class S3IndexedDataset(torch.utils.data.Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this class inherit from MMapIndexedDataset, keeping the length, get item, and properties, but overloading the s3 specific functions?
setstate and getstate could even be inherited if state were presumed defined in init
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After some consideration and after seeing this refactor, it seemed like the best way to integrate S3 data loading into IndexedDataset was to introduce an abstract _BinReader class and concrete _MMapBinReader, _FileBinReader and _S3BinReader subclasses, but let me know your thoughts on the design. I'm open to feedback.
def mock_s3_client(*args, **kwargs): | ||
return MOCK_S3_CLIENT | ||
|
||
monkeypatch.setattr("boto3.client", mock_s3_client) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would the following not work?
monkeypatch.setattr("boto3.client", MockS3Client)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't work, because it creates a new MockS3Client each time boto3.client is called. Instead, I want the state (i.e., self._data
) to be maintained across calls to boto3.client, because it reflects the state of the mock S3.
s3_client.upload_file(path, bucket_name, path[1:]) | ||
assert path_to_data.startswith("/") | ||
prefix_for_path_prefix = "s3://" + bucket_name | ||
IndexedDataset = S3IndexedDataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What purpose does this line serve?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed after the refactor.
@@ -67,7 +71,7 @@ def merge_datasets(idir): | |||
merge_main() | |||
|
|||
|
|||
def do_test_preprocess_data(temp_dir, extra_args=[]): | |||
def do_test_preprocess_data(temp_dir, extra_args=[], indexed_dataset_type=MMapIndexedDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for a default, let's make it a positional arg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This argument is now a boolean called s3
that defaults to False. Do you still want to make the new argument a positional one?
This commit introduces the S3IndexedDataset, which supports loading a dataset stored in S3 in the same format as the MMapIndexedDataset. In particular, the .idx file is downloaded to a local directory at initialization so that we can memory map it and the .bin file is streamed into memory block-by-block.
This commit removes the S3IndexedDataset class and integrates its functionality into the IndexedDataset class. It does so by introducing an abstract _BinReader class. Instead of IndexedDataset switching between file pointers and memory mapping using if/else statements, the class switches between file pointers, memory mapping and S3 data loading based on the specific _BinReader used (i.e., _FileBinReader, _MMapBinReader, and _S3BinReader).
@jkamalu Thank you for the comments! Based on those comments and after seeing this refactor, I refactored S3 data loading to integrate it into IndexedDataset and address your feedback. |
@jkamalu The _S3BinReader will have poor performance when using a global random shuffle over samples (which is what GPTDataset currently does). I need to either implement "block shuffling" in GPTDataset as described in the "Example" section here (that section also describes why _S3BinReader will have poor performance) or I need to add an option to disable shuffling in GPTDataset (the user then has to be responsible for preshuffling their data). I'm inclined to just add the option to disable shuffling to start, because it's simpler. What do you think? |
Moving the "Example" section from the old NeMo PR into this comment. In NeMo, a sample consists of Each sample then takes 100 bytes. Suppose we have a dataset with 12 samples. Sample index 0 is stored in bytes [0, 100), sample index 1 is stored in bytes [100, 200), ..., and sample index 11 is stored in bytes [1100, 1200). Currently, NeMo takes the list of sample indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] And produces a [11, 3, 0, 6, 1, 9, 10, 2, 5, 7, 4, 8] The We could have the Let's try to introduce an in-memory cache. In particular, suppose the IndexedDataset does this:
Suppose the We actually made the problem worse. For most samples, we have to refresh the cache, so we have not reduced the number of requests much. We've just made the requests have to download a larger number of bytes. The issue is that the bytes needed for a sample index are probably not next to the bytes needed for the previous sample index. To use the cache effectively, we have to introduce some correlation in the shuffle. In particular, we divide the original list of sample indices into blocks like:
We then shuffle within the blocks like:
We then shuffle the order of the blocks like:
And we construct the block-shuffled [11, 10, 8, 9, 4, 6, 5, 7, 3, 0, 2, 1] We also have to change which bytes we download on a cache miss. In particular, we download the bytes [ The first request would be for the bytes [1100, 1200). The cache is initially empty, so we refresh the cache by downloading the bytes [800, 1200) and then extract the requested bytes range from that cache. The second request would be for the bytes [1000, 1100). We extract those bytes from the cache. The third request would be for the bytes [800, 1200). We extract those bytes from the cache. And so on. In this way, we only have to refresh cache at the start of each new block. |
This PR introduces
the S3IndexedDataset, which supports loading a dataset stored in S3 in the same format as the MMapIndexedDatasetS3 data loading to IndexedDataset. In particular, the .idx file is downloaded to a local directory at initialization so that we can memory map it and the .bin file is streamed into memory block-by-block.