Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why does cache=None produce different outputs? #88

Open
andsteing opened this issue Dec 15, 2023 · 0 comments · May be fixed by #89
Open

Why does cache=None produce different outputs? #88

andsteing opened this issue Dec 15, 2023 · 0 comments · May be fixed by #89

Comments

@andsteing
Copy link

When computing the log probabilities for

prompts = (
    'the sky is blue',
    'the sky is pink',
    'the sky is bacon',
)

I get very different values, depending on whether I use a cache=RotatingBufferCache(...) or cache=None:

use_cache=False:

[     1] <s>      -> [   272] ▁the    :   -7.41  0.06%
[   272] ▁the     -> [  7212] ▁sky    :   -0.98 37.41%
[  7212] ▁sky     -> [   349] ▁is     :   -0.19 82.40%
[   349] ▁is      -> [  5045] ▁blue   :   -0.81 44.68%

[     1] <s>      -> [   272] ▁the    :   -7.41  0.06%
[   272] ▁the     -> [  7212] ▁sky    :   -0.98 37.41%
[  7212] ▁sky     -> [   349] ▁is     :   -0.19 82.40%
[   349] ▁is      -> [ 12937] ▁pink   :   -2.45  8.59%

[     1] <s>      -> [   272] ▁the    :   -7.41  0.06%
[   272] ▁the     -> [  7212] ▁sky    :   -0.98 37.41%
[  7212] ▁sky     -> [   349] ▁is     :   -0.19 82.40%
[   349] ▁is      -> [   287] ▁b      :   -5.00  0.67%
[   287] ▁b       -> [ 10364] acon    :   -0.04 96.17%

use_cache=True:

[     1] <s>      -> [   272] ▁the    :   -9.24  0.01%
[   272] ▁the     -> [  7212] ▁sky    :   -7.37  0.06%
[  7212] ▁sky     -> [   349] ▁is     :   -1.16 31.46%
[   349] ▁is      -> [  5045] ▁blue   :   -2.39  9.13%

[     1] <s>      -> [   272] ▁the    :   -9.24  0.01%
[   272] ▁the     -> [  7212] ▁sky    :   -7.37  0.06%
[  7212] ▁sky     -> [   349] ▁is     :   -1.16 31.46%
[   349] ▁is      -> [ 12937] ▁pink   :   -4.82  0.81%

[     1] <s>      -> [   272] ▁the    :   -9.24  0.01%
[   272] ▁the     -> [  7212] ▁sky    :   -7.37  0.06%
[  7212] ▁sky     -> [   349] ▁is     :   -1.16 31.46%
[   349] ▁is      -> [   287] ▁b      :   -7.59  0.05%
[   287] ▁b       -> [ 10364] acon    :   -4.41  1.21%

The values without cache do not make any sense (the values with cache seem reasonable though).

Why is this? How can I use the model without cache?

Full code is in this Colab: https://colab.research.google.com/drive/1lNk_JgFFAakTRtEVkpxQ42jlGCygwfSb

Show code from Colab
def get_logprobs(model, tokenizer, prompts, *, use_cache):
  """Returns `(encoded_prompts, logprobs)`, optionally using the cache."""

  encoded_prompts = [tokenizer.encode(prompt, bos=True) for prompt in prompts[:3]]
  seqlens = [len(x) for x in encoded_prompts]
  concatenated_prompts = torch.tensor(sum(encoded_prompts, []), device=model.device, dtype=torch.long)

  if use_cache:
    sliding_window = model.args.sliding_window
    sliding_window = min(max(seqlens), sliding_window)

    cache = mistral.cache.RotatingBufferCache(
        model.args.n_layers,
        model.args.max_batch_size,
        sliding_window,
        model.args.n_kv_heads,
        model.args.head_dim,
    )
    cache.to(device=model.device, dtype=model.dtype)
    cache.reset()
  else:
    cache = None

  prelogits = model.forward(
      concatenated_prompts,
      seqlens=seqlens,
      cache=cache,
  )

  logits = torch.log_softmax(prelogits, dim=-1)
  logprobs = [[] for _ in range(len(prompts))]
  offset = 0
  for i_seq, sequence in enumerate(encoded_prompts):
    logprobs[i_seq].extend([logits[offset + i, sequence[i + 1]].item() for i in range(len(sequence) - 1)])
    offset += len(sequence)

  return encoded_prompts, logprobs


def print_logprobs(id2token, encoded_prompts, logprobs):
  """prints `(encoded_prompts, logprobs)` tokens / transition probabilities."""
  for i, t in enumerate(encoded_prompts):
    for j, (t1, t2) in enumerate(zip(t, t[1:])):
      logit = float(logprobs[i][j])
      print(
          f'[{t1:6}] {id2token(t1):8} '
          f'-> [{t2:6}] {id2token(t2):8}: '
          f'{logit:7.2f} '
          f'{np.exp(logit):6.2%}'
      )
    print()


prompts = (
    'the sky is blue',
    'the sky is pink',
    'the sky is bacon',
)

for use_cache in (False, True):
  print(f'use_cache={use_cache}:\n')
  print_logprobs(tokenizer._model.id_to_piece, *get_logprobs(model, tokenizer, prompts, use_cache=use_cache))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant