Keras internals expose a method on layers to register custom arguments: _register_call_context_args
The example in the documentation looks like this:
NOTE: the code shared below from current version of keras has a typo that you should also consider fixing:
seq.register_call_context_args("foo_mode")
should beseq._register_call_context_args("foo_mode")
. The underscore missed means the code example does not run.
class Inner(layers.Layer):
def __init__(self):
super().__init__()
# Register `foo_mode` as a call-context arg
self._register_call_context_args("foo_mode")
def call(self, x, foo_mode=False):
# If foo_mode=True add 1, otherwise add 0
add_val = ops.where(foo_mode, 1.0, 0.0)
return x + add_val
class Outer(layers.Layer):
def __init__(self):
super().__init__()
self.inner = Inner()
def call(self, x):
# We don't explicitly pass foo_mode here—Base Layer.__call__
# should inject it into `self.inner`
return self.inner(x)
sample_input = np.array([[1.0], [2.0]])
# Sequential model
seq = models.Sequential([Outer()])
# Tell the Sequential model to propagate foo_mode down
# the call-stack
seq.register_call_context_args("foo_mode")
# foo_mode=True -> input + 1
out_true = seq(sample_input, foo_mode=True)
However, it seems like it is only possible to specify foo_mode
when calling the model directly via __call__
. If I were to try any of the commonly exposed KPIs, I will get an error:
seq.predict(sample_input, foo_mode=True)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[10], line 1
----> 1 seq.predict(sample_input, foo_mode=True)
File [~/.../.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122](~/.../.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py#line=121), in filter_traceback.<locals>.error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.__traceback__)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
File [~/.../.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:117](~/.../.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py#line=116), in filter_traceback.<locals>.error_handler(*args, **kwargs)
115 filtered_tb = None
116 try:
--> 117 return fn(*args, **kwargs)
118 except Exception as e:
119 filtered_tb = _process_traceback_frames(e.__traceback__)
TypeError: TensorFlowTrainer.predict() got an unexpected keyword argument 'foo_mode'
The documentation of the _register_call_context_args
method states "This is useful for propagating custom arguments from top-level layers/models to sublayers." However the main useful pathways for doing so at the top-level model do not work. So I can't actually forward custom arguments downstream to sublayers during the recommended methods of fit
or predict
.