Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PaliGemma 4bit Quantization broken and Inference issues. #783

Closed
Blaizzy opened this issue May 16, 2024 · 27 comments
Closed

PaliGemma 4bit Quantization broken and Inference issues. #783

Blaizzy opened this issue May 16, 2024 · 27 comments

Comments

@Blaizzy
Copy link
Contributor

Blaizzy commented May 16, 2024

@awni I have PaliGemma working on MLX. In most cases works great.

But there 4 issues, I don't see in the transformers implementation:

  • 4bit quantisation just doesn't work while I have other models with same arch that work fine.
  • The model refuses to answer most natural language questions: “Sorry, as a base VLM...”
  • I can't seem to get it to output bounding boxes and segmentation masks.

The last two work fine on transformers.

The archictecture is correct and the model does work but it seems like MLX is either changing the precision or something that makes the model not behave 100% normal.

Resources:

@awni
Copy link
Member

awni commented May 16, 2024

Usually the best way to debug these sorts of issues is to have a reference implementation (like the HF one) and the MLX implementation side-by-side.

Then write some wrapper to step through both implementations layer by layer and compare the activations. Keep zooming that in until you find the first place that the activations are substantially different. That should pin-point the part of the model file that is mismatched.

If you get to a point where the ops seems identical but are producing different results given the same inputs, then that is suggestive of a deeper issue in MLX. But in most cases the problem is some mis-match in config settings or operations in the high-level implementation.

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 16, 2024

Thanks!

I will do that and let you know.

@lucasb-eyer
Copy link

For box/seg: are you sure you use our extended tokenizer with total vocab of 256000 + 1024 + 128?

For the « sorry… » it’s from safety tuning we had to do on the mix model. I don’t like it, and we tried hard to not over-trigger, but it’s not perfect. I think if you have some discrepancies in the model, this answer might appear more often than it should. I recommend doing what @awni said. Ideally comparing with our reference implementation in jax, but i believe the HF implementation was verified against it, so should be good too.

about 4bit, i have no idea.

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 17, 2024

Thanks a lot @lucasb-eyer!

How do I get the extended tokenizer? I don't see anything standing out in the huggingface implementation.

After investigating the model as @awni suggested I found that there was a big difference in vision model encoder activations between the two models, was over 100K. This was caused by using nn.Gelu(approx="fast") on MLX, when I changed to approx="precise" or implemented the transformers FastGELUActivation the difference came down to ~3, with precise having slightly lower scores overall. This is strange because, the original implementation uses approx="fast".

However, despite the lower activation difference between the two it still refuses to answer most natural language questions.

@awni
Copy link
Member

awni commented May 17, 2024

when I changed to approx="precise" or implemented the transformers FastGELUActivation the difference came down to ~3

What's that number mean? ~3 would be a large value for the max-abs-diff between two activations. Or is it like a sum over all the diffs?

Note there are three GELU options in MLX:

  1. none the full GELU using ERF
  2. precise a slightly faster approximation
  3. fast also slightly faster less good approximation

none and precise correspond to the options in PyTorch. I'm not entirely sure what is meant byapprox="fast" in your case but it might not match MLXs fast so that is good to double check.

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 17, 2024

It's the sum over all the diffs.
np.abs(ref_model_layer - target_model_layer).sum()

Yes, I did. The orginal implementation of SigLip (paligemma's vision model) uses fast and this setup works fine on a nearly identical model we currently support called NanoLlava, but for paligemma fast approximation creates big divergence the activations.

By approx='fast', I mean the MLX configuration.
nn.GELU(approx='fast')

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 17, 2024

I event copied the transformers GELU activation in numpy to compare but I get similar to theprecise approximation in MLX.

@awni
Copy link
Member

awni commented May 17, 2024

I'm not quite following the gelu story. But I think the safest call is to find the GELU implementation of the reference implementation (presumably the Jax code) and use that. You can check the MLX GELU implementations here to see if one matches.

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 17, 2024

I did exactly that.

Here are the implementations I tried. All of which are identical to the one used in transformer and JAX:

class FastGELUActivation(nn.Module):
    """
    Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
    """

    def __call__(self, input: mx.array) -> mx.array:
        return 0.5 * input * (1.0 + mx.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
class FastGELUActivation(nn.Module):
    """
    Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
    """

    def __call__(self, input: mx.array) -> mx.array:
        return 0.5 * input * (1.0 + mx.tanh(np.sqrt(2 / np.pi) * (input + 0.044715 * (input ** 3))))

Transformers: https://github.com/huggingface/transformers/blob/3d7d3a87a0bf4d0bb9346beb9419b1d76b5b988f/src/transformers/activations.py#L81
JAX: https://github.com/google/jax/blob/e93f36aa7c5cf329b517cd652777eb14ca35e8c0/jax/_src/nn/functions.py#L424

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 17, 2024

Yet, sum of abs-diff is close around 2.39 and 3.77 on the vision path. And the model still refuses a lot.

From the start till the first MLP everything is close to 0.
The only part in which they start to differ significantly are the MLP with around 0.15 on the first vision encoder layer.

A layer before that is 0.08.

Here is my implementation of the MLP:

class MLP(nn.Module):
    def __init__(self, config: VisionConfig):
        super().__init__()
        self.activation_fn = FastGELUActivation()
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)

    def __call__(self, x: mx.array) -> mx.array:
        x = self.fc1(x)
        x = self.activation_fn(x)
        x = self.fc2(x)
        return x

Here is the transformers implementation:

class SiglipMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.activation_fn = ACT2FN[config.hidden_act] # uses same FastGELUActivation
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 17, 2024

I'm not sure what I am I missing here.

Let me go for a walk 🚶🏾‍♂️...

@awni
Copy link
Member

awni commented May 18, 2024

Any luck getting to the bottom of this?

FWIW it's expected there are some numerical differences in the MLX / PyTorch versions. Rather than looking at the sum (which is hard to reason about since it depends on the number of weights), maybe check something like a relative difference. Like ((x - y).abs() / x.abs()).max() or some variation of that. It should be pretty small especially for float32..

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 18, 2024

Not yet,

Yesterday, I tried using the huggingface VLM class in my implementation but that didn't change the results.

Let me check the relative distance and let you know.

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 18, 2024

@awni here are the results:

Language Model (Embedding output)

Relative Distance (using norms): 0.0
Max Absolute Relative Difference: 0.0
Are Matrices Close (np.allclose): True

Vision Model (Patch_embedding output):

Relative Distance (using norms): 5.3614764e-07
Max Absolute Relative Difference: 0.0685524
Are Matrices Close (np.allclose): False

Vision Model (Embeddings Layer output):

Relative Distance (using norms): 2.7392096e-07
Max Absolute Relative Difference: 0.05940594
Are Matrices Close (np.allclose): False

Vision Model (Encoder Layer 1 output):

Layer 1
Relative Distance (using norms): 1.8757571e-06
Max Absolute Relative Difference: 6.3747888
Are Matrices Close (np.allclose): False

Layer 2
Relative Distance (using norms): 2.0738366e-06
Max Absolute Relative Difference: 3.7936087
Are Matrices Close (np.allclose): False

Layer 3
Relative Distance (using norms): 1.9333681e-06
Max Absolute Relative Difference: 3.062366
Are Matrices Close (np.allclose): False

Vision model (Post layerNorm output):

Relative Distance (using norms): 2.5038335e-05
Max Absolute Relative Difference: 1.6712433
Are Matrices Close (np.allclose): False

Multi-modal projector (Linear layer output):

Relative Distance (using norms): 1.9119347e-05
Max Absolute Relative Difference: 3.1223333
Are Matrices Close (np.allclose): False

For context here is the model architecture:

PaliGemmaForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(256, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-26): 27 x SiglipEncoderLayer(
            (self_attn): SiglipAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
              (activation_fn): PytorchGELUTanh()
              (fc1): Linear(in_features=1152, out_features=4304, bias=True)
              (fc2): Linear(in_features=4304, out_features=1152, bias=True)
            )
            (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
          )
        )
      )
      (post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
    )
  )
  (multi_modal_projector): PaliGemmaMultiModalProjector(
    (linear): Linear(in_features=1152, out_features=2048, bias=True)
  )
  (language_model): GemmaForCausalLM(
    (model): GemmaModel(
      (embed_tokens): Embedding(257216, 2048, padding_idx=0)
      (layers): ModuleList(
        (0-17): 18 x GemmaDecoderLayer(
          (self_attn): GemmaSdpaAttention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2048, out_features=256, bias=False)
            (v_proj): Linear(in_features=2048, out_features=256, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (rotary_emb): GemmaRotaryEmbedding()
          )
          (mlp): GemmaMLP(
            (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
            (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
            (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
            (act_fn): PytorchGELUTanh()
          )
          (input_layernorm): GemmaRMSNorm()
          (post_attention_layernorm): GemmaRMSNorm()
        )
      )
      (norm): GemmaRMSNorm()
    )
    (lm_head): Linear(in_features=2048, out_features=257216, bias=False)
  )
)

@awni
Copy link
Member

awni commented May 18, 2024

What are the formulas for these?

Relative Distance (using norms): 2.5038335e-05
Max Absolute Relative Difference: 1.6712433

Yesterday, I tried using the huggingface VLM class in my implementation but that didn't change the results.

Does that not suggest the problem is outside the model itself?

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 18, 2024

What are the formulas for these?

def relative_diff(x1, x2):
    assert x1.shape == x2.shape, "Matrices must have the same dimensions"

    if x1.ndim > 2 or x2.ndim > 2:
        x1 = x1.reshape(-1)
        x2 = x2.reshape(-1)
        
    print("Relative Distance (using norms):", (np.linalg.norm(x1 - x2) / np.linalg.norm(x1)).max())
    print("Max Absolute Relative Difference:", (abs(x1 - x2) / abs(x1)).max())
    print("Are Matrices Close (np.allclose):", np.allclose(x1,x2))

Does that not suggest the problem is outside the model itself?

Yes, but where exactly? Because the language model is Gemma-2B and I double and triple checked it before and it works fine.

Language Model (only)

 python -m mlx_vlm.generate --model google/paligemma-3b-pt-224 \ 
--prompt "Hi"

Prompt: Hi

I'm a very good friend of yours
==========
Prompt: 6.013 tokens-per-sec
Generation: 26.163 tokens-per-sec

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 18, 2024

Ok, after some deeper debugging.

I think the issue is in the multimodal feature merging and/or masking.

I'll update you once I have it working.

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 22, 2024

@awni @lucasb-eyer

I did everything by the book but the model still doesn't behave propely.

It seems like it behaves better only when using multimodal features from the transformers model. But that doesn't make sense because I have a 1:1 copy of that in MLX.

Could you please give this a look:
Blaizzy/mlx-vlm#24

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 22, 2024

@awni this weird behaviour also happened with Idefics2 in the past.

The only thing these have in common is that they are using F32 precision.

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 23, 2024

@awni any thoughts?

@awni
Copy link
Member

awni commented May 23, 2024

I couldn't say what the issue is.. I'll try to take a deeper look in the next few days.

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 23, 2024

Thanks! Looking forward to it :)

@JosefAlbers
Copy link
Contributor

I just started digging in, but I think the problem may actually be in the mlx_vlm's LanguageModel implementation for PaliGemma. To demonstrate this, I replaced the mlx_vlm PaliGemma's LanguageModel with the corresponding implementation from Hugging Face Transformers in the code below.

from huggingface_hub import login
import os
login(token=os.getenv('HF_TOKEN'))
model_id = "google/paligemma-3b-mix-224"
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
prompt = 'Caption: '

import mlx.core as mx
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
import glob
from huggingface_hub import snapshot_download
import json
from PIL import Image
import requests
import torch
import importlib
import numpy as np

def sanitize(weights):
    sanitized_weights = {}
    for k, v in weights.items():
        if "patch_embedding.weight" in k:
            sanitized_weights[k] = v.transpose(0, 2, 3, 1)
        else:
            sanitized_weights[k] = v
    return sanitized_weights

def load_model(model_id):
    model_path = snapshot_download(
        repo_id=model_id,
        revision=None,
        allow_patterns=[
            "*.json",
            "*.safetensors",
            "*.py",
            "tokenizer.model",
            "*.tiktoken",
            "*.txt",
        ],
    )
    with open(f"{model_path}/config.json", "r") as f:
        config = json.load(f)

    weights = {}
    weight_files = glob.glob(str(f"{model_path}/*.safetensors"))
    for wf in weight_files:
        weights.update(mx.load(wf))
    weights = sanitize(weights)

    model_class = importlib.import_module(f"mlx_vlm.models.paligemma")
    model_config = model_class.ModelConfig.from_dict(config)
    model_config.vision_config = model_class.VisionConfig.from_dict(config["vision_config"])
    model_config.text_config = model_class.TextConfig.from_dict(config["text_config"])
    model = model_class.Model(model_config)

    
    model.load_weights(list(weights.items()))
    mx.eval(model.parameters())
    model.eval()
    return model

model_mx = load_model(model_id)
processor = AutoProcessor.from_pretrained(model_id)

prompt_tokens = mx.array(processor.tokenizer.encode(prompt))

inputs = processor(prompt, Image.open(requests.get(img_url, stream=True).raw), return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"])
mask = mx.array(inputs["attention_mask"])

inputs_embeds = model_mx.language_model.model.embed_tokens(input_ids)
hidden_state, _, _ = model_mx.vision_tower(
    pixel_values.transpose(0, 2, 3, 1).astype(inputs_embeds.dtype),
    output_hidden_states=True,
)
image_features = hidden_state[None, :].astype(pixel_values.dtype)
image_features = model_mx.multi_modal_projector(image_features)

input_embeddings, final_attention_mask_4d = (
    model_mx._prepare_inputs_for_multimodal(
        image_features, inputs_embeds, input_ids, mask
    )
)

# # `<<< mx language
# logits, cache = model_mx.language_model(
#     inputs=input_ids,
#     cache=None,
#     inputs_embeds=input_embeddings,
#     mask=final_attention_mask_4d,
# )
# # `>>> mx language

# `<<< hf language
model_hf = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()
final_attention_mask_4d = torch.from_numpy(np.array(final_attention_mask_4d, dtype=np.float32))
input_embeddings = torch.from_numpy(np.array(input_embeddings))
outputs = model_hf.language_model(
            attention_mask=final_attention_mask_4d,
            position_ids=None,
            past_key_values=None,
            inputs_embeds=input_embeddings,
            use_cache=False,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True,
            cache_position=None,
        )
logits = outputs.logits
logits = mx.array(logits.detach().numpy())
# `>>> hf language

logits = logits[:, -1, :]
token = mx.argmax(logits, axis=-1)
print(token, processor.tokenizer.decode(token.tolist()))

The modification immediately improves the output text quality. Output from the commented part of the above code (the mlx_vlm's version of LanguageMoel) is array([12156], dtype=uint32) Sorry (the full output from the mlx_vlm is "Sorry, as a base VLM I am not trained to answer this question.") while that from the huggingface's one is array([886], dtype=uint32) In (the full output from huggingface's PaliGemma is "In this image we can see a car on the road. In the background there is a wall, door, trees and sky.")

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 24, 2024

Thanks @JosefAlbers!

Found the bug and fixed it :)

@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 24, 2024

@awni @lucasb-eyer it's fixed ✅

After my changes, I didn't update the gemma embedding scaling to all inputs (text and multimodal). It was only scaling text embeddings.

That's why when I unit tested the language model it worked but failed with multimodal.

@Blaizzy Blaizzy closed this as completed May 24, 2024
@Blaizzy
Copy link
Contributor Author

Blaizzy commented May 24, 2024

@JosefAlbers could you share your X handle ?

I want to tag you on the release :)

@JosefAlbers
Copy link
Contributor

Great, I'm just glad I could help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants