Keras internals expose a method on layers to register custom arguments: _register_call_context_args

The example in the documentation looks like this:

NOTE: the code shared below from current version of keras has a typo that you should also consider fixing: seq.register_call_context_args("foo_mode") should be seq._register_call_context_args("foo_mode"). The underscore missed means the code example does not run.

class Inner(layers.Layer):

    def __init__(self):
        super().__init__()
        # Register `foo_mode` as a call-context arg
        self._register_call_context_args("foo_mode")

    def call(self, x, foo_mode=False):
        # If foo_mode=True add 1, otherwise add 0
        add_val = ops.where(foo_mode, 1.0, 0.0)
        return x + add_val

class Outer(layers.Layer):
    def __init__(self):
        super().__init__()
        self.inner = Inner()

    def call(self, x):
        # We don't explicitly pass foo_mode here—Base Layer.__call__
        # should inject it into `self.inner`
        return self.inner(x)

sample_input = np.array([[1.0], [2.0]])

# Sequential model
seq = models.Sequential([Outer()])

# Tell the Sequential model to propagate foo_mode down
# the call-stack
seq.register_call_context_args("foo_mode")

# foo_mode=True -> input + 1
out_true = seq(sample_input, foo_mode=True)

However, it seems like it is only possible to specify foo_mode when calling the model directly via __call__. If I were to try any of the commonly exposed KPIs, I will get an error:

seq.predict(sample_input, foo_mode=True)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[10], line 1
----> 1 seq.predict(sample_input, foo_mode=True)

File [~/.../.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122](~/.../.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py#line=121), in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File [~/.../.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:117](~/.../.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py#line=116), in filter_traceback.<locals>.error_handler(*args, **kwargs)
    115 filtered_tb = None
    116 try:
--> 117     return fn(*args, **kwargs)
    118 except Exception as e:
    119     filtered_tb = _process_traceback_frames(e.__traceback__)

TypeError: TensorFlowTrainer.predict() got an unexpected keyword argument 'foo_mode'

The documentation of the _register_call_context_args method states "This is useful for propagating custom arguments from top-level layers/models to sublayers." However the main useful pathways for doing so at the top-level model do not work. So I can't actually forward custom arguments downstream to sublayers during the recommended methods of fit or predict.