Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix constant tagging in mps backend #3503

Closed
wants to merge 1 commit into from

Commits on May 7, 2024

  1. fix constant tagging in mps backend (pytorch#3503)

    Summary:
    
    Test with pytorch#3399 and this command passes 
    ```
    python -m examples.models.llama2.export_llama -kv --mps
    ```
    Without this diff, it will error out
    ```
    in _verify_exported_program_signature
        raise SpecViolationError(
    torch._export.verifier.SpecViolationError: Buffer output getitem_1 does not point to a buffer that exists.
    Dict of buffers that are mutated, in order: {'getitem_1': 'layers_0_attention_SDPA_kv_cache_k_cache', 'getitem': 'layers_0_attention_SDPA_kv_cache_v_cache', 'getitem_3': 'layers_1_attention_SDPA_kv_cache_k_cache', 'getitem_2': 'layers_1_attention_SDPA_kv_cache_v_cache', 'getitem_5': 'layers_2_attention_SDPA_kv_cache_k_cache', 'getitem_4': 'layers_2_attention_SDPA_kv_cache_v_cache', 'getitem_7': 'layers_3_attention_SDPA_kv_cache_k_cache', 'getitem_6': 'layers_3_attention_SDPA_kv_cache_v_cache', 'getitem_9': 'layers_4_attention_SDPA_kv_cache_k_cache', 'getitem_8': 'layers_4_attention_SDPA_kv_cache_v_cache'}
    Buffer nodes available: []
    ```
    The root cause is that by `is_parameter`, it tags all data including mutable buffers.
    
    Reviewed By: larryliu0820
    
    Differential Revision: D56941763
    cccclai authored and facebook-github-bot committed May 7, 2024
    Configuration menu
    Copy the full SHA
    25eae44 View commit details
    Browse the repository at this point in the history