For the tensorflow backend, keras.ops.correlate
does not match numpy outputs in mode='same' and mode='full'. The other backends and mode='valid' work fine.
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
from keras import ops
x = np.array([1, 3, 5])
y = np.array([7, 9])
mode = "full"
print("Mode: ", mode)
print("numpy: ", np.correlate(x, y, mode=mode))
print("tensorflow: ", ops.convert_to_numpy(ops.correlate(x, y, mode=mode)))
mode = "same"
print("Mode: ", mode)
print("numpy: ", np.correlate(x, y, mode=mode))
print("tensorflow: ", ops.convert_to_numpy(ops.correlate(x, y, mode=mode)))
mode = "valid"
print("Mode: ", mode)
print("numpy: ", np.correlate(x, y, mode=mode))
print("tensorflow: ", ops.convert_to_numpy(ops.correlate(x, y, mode=mode)))
Mode: full
numpy: [ 9 34 66 35]
tensorflow: [34. 66. 35. 0.]
Mode: same
numpy: [ 9 34 66]
tensorflow: [34. 66. 35.]
Mode: valid
numpy: [34 66]
tensorflow: [34. 66.]
Comment From: wesselvannierop
Besides, for jax, jnp.correlate
supports complex values, while in torch and tensorflow it seems to cast to float.
I have an implementation that fixes both issues:
def correlate(x, y, mode="full"):
"""
Complex correlation via splitting real and imaginary parts.
Equivalent to np.correlate(x, y, mode).
NOTE: this function exists because tensorflow does not support complex correlation.
NOTE: tensorflow also handles padding differently than numpy, so we manually pad the input.
Args:
x: np.ndarray (complex or real)
y: np.ndarray (complex or real)
mode: "full", "valid", or "same"
"""
x = ops.convert_to_tensor(x)
y = ops.convert_to_tensor(y)
is_complex = "complex" in ops.dtype(x) or "complex" in ops.dtype(y)
# Split into real and imaginary
xr, xi = ops.real(x), ops.imag(x)
yr, yi = ops.real(y), ops.imag(y)
# Pad to do full correlation
pad_left = ops.shape(y)[0] - 1
pad_right = ops.shape(y)[0] - 1
xr = ops.pad(xr, [[pad_left, pad_right]])
xi = ops.pad(xi, [[pad_left, pad_right]])
# Correlation: sum over x[n] * conj(y[n+k])
rr = ops.correlate(xr, yr, mode="valid")
ii = ops.correlate(xi, yi, mode="valid")
ri = ops.correlate(xr, yi, mode="valid")
ir = ops.correlate(xi, yr, mode="valid")
real_part = rr + ii
imag_part = ir - ri
real_part = ops.cast(real_part, "complex64")
imag_part = ops.cast(imag_part, "complex64")
complex_tensor = real_part + 1j * imag_part
# Extract relevant part based on mode
full_length = ops.shape(real_part)[0]
x_len = ops.shape(x)[0]
y_len = ops.shape(y)[0]
if mode == "same":
# Return output of length max(M, N)
target_len = ops.maximum(x_len, y_len)
start = ops.floor((full_length - target_len) / 2)
start = ops.cast(start, "int32")
end = start + target_len
complex_tensor = complex_tensor[start:end]
elif mode == "valid":
# Return output of length max(M, N) - min(M, N) + 1
target_len = ops.maximum(x_len, y_len) - ops.minimum(x_len, y_len) + 1
start = ops.ceil((full_length - target_len) / 2)
start = ops.cast(start, "int32")
end = start + target_len
complex_tensor = complex_tensor[start:end]
# For "full" mode, use the entire result (no slicing needed)
if is_complex:
return complex_tensor
else:
return ops.real(complex_tensor)
Mode: full
numpy: [ 9 34 66 35]
custom: [ 9. 34. 66. 35.]
Mode: same
numpy: [ 9 34 66]
custom: [ 9. 34. 66.]
Mode: valid
numpy: [34 66]
custom: [34. 66.]
Comment From: mehtamansi29
Hi @wesselvannierop - Please feel free to raise the PR regarding fixes the issue.