Keras internals expose a method on layers to register custom arguments: _register_call_context_args

The example in the documentation looks like this:

NOTE: the code shared below from current version of keras has a typo that you should also consider fixing: seq.register_call_context_args("foo_mode") should be seq._register_call_context_args("foo_mode"). The underscore missed means the code example does not run.

class Inner(layers.Layer):

    def __init__(self):
        super().__init__()
        # Register `foo_mode` as a call-context arg
        self._register_call_context_args("foo_mode")

    def call(self, x, foo_mode=False):
        # If foo_mode=True add 1, otherwise add 0
        add_val = ops.where(foo_mode, 1.0, 0.0)
        return x + add_val

class Outer(layers.Layer):
    def __init__(self):
        super().__init__()
        self.inner = Inner()

    def call(self, x):
        # We don't explicitly pass foo_mode here—Base Layer.__call__
        # should inject it into `self.inner`
        return self.inner(x)

sample_input = np.array([[1.0], [2.0]])

# Sequential model
seq = models.Sequential([Outer()])

# Tell the Sequential model to propagate foo_mode down
# the call-stack
seq.register_call_context_args("foo_mode")

# foo_mode=True -> input + 1
out_true = seq(sample_input, foo_mode=True)

However, it seems like it is only possible to specify foo_mode when calling the model directly via __call__. If I were to try any of the commonly exposed KPIs, I will get an error:

seq.predict(sample_input, foo_mode=True)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[10], line 1
----> 1 seq.predict(sample_input, foo_mode=True)

File [~/.../.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:122](~/.../.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py#line=121), in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File [~/.../.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py:117](~/.../.venv/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py#line=116), in filter_traceback.<locals>.error_handler(*args, **kwargs)
    115 filtered_tb = None
    116 try:
--> 117     return fn(*args, **kwargs)
    118 except Exception as e:
    119     filtered_tb = _process_traceback_frames(e.__traceback__)

TypeError: TensorFlowTrainer.predict() got an unexpected keyword argument 'foo_mode'

The documentation of the _register_call_context_args method states "This is useful for propagating custom arguments from top-level layers/models to sublayers." However the main useful pathways for doing so at the top-level model do not work. So I can't actually forward custom arguments downstream to sublayers during the recommended methods of fit or predict.

Comment From: hertschuh

NOTE: the code shared below from current version of keras has a typo that you should also consider fixing: seq.register_call_context_args("foo_mode") should be seq._register_call_context_args("foo_mode"). The underscore missed means the code example does not run.

Oh, true, this can easily be fixed.

The documentation of the _register_call_context_args method states "This is useful for propagating custom arguments from top-level layers/models to sublayers." However the main useful pathways for doing so at the top-level model do not work. So I can't actually forward custom arguments downstream to sublayers during the recommended methods of fit or predict.

The goal of this feature was not to propagate from fit to call. The training argument for instance is implied by fit.

To be honest, this feature is very niche and was not intended to be used publicly.

What is your use case? There might be a may to make it work, with or without _register_call_context_args.

Comment From: RyanSaxe

The goal of this feature was not to propagate from fit to call. The training argument for instance is implied by fit.

To be honest, this feature is very niche and was not intended to be used publicly.

What is your use case? There might be a may to make it work, with or without _register_call_context_args.

In the Subclassing API, unlike the functional API, debugging and inspecting intermediary objects can be quite difficult, because you don't have access to them. And returning giant structures from each layer can introduce inefficiencies in compilation and computation.

I am trying to solve this with a wrapper on top of the keras.layers.Layer abstraction that effectively registers a debugging flag such that call functions end up looking like below (NOTE: this is an oversimplification as I cannot share the actual specifics beyond this, but hopefully it illustrates how I ended up looking into _register_call_context_args)

def call(self, inputs, training=None, debugging=False):
    # do whatever the layer does
    return required_output_dict if not debugging else required_output_dict | intermediary_dict

this lets me optionally return the giant structures if I am debugging something about my system, and do so in ways where I can even have model level hooks that will let me calculate gradients between intermediaries etc etc. This works totally fine if I use _register_call_context_args as long as I only debug through the __call__ route. While that's okay, it technically could introduce issues if something odd is being caused by keras internals of fit, and hence I can't inspect intermediaries under that condition.

Lets look at the training step for TensorFlow to explain what feels reasonable to me:

    def train_step(self, data):
        x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)

        # Forward pass
        with tf.GradientTape() as tape:
            if self._call_has_training_arg:
                y_pred = self(x, training=True)
            else:
                y_pred = self(x)
        # the rest below is irrelevant

Why can users not have more flexible call signatures? Where the internals looked more like this

    def train_step(self, data, **kwargs):
        # you could also inject training into kwargs here, and use warnings.warn if the value was changed or something
        if not kwargs.get("training"):
            raise ArgumentError(f"train step is called with training={kwargs.get('training')}, which is not truthy")
        x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)

        # Forward pass
        with tf.GradientTape() as tape:
            kwargs = self._filter_kwargs_if_call_has_them(kwargs)
            y_pred = self(x, **kwargs)
        # the rest below is irrelevant

