Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classification with LLMs: Open tasks #989

Open
9 tasks
BenjaminBossan opened this issue Jun 29, 2023 · 0 comments
Open
9 tasks

Classification with LLMs: Open tasks #989

BenjaminBossan opened this issue Jun 29, 2023 · 0 comments

Comments

@BenjaminBossan
Copy link
Collaborator

I wanted to create a single issue that collects TODOs around the LLM feature that was added recently, so that everything is in one place.

  • Have a "fast/greedy" option for predict - it is not necessary to calculate probabilities for all classes up until the last token. When class A has probability p_A and class B has p_B(t) < p_A at token t, then no matter what the probability for tokens >t, p_B cannot exceed p_A anymore.
  • A small use case where the classifiers are used as a transformer in a bigger pipeline, e.g. to extract structured knowledge from a text ("Does this product description contain the size of the item?")
  • A way to format the text/labels/few-shot samples before they're string-interpolated, maybe Jinja2?
  • Test if this works with a more diverse range of LLMs
  • Enable multi-label classification. Would probably require sigmoid instead of softmax and a (empirically determined?) threshold.
  • Check if it is possible to enable caching for encoder-decoder LLMs like flan-t5.
  • Sampling strategy for few-shot learning:
    Right now, the sampling is hard-coded and basically tries to add each label at least once. This seems reasonable but there are situations where other strategies could make sense. Therefore, I would like to see a feature that allows setting the sampling strategy as a parameter. Options that come to mind:
    • Stratified sampling: roughly what we have now, but not quite
    • Fully random sampling: sample regardless of label
    • Similarity-based sampling: use the current sample to find similar samples from the training data (maybe with a simple tfidf vector?)
    • Custom sampling: Allow users to pass a callable that performs the sampling
  • Fine-tuning:
    Instead of in-context learning via few-shot samples, as in FewShotClassifier, it can often be more performant (both from runtime and from scoring perspective) to fine-tune on the training data. We could consider using peft under the hood, which is agnostic with regard to the training framework, so it should work with skorch. This would be implemented in a separate class.
  • Refactor to use forward instead of generate:
    For this change, it is not clear if it is better than the existing implementation or not.
    Right now we rely on the generate methods for transformers models but we could instead use forward: by constructing the whole token sequence (input + label) and returning the corresponding logits, we can calculate the probabilities without having to go through logit processor + forcing. Some advantages are that the code could be simplified and it is more trivial to add batching (right now we predict one sample+label at a time, with this we could predict all labels in a single batch). Disadvantages are that generate does some heavy lifting for encoder-decoder, which we would have to reproduce, and that we lose caching. In practice, which approach is faster depends on many factors: batch size (memory!), length of input, length and overlap of labels, etc.
    Here is some sample code that I adopted from my colleague Joao that demonstrates this approach and returns the exact same probabilities as our existing approach:
model, tokenizer = ...
result = []
for x in X:
    result.append([])
    input_length = len(tokenizer(get_prompt(x))['input_ids'])
    for label in labels:
        # Build the prompts for all possible labels
        inputs = tokenizer(
            [get_prompt(x) + label],
            return_tensors="pt",
            padding=True,
        ).to(model.device)

        # Run the forward pass for all prompts, extract the logits
        # TODO fails for enc-dec
        # TODO allow setting batch size
        logits = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask).logits

        # Discard the logits the correspond to the input (remember: the logits at index N correspond to the token at
        # index N + 1. The first token has no logits, the last set of logits correspond to a token not present in the
        # input. We have to shift by 1.)
        logits = logits[:, input_length - 1:-1, :]
        # Compute the probabilities for each label. Here you need to be careful to remove the padding that may exist.
        probas = logits.softmax(dim=-1).cpu().numpy()
        label_ids = inputs.input_ids[:, input_length:]
        padding_mask = ~label_ids.eq(pad_token_id)
        label_token_probas = probas[0, np.arange(0, label_ids.shape[1], dtype=int), label_ids[0].cpu().numpy()]
        label_proba = label_token_probas.prod()
        result[-1].append((label, label_proba.item()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant