Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compile time] AOT Autograd is taking long time in tracing #125977

Open
anijain2305 opened this issue May 10, 2024 · 7 comments
Open

[compile time] AOT Autograd is taking long time in tracing #125977

anijain2305 opened this issue May 10, 2024 · 7 comments
Assignees
Labels
high priority oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@anijain2305
Copy link
Contributor

anijain2305 commented May 10, 2024

馃悰 Describe the bug

profile - https://fburl.com/scuba/pyperf_experimental/on_demand/ghc8lpjr
xref - https://fb.workplace.com/groups/1075192433118967/permalink/1425136311457909/
Repro - D57090987

AOT Autograd is taking large amount of time. This is the rough breakdown

Finished loading.
Ran eager model.
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] TorchDynamo compilation metrics:
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] Function                                  Runtimes (s)
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] --------------------------------------  --------------
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] _compile.<locals>.compile_inner               962.355
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] OutputGraph.call_user_compiler                923.198
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] create_aot_dispatcher_function                911.898
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] compile_fx.<locals>.fw_compiler_base          377.886
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] compile_fx_inner                              375.881
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] GraphLowering.run                             186.831
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] GraphLowering.compile_to_module               129.836
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] Scheduler.__init__                             60.484
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] Scheduler.codegen                              33.0063
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] WrapperCodeGen.generate                        27.8892
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] cudagraphify                                    0.0035
I0510 16:25:43.544000 139884208832960 torch/_dynamo/utils.py:331] CachingAutotuner.benchmark_all_configs          3.9676

create_aot_dispatcher_function is taking 900 seconds, while compile_fx_inner is taking 375 seconds. So, around 500 seconds are spent in AOT Autograd.

Generator strobelight profile:

use
TORCH_COMPILE_STROBELIGHT=TRUE buck2 run ...
for more info how to navigate the profile see
https://fb.workplace.com/groups/257735836456307/posts/669969978566222

Error logs

No response

Minified repro

No response

Versions

N/A

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @chauhang

@ezyang
Copy link
Contributor

ezyang commented May 11, 2024

This one high prio?

@anijain2305 anijain2305 changed the title [compile time] AOT Autograd is taking log time in tracing [compile time] AOT Autograd is taking long time in tracing May 11, 2024
@anijain2305
Copy link
Contributor Author

Making it high priority, mostly to get an owner to take a look and check for low hanging fruits if any.

@bdhirsh bdhirsh self-assigned this May 14, 2024
@bdhirsh bdhirsh added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels May 14, 2024
@bdhirsh
Copy link
Contributor

bdhirsh commented May 21, 2024

I ran the repro locally, switching out the backend for aot_eager_decomp_partition to only include non-inductor bits.

I found:

(1) the total time was 390s. 132 of that was coming from _extract_graph_with_inputs_outputs.

(2) I noticed that @aorenste recently had a PR to speed up this function due to inefficient lookups of graph inputs (PR), that hadn't made it into fbcode yet.

When I patch in that PR, the total drops to 251s. Still not great, but a >50% speedup.

I now see:

  • 74s in min_cut_rematerialization_partition

  • 129s in FunctionalTensorMode.__torch_dispatch__. This also includes everything below functionalization (fake tensor compute), although functionalization is definitely not fast here.

@ezyang
Copy link
Contributor

ezyang commented May 21, 2024

Cc @Chillee partitioner

@bdhirsh
Copy link
Contributor

bdhirsh commented May 21, 2024

Trying to think more about where the low-hanging fruit in FunctionalTensor.__torch_dispatch__ is:

(1) 6s: I tried making return_and_correct_aliasing a no-op (here), and it shaved ~6 seconds off the repro (251s -> 245s). It's possible that @jbschlosser 's changes here #126552 will help (the fn shouldn't have to do a lot of work, just returnining the right argument and swapping storages for view ops).

(2) 21s: removing custom size dispatch on FunctionalTensor here. Right now, every call to .shape() goes through torch_dispatch of functional tensor, then FunctionalTensorWrapper, then FakeTensor. I don't... think this is actually necessary? When I changed it to None, I get a drop from 245s -> 224s.

(3) ~4s. When I print out any other metadata calls that are getting plumbed through FunctionalTensorMode here, I see many calls to prim.device.default. Looking at the stacktrace from one of them, it looks like every time we construct a fresh FunctionalStorageWrapper here, the inner tensor is a FakeTensor so we go through the full set of mode dispatch. I got the 4s just from hardcoding that value to Device(DeviceType::CUDA).

I'm not sure if (3) is worth the work immediately, but (2) definitely seems worth attempting to fix (PR incoming)

@Chillee
Copy link
Contributor

Chillee commented May 21, 2024

How do I look at the SVG for this?

@laithsakka
Copy link
Contributor

How do I look at the SVG for this?

curious is there any thing specfic about SVG that makes it easier to navigate than strobelight
graph profile ( https://fburl.com/scuba/pyperf_experimental/on_demand/n5umzh5x)
, or icicle profile ( https://fburl.com/scuba/pyperf_experimental/on_demand/ghc8lpjr) ?

also:
you can get a SVG like visualization by clicking on visualize from the graph profiler view
https://fburl.com/scuba/pyperf_experimental/on_demand/04q2l2de

Screenshot 2024-05-22 at 10 10 44鈥疉M

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants