Skip to content

Commit

Permalink
Fix UltraCMTask scoring range and align argilla imports (#201)
Browse files Browse the repository at this point in the history
* Override `to_argilla_dataset` in `UltraCMTask` to use 1-10 scores

* Align `argilla` imports across codebase
  • Loading branch information
alvarobartt committed Dec 27, 2023
1 parent 8d79860 commit 9835760
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 7 deletions.
3 changes: 1 addition & 2 deletions src/distilabel/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
from distilabel.tasks.prompt import Prompt

if TYPE_CHECKING:
from argilla import FeedbackDataset
from argilla.client.feedback.schemas.records import FeedbackRecord
from argilla import FeedbackDataset, FeedbackRecord


def get_template(template_name: str) -> str:
Expand Down
21 changes: 20 additions & 1 deletion src/distilabel/tasks/critique/ultracm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@

import re
from dataclasses import dataclass
from typing import Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional

from distilabel.tasks.base import get_template
from distilabel.tasks.critique.base import CritiqueTask, CritiqueTaskOutput
from distilabel.tasks.prompt import Prompt

if TYPE_CHECKING:
from argilla import FeedbackDataset

_ULTRACM_TEMPLATE = get_template("ultracm.jinja2")


Expand Down Expand Up @@ -52,3 +55,19 @@ def parse_output(self, output: str) -> CritiqueTaskOutput: # type: ignore
score=float(match.group(1)),
critique=match.group(2).strip(),
)

def to_argilla_dataset(
self,
dataset_row: Dict[str, Any],
generations_column: str = "generations",
score_column: str = "score",
critique_column: str = "critique",
score_values: Optional[List[int]] = None,
) -> "FeedbackDataset":
return super().to_argilla_dataset(
dataset_row=dataset_row,
generations_column=generations_column,
score_column=score_column,
critique_column=critique_column,
score_values=score_values or [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
)
3 changes: 1 addition & 2 deletions src/distilabel/tasks/text_generation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
import argilla as rg

if TYPE_CHECKING:
from argilla import FeedbackDataset
from argilla.client.feedback.schemas.records import FeedbackRecord
from argilla import FeedbackDataset, FeedbackRecord


@dataclass
Expand Down
3 changes: 1 addition & 2 deletions src/distilabel/tasks/text_generation/self_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
import argilla as rg

if TYPE_CHECKING:
from argilla import FeedbackDataset
from argilla.client.feedback.schemas.records import FeedbackRecord
from argilla import FeedbackDataset, FeedbackRecord

_SELF_INSTRUCT_TEMPLATE = get_template("self-instruct.jinja2")

Expand Down

0 comments on commit 9835760

Please sign in to comment.