Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUDA Graph and AOT Autograd support #1271

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

xwang233
Copy link
Contributor

@xwang233 xwang233 commented May 24, 2022

Add CUDA Graph support with --cuda-graph and AOT Autograd support with --aot-autograd to benchmark.py and train.py

The workflow for cuda graph in train.py might be a bit overcomplicated.

Related: #1244

@xwang233
Copy link
Contributor Author

I'm still working on extra benchmark and accuracy test for the new options at this moment.

@xwang233
Copy link
Contributor Author

@csarofeen
Copy link

FYI I intend to review (can't set myself as a reviewer)

@rwightman
Copy link
Collaborator

FYI I intend to review (can't set myself as a reviewer)

Seems I can't add you as a formal reviewer either, might require reviewer to be added as collaborator. Hmm, I thought only read-only access was needed...


if not args.distributed:
losses_m.update(loss.item(), input.size(0))

torch.cuda.synchronize()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to amortize CPU overhead by allowing it to run ahead during the previous step, I wanted to suggest not syncing on every step and instead recording the time at the beginning and end of the run and dividing by the batch size.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, this is typically a better approach and more representative of the user experience.

Copy link

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good, some comments. Not 100% sure if we want the complexity of cuda graphs in training as it won't necessarily give benefits unless running inference sized batch sizes per GPU (<16 is probably where you'll start seeing benefit).

Have you run some correctness testing with AOTAutograd with and without CUDA Graphs?

parser.add_argument('--cuda-graph', default=False, action='store_true',
help="Enable CUDA Graph support")
parser.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` but without `--torchscript`)")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't they change so AOTAutograd defaults to nvFuser now? Can you assert in the script to make sure torchscript and aotautograd are not on at the same time just to give a cleaner error message stating not to use the options at the same time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think torchdynamo with aot_autograd_speedup_strategy has nvfuser enabled by default now, but not memory_efficient_fusion? Not 100% sure.

Will add a mutually exclusive check on torchscript and aot_autograd.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, no worries, checking they're mutually exclusive would be great. thanks.


if not args.distributed:
losses_m.update(loss.item(), input.size(0))

torch.cuda.synchronize()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, this is typically a better approach and more representative of the user experience.

@@ -265,12 +279,28 @@ def _step():
for _ in range(self.num_warm_iter):
_step()

if self.cuda_graph:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments simply mentioning we need to sync the stream, do warm up iterations, then capture the graph.

@@ -367,7 +397,21 @@ def _step(detail=False):
for _ in range(self.num_warm_iter):
_step()

t_run_start = self.time_fn()
if self.cuda_graph:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA Graphs can be helpful in training, but since larger batches are often used there isn't typically a perf gain from using them. @rwightman do you think we should just leave cuda graphs for inference but not training? I don't suspect above something like batch size 8 you'd see much of a tangible benefit.

@@ -17,7 +17,7 @@
class ApexScaler:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rwightman do you mind if we remove the Apex variant of AMP and just use native?

delta_fwd = _step()
total_step += delta_fwd
if self.cuda_graph:
g.replay()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't 100% representative of inference with CUDA Graphs as we're not including copying the new inputs into the input buffer. I'm wondering if we want to confer with @mcarilli to see if we can come up with a reasonable async mechanism to copy the data in. I'm wondering if we want to just do something like have two CUDA Graphs going at the same time so we can async copy one input into one graph while running the other graph and just ping-pong back and forth.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be important to do if we're actually doing inference and wanting the highest perf possible. For benchmarking we could just keep it simple with running the same inputs over and over.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add inputs copying like static_inputs.copy_(self.example_inputs) before every replay.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That'll really slow down the benchmarks as it will be 100% serialized. I'm hoping @mcarilli could suggest a mechanism (though maybe somewhat complex) to be able to overlap it. I'm thinking if we really want having 2 CUDA Graphs around to ping pong between would probably work fine if we throw them in different streams. Not an absolute must, but wanted to make sure we mentioned this added complexity.

@@ -682,7 +745,9 @@ def main():
def train_one_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
loss_scaler=None, model_ema=None, mixup_fn=None,
cuda_graph=None, cg_stage=None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments about the newly added arguments. Explicitly what they are, and why they're important. You can reference the cuda graphs docs if that makes it easier.

loss = _step(input, target)
torch.cuda.current_stream().wait_stream(s)
return (input, target, loss)
elif cg_stage == 'capture':

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again some comments in here would be helpful.

s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(11):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason for 11? In benchmarking you're doing 3 warmups.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcarilli mentioned that cuda graph warmup with DDP needs 11 iterations instead of 3 in the cuda graph doc. https://pytorch.org/docs/stable/notes/cuda.html#id5 🤔

The benchmark script is for single thread without DDP but training script may have DDP enabled.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah, okay no problem. If Carilli says so, Carilli says so ;-)

@xwang233
Copy link
Contributor Author

xwang233 commented May 25, 2022

I have some ResNet50 AMP+channels-last training results with cuda graph and nvfuser that verified the training accuracy (loss, val acc) here https://gist.github.com/xwang233/f3b5b4818762b08d716f969899b6d263.

After 10 epochs,

V100x8, BS = 128

mode throughput eval top1
Eager 4183.51 43.13
Cuda graph 4141.28 42.66
Cuda graph + nvfuser 4180.31 42.74

A100x8, BS = 32

mode throughput eval top1
Eager 5665.04 59.296
Cuda graph 6630.37 59.2
Cuda graph + nvfuser 6672.33 59.274

TL;DR: the training accuracies are the same for eager mode, cuda graph, and cuda graph + nvfuser. Cuda graph keeps the training throughput the same at large batch size, but can get out-of-the-box improvements on small batch size. For example, in the results shown above, ResNet50 on A100x8 with batch size = 32 got training throughput improvements from 5600 -> 6600 img/s.

I'm also checking training accuracy with aot_autograd.

@xwang233
Copy link
Contributor Author

xwang233 commented May 26, 2022

ResNet50 FP32 training results with eager, cuda graph, TS+nvfuser, AOT_autograd+nvfuser https://gist.github.com/xwang233/d5136facb3361af54693081da346fd33

After 10 epochs,

A100x8, BS = 128

mode throughput eval top1
Eager 6499.87 38.21
Cuda graph 6608.18 43.19
TorchScript + nvfuser 6453.39 38.77
AOT_autograd + nvfuser 6887.96 38.33

A100x8, BS = 32

mode throughput eval top1
Eager 5130.81 59.50
Cuda graph 5833.69 57.36
TorchScript + nvfuser 4986.46 59.50
AOT_autograd + nvfuser 5228.49 59.39

V100x8, BS = 64

mode throughput eval top1
Eager 2573.18 51.29
Cuda graph 2653.16 53.00
TorchScript + nvfuser 2581.52 51.19
AOT_autograd + nvfuser 2687.08 51.24

@rwightman
Copy link
Collaborator

@csarofeen @kevinstephano @xwang233 putting a few comments down here that relate to the whole PR
One of the reasons I haven't put time into exploring the graph replay in train script up until now is that it was clear it's a a fair bit of very specific code that will make quite a mess of the train loop and setup code...

It's nice to see it together but not sure it's worth it just yet, it really needs to be pushed into a model / task wrapper. I had a plan to work it into the bits_and_tpu branch (https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits) that I've been using for PT XLA TPU training (to be merged some day to master). There are GPU (CUDA), XLA specific interfaces for device, distributed primitives, and optimizer / step updates... I need to further refine it to cover DeepSpeed though. I feel graph mode would make sense a state machine within the class wrapping optimizer/step (Updater).

I have to think if there's a way to have the graph code in this train script that better separates the extra code (even if it adds redundancy)...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants