Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtral past_key_values and output_router_logits incompatible #30731

Open
2 of 4 tasks
sorgfresser opened this issue May 9, 2024 · 1 comment
Open
2 of 4 tasks

Mixtral past_key_values and output_router_logits incompatible #30731

sorgfresser opened this issue May 9, 2024 · 1 comment
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!

Comments

@sorgfresser
Copy link
Contributor

System Info

transformers==4.40.2
Python 3.11.8

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import MixtralConfig, MixtralForCausalLM, AutoTokenizer
import torch
# Initializing a smaller version of Mixtral for faster execution
configuration = MixtralConfig(
    hidden_size=256,
    intermediate_size=896,
    num_hidden_layers=8,
    num_attention_heads=8,
    num_key_value_heads=8,
    num_local_experts=4,
    num_experts_per_tok=1,
)

model = MixtralForCausalLM(configuration)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
prompt = "This is a test"
tokenized = tokenizer(prompt, return_tensors="pt")
output = model(**tokenized, output_router_logits=True)
key_values = output.past_key_values
logits = output.logits
next_token_logits = logits[..., -1, :]
# Softmax
softmaxed = torch.nn.functional.softmax(next_token_logits, dim=-1)
# Sample
sampled = torch.multinomial(softmaxed.squeeze(), num_samples=1)
ids = sampled.item()

attention_mask = torch.cat([tokenized["attention_mask"], torch.tensor([[1]])], dim=-1)
next_output = model(
    torch.tensor([[ids]]),
    attention_mask=attention_mask,
    past_key_values=key_values,
    output_router_logits=True
)

Expected behavior

It seems that this is the same underlying issue as in #29087 - I would expect past_key_values to work with output_router_logits.
So what happens?

  1. Without past key values (and with multiple input ids) the all_router_logits has the proper sequence length, thus in load_balancing_loss_func this num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) correctly evaluates the number of hidden layers.
  2. If past key values are used, all_router_logits has a sequence length of 1, but since the attention mask is still the whole sequence (from which the sequence_length is inferred) the hidden layers evaluate to a small value or 0, leading to the same error as in Mixtral inference breaks when output_router_logits=True #29087

Instead, I would like the load_balancing_loss_func to be able to deal with a case where the gate_logits passed are of shape [batch_size X 1, num_experts] instead of [batch_size X sequence_length, num_experts].

@ArthurZucker
Copy link
Collaborator

Hey! The generate function is not supposed to work for training. That is why we don't test past key values and output router logits. Though it's actually not that incompatible (you could want to look at the distribution of the router logits during generation).
Do you want to open a PR for a fix?

@ArthurZucker ArthurZucker added the Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! label May 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!
Projects
None yet
Development

No branches or pull requests

2 participants