You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
fromtransformersimportMixtralConfig, MixtralForCausalLM, AutoTokenizerimporttorch# Initializing a smaller version of Mixtral for faster executionconfiguration=MixtralConfig(
hidden_size=256,
intermediate_size=896,
num_hidden_layers=8,
num_attention_heads=8,
num_key_value_heads=8,
num_local_experts=4,
num_experts_per_tok=1,
)
model=MixtralForCausalLM(configuration)
tokenizer=AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
prompt="This is a test"tokenized=tokenizer(prompt, return_tensors="pt")
output=model(**tokenized, output_router_logits=True)
key_values=output.past_key_valueslogits=output.logitsnext_token_logits=logits[..., -1, :]
# Softmaxsoftmaxed=torch.nn.functional.softmax(next_token_logits, dim=-1)
# Samplesampled=torch.multinomial(softmaxed.squeeze(), num_samples=1)
ids=sampled.item()
attention_mask=torch.cat([tokenized["attention_mask"], torch.tensor([[1]])], dim=-1)
next_output=model(
torch.tensor([[ids]]),
attention_mask=attention_mask,
past_key_values=key_values,
output_router_logits=True
)
Expected behavior
It seems that this is the same underlying issue as in #29087 - I would expect past_key_values to work with output_router_logits.
So what happens?
Without past key values (and with multiple input ids) the all_router_logits has the proper sequence length, thus in load_balancing_loss_func this num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) correctly evaluates the number of hidden layers.
If past key values are used, all_router_logits has a sequence length of 1, but since the attention mask is still the whole sequence (from which the sequence_length is inferred) the hidden layers evaluate to a small value or 0, leading to the same error as in Mixtral inference breaks when output_router_logits=True #29087
Instead, I would like the load_balancing_loss_func to be able to deal with a case where the gate_logits passed are of shape [batch_size X 1, num_experts] instead of [batch_size X sequence_length, num_experts].
The text was updated successfully, but these errors were encountered:
Hey! The generate function is not supposed to work for training. That is why we don't test past key values and output router logits. Though it's actually not that incompatible (you could want to look at the distribution of the router logits during generation).
Do you want to open a PR for a fix?
System Info
transformers==4.40.2
Python 3.11.8
Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
It seems that this is the same underlying issue as in #29087 - I would expect
past_key_values
to work withoutput_router_logits
.So what happens?
all_router_logits
has the proper sequence length, thus inload_balancing_loss_func
thisnum_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
correctly evaluates the number of hidden layers.all_router_logits
has a sequence length of 1, but since the attention mask is still the whole sequence (from which thesequence_length
is inferred) the hidden layers evaluate to a small value or 0, leading to the same error as in Mixtral inference breaks whenoutput_router_logits=True
#29087Instead, I would like the
load_balancing_loss_func
to be able to deal with a case where thegate_logits
passed are of shape[batch_size X 1, num_experts]
instead of[batch_size X sequence_length, num_experts]
.The text was updated successfully, but these errors were encountered: