I am trying to write backend-agnostic code, and I have layers (Like the VQ layer) that compute losses in their forward functions. This is usually implemented with a self.add_loss
call in the call
function. However, the current handling of this extra loss differs between the jax backend and others. After reading the relevant code, here are some observations:
- In TF/torch, the handling of
self.losses
happens in the genericcompute_loss
function, and the returned loss is the final loss value to use. - In jax,
self.losses
is always an empty list in the genericcompute_loss
andcompute_metrics
functions. Instead, the handling happens in the jax-exclusivecompute_loss_and_updates
, where it adds the extracted losses to the partial loss returned fromcompute_loss
. - The direct implication is that, the backend-agnostic function
compute_loss
will have different values for different backends when we useadd_loss
; and we cannot access or use the layer's loss incompute_loss
andcompute_metrics
because of scoping.
Given the situation, is it still possible to write backend-agnostic code when add_loss
is used? Or is there a newer design pattern that can replace add_loss
?
Below is a simplified example demonstrating the difference. I will post the code and the relevant outputs (since TF and torch have similar behavior, I am only pasting the PyTorch output).
import os
os.environ["KERAS_BACKEND"] = "jax"
import numpy as np
import keras
from keras import models, layers, ops
# helper function to print for debugging
match keras.backend.backend():
case "jax":
import jax
print_fn = jax.debug.print
case "tensorflow":
import tensorflow as tf
tf.config.run_functions_eagerly(True)
def print_fn(template, **kwargs):
print(template.format(**kwargs))
case _:
def print_fn(template, **kwargs):
print(template.format(**kwargs))
class ActivityRegularizationLayer(layers.Layer):
def call(self, inputs):
self.add_loss(ops.mean(inputs) * 100)
print_fn("(1) layer.loss after layer call {x}", x=self.losses)
return inputs
class TestModel(keras.Model):
def __init__(self):
super().__init__()
# self.layer1 = layers.Dense(3, activation="relu", name="dense_1")
self.activity = ActivityRegularizationLayer()
self.big_loss_tracker = keras.metrics.Mean(name="big_loss")
self.mse_loss_tracker = keras.metrics.Mean(name="mse_loss")
@property
def metrics(self):
return [
self._loss_tracker,
self.big_loss_tracker,
self.mse_loss_tracker,
]
def reset_metrics(self):
self.big_loss_tracker.reset_state()
self.mse_loss_tracker.reset_state()
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False):
print_fn("(3.1) self.losses in compute_loss {x}", x=self.losses)
res = super().compute_loss(x, y, y_pred, sample_weight, allow_empty)
print_fn("(3.2) compute_loss returns {x}", x=res)
return res
def compute_metrics(self, x, y, y_pred, sample_weight=None):
print_fn("(4) self.losses in compute_metrics {x}", x=self.losses)
total_loss = ops.sum(self.activity.losses)
self.big_loss_tracker.update_state(total_loss)
self.mse_loss_tracker.update_state(keras.losses.MeanSquaredError()(y, y_pred))
return self.get_metrics_result()
def call(self, inputs):
x = inputs
# x = self.layer1(inputs)
x = self.activity(x)
print_fn("(2) self.losses after model call {x}", x=self.losses)
return x
model = TestModel()
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
loss=keras.losses.MeanSquaredError(),
)
print(f"the current backend is {keras.backend.backend()}")
x = np.array([[1., 2, 3]])
y = np.array([[4., 5, 6]])
model.fit(x, y, batch_size=10, epochs=1)
PyTorch outputs:
the current backend is torch
...
(1) layer.loss after layer call [tensor(200., device='cuda:0')]
(2) self.losses after model call [tensor(200., device='cuda:0')]
(3.1) self.losses in compute_loss [tensor(200., device='cuda:0')]
(3.2) compute_loss returns 209.0
(4) self.losses in compute_metrics [tensor(200., device='cuda:0')]
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 122ms/step - big_loss: 200.0000 - loss: 209.0000 - mse_loss: 9.0000
jax outputs:
the current backend is jax
(4) self.losses in compute_metrics []
(3.1) self.losses in compute_loss []
(3.2) compute_loss returns 9.0
(2) self.losses after model call [array(200.00002, dtype=float32)]
(1) layer.loss after layer call [array(200.00002, dtype=float32)]
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 84ms/step - big_loss: 0.0000e+00 - loss: 209.0000 - mse_loss: 9.0000
Comment From: SuryanarayanaY
Hi @water-vapor ,
I have replicated the reported behaviour and attached gist here. The 'jax' behaviour is different wrt torch and tensorflow.
Comment From: fchollet
This is a significant discrepancy indeed, thanks for pointing it out. We're making a few changes now which will allow you to write your model like this -- which is entirely backend-agnostic.
class TestModel(keras.Model):
def __init__(self):
super().__init__()
self.activity = ActivityRegularizationLayer()
self.big_loss_tracker = keras.metrics.Mean(name="big_loss")
self.mse_loss_tracker = keras.metrics.Mean(name="mse_loss")
def compute_loss(self, x=None, y=None, y_pred=None, sample_weight=None):
# Note: use `sum` and not `ops.sum`, because `self.losses` is a list, not a tensor.
# Also note: you have access to `self.losses`, which behaves the same in all backends.
# However sublayers (like self.activity) will have no losses in JAX.
total_loss = sum(self.losses)
self.big_loss_tracker.update_state(total_loss)
mse_loss = keras.losses.MeanSquaredError()(y, y_pred)
self.mse_loss_tracker.update_state(mse_loss)
return mse_loss + total_loss
def call(self, inputs):
x = inputs
x = self.activity(x)
return x
To note, you don't need to explicitly track metrics
and implement reset_metrics
because that's handled automatically. You can still do it if you want granular control though.
Comment From: fchollet
This is now available at HEAD.
Comment From: RyanSaxe
While I know this is closed, it's a bit of a "gotcha" that maybe should be documented.
I had the same issue coming up, which was being caused by doing a calculation inside compute_metrics instead of compute_losses. This worked fine in every other backend. But I wrote a matrix that runs my tests across all backends and caught this specifically for Jax.
This is because I was calculating something that was a function of my losses, but it wasn't getting added to my loss, and rather meant to be displayed as a metric.