Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile uses customed trition kernel reports: RuntimeError: Inference tensors do not track version counter #125989

Open
arthursunbao opened this issue May 11, 2024 · 5 comments
Assignees
Labels
module: aotdispatch umbrella label for AOTAutograd issues module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, module: user triton related to ability to directly torch.compile triton kernels oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@arthursunbao
Copy link

arthursunbao commented May 11, 2024

🐛 Describe the bug

torch.compile uses customed trition kernel reports: RuntimeError: Inference tensors do not track version counter
I tried to use a trition layernorm kernel on ComfyUI SVD UNet to accelerate inference
the code is like this:

import torch
from contextlib import contextmanager
import comfy.model_management
import triton
import triton.language as tl

@triton.jit
def _layer_norm_fwd_fused(
        X,  # pointer to the input
        Y,  # pointer to the output
        W,  # pointer to the weights
        B,  # pointer to the biases
        Mean,  # pointer to the mean
        Rstd,  # pointer to the 1/std
        stride,  # how much to increase the pointer when moving by 1 row
        N,  # number of columns in X
        eps,  # epsilon to avoid division by zero
        BLOCK_SIZE: tl.constexpr,
):
    # Map the program id to the row of X and Y it should compute.
    row = tl.program_id(0)
    Y += row * stride
    X += row * stride
    # Compute mean
    mean = 0
    _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
        _mean += a
    mean = tl.sum(_mean, axis=0) / N
    # Compute variance
    _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
        x = tl.where(cols < N, x - mean, 0.)
        _var += x * x
    var = tl.sum(_var, axis=0) / N
    rstd = 1 / tl.sqrt(var + eps)
    # Write mean / rstd
    tl.store(Mean + row, mean)
    tl.store(Rstd + row, rstd)
    # Normalize and apply linear transformation
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        mask = cols < N
        w = tl.load(W + cols, mask=mask)
        b = tl.load(B + cols, mask=mask)
        x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
        x_hat = (x - mean) * rstd
        y = x_hat * w + b
        # Write output
        tl.store(Y + cols, y, mask=mask)


class disable_weight_init:
    
    class LayerNorm(torch.nn.LayerNorm):

        def trition_forward(self, x, normalized_shape, weight, bias, eps):
            # allocate output
            y = torch.empty_like(x)
            # reshape input data into 2D tensor
            x_arg = x.reshape(-1, x.shape[-1])
            M, N = x_arg.shape
            mean = torch.empty((M,), dtype=torch.float32, device='cuda')
            rstd = torch.empty((M,), dtype=torch.float32, device='cuda')
            # Less than 64KB per feature: enqueue fused kernel
            MAX_FUSED_SIZE = 65536 // x.element_size()
            BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
            if N > BLOCK_SIZE:
                raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
            # heuristics for number of warps
            num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
            # enqueue kernel
            _layer_norm_fwd_fused[(M,)](  #
                x_arg, y, weight, bias, mean, rstd,  #
                x_arg.stride(0), N, eps,  #
                BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
            return y

        def layernorm_forward_wrapper(self, input):
            return self.trition_forward(input, self.normalized_shape, self.weight, self.bias, self.eps)

        def forward(self, *args, **kwargs):
            return self.layernorm_forward_wrapper(*args, **kwargs)

Error logs

!!! Exception during processing !!!
Traceback (most recent call last):
File "/data/ComfyUI/execution.py", line 151, in recursive_execute
output_data, output_ui = get_output_data(obj, input_data_all)
File "/data/ComfyUI/execution.py", line 81, in get_output_data
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
File "/data/ComfyUI/execution.py", line 74, in map_node_over_list
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
File "/data/ComfyUI/nodes.py", line 1384, in sample
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
File "/data/ComfyUI/nodes.py", line 1350, in common_ksampler
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
File "/data/ComfyUI/custom_nodes/ComfyUI-AnimateDiff-Evolved/animatediff/sampling.py", line 248, in motion_sample
return orig_comfy_sample(model, noise, *args, **kwargs)
File "/data/ComfyUI/custom_nodes/ComfyUI-Impact-Pack/modules/impact/sample_error_enhancer.py", line 22, in informative_sample
raise e
File "/data/ComfyUI/custom_nodes/ComfyUI-Impact-Pack/modules/impact/sample_error_enhancer.py", line 9, in informative_sample
return original_sample(*args, **kwargs) # This code helps interpret error messages that occur within exceptions but does not have any impact on other operations.
File "/data/ComfyUI/comfy/sample.py", line 100, in sample
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
File "/data/ComfyUI/comfy/samplers.py", line 754, in sample
return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
File "/data/ComfyUI/comfy/samplers.py", line 659, in sample
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
File "/data/ComfyUI/comfy/samplers.py", line 598, in sample
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/data/ComfyUI/comfy/k_diffusion/sampling.py", line 580, in sample_dpmpp_2m
denoised = model(x, sigmas[i] * s_in, **extra_args)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/data/ComfyUI/comfy/samplers.py", line 336, in forward
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self.call_impl(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in call_impl
return forward_call(*args, **kwargs)
File "/data/ComfyUI/comfy/samplers.py", line 323, in forward
return self.apply_model(*args, **kwargs)
File "/data/ComfyUI/comfy/samplers.py", line 320, in apply_model
out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
File "/data/ComfyUI/comfy/samplers.py", line 292, in sampling_function
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond
, x, timestep, model_options)
File "/data/ComfyUI/comfy/samplers.py", line 266, in calc_cond_uncond_batch
output = model.apply_model(input_x, timestep
, **c).chunk(batch_chunks)
File "/data/ComfyUI/comfy/model_base.py", line 100, in apply_model
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 414, in _fn
return fn(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 986, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 827, in _convert_frame
result = inner_convert(
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 381, in _convert_frame_assert
return _compile(
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_utils_internal.py", line 70, in wrapper_function
return function(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 708, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 543, in compile_inner
out_code = transform_code_object(code, transform)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
transformations(instructions, code_options)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 172, in _fn
return fn(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 490, in transform
tracer.run()
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2234, in run
super().run()
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 884, in run
while self.step():
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 799, in step
self.dispatch_table[inst.opcode](self, inst)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2423, in RETURN_VALUE
self._return(inst)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2408, in _return
self.output.compile_subgraph(
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1108, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1300, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1391, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/dynamo/output_graph.py", line 1372, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/init.py", line 1747, in call
return compile_fx(model
, inputs
, config_patches=self.config)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1478, in compile_fx
return aot_autograd(
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 65, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 956, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 273, in time_wrapper
r = func(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 683, in create_aot_dispatcher_function
compiled_fn = compiler_fn(
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 567, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 769, in aot_wrapper_synthetic_base
return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 141, in aot_dispatch_base
fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( # type: ignore[misc]
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 134, in aot_dispatch_base_graph
fw_module = _create_graph(
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 43, in _create_graph
fx_g = make_fx(
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 1271, in wrapped
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 548, in _fn
return fn(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 653, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 548, in _fn
return fn(*args, **kwargs)
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 820, in trace
(self.create_arg(fn(*args)),),
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 671, in wrapped
out = f(*tensors)
File "", line 1, in
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 506, in _functionalized_f_helper
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 390, in init
self.prev_version = tensor._version
File "/data/miniconda3/envs/env-novelai/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 716, in torch_function
return func(*args, **kwargs)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Inference tensors do not track version counter.

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

Minified repro

minifier_launcher.txt

Versions

Collecting environment information...
PyTorch version: 2.4.0.dev20240510+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Tencent tlinux 2.2 (Final) (x86_64)
GCC version: (GCC) 11.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.17

Python version: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:36:39) [GCC 10.4.0] (64-bit runtime)
Python platform: Linux-5.4.241-1-tlinux4-0017.7-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA L40
Nvidia driver version: 525.125.06
cuDNN version: Probably one of the following:
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn.so.8.9.6
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.9.6
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.9.6
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.9.6
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.9.6
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.9.6
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 192
On-line CPU(s) list: 0-191
Thread(s) per core: 2
Core(s) per socket: 96
Socket(s): 1
NUMA node(s): 1
Vendor ID: AuthenticAMD
CPU family: 25
Model: 17
Model name: AMD EPYC 9K84 96-Core Processor
Stepping: 0
CPU MHz: 2600.012
BogoMIPS: 5200.02
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32K
L1i cache: 32K
L2 cache: 1024K
L3 cache: 32768K
NUMA node0 CPU(s): 0-191
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid amd_dcm tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext perfctr_core invpcid_single ibpb vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 avx512_bf16 clzero xsaveerptr wbnoinvd arat avx512vbmi umip avx512_vbmi2 vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm

Versions of relevant libraries:
[pip3] lion-pytorch==0.1.4
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.2
[pip3] onnx==1.14.1
[pip3] onnx-graphsurgeon==0.4.0
[pip3] onnxruntime==1.16.3
[pip3] onnxruntime-gpu==1.16.1
[pip3] open-clip-torch==2.20.0
[pip3] pytorch-lightning==1.9.4
[pip3] pytorch_optimizer==2.12.0
[pip3] pytorch-triton==3.0.0+45fff310c8
[pip3] torch==2.4.0.dev20240510+cu121
[pip3] torchaudio==2.3.0
[pip3] torchdiffeq==0.2.3
[pip3] torchmetrics==1.3.2
[pip3] torchsde==0.2.6
[pip3] torchvision==0.19.0.dev20240510+cu121
[pip3] triton==2.3.0
[conda] libopenvino-pytorch-frontend 2023.0.2 h59595ed_0 conda-forge
[conda] lion-pytorch 0.1.4 pypi_0 pypi
[conda] numpy 1.26.2 pypi_0 pypi
[conda] open-clip-torch 2.20.0 pypi_0 pypi
[conda] pytorch-lightning 1.9.4 pypi_0 pypi
[conda] pytorch-optimizer 2.12.0 pypi_0 pypi
[conda] pytorch-triton 3.0.0+45fff310c8 pypi_0 pypi
[conda] torch 2.4.0.dev20240510+cu121 pypi_0 pypi
[conda] torchaudio 2.3.0 pypi_0 pypi
[conda] torchdiffeq 0.2.3 pypi_0 pypi
[conda] torchmetrics 1.3.2 pypi_0 pypi
[conda] torchsde 0.2.6 pypi_0 pypi
[conda] torchvision 0.19.0.dev20240510+cu121 pypi_0 pypi
[conda] triton 2.3.0 pypi_0 pypi

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @zou3519 @oulgen @aakhundov

@Chillee
Copy link
Contributor

Chillee commented May 11, 2024

cc: @oulgen

@oulgen
Copy link
Contributor

oulgen commented May 11, 2024

@bdhirsh I think you had a workaround/solution for this?

@bdhirsh
Copy link
Contributor

bdhirsh commented May 13, 2024

@arthursunbao I think I fixed this a few weeks ago - I just tried your example locally and it runs successfully for me. Can you try it on a nightly?

@bdhirsh bdhirsh added module: aotdispatch umbrella label for AOTAutograd issues module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, module: user triton related to ability to directly torch.compile triton kernels labels May 13, 2024
@arthursunbao
Copy link
Author

@arthursunbao I think I fixed this a few weeks ago - I just tried your example locally and it runs successfully for me. Can you try it on a nightly?

I tried it on nightly, the version is PyTorch version: 2.4.0.dev20240510+cu121

Could you give a hint about what causes this problem and how to fix it manually? Because this piece of code runs in ComfyUI framework and there is a lot of codes which wrap this triton code and torch.compile

@bdhirsh bdhirsh self-assigned this May 14, 2024
@bdhirsh
Copy link
Contributor

bdhirsh commented May 14, 2024

Ah @arthursunbao this PR should fix the error, I just need to get around to fixing CI and landing it: #124489

@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: aotdispatch umbrella label for AOTAutograd issues module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, module: user triton related to ability to directly torch.compile triton kernels oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants