starter: adaptive_pooling <- this works in torch and tensorflow backend, not jax.

Comment From: innat

@dhantule It's not a bug for sure. It's a feature request. Please change the tag, its misleading.

The code I shared is for those who wants to contribute if this feature request is accepted to add.

This code, though able run in tf/torch backend, it's using python loop which is problematic. Need to look for vectorize ops or some tricky way!

Comment From: innat

https://github.com/ml-explore/mlx/issues/400

Comment From: innat

@AakashKumarNain Apologies for the interruption. I’m currently working on implementing adaptive pooling in Keras and noticed that tf-addons included an implementation before. However, it doesn’t seem to match the behavior of torch’s adaptive pooling. Since you’re the maintainer, could you please share any insights on whether there were specific challenges or blockers in implementing this method? Any related discussion threads or references would be greatly appreciated. Thanks.

Comment From: dhantule

@dhantule It's not a bug for sure. It's a feature request. Please change the tag, its misleading.

I think I accidentally added bug label, I have changed it.

Comment From: innat

Here’s the updated version of adaptive pooling implemented in keras. Unlike previosu starter code, it eliminates multiple for loops and introduces an efficient n-dimensional pooling gather method inspired by Hugging Face. The implementation runs seamlessly on both GPU and TPU across all supported backends. It also numerically identical to torch.nn.AdaptiveAvgPool within standard tolerances of 1e-6.

Comment From: MalyalaKarthik66

I’ve implemented adaptive average and max pooling across JAX, NumPy, TensorFlow, and PyTorch backends (currently excluding OpenVINO). Please review and share any feedback or suggested changes. PR: #21820

Comment From: innat

@MalyalaKarthik66 Thanks for the PR! Please make sure to address the following points: - Benchmark the execution time across TensorFlow, JAX, and PyTorch backends, using low to high input resolutions and varying output/pooling sizes. Include a comparison with torch.nn.AdaptiveAvgPool. - Verify numerical equivalence between the Keras implementation and native PyTorch adaptive pooling. - Test the implementation in a real training setup on both GPU (TF/PyTorch) and TPU (JAX).

Comment From: MalyalaKarthik66

@innat Thanks for the feedback! I’ve implemented for Torch, JAX, and NumPy backends(excluding tensorflow and openVINO). The PR includes:

  • Benchmarking across Torch, JAX for different input resolutions and output/pooling sizes.

  • Numerical equivalence verification between Keras adaptive pooling and PyTorch native adaptive pooling.

  • Real training setup tests on Torch GPU, ensuring correctness for training and inference.

Looking forward to your review!

Comment From: innat

@MalyalaKarthik66 Thanks. Is there any blockers to implement tf version?

Comment From: MalyalaKarthik66

@innat , thanks for checking!

Actually, I tried implementing the TensorFlow version at a basic level using a simple for-loop style. For JAX, I implemented the adaptive pooling with optimized, vectorized operations and committed that version.

I didn’t test TensorFlow fully because the for-loop style implementation isn’t efficient, so I focused on JAX to make it performant.

Comment From: innat

@MalyalaKarthik66 Thanks.

I've just quicky checked the jax implementation, based on some condition, it also uses for..loop, no? Extending this for 3D cases later might cause issue, perhaps.

Also, for torch cases, all 1D, 2D, 3D are imported from torch, great.

Comment From: MalyalaKarthik66

@innat I have implemented adaptive pooling for 1D, 2D, and 3D across JAX, TensorFlow, and PyTorch backends.

  • For PyTorch, the implementation uses native PyTorch adaptive pooling.

  • For JAX and TensorFlow, I implemented using efficient n-dimensional two-pool gather adaptive pooling method that eliminates multiple for-loops, enabling robust support on CPU, GPU, and TPU backends.

  • All corresponding unit tests for JAX, TensorFlow, and PyTorch adaptive pooling pass successfully.

  • In real training model tests, PyTorch passes on both GPU and CPU.

  • TensorFlow backend passes on CPU in training tests but currently fails on GPU with the error:AttributeError: 'str' object has no attribute 'base_dtype' during tape.gradient() in test_training_adaptive_pooling.py

I have tried but have not been able to resolve it. Would like to confirm if a similar problem has been encountered and get guidance on resolving this blocker affecting TensorFlow GPU training tests.

Comment From: innat

@MalyalaKarthik66 I've just re-run my implementaiton with tf backend, works properly in tf backend.

Looking at the error you face 'str' object has no attribute 'base_dtype' seems like somewhere tf.ops tried to compute base_dtype from a tensor but instead got string. But as you mentioned, it happened in training time with tf. What the tf or keras version are you using? Also make sure to set proper backend while testing.

stack-overflow, google-forum