I am trying to write backend-agnostic code, and I have layers (Like the VQ layer) that compute losses in their forward functions. This is usually implemented with a self.add_loss call in the call function. However, the current handling of this extra loss differs between the jax backend and others. After reading the relevant code, here are some observations:

  • In TF/torch, the handling of self.losses happens in the generic compute_loss function, and the returned loss is the final loss value to use.
  • In jax, self.losses is always an empty list in the generic compute_loss and compute_metrics functions. Instead, the handling happens in the jax-exclusive compute_loss_and_updates, where it adds the extracted losses to the partial loss returned from compute_loss.
  • The direct implication is that, the backend-agnostic function compute_loss will have different values for different backends when we use add_loss; and we cannot access or use the layer's loss in compute_loss and compute_metrics because of scoping.

Given the situation, is it still possible to write backend-agnostic code when add_loss is used? Or is there a newer design pattern that can replace add_loss?

Below is a simplified example demonstrating the difference. I will post the code and the relevant outputs (since TF and torch have similar behavior, I am only pasting the PyTorch output).

import os
os.environ["KERAS_BACKEND"] = "jax"

import numpy as np
import keras
from keras import models, layers, ops


# helper function to print for debugging
match keras.backend.backend():
    case "jax":
        import jax
        print_fn = jax.debug.print
    case "tensorflow":
        import tensorflow as tf
        tf.config.run_functions_eagerly(True)   
        def print_fn(template, **kwargs):
            print(template.format(**kwargs))
    case _:
        def print_fn(template, **kwargs):
            print(template.format(**kwargs))


class ActivityRegularizationLayer(layers.Layer):
    def call(self, inputs):
        self.add_loss(ops.mean(inputs) * 100)
        print_fn("(1) layer.loss after layer call {x}", x=self.losses)
        return inputs 


class TestModel(keras.Model):
    def __init__(self):
        super().__init__()
        # self.layer1 = layers.Dense(3, activation="relu", name="dense_1")
        self.activity = ActivityRegularizationLayer()
        self.big_loss_tracker = keras.metrics.Mean(name="big_loss")
        self.mse_loss_tracker = keras.metrics.Mean(name="mse_loss")

    @property
    def metrics(self):
        return [
            self._loss_tracker,
            self.big_loss_tracker,
            self.mse_loss_tracker,
        ]

    def reset_metrics(self):
        self.big_loss_tracker.reset_state()
        self.mse_loss_tracker.reset_state()

    def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False):
        print_fn("(3.1) self.losses in compute_loss {x}", x=self.losses)
        res = super().compute_loss(x, y, y_pred, sample_weight, allow_empty)
        print_fn("(3.2) compute_loss returns {x}", x=res)
        return res

    def compute_metrics(self, x, y, y_pred, sample_weight=None):
        print_fn("(4) self.losses in compute_metrics {x}", x=self.losses)
        total_loss = ops.sum(self.activity.losses)
        self.big_loss_tracker.update_state(total_loss)
        self.mse_loss_tracker.update_state(keras.losses.MeanSquaredError()(y, y_pred))
        return self.get_metrics_result()

    def call(self, inputs):
        x = inputs
        # x = self.layer1(inputs)
        x = self.activity(x)
        print_fn("(2) self.losses after model call {x}", x=self.losses)
        return x

model = TestModel()
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=keras.losses.MeanSquaredError(),
)

print(f"the current backend is {keras.backend.backend()}")
x = np.array([[1., 2, 3]])
y = np.array([[4., 5, 6]])
model.fit(x, y, batch_size=10, epochs=1)

PyTorch outputs:

the current backend is torch
...
(1) layer.loss after layer call [tensor(200., device='cuda:0')]
(2) self.losses after model call [tensor(200., device='cuda:0')]
(3.1) self.losses in compute_loss [tensor(200., device='cuda:0')]
(3.2) compute_loss returns 209.0
(4) self.losses in compute_metrics [tensor(200., device='cuda:0')]
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 122ms/step - big_loss: 200.0000 - loss: 209.0000 - mse_loss: 9.0000

jax outputs:

the current backend is jax
(4) self.losses in compute_metrics []
(3.1) self.losses in compute_loss []
(3.2) compute_loss returns 9.0
(2) self.losses after model call [array(200.00002, dtype=float32)]
(1) layer.loss after layer call [array(200.00002, dtype=float32)]
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 84ms/step - big_loss: 0.0000e+00 - loss: 209.0000 - mse_loss: 9.0000

Comment From: SuryanarayanaY

Hi @water-vapor ,

I have replicated the reported behaviour and attached gist here. The 'jax' behaviour is different wrt torch and tensorflow.

Comment From: fchollet

This is a significant discrepancy indeed, thanks for pointing it out. We're making a few changes now which will allow you to write your model like this -- which is entirely backend-agnostic.

class TestModel(keras.Model):
    def __init__(self):
        super().__init__()
        self.activity = ActivityRegularizationLayer()
        self.big_loss_tracker = keras.metrics.Mean(name="big_loss")
        self.mse_loss_tracker = keras.metrics.Mean(name="mse_loss")

    def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
        # Note: use `sum` and not `ops.sum`, because `self.losses` is a list, not a tensor.
        # Also note: you have access to `self.losses`, which behaves the same in all backends.
        # However sublayers (like self.activity) will have no losses in JAX.
        total_loss = sum(self.losses)
        self.big_loss_tracker.update_state(total_loss)
        mse_loss = keras.losses.MeanSquaredError()(y, y_pred)
        self.mse_loss_tracker.update_state(mse_loss)
        return mse_loss + total_loss

    def call(self, inputs):
        x = inputs
        x = self.activity(x)
        return x

To note, you don't need to explicitly track metrics and implement reset_metrics because that's handled automatically. You can still do it if you want granular control though.

Comment From: fchollet

This is now available at HEAD.

Comment From: RyanSaxe

While I know this is closed, it's a bit of a "gotcha" that maybe should be documented.

I had the same issue coming up, which was being caused by doing a calculation inside compute_metrics instead of compute_losses. This worked fine in every other backend. But I wrote a matrix that runs my tests across all backends and caught this specifically for Jax.

This is because I was calculating something that was a function of my losses, but it wasn't getting added to my loss, and rather meant to be displayed as a metric.