Basically, given the documentation of _register_call_context_args, I thought the only reason for this to exist as a generalization (since if it's just for the training flag, why even have this be some abstraction at all) is because something like the above would be supported. If it is not intended to be supported, you should update the docs accordingly. If it is not meant to be used, state that instead of providing an example of how to use it to create some arbitrary foo_mode argument.

Comment From: hertschuh

In the Subclassing API, unlike the functional API, debugging and inspecting intermediary objects can be quite difficult, because you don't have access to them. And returning giant structures from each layer can introduce inefficiencies in compilation and computation.

Agreed. We may be working on a different solution for this issue. @abheesht17

I am trying to solve this with a wrapper on top of the keras.layers.Layer abstraction that effectively registers a debugging flag such that call functions end up looking like below

Why not use a global flag or a scope (ContextManager)? There is no need to actually pass the flag from layer to layer, is there?

If it is not meant to be used, state that instead of providing an example of how to use it to create some arbitrary foo_mode argument.

Some of the private internals of Keras are documented not because they're meant for public use, but for the maintainers of Keras. We strongly discourage using private APIs because the behavior and APIs can change.

Comment From: RyanSaxe

In the Subclassing API, unlike the functional API, debugging and inspecting intermediary objects can be quite difficult, because you don't have access to them. And returning giant structures from each layer can introduce inefficiencies in compilation and computation.

Agreed. We may be working on a different solution for this issue. @abheesht17

Glad to hear this. It has been my biggest pain point. I'm also happy to discuss/help/contribute on this side if I can.

I am trying to solve this with a wrapper on top of the keras.layers.Layer abstraction that effectively registers a debugging flag such that call functions end up looking like below

Why not use a global flag or a scope (ContextManager)? There is no need to actually pass the flag from layer to layer, is there?

Won't this mess with function compilation? The return of call is different according to the state of that global flag (or scope). So won't keras internals not know that the return structure has changed unless it is an aspect of the signature? At least that's how tf.function, torch.compile, and jax.jit work if I remember correctly.

If it is not meant to be used, state that instead of providing an example of how to use it to create some arbitrary foo_mode argument.

Some of the private internals of Keras are documented not because they're meant for public use, but for the maintainers of Keras. We strongly discourage using private APIs because the behavior and APIs can change.

I understand this. I almost always avoid using internals where possible for this reason. I originally was doing this without calling the private function (since I don't need to have things skip layers via call context). I only came across this because I didn't want to have to overwrite all the normal paths (e.g. fit, predict, evaluate) just for the purpose of forwarding a boolean flag. And when I saw the docs were extensive with code examples that called a non-private version of the function (even though it doesn't exist, which I point out in this issue) lead me to believe this is something worth looking into.

Comment From: hertschuh

Glad to hear this. It has been my biggest pain point. I'm also happy to discuss/help/contribute on this side if I can.

What backend do you use?

Won't this mess with function compilation? The return of call is different according to the state of that global flag (or scope). So won't keras internals not know that the return structure has changed unless it is an aspect of the signature? At least that's how tf.function, torch.compile, and jax.jit work if I remember correctly.

Correct, this will only work if you call fit, evaluate and predict at most once each.

I was assuming that calling fit(debugging=True) and then right after fit(debugging=False) was not a use-case (either you're debugging or you're not).

Comment From: RyanSaxe

What backend do you use?

Mostly TensorFlow, but I have used all three. Hoping to be able to do a bit more with backend agnostic code here though.

Won't this mess with function compilation? The return of call is different according to the state of that global flag (or scope). So won't keras internals not know that the return structure has changed unless it is an aspect of the signature? At least that's how tf.function, torch.compile, and jax.jit work if I remember correctly.

Correct, this will only work if you call fit, evaluate and predict at most once each.

I was assuming that calling fit(debugging=True) and then right after fit(debugging=False) was not a use-case (either you're debugging or you're not).

You're right, that's not a use case to call the same one twice (unless ensuring that, say evaluate(debugging=True) == evaluate(debugging=False).

but evaluate(debugging=False) following by predict(debugging=True) could come up if I want to inspect model internals to understand something related to the model performance.

My understanding is, under the hood, these both go through routes that will lead to compillation of __call__, and hence the model call gets compiled with debugging=False equivalent if I use a global scope, and then I cannot call predict(debugging=True) or even model(debugging=True) because it was compiled with the wrong global scope and there is no signature change (since we cant pass debugging=True if we use global scope).

Sorry if this is needlessly confusing. My use case mostly works by using this private function and calling the model directly, but it just misses some particular edge cases is all.

Comment From: hertschuh

Mostly TensorFlow, but I have used all three. Hoping to be able to do a bit more with backend agnostic code here though.

If you're open to doing the debugging on JAX only, you can enable NNX and then use Module.sow to transparently pass intermediate values.

but evaluate(debugging=False) following by predict(debugging=True) could come up if I want to inspect model internals to understand something related to the model performance.

That would be fine.

these both go through routes that will lead to compillation of call

Yes, but what's compiled is a wrapper of __call__ that is different for fit, predict and evaluate, so they each get compiled separately.

model(debugging=True)

This is run eagerly, not compiled, so turning in on / off the global debugging flag and calling model(x) will work.