Hi,
I have a question about keras.ops.eye of Keras3.
With Tensorflow, the keras.ops.eye accepts integers and floats. However, with torch and jax, only integers are accepted.
Which behavior is correct?
There is no limitation description on documents. Because Keras 2 accepts integers and floats, I was confused with it during migration.
Comment From: dhantule
Hi @yubori,
thanks for reporting this. keras.ops.eye
will return a 2-D tensor with ones on the diagonal and zeros elsewhere (identity matrix). However, you can control the data type of the resulting matrix to be float32 or float64. The values on the diagonal will still be 1.0 and the other elements will be 0.0 . Attaching gist.
Comment From: yubori
@dhantule Sorry for the lack of my description.
This is a very simple issue.
Please see the following gist.
Comment From: dhantule
@yubori, Thanks for reporting this. I have reproduced this error here and keras.ops.eye
produces error when using float values as input dimensions with Jax and PyTorch as backends. Although in Jax the jax.numpy.eye
operation doesn’t support float values, similarly with torch.eye
. We will look into this issue more. Thanks !
Comment From: yubori
Hi @dhantule I am looking forward to it being fixed. Thank you.
Comment From: mattdangerw
I think in general we should probably let numpy
be the guide here. np.eye
will not accept floats or numpy array values, the same is true for jax and torch. This seems correct to me.
The bug is that tf is too permissive here, we could add errors so that keras.ops.eye
errors for floats on the tf backend similar to jax and torch.
I'll mark this as contributors welcome. This is a good first issue to bring tf in line with jax and torch.
Comment From: github-actions[bot]
This issue is stale because it has been open for 180 days with no activity. It will be closed if no further activity occurs. Thank you.