Issue for tracking and coordinating mlx backend work:
mlx.math
- [ ]
fft
- [ ]
fft2
- [ ]
rfft
- [ ]
irfft
- [ ]
stft
- [ ]
istft
- [x]
logsumexp
#19578 - [ ]
qr
- [ ]
segment_sum
#19652 - [ ]
segment_max
#19652 - [x]
erfinv
#19628
mlx.numpy
- [ ]
einsum
- [ ]
bincount
- [ ]
nonzero
- [ ]
cross
- [ ]
vdot
- [ ]
nan_to_num
- [ ]
copy
- [ ]
roll
- [x]
median
#19568 #19574 - [x]
meshgrid
#19574 - [x]
conjugate
- [x]
arctan2
#19759 - [ ]
quantile
- [ ]
imag
- [ ]
real
- [ ]
select
- [x]
argpartition
https://github.com/keras-team/keras/pull/19680 - [ ]
slogdet
- [ ]
select
- [ ]
vectorize
- [ ]
correlate
- [x]
diag
#19714 - [x]
diagonal
#19714
mlx.image
- [x]
rgb_to_grayscale
#19609 - [x]
resize
- #19699
mlx.nn
- [ ]
max_pool
- [ ]
avg_pool
- [ ]
conv
- [ ]
depthwise_conv
- [ ]
separable_conv
- [ ]
conv_transpose
- [ ]
ctc_loss
mlx.rnn
- [ ]
rnn
- [ ]
lstm
- [ ]
gru
mlx.linalg
- [x]
cholesky
- [ ]
det
- [ ]
eig
- [ ]
eigh
- [x]
inv
- [ ]
lu_factor
- [x]
norm
#19698 - [x]
qr
- [ ]
solve
- [ ]
solve_triangular
- [x]
svd
mlx.core
- [x] np.ndarray of i64 is being cast to i32 in mlx during conversion if dtype is not passed
- [x] https://github.com/ml-explore/mlx/issues/1076
- [ ] https://github.com/ml-explore/mlx/issues/1075
- [x] https://github.com/ml-explore/mlx/issues/1066
- [x] https://github.com/ml-explore/mlx/issues/1065
Comment From: lkarthee
PyTest Output
=========================================================================== test session starts ============================================================================
platform darwin -- Python 3.12.2, pytest-8.1.1, pluggy-1.4.0 -- /Users/kartheek/erlang-ws/github-ws/latest/keras/.venv/bin/python3.12
cachedir: .pytest_cache
rootdir: /Users/kartheek/erlang-ws/github-ws/latest/keras
configfile: pyproject.toml
plugins: cov-5.0.0
collected 6 items
keras/src/ops/operation_test.py::OperationTest::test_autoconfig PASSED [ 16%]
keras/src/ops/operation_test.py::OperationTest::test_eager_call PASSED [ 33%]
keras/src/ops/operation_test.py::OperationTest::test_input_conversion FAILED [ 50%]
keras/src/ops/operation_test.py::OperationTest::test_serialization PASSED [ 66%]
keras/src/ops/operation_test.py::OperationTest::test_symbolic_call PASSED [ 83%]
keras/src/ops/operation_test.py::OperationTest::test_valid_naming PASSED [100%]
================================================================================= FAILURES =================================================================================
___________________________________________________________________ OperationTest.test_input_conversion ____________________________________________________________________
self = <keras.src.ops.operation_test.OperationTest testMethod=test_input_conversion>
def test_input_conversion(self):
x = np.ones((2,))
y = np.ones((2,))
z = knp.ones((2,)) # mix
if backend.backend() == "torch":
z = z.cpu()
op = OpWithMultipleInputs()
> out = op(x, y, z)
keras/src/ops/operation_test.py:152:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
keras/src/utils/traceback_utils.py:113: in error_handler
return fn(*args, **kwargs)
keras/src/ops/operation.py:56: in __call__
return self.call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Operation name=op_with_multiple_inputs>, x = array([1., 1.]), y = array([1., 1.])
z = <[ValueError('item can only be called on arrays of size 1.') raised in repr()] array object at 0x13f7450c0>
def call(self, x, y, z=None):
# `z` has to be put first due to the order of operations issue with
# torch backend.
> return 3 * z + x + 2 * y
E ValueError: Cannot perform addition on an mlx.core.array and ndarray
keras/src/ops/operation_test.py:14: ValueError
========================================================================= short test summary info ==========================================================================
FAILED keras/src/ops/operation_test.py::OperationTest::test_input_conversion - ValueError: Cannot perform addition on an mlx.core.array and ndarray
======================================================================= 1 failed, 5 passed in 0.13s ========================================================================
How to fix this test case any idea ? add(mx_array, numpy_array) works but fails when using + operator. Should we skip this test for mlx backend ?
Comment From: fchollet
How to fix this test case any idea ? add(mx_array, numpy_array) works but fails when using + operator. Should we skip this test for mlx backend ?
It's not fixable on our side, we should file an issue with the MLX repo. +
will hit array.__add__
which is on their side.
Comment From: Faisal-Alsrheed
Thank you for the list.
I am doing
keras/backend/mlx/nn.py:conv keras/backend/mlx/nn.py:depthwise_conv keras/backend/mlx/nn.py:separable_conv keras/backend/mlx/nn.py:conv_transpose
Comment From: lkarthee
I am working on segment_sum, segment_max, max_pool and avg_pool. Thank you .
Comment From: yrahul3910
I want to take a stab at arctan2
(first-time contributor, so I want to start small). I'm working with the mlx team to see if I can add in the required stuff there first, and then I'll add the implementation here.
Comment From: lkarthee
Thank you @yrahul3910 , please go ahead with adding arctan2
impl.
Comment From: lkarthee
mx.matmul
and mx.tensordot
works only for bfloat16, float16, float32.
FAILED keras/src/ops/numpy_test.py::NumpyDtypeTest::test_tensordot_('int16', 'bool') - ValueError: [matmul] Only real floating point types are supported but int16 and bool were provided which results in int16, which is not a real floating point type.
@fchollet How do we handle this - we can cast integers arguments to float32 if both are integers and result will be float32. If we go this route, we have to modify test cases in numpy_test.py
for mlx. Do you have any suggestions.
Comment From: awni
Just want to let you all know some updates to MLX as of 0.16.1 that may be useful here:
mx.einsum
mx.nan_to_num
mx.conjugate
Are there any high priority items we can fix or add to help move this along?
Comment From: lkarthee
Thank you @awni , we need some help in moving this forward. I will make a list and get back to you in a day or two.
Comment From: acsweet
I'd like to pick up on this issue (first time contributor) starting with fft
if that's okay
Comment From: acsweet
I'm going to start with the "easy" stuff already implemented in mlx, and I'll start in mlx.math
with
- fft2
- rfft
- irfft
- qr
(I'll have to see how to handle the mode
argument from Keras
Comment From: awni
Sounds great! Let us know how we can help on the MLX side.
Comment From: acsweet
@awni Thank you! I'll keep you updated as I progress.
Right now, would it be possible to get stft
and istft
implemented on the mlx side? It looks like it was started here ml-explore/mlx#1004
I saw this implementation too (without an inverse) https://github.com/nuniz/mlx_stft
Comment From: fchollet
Please note, the nn and rnn namespaces are the most important for getting mlx to work with typical workflows.
On Fri, Jan 17, 2025, 4:23 PM acsweet @.***> wrote:
@awni https://github.com/awni Thank you! I'll keep you updated as I progress.
Right now, would it be possible to get stft and istft implemented on the mlx side? It looks like it was started here ml-explore/mlx#1004 https://github.com/ml-explore/mlx/issues/1004 I saw this implementation too (without an inverse) https://github.com/nuniz/mlx_stft
— Reply to this email directly, view it on GitHub https://github.com/keras-team/keras/issues/19571#issuecomment-2599408622, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAFNM37QZKZSTXJGJWT7ADD2LGNJLAVCNFSM6AAAAABVICBTNGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKOJZGQYDQNRSGI . You are receiving this because you were mentioned.Message ID: @.***>
Comment From: acsweet
I'm going to hold off on math.qr
for now, mlx currently only supports square matrices (and no option for the complete or reduced factorization).
I have a PR for fft2
, rfft
, and irfft
(and a fix to fft
), if that looks good I'll start looking at the rnn
namespace.
It looked like the backend implementations for rnn.gru
and rnn.lstm
were only implemented for tensorflow for cudnn specific speedups with tf. So I think it's safe to follow similarly to jax and torch?
Comment From: fchollet
Right, unless mlx actually exposes some cudnn bindings for these
On Fri, Jan 17, 2025, 11:15 PM acsweet @.***> wrote:
I'm going to hold off on math.qr for now, mlx currently only supports square matrices (and no option for the complete or reduced factorization).
I have a PR for fft2, rfft, and irfft (and a fix to fft), if that looks good I'll start looking at the rnn namespace.
It looked like the backend implementations for rnn.gru and rnn.lstm were only implemented for tensorflow for cudnn specific speedups with tf. So I think it's safe to follow similarly to jax and torch?
— Reply to this email directly, view it on GitHub https://github.com/keras-team/keras/issues/19571#issuecomment-2599588836, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAFNM367LBQFSKY3WNIPHQ32LH5S5AVCNFSM6AAAAABVICBTNGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKOJZGU4DQOBTGY . You are receiving this because you were mentioned.Message ID: @.***>
Comment From: acsweet
I'm going to start working through mlx.nn
now.
I hope that's okay, but I'm going to start with conv
, and if @lkarthee or @Faisal-Alsrheed would like to jump back in, please do! Otherwise I'll keep working through the other functions.
Comment From: acsweet
@awni Would it be possible to get support for non-square matrices implemented in mlx.linalg.qr
? I didn't see an open issue for it, I can open a feature enhancement too.
Comment From: awni
Yes please open an issue about it, it should be straightforward to get it working
Comment From: acsweet
If the conv
implementation looks good, I think I'll get started on the other convolutional functions
- depthwise_conv
- separable_conv
- conv_transpose
Comment From: acsweet
I have a pull request for the remaining convolutional functions, if those look good I'll continue!
Fadi asked to work on max_pool
and avg_pool
, so I'm going to work on the remaining nn
functions that are failing tests.
Comment From: acsweet
@awni Would it be pretty straightforward to implement singular value norms in linalg::matrix_norm
?
I can open an issue for it too!
Comment From: acsweet
It's in the PR, but for reference:
- psnr
, ctc_loss
, and ctc_decode
implemented
- ctc_decode
with beam_search
is not the most efficient without a unique
function from mlx
- norm
was swapped to use jax's until mlx`s supports singular value norms (https://github.com/ml-explore/mlx/issues/1791)
I'm going to continue with numpy
and linalg
implementations focusing on passing tests in keras/src/layers
Comment From: fbadine
pooling functionality added in https://github.com/keras-team/keras/pull/20814
Comment From: acsweet
Once these latest two PRs are merged, I'd like to try merging the master branch into mlx (if that's okay)
Comment From: acsweet
Merged the Keras master
branch into mlx
and patched a few files for pytest to work
Going to add new functions and check tests starting with nn.py
Will start adjusting related tests that should be skipped for mlx
, e.g. float64
, flash_attention
, etc.
Comment From: acsweet
I'm currently working through getting the layer tests to pass (keras/src/layers
) including updates to ops
functions as needed, and skipping unsupported tests.
Updates to ops
are mostly in math
, nn
, image
, numpy
, and core
.
Comment From: acsweet
We're continuing to work through the remaining functions in linalg
and numpy
, and adjusting tests where appropriate.
I'm marking quantize tests as unsupported by mlx for now. The quantized matmul with int8
is quite strict, and float8
isn't supported yet. If anyone has some other thoughts on this, I'd be very happy to hear.
Comment From: awni
The quantized matmul with int8 is quite strict
Just curious what you mean by that? What flexibility is missing?
Comment From: acsweet
Sorry Awni, I think I spoke too soon! I'm still wrapping my head around mlx's quantization and Keras' quantization related methods (and quantization in general).
It looked like to call quantize
the columns of the input needed to be divisible by the group_size
(either 64 or 128), and quantized_matmul
was with one non-quantized array. Can matrix multiplication be performed with two quantized arrays? I think mlx.matmul
only supports floating point types, will this method allow for integer types too at some point?
I think I need to read up more and maybe pick someone's brain on this topic soon!
Comment From: awni
It looked like to call quantize the columns of the input needed to be divisible by the group_size (either 64 or 128)
Yes (32, 64 or 128 are supported)
and quantized_matmul was with one non-quantized array
Yes.
Can matrix multiplication be performed with two quantized arrays?
Not yet
I think mlx.matmul only supports floating point types, will this method allow for integer types too at some point?
I think so but it's not been a top priority.
I think you are probably right that quantization might be difficult to support across platforms. Quant formats and options are quite diverse, there isn't a standard yet.
Comment From: fbadine
New release of MLX (0.23.0) with mx.float64 support for CPU, non square QR factorisation among the introduced features.
https://github.com/ml-explore/mlx/releases/tag/v0.23.0
Comment From: fbadine
An issue was raised with MLX team regarding an issue with solve_triangular (https://github.com/ml-explore/mlx/issues/1871) A new PR is raised in MLX to tackle this issue https://github.com/ml-explore/mlx/pull/1876. When a new MLX is released that includes the fix, I will add the solve_triangular support. On the other hand solve works fine and will add it in a PR soon.
Comment From: fbadine
@awni is there any plan to support lu_factor for a non-square array?
Comment From: awni
No plan but we can do it if it’s useful. Please file an issue
Comment From: fbadine
No plan but we can do it if it’s useful. Please file an issue
This is now implemented in mlx https://github.com/ml-explore/mlx/pull/1889 Once it is released, I will add lu_factor support
Comment From: fbadine
test_argmax_neg and test_argmin_negative_zero are failing in numpy_test.py::NumpyOneInputOpsDynamicShapeTest due to a bug MLX where the values are wrong on GPU. An issue was raised with MLX team.
https://github.com/ml-explore/mlx/issues/1895
Comment From: fbadine
test_argmax_neg and test_argmin_negative_zero are failing in numpy_test.py::NumpyOneInputOpsDynamicShapeTest due to a bug MLX where the values are wrong on GPU. An issue was raised with MLX team.
https://github.com/ml-explore/mlx/issues/1895
This turned out to be normal behaviour on metal with no flag to disable it. We need to skip those tests for MLX. I will do so in my next PR.
Comment From: fbadine
@awni is there any plan to implement eig function of an arbitrary matrix as suggested in this thread https://github.com/ml-explore/mlx/pull/1334#discussion_r1722291547? I couldn't find any PR for this
Comment From: awni
@fbadine there is no plan either way right now. Since it's a reasonable part of the spec to support that, we'd prioritize based on need. If you need this for the Keras back-end, please file an issue in MLX with more details on what you need supported and we'll look into it.
Comment From: fbadine
Yes, it's needed for the mlx backend on Keras. I will file an issue in mlx Thanks
Comment From: acsweet
With the latest PR we have 100% tests passing locally!
A few caveats: - no quantization support implemented yet - no distributed training implemented yet - a few linalg ops are wrapped around numpy for now
The main priorities now are testing Keras Hub and getting the github CI tests to pass. There currently seem to be some compile issues, but I think the linux build is a few releases behind. Going to try building locally in a similar environment to the CI tests and report any issues.
Comment From: awni
Wow!! Awesome work!
There currently seem to be some compile issues, but I think the linux build is a few releases behind.
True. Let me know if you need an update there.
Comment From: acsweet
@awni yes please!
I was able to build the current mlx release against the latest ubuntu and run a simple training test. The compile errors disappeared. Going to run the full suite of tests and see if any other cpu/linux specific errors crop up.
Comment From: awni
@awni yes please!
Linux releases are being built now. Should be on PyPi within ~10-20 mins.
Comment From: fbadine
MLX support for the calculation of eigenvalues and eigenvectors of a square matrix was just added. https://github.com/ml-explore/mlx/pull/2188 Once it is released, I will change the eig function in linalg to use mlx's implementation. Thanks @awni
Comment From: acsweet
After the next MLX tagged release we should be passing all tests except one. Calling inspect.getfullargspec()
on a callable object wrapped in mx.checkpoint
for rematerialization raises an error in python 3.10. It does not in python 3.11. In Keras this is used in the keras/src/backend/common/remat_test.py::RematTest::test_remat_basic_call
.
It seems like an edge case, and would add complexity to python/src/mlx_func.cpp
in MLX to support this. I'm thinking of skipping the test for MLX and python 3.10 (or swapping to a Dense layer), and maybe raising a warning to use python 3.11 if that case ever arises. Even in 3.11 it doesn't raise an error but also doesn't grab the wrapped callable inputs. Any feedback would be appreciated.
With just MLX a minimal example would be something like this (and test in python 3.10 and 3.11):
import mlx.core as mx
import inspect
def computeExpSum(inputs, training=True):
return mx.sum(mx.exp(inputs))
decorated_func = mx.checkpoint(computeExpSum)
try:
print('inspect.signature', inspect.signature(decorated_func))
except Exception as e:
print(f"Exception (inspect.signature): {e}")
try:
print('inspect.getfullargspec', inspect.getfullargspec(decorated_func).args)
except Exception as e:
print(f"Exception (inspect.getfullargspec): {e}")
class CustomClass():
def __init__(self, w):
self.w = mx.array(w)
def __call__(self, x, training=True):
return mx.matmul(w, x)
w = mx.random.normal(shape=(10, 10))
decorated_custom_class = mx.checkpoint(CustomClass(w))
try:
print('inspect.signature (class object)', inspect.signature(decorated_custom_class))
except Exception as e:
print(f"Exception (inspect.signature class object): {e}")
try:
print('inspect.getfullargspec (class object)', inspect.getfullargspec(decorated_custom_class))
except Exception as e:
print(f"Exception (inspect.getfullargspec class object): {e}")
Also, I'm not even sure the getfullargspec
returns what it's expected to for a checkpointed function/callable with other backends too (i.e. to check if "training"
or "mask"
are in the inputs). I tested with jax, and inspect.signature
seemed to work but not inspect.getfullargspec
.
Comment From: awni
@acsweet inspect.signature
will follow the __wrapped__
attribute if set. But inspect.getfullargspec
does not. So one fix I'm playing around with is to set __wrapped__
in mlx_func
. Then inspect.signature
would work. But inspect.getfullargspec
still raises. I think making getfullargspec
work is probably not worth it... are you able to work around that?
Also is it useful to have inspect.signature
work?
Comment From: acsweet
@awni I think we can work around it. And I was going to test if it's actually being applied properly with the other backends (or just happens to not be raising an error).
@fchollet I think this is a pretty narrow case, remat around a lambda layer and checking if the lambda wrapped function has training or mask arguments. I can try testing with inspect.signature
in the Lambda
layer (as opposed to insepct.getfullargspec
), swapping the remat test to use a custom subclassed layer, or something else. Any thoughts on this?
Comment From: acsweet
Sorry @awni, could we get another linux build of mlx for the latest release? I'm getting some cpp errors with 0.26.1
e.g.
error: redeclaration of C++ built-in type '_Float128' [-fpermissive]
When I build it locally it's working fine. (I'm testing on Ubuntu 24.04.2 LTS)
Comment From: awni
@acsweet that's odd, we can definitely put out a new release of the linux build soon but I don't know if doing another build will solve that issue. There might be something that needs to be fixed first.
Comment From: awni
@acsweet Would you mind filing an issue on MLX so we make sure to investigate that?
Comment From: acsweet
Sorry @awni, it was an issue with my python environment (and VS Code). I tried to recreate the issue this morning and it was all working 🙃
The terminal in my editor was displaying a different python environment (e.g. via starship) than the one it was using for code execution. The error occurs when executing with mlx==0.25.2
, but the latest version is fine. I can still file an issue with the older version if you'd like.
Comment From: awni
If it's working with the latest version then that's great.. no need to file an issue. Thanks!