Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disables token sampling when temperature set to 0 #200

Closed
wants to merge 10 commits into from

Conversation

akelch11
Copy link

@akelch11 akelch11 commented Jan 22, 2024

Disable token sampling when temperature = 0, PR addressing Issue #197

Problem/Issue: This PR turns off token sampling in the next token chooser classes (NextTokenChooser, HeterogeneousNextTokenChooser) when initialized with a temperature parameter set to 0, as outlined in #197 . LLM output should be deterministic and use the Greedy token choice system, outputting tokens that have the highest log probability in the logit distribution.

Solution: This involves setting sampling and do_sample flags in the next token chooser classes (NextTokenChooser, HeterogeneousNextTokenChooser) to False when the temperature is set to 0, so thatGreedy token choosing is enabled, creating deterministic token choices/results. Also, there are some changes to input validation that previously didn't account/check for a 0 temperature parameter so they no longer treat this as invalid.

Testing: A new test calling the default CausalLLM with deterministic next token choosers initialized with temperature = 0 was added, checking that each generated token's log probability is the maximum from its distribution. As of 1/22, the changes have all server tests passing that do not involve authenticating with HuggingFace token/login to access Llama2. Assistance with setting the login information would be appreciated.

Fixes #(#197)

Before submitting

  • [] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Was this discussed/approved via a Github issue or the discord / slack channel? Please add a link
    to it if that's the case.: #(Disable sampling when temperature=0 #197)
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@akelch11
Copy link
Author

akelch11 commented Jan 22, 2024

@tgaddair Could a maintainer please approve the running of the test workflows?

Also, I'd want to go ahead and write a test in test_tokens.py and shows that output is deterministic, and matches the token with highest logits from the distribution? I see some tests in `server/tests/models/test_causal_lm.py that initialized models and generate tokens, but don't see an easy place to set the temperature parameters. Would anyone familiar with this be able to provide input to help write this test?

@tgaddair
Copy link
Contributor

Approved! :)

@akelch11
Copy link
Author

The failed tests have to do with not being able to log into HuggingFace and use the Llama model -- how would one go about setting up the tests to access the huggingface login token that the repo uses?

@akelch11 akelch11 mentioned this pull request Jan 23, 2024
3 tasks
@tgaddair
Copy link
Contributor

Hey @akelch11, apologies for the failing test. What I need to do is disable those particular tests when run from a forked repo. It's mostly a GitHub Action change. For now, though, we can ignore those tests, as the others will have run.

Copy link
Contributor

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Can you also update clients/python/lorax/types.py:91 to:

if v is not None and v < 0:
            raise ValidationError("`temperature` must be non-negative")

from lorax_server.utils.lora import AdapterBatchData
from lorax_server.pb import generate_pb2
from lorax_server.models.causal_lm import CausalLM, CausalLMBatch
from tests.models.test_causal_lm import default_causal_lm, default_causal_lm_batch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports from tests can be a little weird. The recommended thing to do here, since these are fixtures, would be to move them into conftest.py. Then you can pass them as arguments to your test functions without needing to import them (see usage of default_pb_parameters as an example).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, looks like the current failing test is failing due to this issue.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying, will move test there and refactor it

@@ -88,7 +88,7 @@ def valid_seed(cls, v):

@validator("temperature")
def valid_temp(cls, v):
if v is not None and v <= 0:
if v is not None and v < 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: please change validation error to read "temperature must be non-negative".

@tgaddair
Copy link
Contributor

@akelch11, server tests look good (just the expected permissions failures). Looks like one of the Python client tests needs to be updated to account for the new constraints.

@tgaddair
Copy link
Contributor

Hey @akelch11, are you able to fix the remaining client test?

@akelch11
Copy link
Author

akelch11 commented Jan 25, 2024 via email

@tgaddair
Copy link
Contributor

Thanks @akelch11! No rush :)

@prd-tuong-nguyen
Copy link

prd-tuong-nguyen commented May 13, 2024

@tgaddair any update on this PR? I really need this feature :( Do you have another way to disable sampling in current version>

@tgaddair
Copy link
Contributor

Hey @prd-tuong-nguyen, the contributor to this one went dark, but I can definitely pick this up and close it out.

For now, setting temperature to 1 (default) and keeping do_sample=False should make results deterministic.

@prd-tuong-nguyen
Copy link

@tgaddair cool, thank u

@tgaddair
Copy link
Contributor

@prd-tuong-nguyen this has now landed in #467.

@tgaddair tgaddair closed this May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants