-
Notifications
You must be signed in to change notification settings - Fork 759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PaliGemma 4bit Quantization broken and Inference issues. #783
Comments
Usually the best way to debug these sorts of issues is to have a reference implementation (like the HF one) and the MLX implementation side-by-side. Then write some wrapper to step through both implementations layer by layer and compare the activations. Keep zooming that in until you find the first place that the activations are substantially different. That should pin-point the part of the model file that is mismatched. If you get to a point where the ops seems identical but are producing different results given the same inputs, then that is suggestive of a deeper issue in MLX. But in most cases the problem is some mis-match in config settings or operations in the high-level implementation. |
Thanks! I will do that and let you know. |
For box/seg: are you sure you use our extended tokenizer with total vocab of 256000 + 1024 + 128? For the « sorry… » it’s from safety tuning we had to do on the mix model. I don’t like it, and we tried hard to not over-trigger, but it’s not perfect. I think if you have some discrepancies in the model, this answer might appear more often than it should. I recommend doing what @awni said. Ideally comparing with our reference implementation in jax, but i believe the HF implementation was verified against it, so should be good too. about 4bit, i have no idea. |
Thanks a lot @lucasb-eyer! How do I get the extended tokenizer? I don't see anything standing out in the huggingface implementation. After investigating the model as @awni suggested I found that there was a big difference in vision model encoder activations between the two models, was over 100K. This was caused by using However, despite the lower activation difference between the two it still refuses to answer most natural language questions. |
What's that number mean? ~3 would be a large value for the max-abs-diff between two activations. Or is it like a sum over all the diffs? Note there are three GELU options in MLX:
|
It's the sum over all the diffs. Yes, I did. The orginal implementation of SigLip (paligemma's vision model) uses By |
I event copied the transformers GELU activation in numpy to compare but I get similar to the |
I'm not quite following the gelu story. But I think the safest call is to find the GELU implementation of the reference implementation (presumably the Jax code) and use that. You can check the MLX GELU implementations here to see if one matches. |
I did exactly that. Here are the implementations I tried. All of which are identical to the one used in transformer and JAX: class FastGELUActivation(nn.Module):
"""
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
def __call__(self, input: mx.array) -> mx.array:
return 0.5 * input * (1.0 + mx.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) class FastGELUActivation(nn.Module):
"""
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
def __call__(self, input: mx.array) -> mx.array:
return 0.5 * input * (1.0 + mx.tanh(np.sqrt(2 / np.pi) * (input + 0.044715 * (input ** 3)))) Transformers: https://github.com/huggingface/transformers/blob/3d7d3a87a0bf4d0bb9346beb9419b1d76b5b988f/src/transformers/activations.py#L81 |
Yet, sum of abs-diff is close around 2.39 and 3.77 on the vision path. And the model still refuses a lot. From the start till the first MLP everything is close to 0. A layer before that is 0.08. Here is my implementation of the MLP: class MLP(nn.Module):
def __init__(self, config: VisionConfig):
super().__init__()
self.activation_fn = FastGELUActivation()
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
def __call__(self, x: mx.array) -> mx.array:
x = self.fc1(x)
x = self.activation_fn(x)
x = self.fc2(x)
return x Here is the transformers implementation: class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act] # uses same FastGELUActivation
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states |
I'm not sure what I am I missing here. Let me go for a walk 🚶🏾♂️... |
Any luck getting to the bottom of this? FWIW it's expected there are some numerical differences in the MLX / PyTorch versions. Rather than looking at the sum (which is hard to reason about since it depends on the number of weights), maybe check something like a relative difference. Like |
Not yet, Yesterday, I tried using the huggingface VLM class in my implementation but that didn't change the results. Let me check the relative distance and let you know. |
@awni here are the results: Language Model (Embedding output)
Vision Model (Patch_embedding output):
Vision Model (Embeddings Layer output):
Vision Model (Encoder Layer 1 output):
Vision model (Post layerNorm output):
Multi-modal projector (Linear layer output):
For context here is the model architecture: PaliGemmaForConditionalGeneration(
(vision_tower): SiglipVisionModel(
(vision_model): SiglipVisionTransformer(
(embeddings): SiglipVisionEmbeddings(
(patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
(position_embedding): Embedding(256, 1152)
)
(encoder): SiglipEncoder(
(layers): ModuleList(
(0-26): 27 x SiglipEncoderLayer(
(self_attn): SiglipAttention(
(k_proj): Linear(in_features=1152, out_features=1152, bias=True)
(v_proj): Linear(in_features=1152, out_features=1152, bias=True)
(q_proj): Linear(in_features=1152, out_features=1152, bias=True)
(out_proj): Linear(in_features=1152, out_features=1152, bias=True)
)
(layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
(mlp): SiglipMLP(
(activation_fn): PytorchGELUTanh()
(fc1): Linear(in_features=1152, out_features=4304, bias=True)
(fc2): Linear(in_features=4304, out_features=1152, bias=True)
)
(layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
)
)
)
(post_layernorm): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
)
)
(multi_modal_projector): PaliGemmaMultiModalProjector(
(linear): Linear(in_features=1152, out_features=2048, bias=True)
)
(language_model): GemmaForCausalLM(
(model): GemmaModel(
(embed_tokens): Embedding(257216, 2048, padding_idx=0)
(layers): ModuleList(
(0-17): 18 x GemmaDecoderLayer(
(self_attn): GemmaSdpaAttention(
(q_proj): Linear(in_features=2048, out_features=2048, bias=False)
(k_proj): Linear(in_features=2048, out_features=256, bias=False)
(v_proj): Linear(in_features=2048, out_features=256, bias=False)
(o_proj): Linear(in_features=2048, out_features=2048, bias=False)
(rotary_emb): GemmaRotaryEmbedding()
)
(mlp): GemmaMLP(
(gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
(up_proj): Linear(in_features=2048, out_features=16384, bias=False)
(down_proj): Linear(in_features=16384, out_features=2048, bias=False)
(act_fn): PytorchGELUTanh()
)
(input_layernorm): GemmaRMSNorm()
(post_attention_layernorm): GemmaRMSNorm()
)
)
(norm): GemmaRMSNorm()
)
(lm_head): Linear(in_features=2048, out_features=257216, bias=False)
)
) |
What are the formulas for these?
Does that not suggest the problem is outside the model itself? |
def relative_diff(x1, x2):
assert x1.shape == x2.shape, "Matrices must have the same dimensions"
if x1.ndim > 2 or x2.ndim > 2:
x1 = x1.reshape(-1)
x2 = x2.reshape(-1)
print("Relative Distance (using norms):", (np.linalg.norm(x1 - x2) / np.linalg.norm(x1)).max())
print("Max Absolute Relative Difference:", (abs(x1 - x2) / abs(x1)).max())
print("Are Matrices Close (np.allclose):", np.allclose(x1,x2))
Yes, but where exactly? Because the language model is Gemma-2B and I double and triple checked it before and it works fine. Language Model (only) python -m mlx_vlm.generate --model google/paligemma-3b-pt-224 \
--prompt "Hi"
Prompt: Hi
I'm a very good friend of yours
==========
Prompt: 6.013 tokens-per-sec
Generation: 26.163 tokens-per-sec |
Ok, after some deeper debugging. I think the issue is in the multimodal feature merging and/or masking. I'll update you once I have it working. |
I did everything by the book but the model still doesn't behave propely. It seems like it behaves better only when using multimodal features from the transformers model. But that doesn't make sense because I have a 1:1 copy of that in MLX. Could you please give this a look: |
@awni this weird behaviour also happened with The only thing these have in common is that they are using F32 precision. |
@awni any thoughts? |
I couldn't say what the issue is.. I'll try to take a deeper look in the next few days. |
Thanks! Looking forward to it :) |
I just started digging in, but I think the problem may actually be in the mlx_vlm's LanguageModel implementation for PaliGemma. To demonstrate this, I replaced the mlx_vlm PaliGemma's LanguageModel with the corresponding implementation from Hugging Face Transformers in the code below. from huggingface_hub import login
import os
login(token=os.getenv('HF_TOKEN'))
model_id = "google/paligemma-3b-mix-224"
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
prompt = 'Caption: '
import mlx.core as mx
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
import glob
from huggingface_hub import snapshot_download
import json
from PIL import Image
import requests
import torch
import importlib
import numpy as np
def sanitize(weights):
sanitized_weights = {}
for k, v in weights.items():
if "patch_embedding.weight" in k:
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
else:
sanitized_weights[k] = v
return sanitized_weights
def load_model(model_id):
model_path = snapshot_download(
repo_id=model_id,
revision=None,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
],
)
with open(f"{model_path}/config.json", "r") as f:
config = json.load(f)
weights = {}
weight_files = glob.glob(str(f"{model_path}/*.safetensors"))
for wf in weight_files:
weights.update(mx.load(wf))
weights = sanitize(weights)
model_class = importlib.import_module(f"mlx_vlm.models.paligemma")
model_config = model_class.ModelConfig.from_dict(config)
model_config.vision_config = model_class.VisionConfig.from_dict(config["vision_config"])
model_config.text_config = model_class.TextConfig.from_dict(config["text_config"])
model = model_class.Model(model_config)
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
model.eval()
return model
model_mx = load_model(model_id)
processor = AutoProcessor.from_pretrained(model_id)
prompt_tokens = mx.array(processor.tokenizer.encode(prompt))
inputs = processor(prompt, Image.open(requests.get(img_url, stream=True).raw), return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"])
mask = mx.array(inputs["attention_mask"])
inputs_embeds = model_mx.language_model.model.embed_tokens(input_ids)
hidden_state, _, _ = model_mx.vision_tower(
pixel_values.transpose(0, 2, 3, 1).astype(inputs_embeds.dtype),
output_hidden_states=True,
)
image_features = hidden_state[None, :].astype(pixel_values.dtype)
image_features = model_mx.multi_modal_projector(image_features)
input_embeddings, final_attention_mask_4d = (
model_mx._prepare_inputs_for_multimodal(
image_features, inputs_embeds, input_ids, mask
)
)
# # `<<< mx language
# logits, cache = model_mx.language_model(
# inputs=input_ids,
# cache=None,
# inputs_embeds=input_embeddings,
# mask=final_attention_mask_4d,
# )
# # `>>> mx language
# `<<< hf language
model_hf = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()
final_attention_mask_4d = torch.from_numpy(np.array(final_attention_mask_4d, dtype=np.float32))
input_embeddings = torch.from_numpy(np.array(input_embeddings))
outputs = model_hf.language_model(
attention_mask=final_attention_mask_4d,
position_ids=None,
past_key_values=None,
inputs_embeds=input_embeddings,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
cache_position=None,
)
logits = outputs.logits
logits = mx.array(logits.detach().numpy())
# `>>> hf language
logits = logits[:, -1, :]
token = mx.argmax(logits, axis=-1)
print(token, processor.tokenizer.decode(token.tolist())) The modification immediately improves the output text quality. Output from the commented part of the above code (the mlx_vlm's version of LanguageMoel) is |
Thanks @JosefAlbers! Found the bug and fixed it :) |
@awni @lucasb-eyer it's fixed ✅ After my changes, I didn't update the gemma embedding scaling to all inputs (text and multimodal). It was only scaling text embeddings. That's why when I unit tested the language model it worked but failed with multimodal. |
@JosefAlbers could you share your X handle ? I want to tag you on the release :) |
Great, I'm just glad I could help! |
@awni I have PaliGemma working on MLX. In most cases works great.
But there 4 issues, I don't see in the transformers implementation:
The last two work fine on transformers.
The archictecture is correct and the model does work but it seems like MLX is either changing the precision or something that makes the model not behave 100% normal.
Resources:
The text was updated successfully, but these errors were encountered: