Last night I was trying to code some examples with jax backend but couldn't get it working because of cuda
and cuddnn
issues. Given that all three backends depends on different versions of cuda and cudnn, we need to define a compatibility matrix in the README
. For example, something like this:
Backend | version | cuda version | cudnn version |
---|---|---|---|
TensorFlow | 2.11 | 11.xx | 8.xx |
JAX | 0.4.7 | 11.xx | 8.xx |
PyTorch | 2.x | 11.xx | 8.xx |
For TensorFlow, it is easy to find this information but for other backends it's not very well documented. Even the minor difference in cudnn
version doesn't work. This information will be helpful for people having their setup in cloud e.g. GCP because most of the virtual machines are still optimized for TF and PyTorch only, and that too separately. Let me know what you think
Comment From: f-hafner
Hi, I'm trying to understand the compatibility between keras and multiple backends (especially pytorch), and came across this issue. What's the status? I could not find anything in the README.
I found this on the keras docs:
The following Keras + PyTorch versions are compatible with each other:
torch~=2.1.0 & keras~=3.0
Is this still the current state (torch 2.1.0 is from late 2023...)?
Thanks!