Skip to content

Commit

Permalink
Merge pull request #408 from GispoCoding/407-evaluationvalidation-met…
Browse files Browse the repository at this point in the history
…rics-invalid-when-using-cli-tools

Fix invalid scoring for classifier metrics
  • Loading branch information
nmaarnio committed Jun 18, 2024
2 parents 8b8fccb + 40f7c1f commit 813362c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion eis_toolkit/evaluation/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ def score_predictions(
def _score_predictions(
y_true: Union[np.ndarray, pd.Series], y_pred: Union[np.ndarray, pd.Series], metric: str
) -> Number:
num_classes = len(np.unique(y_true))

# Multiclass classification
if len(y_true) > 2:
if num_classes > 2:
average_method = "micro"
# Binary classification
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/evaluation/scoring_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

X, y = make_classification(n_samples=200, n_features=20, n_informative=2, n_redundant=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
rf_model, history = random_forest_classifier_train(X_train, y_train)
rf_model, history = random_forest_classifier_train(X_train, y_train, random_state=42)
y_pred = predict_classifier(X_test, rf_model, include_probabilities=False)


Expand Down

0 comments on commit 813362c

Please sign in to comment.