For the tensorflow backend, keras.ops.correlate does not match numpy outputs in mode='same' and mode='full'. The other backends and mode='valid' work fine.

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import numpy as np
from keras import ops

x = np.array([1, 3, 5])
y = np.array([7, 9])

mode = "full"
print("Mode: ", mode)
print("numpy: ", np.correlate(x, y, mode=mode))
print("tensorflow: ", ops.convert_to_numpy(ops.correlate(x, y, mode=mode)))

mode = "same"
print("Mode: ", mode)
print("numpy: ", np.correlate(x, y, mode=mode))
print("tensorflow: ", ops.convert_to_numpy(ops.correlate(x, y, mode=mode)))

mode = "valid"
print("Mode: ", mode)
print("numpy: ", np.correlate(x, y, mode=mode))
print("tensorflow: ", ops.convert_to_numpy(ops.correlate(x, y, mode=mode)))
Mode:  full
numpy:  [ 9 34 66 35]
tensorflow:  [34. 66. 35.  0.]
Mode:  same
numpy:  [ 9 34 66]
tensorflow:  [34. 66. 35.]
Mode:  valid
numpy:  [34 66]
tensorflow:  [34. 66.]

Comment From: wesselvannierop

Besides, for jax, jnp.correlate supports complex values, while in torch and tensorflow it seems to cast to float.

I have an implementation that fixes both issues:

def correlate(x, y, mode="full"):
    """
    Complex correlation via splitting real and imaginary parts.
    Equivalent to np.correlate(x, y, mode).

    NOTE: this function exists because tensorflow does not support complex correlation.
    NOTE: tensorflow also handles padding differently than numpy, so we manually pad the input.

    Args:
        x: np.ndarray (complex or real)
        y: np.ndarray (complex or real)
        mode: "full", "valid", or "same"
    """
    x = ops.convert_to_tensor(x)
    y = ops.convert_to_tensor(y)

    is_complex = "complex" in ops.dtype(x) or "complex" in ops.dtype(y)

    # Split into real and imaginary
    xr, xi = ops.real(x), ops.imag(x)
    yr, yi = ops.real(y), ops.imag(y)

    # Pad to do full correlation
    pad_left = ops.shape(y)[0] - 1
    pad_right = ops.shape(y)[0] - 1
    xr = ops.pad(xr, [[pad_left, pad_right]])
    xi = ops.pad(xi, [[pad_left, pad_right]])

    # Correlation: sum over x[n] * conj(y[n+k])
    rr = ops.correlate(xr, yr, mode="valid")
    ii = ops.correlate(xi, yi, mode="valid")
    ri = ops.correlate(xr, yi, mode="valid")
    ir = ops.correlate(xi, yr, mode="valid")

    real_part = rr + ii
    imag_part = ir - ri

    real_part = ops.cast(real_part, "complex64")
    imag_part = ops.cast(imag_part, "complex64")

    complex_tensor = real_part + 1j * imag_part

    # Extract relevant part based on mode
    full_length = ops.shape(real_part)[0]
    x_len = ops.shape(x)[0]
    y_len = ops.shape(y)[0]

    if mode == "same":
        # Return output of length max(M, N)
        target_len = ops.maximum(x_len, y_len)
        start = ops.floor((full_length - target_len) / 2)
        start = ops.cast(start, "int32")
        end = start + target_len
        complex_tensor = complex_tensor[start:end]
    elif mode == "valid":
        # Return output of length max(M, N) - min(M, N) + 1
        target_len = ops.maximum(x_len, y_len) - ops.minimum(x_len, y_len) + 1
        start = ops.ceil((full_length - target_len) / 2)
        start = ops.cast(start, "int32")
        end = start + target_len
        complex_tensor = complex_tensor[start:end]
    # For "full" mode, use the entire result (no slicing needed)

    if is_complex:
        return complex_tensor
    else:
        return ops.real(complex_tensor)
Mode:  full
numpy:  [ 9 34 66 35]
custom:  [ 9. 34. 66. 35.]
Mode:  same
numpy:  [ 9 34 66]
custom:  [ 9. 34. 66.]
Mode:  valid
numpy:  [34 66]
custom:  [34. 66.]

Comment From: mehtamansi29

Hi @wesselvannierop - Please feel free to raise the PR regarding fixes the issue.