Currently the SKLearnClassifier only exposes a predict method. At the moment the raw probabilities are used in the predict method like:
# keras/keras/src/wrappers/sklearn_wrapper.py
class SKLBase(BaseEstimator):
...
def predict(self, X):
"""Predict using the model."""
from sklearn.utils.validation import check_is_fitted
check_is_fitted(self)
X = _validate_data(self, X, reset=False)
raw_output = self.model_.predict(X) # <----- here
return self._reverse_process_target(raw_output) # <----- transformed back
If a user wants to get probabilities out, they need to create a custom class like:
class MyKerasSKLearnClassifier(SKLearnClassifier):
def predict_proba(self, X):
if not hasattr(self, 'model_'):
raise RuntimeError("You must fit the model before calling predict_proba.")
return self.model_.predict(X) # returns the raw probas
Would you consider adding a predict_proba method to the SKLearnClassifier class:
class SKLearnClassifier(ClassifierMixin, SKLBase):
...
def predict_proba(self, X):
"""Predict class probabilities of the input samples X."""
from sklearn.utils.validation import check_is_fitted
check_is_fitted(self)
X = _validate_data(self, X, reset=False)
return self.model_.predict(X)
so that downstream implementations and users can rely on the default class imported from Keras, rather then creating a custom wrapping class. predict_proba is the common method in the scikit-learn library for getting the probabilities out of a classifier.
Comment From: divakaivan
If ok, I can open a PR :)
Comment From: abheesht17
@divakaivan - please go ahead! Do you think we can add a flag to predict instead of defining a new method predict_proba()? Like so: predict(..., return_probabilities=False)
Comment From: divakaivan
@divakaivan - please go ahead! Do you think we can add a flag to
predictinstead of defining a new methodpredict_proba()? Like so:predict(..., return_probabilities=False)
Thank you. I think it's better to create a predict_proba method as some sklearn functionalities and underlying implementations require it. Examples are:
which are commonly used, and if the user wants to use them they have to create a class on top of the SKLearnClassifier again.
Also, for classifiers in sklearn, it's more common to have predict_proba as a separate method.