Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA D2 Brier Score #28971

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Conversation

OmarManzoor
Copy link
Contributor

Reference Issues/PRs

Closes #20943

What does this implement/fix? Explain your changes.

  • Adds the D2 Brier score which is the D2 score for brier_score_loss

Any other comments?

CC: @lorentzenchr

Copy link

github-actions bot commented May 7, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 5bc5037. Link to the linter CI: here

@OmarManzoor OmarManzoor changed the title D2 Brier Score FEA D2 Brier Score May 7, 2024
doc/modules/model_evaluation.rst Outdated Show resolved Hide resolved
},
prefer_skip_nested_validation=True,
)
def d2_brier_score(y_true, y_proba, *, sample_weight=None, pos_label=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def d2_brier_score(y_true, y_proba, *, sample_weight=None, pos_label=None):
def d2_brier_score(y_true, y_proba, *, sample_weight=None, labels=None):

This should be able to handle multiclass, too. See log_loss.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are we supposed to extract the pos_label to pass to brier_score_loss?

Comment on lines +3415 to +3419
y_true = column_or_1d(y_true)
positive_label = _get_positive_label_for_brier_score(y_true, pos_label)
weights = _check_sample_weight(sample_weight, y_true)
positive_prob = np.sum((y_true == positive_label) * weights) / np.sum(weights)
y_proba_ref = np.full(y_true.shape, positive_prob)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the same code as in d2_log_loss.

Comment on lines +3430 to +3442
def _get_positive_label_for_brier_score(y_true, pos_label=None):
try:
pos_label = _check_pos_label_consistency(pos_label, y_true)
except ValueError:
classes = np.unique(y_true)
if classes.dtype.kind not in ("O", "U", "S"):
# for backward compatibility, if classes are not string then
# `pos_label` will correspond to the greater label
pos_label = classes[-1]
else:
raise

return pos_label
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _get_positive_label_for_brier_score(y_true, pos_label=None):
try:
pos_label = _check_pos_label_consistency(pos_label, y_true)
except ValueError:
classes = np.unique(y_true)
if classes.dtype.kind not in ("O", "U", "S"):
# for backward compatibility, if classes are not string then
# `pos_label` will correspond to the greater label
pos_label = classes[-1]
else:
raise
return pos_label

sklearn/metrics/tests/test_classification.py Outdated Show resolved Hide resolved
@lorentzenchr
Copy link
Member

#22046 seems like a blocker

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add more D2 scores
2 participants