created this layer to monitor what's happening:
class ReportingConv1D(Conv1D): def call(self, inputs): print(f"{inputs.dtype} {inputs.shape} {self.kernel.dtype} {self.kernel.shape} {self.kernel.path}") return super().call(inputs)
forward pass:
backward pass:
Comment From: dhantule
Hi @pumpnineteen, thanks for reporting this.
Could you provide standalone code to reproduce, if possible in a colab gist ?
Comment From: pumpnineteen
https://gist.github.com/pumpnineteen/533f091f48e8fe856d8734488f69a116
As you can see, there is no issue with the forwards pass, no issue with the backwards pass when it's dtype="float32", but error when its the mixed prec layer.
swapping rematscope mode to None the model.fit runs without issue
The issue is not limited to Conv1D