created this layer to monitor what's happening:
class ReportingConv1D(Conv1D): def call(self, inputs): print(f"{inputs.dtype} {inputs.shape} {self.kernel.dtype} {self.kernel.shape} {self.kernel.path}") return super().call(inputs)
forward pass:
backward pass:
Comment From: dhantule
Hi @pumpnineteen, thanks for reporting this.
Could you provide standalone code to reproduce, if possible in a colab gist ?
Comment From: pumpnineteen
https://gist.github.com/pumpnineteen/533f091f48e8fe856d8734488f69a116
As you can see, there is no issue with the forwards pass, no issue with the backwards pass when it's dtype="float32", but error when its the mixed prec layer.
swapping rematscope mode to None the model.fit runs without issue
The issue is not limited to Conv1D
Comment From: dhantule
Hi @pumpnineteen, thanks for your response but I'm not able to access the gist.
Comment From: pumpnineteen
@dhantule made the gist public, hope that helps
Comment From: dhantule
Hi @pumpnineteen,
I've tested the code with Keras 3.10.0
and I'm facing the same error in this gist.
We'll look into this and update you.
Comment From: divyashreepathihalli
Hi @dhantule @pumpnineteen Thank you for reporting this. This is a known error in TF backend. This works fine on JAX backend for example
tf.image.resize(..., method='bilinear') gradient op produces float 32 tensors
Keras casts forward pass output to float16, consequent layers produce
float16
gradients - leading to XLA error
TF team would need to re-implement tf.image.resize(..., method='bilinear`) - But a workaround would be to set dtype to float32 in tf backend for conv2d.