The output of ops.prod
doesn't work in layer's argument. To make it work, it's required to cast, i.e. int(ops.prod)
. Now, still its an issue to use ops.prod
in call
method anyway.
import keras
from keras import layers, ops
import numpy as np
class ProdDenseLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shape):
self.input_prod = ops.prod(input_shape[1:])
self.dense = layers.Dense(self.input_prod, activation='relu')
self.dense.build(input_shape)
def call(self, inputs):
scale_factor = ops.prod(ops.shape(inputs)[1:])
scaled_inputs = inputs * ops.cast(scale_factor, inputs.dtype)
return self.dense(scaled_inputs)
# Main model using ops.prod
class ProdModel(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.prod_layer1 = ProdDenseLayer()
def build(self, input_shape):
self.prod_layer1.build(input_shape)
def call(self, inputs):
batch_size = ops.shape(inputs)[0]
total_elements = ops.prod(ops.shape(inputs)[1:])
normalized_inputs = inputs / ops.cast(total_elements, inputs.dtype)
x = self.prod_layer1(normalized_inputs)
return x
# Create dummy data
batch_size = 32
input_shape = (10,)
X_train = np.random.randn(batch_size * 10, *input_shape).astype(np.float32)
y_train = np.random.randint(0, 2, (batch_size * 10, 1)).astype(np.float32)
model = ProdModel()
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy']
)
history = model.fit(
X_train, y_train,
epochs=3,
batch_size=batch_size,
validation_split=0.2,
verbose=1
)
model.summary()
Epoch 1/3
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_36/1432196851.py in <cell line: 0>()
45 metrics=['accuracy']
46 )
---> 47 history = model.fit(
48 X_train, y_train,
49 epochs=3,
/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
/tmp/ipykernel_36/1432196851.py in build(self, input_shape)
25
26 def build(self, input_shape):
---> 27 self.prod_layer1.build(input_shape)
28
29 def call(self, inputs):
/tmp/ipykernel_36/1432196851.py in build(self, input_shape)
11 self.input_prod = ops.prod(input_shape[1:])
12 self.dense = layers.Dense(self.input_prod, activation='relu')
---> 13 self.dense.build(input_shape)
14
15 def call(self, inputs):
ValueError: Invalid dtype: <property object at 0x79c67d66d3f0>
Comment From: SamanehSaadat
Hi @innat
I made two changes to your code that made it work:
1. As you mentioned, I cast the output of np.prod
to int
.
2. I added self.built=True
to the custom layer and model and built the model.
Here is the working colab.
I think it's reasonable that the output of op.prod
is float32
. Right?
Comment From: innat
First of all, setting built=1
is not needed and irrelevant here.
About casting ops.prod
to int
- in that case I can just use np.prod
instead of ops.prod
. The behaviour of them are not same.
The issue with ops.prod
will also hit when you use it in call
method with ops.reshape
or layers.Reshape
methods.
Comment From: SamanehSaadat
Could you explain why you need to use ops.prod
here?
Comment From: innat
@SamanehSaadat To reproduce the error.
Comment From: SamanehSaadat
@SamanehSaadat To reproduce the error.
I mean what's your usecase that you need to use ops.prod
and you can't use something like math.prod
?
I'm trying to understand the problem better.
The output of
ops.prod
doesn't work in layer's argument.
Isn't this because its type is float
? Are there other reasons?
The issue with
ops.prod
will also hit when you use it incall
method withops.reshape
orlayers.Reshape
methods.
What's the issue you're seeing with the call
method and why do you think the root cause of the issue is ops.prod
?
Comment From: github-actions[bot]
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
Comment From: georgeneedles60-hub
ok thanks for the update