Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support S3 checkpointing for the torch strategy in distributed checkpointing #748

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jrocmar
Copy link

@jrocmar jrocmar commented Mar 22, 2024

This PR adds support for saving checkpoints to cloud storage (e.g., S3) and loading checkpoints from cloud storage for the torch strategy in distributed checkpointing. It does so by replacing pathlib.Path with cloudpathlib.AnyPath, FileSystemReader with FsspecSystemReader, and FileSytemWriter with FsspecSystemWriter. The PR enables cloud checkpointing, but makes little attempt to optimize it.

This commit adds support for saving checkpoints to cloud storage
(e.g., S3) and loading checkpoints from cloud storage for the
torch strategy in distributed checkpointing. It does so by
replacing pathlib.Path with cloudpathlib.AnyPath, FileSystemReader
with FsspecSystemReader, and FileSytemWriter with
FsspecSystemWriter.

The commit enables cloud checkpointing, but makes little attempt
to optimize it.
@jrocmar
Copy link
Author

jrocmar commented Mar 25, 2024

@mikolajblaz Would you be the right one to review this? Thanks in advance!

Copy link
Contributor

@mikolajblaz mikolajblaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general, thanks.

Non-optimized cloud upload is understandable, but I want to make sure that the baseline scenario does not regress in terms of performance (only worried about 2-stage torch.save) and usability (added cloudpathlib requirement for MCore).

@@ -345,7 +343,7 @@ class TorchDistSaveShardedStrategy(SaveShardedStrategy):
"""

def __init__(
self, backend: str, version: int, keep_only_main_replica: bool = True, thread_count: int = 2
self, backend: str, version: int, keep_only_main_replica: bool = True, thread_count: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 1 as a default?

@@ -382,27 +380,27 @@ def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
# Use PyT saving mechanism
save_state_dict(
pyt_state_dict,
FileSystemWriter(checkpoint_dir, thread_count=self.thread_count),
FsspecWriter(path=checkpoint_dir, thread_count=self.thread_count),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This need to be configurable, because I think Fsspec is slower than FilesystemWriter. Also, please note that there is a change on main which uses a custom save_state_dict and a custom FileSystemAsyncWriter for performance.

@@ -3,8 +3,9 @@
""" Module for managing distributed checkpoints metadata. """

import json
import os
from cloudpathlib import AnyPath
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume it requires extra dependencies.
Can we make it optional?

@@ -57,6 +59,22 @@
logger = logging.getLogger(__name__)


def _save(obj: object, path: AnyPath):
fileobj = io.BytesIO()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried 2-stage writing might affect performance in the baseline scenario (no cloud).
Can we make it optional as well?

@jrocmar
Copy link
Author

jrocmar commented Apr 15, 2024

Looks good in general, thanks.

Non-optimized cloud upload is understandable, but I want to make sure that the baseline scenario does not regress in terms of performance (only worried about 2-stage torch.save) and usability (added cloudpathlib requirement for MCore).

Thanks for the review! Quick question: do you have a suite of benchmarks that you run to detect performance regressions? If not, I can run some of my own benchmarks and report back.

@mikolajblaz
Copy link
Contributor

Looks good in general, thanks.
Non-optimized cloud upload is understandable, but I want to make sure that the baseline scenario does not regress in terms of performance (only worried about 2-stage torch.save) and usability (added cloudpathlib requirement for MCore).

Thanks for the review! Quick question: do you have a suite of benchmarks that you run to detect performance regressions? If not, I can run some of my own benchmarks and report back.

Please run your own benchmarks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants