Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phi-3 error on Kaggle: RuntimeError: Internal Triton PTX codegen error #403

Closed
thewebscraping opened this issue Apr 30, 2024 · 5 comments
Closed

Comments

@thewebscraping
Copy link

Hi,
Can you support notebook on Kaggle?
P100 not working.
GPU T4 error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[15], line 1
----> 1 trainer_stats = trainer.train()

File /opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:361, in SFTTrainer.train(self, *args, **kwargs)
    358 if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
    359     self.model = self._trl_activate_neftune(self.model)
--> 361 output = super().train(*args, **kwargs)
    363 # After training we make sure to retrieve back the original forward pass method
    364 # for the embedding layer by removing the forward post hook.
    365 if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:1780, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1778         hf_hub_utils.enable_progress_bars()
   1779 else:
-> 1780     return inner_training_loop(
   1781         args=args,
   1782         resume_from_checkpoint=resume_from_checkpoint,
   1783         trial=trial,
   1784         ignore_keys_for_eval=ignore_keys_for_eval,
   1785     )

File <string>:355, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:3036, in Trainer.training_step(self, model, inputs)
   3033     return loss_mb.reduce_mean().detach().to(self.args.device)
   3035 with self.compute_loss_context_manager():
-> 3036     loss = self.compute_loss(model, inputs)
   3038 if self.args.n_gpu > 1:
   3039     loss = loss.mean()  # mean() to average on multi-gpu parallel training

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:3059, in Trainer.compute_loss(self, model, inputs, return_outputs)
   3057 else:
   3058     labels = None
-> 3059 outputs = model(**inputs)
   3060 # Save past state if it exists
   3061 # TODO: this needs to be fixed and made cleaner later.
   3062 if self.args.past_index >= 0:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py:825, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    824 def forward(*args, **kwargs):
--> 825     return model_forward(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py:813, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    812 def __call__(self, *args, **kwargs):
--> 813     return convert_to_fp32(self.model_forward(*args, **kwargs))

File /opt/conda/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     13 @functools.wraps(func)
     14 def decorate_autocast(*args, **kwargs):
     15     with autocast_instance:
---> 16         return func(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py:825, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    824 def forward(*args, **kwargs):
--> 825     return model_forward(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py:813, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    812 def __call__(self, *args, **kwargs):
--> 813     return convert_to_fp32(self.model_forward(*args, **kwargs))

File /opt/conda/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     13 @functools.wraps(func)
     14 def decorate_autocast(*args, **kwargs):
     15     with autocast_instance:
---> 16         return func(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/unsloth/models/llama.py:882, in PeftModelForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
    869 def PeftModelForCausalLM_fast_forward(
    870     self,
    871     input_ids=None,
   (...)
    880     **kwargs,
    881 ):
--> 882     return self.base_model(
    883         input_ids=input_ids,
    884         causal_mask=causal_mask,
    885         attention_mask=attention_mask,
    886         inputs_embeds=inputs_embeds,
    887         labels=labels,
    888         output_attentions=output_attentions,
    889         output_hidden_states=output_hidden_states,
    890         return_dict=return_dict,
    891         **kwargs,
    892     )

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File /opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:161, in BaseTuner.forward(self, *args, **kwargs)
    160 def forward(self, *args: Any, **kwargs: Any):
--> 161     return self.model.forward(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/unsloth/models/mistral.py:213, in MistralForCausalLM_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
    205     outputs = LlamaModel_fast_forward_inference(
    206         self,
    207         input_ids,
   (...)
    210         attention_mask = attention_mask,
    211     )
    212 else:
--> 213     outputs = self.model(
    214         input_ids=input_ids,
    215         causal_mask=causal_mask,
    216         attention_mask=attention_mask,
    217         position_ids=position_ids,
    218         past_key_values=past_key_values,
    219         inputs_embeds=inputs_embeds,
    220         use_cache=use_cache,
    221         output_attentions=output_attentions,
    222         output_hidden_states=output_hidden_states,
    223         return_dict=return_dict,
    224     )
    225 pass
    227 hidden_states = outputs[0]

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/unsloth/models/llama.py:650, in LlamaModel_fast_forward(self, input_ids, causal_mask, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, *args, **kwargs)
    647 past_key_value = past_key_values[idx] if past_key_values is not None else None
    649 if offloaded_gradient_checkpointing:
--> 650     hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply(
    651         decoder_layer,
    652         hidden_states,
    653         causal_mask,
    654         attention_mask,
    655         position_ids,
    656         past_key_values,
    657         output_attentions,
    658         use_cache,
    659     )
    661 elif gradient_checkpointing:
    662     def create_custom_forward(module):

File /opt/conda/lib/python3.10/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, **kwargs)
    595 if not torch._C._are_functorch_transforms_active():
    596     # See NOTE: [functorch vjp and autograd interaction]
    597     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 598     return super().apply(*args, **kwargs)  # type: ignore[misc]
    600 if not is_setup_ctx_defined:
    601     raise RuntimeError(
    602         "In order to use an autograd.Function with functorch transforms "
    603         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    604         "staticmethod. For more details, please see "
    605         "[https://pytorch.org/docs/master/notes/extending.func.html](https://pytorch.org/docs/master/notes/extending.func.html%3C/span%3E%3Cspan) style="color:rgb(175,0,0)">"
    606     )

File /opt/conda/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py:115, in custom_fwd.<locals>.decorate_fwd(*args, **kwargs)
    113 if cast_inputs is None:
    114     args[0]._fwd_used_autocast = torch.is_autocast_enabled()
--> 115     return fwd(*args, **kwargs)
    116 else:
    117     autocast_context = torch.is_autocast_enabled()

File /opt/conda/lib/python3.10/site-packages/unsloth/models/_utils.py:333, in Unsloth_Offloaded_Gradient_Checkpointer.forward(ctx, forward_function, hidden_states, *args)
    331 saved_hidden_states = hidden_states.to("cpu", non_blocking = True)
    332 with torch.no_grad():
--> 333     (output,) = forward_function(hidden_states, *args)
    334 ctx.save_for_backward(saved_hidden_states)
    335 ctx.forward_function = forward_function

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/unsloth/models/llama.py:432, in LlamaDecoderLayer_fast_forward(self, hidden_states, causal_mask, attention_mask, position_ids, past_key_value, output_attentions, use_cache, padding_mask, *args, **kwargs)
    430 else:
    431     residual = hidden_states
--> 432     hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states)
    433     hidden_states, self_attn_weights, present_key_value = self.self_attn(
    434         hidden_states=hidden_states,
    435         causal_mask=causal_mask,
   (...)
    441         padding_mask=padding_mask,
    442     )
    443     hidden_states = residual + hidden_states

File /opt/conda/lib/python3.10/site-packages/unsloth/kernels/rms_layernorm.py:190, in fast_rms_layernorm(layernorm, X, gemma)
    188 W   = layernorm.weight
    189 eps = layernorm.variance_epsilon
--> 190 out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
    191 return out

File /opt/conda/lib/python3.10/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, **kwargs)
    595 if not torch._C._are_functorch_transforms_active():
    596     # See NOTE: [functorch vjp and autograd interaction]
    597     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 598     return super().apply(*args, **kwargs)  # type: ignore[misc]
    600 if not is_setup_ctx_defined:
    601     raise RuntimeError(
    602         "In order to use an autograd.Function with functorch transforms "
    603         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    604         "staticmethod. For more details, please see "
    605         "https://pytorch.org/docs/master/notes/extending.func.html style="color:rgb(175,0,0)">"
    606     )

File /opt/conda/lib/python3.10/site-packages/unsloth/kernels/rms_layernorm.py:144, in Fast_RMS_Layernorm.forward(ctx, X, W, eps, gemma)
    141 r = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
    143 fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
--> 144 fx[(n_rows,)](
    145     Y, Y.stride(0),
    146     X, X.stride(0),
    147     W, W.stride(0),
    148     r, r.stride(0),
    149     n_cols, eps,
    150     BLOCK_SIZE = BLOCK_SIZE,
    151     num_warps  = num_warps,
    152 )
    153 ctx.eps = eps
    154 ctx.BLOCK_SIZE = BLOCK_SIZE

File /opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py:167, in KernelInterface.__getitem__.<locals>.<lambda>(*args, **kwargs)
    161 def __getitem__(self, grid) -> T:
    162     """
    163     A JIT function is launched with: fn[grid](*args, **kwargs).
    164     Hence JITFunction.__getitem__ returns a callable proxy that
    165     memorizes the grid.
    166     """
--> 167     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/triton/runtime/jit.py:416, in JITFunction.run(self, grid, warmup, *args, **kwargs)
    414     # compile the kernel
    415     src = ASTSource(self, signature, constants, configs[0])
--> 416     self.cache[device][key] = compile(
    417         src,
    418         target=target,
    419         options=options.__dict__,
    420     )
    422 kernel = self.cache[device][key]
    423 if not warmup:

File /opt/conda/lib/python3.10/site-packages/triton/compiler/compiler.py:193, in compile(src, target, options)
    191 module = src.make_ir(options)
    192 for ext, compile_ir in list(stages.items())[first_stage:]:
--> 193     next_module = compile_ir(module, metadata)
    194     metadata_group[f"{src.name}.{ext}"] = fn_cache_manager.put(next_module, f"{src.name}.{ext}")
    195     module = next_module

File /opt/conda/lib/python3.10/site-packages/triton/compiler/backends/cuda.py:201, in CUDABackend.add_stages.<locals>.<lambda>(src, metadata)
    199 stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
    200 stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability)
--> 201 stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability)

File /opt/conda/lib/python3.10/site-packages/triton/compiler/backends/cuda.py:194, in CUDABackend.make_cubin(src, metadata, opt, capability)
    192 metadata["name"] = get_kernel_name(src, pattern='// .globl')
    193 ptxas, _ = path_to_ptxas()
--> 194 return compile_ptx_to_cubin(src, ptxas, capability, opt.enable_fp_fusion)

RuntimeError: Internal Triton PTX codegen error: 
ptxas /tmp/compile-ptx-src-a191d3, line 100; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 100; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 102; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 102; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 104; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 104; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 106; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 106; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 108; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 108; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 110; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 110; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 112; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 112; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 114; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 114; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 116; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 116; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 118; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 118; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 120; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 120; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 122; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 122; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 124; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 124; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 126; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 126; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 128; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 128; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 130; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 130; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 316; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 316; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 318; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 318; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 320; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 320; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 322; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 322; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 324; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 324; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 326; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 326; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 328; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 328; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 330; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 330; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 332; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 332; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 334; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 334; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 336; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 336; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 338; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 338; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 340; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 340; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 342; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 342; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 344; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 344; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 346; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 346; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 350; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 350; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 354; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 354; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 358; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 358; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 362; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 362; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 366; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 366; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 370; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 370; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 374; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 374; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 378; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 378; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 382; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 382; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 386; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 386; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 390; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 390; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 394; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 394; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 398; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 398; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 402; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 402; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 406; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 406; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 410; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a191d3, line 410; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas fatal   : Ptx assembly aborted due to errors
@danielhanchen
Copy link
Contributor

@thewebscraping Oh yep I updated all Kaggle notebooks!! Change the install instructions to

%%capture
!pip install -U "xformers<0.0.26" --index-url https://download.pytorch.org/whl/cu121
!pip install "unsloth[kaggle-new] @ git+https://github.com/unslothai/unsloth.git"

# Temporary fix for https://github.com/huggingface/datasets/issues/6753
!pip install datasets==2.16.0 fsspec==2023.10.0 gcsfs==2023.10.0

import os
os.environ["WANDB_DISABLED"] = "true"

@difonjohaiv
Copy link

@danielhanchen the first solution works for me!!! Thanks bro

@danielhanchen
Copy link
Contributor

Great!

@TheGhoul21
Copy link

what is the actual problem? I'm receiving the same error on hugginface spaces but I can't use the same solution

@thewebscraping
Copy link
Author

@thewebscraping Oh yep I updated all Kaggle notebooks!! Change the install instructions to

%%capture
!pip install -U "xformers<0.0.26" --index-url https://download.pytorch.org/whl/cu121
!pip install "unsloth[kaggle-new] @ git+https://github.com/unslothai/unsloth.git"

# Temporary fix for https://github.com/huggingface/datasets/issues/6753
!pip install datasets==2.16.0 fsspec==2023.10.0 gcsfs==2023.10.0

import os
os.environ["WANDB_DISABLED"] = "true"

Thanks, it's working

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants