In addition to this example:
callback = keras.callbacks.EarlyStopping(monitor='val_loss')
Allow monitoring of multiple metrics, as in this example:
callback = keras.callbacks.EarlyStopping(monitor=['val_loss', 'val_accuracy', 'val_f1measure'])
This way, training should not stop while any of these metrics get better values, not just one of them.
Comment From: rlcauvin
One factor to consider when using multiple metrics for early stopping: how do you determine which epoch is "best" and should be used for restoring best weights (when restore_best_weights
is True
) after early stopping?
Comment From: fabriciorsf
The best epoch will be the one with the best val_loss
value, but patience should not consider just one metric, but rather a list of metrics passed in the monitor
parameter.
Another possibility would be to consider the best epoch by looking at the value of the first metric in the list passed to the monitor
parameter, in case of a tie, look at the second, and so on.
Thank you for your quick response.
Comment From: rlcauvin
Another consideration is that we want to minimize some metrics (e.g. val_loss
) but maximize other metrics (e.g. val_accuracy
).
It might be instructive to look at this implementation of a composite of early stopping callbacks. The iterable_condition
enables the caller to determine whether all, any, or some other combination of early stopping conditions must hold. It assumes the last early stopping callback is the "conductor", which means it determines which epoch is "best".
class CompoundEarlyStopping(keras.callbacks.Callback):
def __init__(
self,
callbacks: Iterable[keras.callbacks.Callback],
iterable_condition: Callable[Iterable, bool] = all):
super().__init__()
self.callbacks = callbacks
self.stopped_epoch = 0
self.iterable_condition = iterable_condition
def on_train_begin(
self,
logs: Dict = None):
for callback in self.callbacks:
callback.on_train_begin(logs)
def on_train_end(
self,
logs: Dict = None):
if self.model.stop_training:
conductor = next(reversed(self.callbacks))
conductor.stopped_epoch = self.stopped_epoch
conductor.on_train_end(logs)
def on_epoch_begin(
self,
epoch: int,
logs: Dict = None):
for callback in self.callbacks:
callback.on_epoch_begin(epoch, logs)
def on_epoch_end(
self,
epoch: int,
logs: Dict = None):
for callback in self.callbacks:
callback.on_epoch_end(epoch, logs)
if self.iterable_condition([callback.stopped_epoch >= max(1, epoch) for callback in self.callbacks]):
self.stopped_epoch = epoch
self.model.stop_training = True
else:
self.model.stop_training = False
def set_model(
self,
model: keras.Model):
super().set_model(model)
for callback in self.callbacks:
callback.set_model(model)
Here is an example of how I've used the CompoundEarlyStopping
class:
patience = 3
early_stopping_loss = keras.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min", verbose=1)
early_stopping_accuracy = keras.callbacks.EarlyStopping(
monitor="val_accuracy",
patience=patience,
mode="max",
verbose=1)
early_stopping_auc = keras.callbacks.EarlyStopping(
monitor="val_auc",
patience=patience,
mode="max",
restore_best_weights=True,
verbose=1)
early_stopping = CompoundEarlyStopping(
callbacks=[early_stopping_loss, early_stopping_accuracy, early_stopping_auc],
iterable_condition=all)
Comment From: SamanehSaadat
Hi @fabriciorsf
As @rlcauvin mentioned, you can achieve this by creating a custom callback. Closing this issue. Please feel free to re-open if that's not what you're looking for.
Comment From: fabriciorsf
I tried this solution, but I had some problems:
- Intermittently, at the end of the first epoch the
fit
method freezes before calculating val_metrics; - When the previous problem does not occur, this error occurs:
...
Epoch 5/100
23/23 ━━━━━━━━━━━━━━━━━━━━ 0s 98ms/step - loss: 1.3509 - sumsqeuc_dist: 1.3499
Epoch 5: saving model to ./saved_models/my_model_epoch_005.keras
23/23 ━━━━━━━━━━━━━━━━━━━━ 5s 185ms/step - loss: 1.3383 - sumsqeuc_dist: 1.3369 - val_loss: 0.7643 - val_sumsqeuc_dist: 0.7621
Traceback (most recent call last):
Traceback (most recent call last):
...
File "myscript.py", line 293, in train_autonem
history = self.my_model.fit(self.train_inputs,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "....../python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "..../mypackage/utils/keras_extensions.py", line 268, in on_train_end
conductor = self.callbacks[-1]
~~~~~~~~~~~~~~^^^^
TypeError: 'dict_values' object is not subscriptable
The class CompoundEarlyStopping
is in file keras_extensions.py
:
class CompoundEarlyStopping(keras.callbacks.Callback):
def __init__(
self, callbacks: Iterable[Callback],
iterable_condition: Callable[[], bool] = all):
super().__init__()
self.callbacks = callbacks
self.stopped_epoch = 0
self.iterable_condition = iterable_condition
def on_train_begin(self, logs: dict = None):
for callback in self.callbacks:
callback.on_train_begin(logs)
def on_train_end(self, logs: dict = None):
if self.model.stop_training:
conductor = self.callbacks[-1] ## intermittent ERROR here
conductor.stopped_epoch = self.stopped_epoch
conductor.on_train_end(logs)
def on_epoch_begin(self, epoch: int, logs: dict = None):
for callback in self.callbacks:
callback.on_epoch_begin(epoch, logs)
def on_epoch_end(self, epoch: int, logs: dict = None):
for callback in self.callbacks:
callback.on_epoch_end(epoch, logs)
if self.iterable_condition([callback.stopped_epoch >= max(1, epoch) \
for callback in self.callbacks]):
self.stopped_epoch = epoch
self.model.stop_training = True
else:
self.model.stop_training = False
def set_model(self, model: Model):
super().set_model(model)
for callback in self.callbacks:
callback.set_model(model)
Note: There is a error with Callable[Iterable, bool]
, so I replaced it with Callable[[], bool]
I instantiate the callbacks like this:
dict_early_stopping = {val_metric_name: EarlyStopping(monitor=val_metric_name, patience=5,
start_from_epoch=1, verbose=self.verbose) \
for val_metric_name in val_metrics_name}
early_stopping = CompoundEarlyStopping(
callbacks=dict_early_stopping.values(),
iterable_condition=all)
As a baseline, doing early_stopping = dict_early_stopping[val_metrics_name[0]]
works fine.
Comment From: rlcauvin
@fabriciorsf The original implementation of CompoundEarlyStopping.on_train_end
incorrectly assumed that self.callbacks
is subscriptable. In your code, dict_early_stopping.values()
does not produce a subscriptable result.
I have edited the CompoundEarlyStopping
code in my earlier comment to work with instances of Iterable
that are not subscriptable.
I changed
conductor = self.callbacks[-1]
to
conductor = next(reversed(self.callbacks))
Let us know if it works for you.
Comment From: fabriciorsf
I tested it, and the second problem doesn't occur, but the first problem remains:
Epoch 1/100
23/23 ━━━━━━━━━━━━━━━━━━━━ 0s 100ms/step - loss: 1.3480 - sumsqeuc_dist: 1.3470
And it stays frozen like that.
Comment From: fabriciorsf
I also tested without multithreading and everything worked fine.
So I suspect there is some problem with the CompoundEarlyStopping
when training with multithreading.
For the record: without the CompoundEarlyStopping
, the training with multiprocessing works fine.
Comment From: rlcauvin
@fabriciorsf What do you mean by "multithreading" in this context?
Comment From: fabriciorsf
It means using a PyDataset
instance to load data during fit
with use_multiprocessing=True
and workers > 1
.
Again: without the CompoundEarlyStopping
, the training with multiprocessing works fine.
Ctrl+C while freezing:
Epoch 1/100
23/23 ━━━━━━━━━━━━━━━━━━━━ 0s 100ms/step - loss: 1.3480 - sumsqeuc_dist: 1.3470^CProcess Keras_worker_ForkPoolWorker-28:
Process Keras_worker_ForkPoolWorker-36:
Process Keras_worker_ForkPoolWorker-39:
Process Keras_worker_ForkPoolWorker-25:
Process Keras_worker_ForkPoolWorker-35:
Process Keras_worker_ForkPoolWorker-29:
Process Keras_worker_ForkPoolWorker-31:
Process Keras_worker_ForkPoolWorker-26:
Process Keras_worker_ForkPoolWorker-32:
Process Keras_worker_ForkPoolWorker-37:
Traceback (most recent call last):
File ".........../myscript.py", line 853, in <module>
Process Keras_worker_ForkPoolWorker-23:
Process Keras_worker_ForkPoolWorker-27:
Process Keras_worker_ForkPoolWorker-30:
Process Keras_worker_ForkPoolWorker-40:
Process Keras_worker_ForkPoolWorker-34:
Process Keras_worker_ForkPoolWorker-41:
Process Keras_worker_ForkPoolWorker-33:
Process Keras_worker_ForkPoolWorker-44:
Process Keras_worker_ForkPoolWorker-24:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
File "....../python3.12/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "....../python3.12/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "....../python3.12/multiprocessing/pool.py", line 114, in worker
task = get()
^^^^^
File "....../python3.12/multiprocessing/queues.py", line 386, in get
with self._rlock:
File "....../python3.12/multiprocessing/synchronize.py", line 95, in __enter__
return self._semlock.__enter__()
^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt
.....
Comment From: rlcauvin
@fabriciorsf Thanks for clarifying how you're doing multiprocessing. I guess the CompoundEarlyStopping
class is not thread safe. Perhaps someone else can suggest how to make it thread safe, whether by using threading.Lock()
, threading.local()
, or some other means.
Comment From: fabriciorsf
Hi @fabriciorsf
As @rlcauvin mentioned, you can achieve this by creating a custom callback. Closing this issue. Please feel free to re-open if that's not what you're looking for.
Hi @SamanehSaadat, I don't have permissions to reopen this issue.
Comment From: SamanehSaadat
Have you tested switching to the JAX backend? (JAX is thread-safe but Tensorflow is not)
Comment From: github-actions[bot]
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
Comment From: github-actions[bot]
This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.