As mentioned keras-team/keras-core#843, I was trying further modify the mlm-bert pipeline for keras-core with backend-agnostic support. The last part of the example tries to create an end-to-end pipeline with raw texts as the input to the model. Mentioned as follows:
def get_end_to_end(model):
inputs_string = keras.Input(shape=(1,), dtype="string")
indices = vectorize_layer(inputs_string)
outputs = model(indices)
end_to_end_model = keras.Model(inputs_string, outputs, name="end_to_end_model")
optimizer = keras.optimizers.Adam(learning_rate=config.LR)
end_to_end_model.compile(
optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
)
return end_to_end_model
end_to_end_classification_model = get_end_to_end(classifer_model)
end_to_end_classification_model.evaluate(test_raw_classifier_ds)
But while executing with jax
backend, it throws the following error:
ValueError Traceback (most recent call last)
Cell In[10], line 14
10 return end_to_end_model
13 end_to_end_classification_model = get_end_to_end(classifer_model)
---> 14 end_to_end_classification_model.evaluate(test_raw_classifier_ds)
File [~/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/keras_core/src/utils/traceback_utils.py:123](https://file+.vscode-resource.vscode-cdn.net/Users/mrutyunjaybiswal/Documents/oss/tfjax/keras-core-examples/keras-core-port-nb/nlp/~/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/keras_core/src/utils/traceback_utils.py:123), in filter_traceback..error_handler(*args, **kwargs)
120 filtered_tb = _process_traceback_frames(e.__traceback__)
121 # To get the full stack trace, call:
122 # `keras_core.config.disable_traceback_filtering()`
--> 123 raise e.with_traceback(filtered_tb) from None
124 finally:
125 del filtered_tb
File [~/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/tree/__init__.py:435](https://file+.vscode-resource.vscode-cdn.net/Users/mrutyunjaybiswal/Documents/oss/tfjax/keras-core-examples/keras-core-port-nb/nlp/~/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/tree/__init__.py:435), in map_structure(func, *structures, **kwargs)
432 for other in structures[1:]:
433 assert_same_structure(structures[0], other, check_types=check_types)
434 return unflatten_as(structures[0],
--> 435 [func(*args) for args in zip(*map(flatten, structures))])
File [~/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/tree/__init__.py:435](https://file+.vscode-resource.vscode-cdn.net/Users/mrutyunjaybiswal/Documents/oss/tfjax/keras-core-examples/keras-core-port-nb/nlp/~/Documents/oss/tfjax/keras-core/kscore/lib/python3.10/site-packages/tree/__init__.py:435), in (.0)
432 for other in structures[1:]:
433 assert_same_structure(structures[0], other, check_types=check_types)
434 return unflatten_as(structures[0],
--> 435 [func(*args) for args in zip(*map(flatten, structures))])
ValueError: Invalid dtype: object
Is having the input layer of dtype `string' might cause the problem here?
Comment From: fchollet
The dtype string
is only supported with the TF backend. So there's another issue here, which is that the error message is not clear. It should just say that dtype string isn't supported with the JAX backend.
Comment From: sachinprasadhs
@Mrutyunjay01, Could you please provide sample reproducible code, the tutorial you have mentioned has other error to solve before we get to the reported issue part.
Comment From: Mrutyunjay01
Apologies for late response. I faced the issue when tried to port BERT MLM to backend agnostic. Been a while, let me refine the port code once, and see if the issue persists still.
cc: @sachinprasadhs
Comment From: Mrutyunjay01
Update:
The issue no longer persists in keras 3.3. So, we can close this. Currently I am trying to port BERT MLM to backend agnostic, will report if I come across any such issues.
Comment From: google-ml-butler[bot]
Are you satisfied with the resolution of your issue? Yes No
Comment From: Mrutyunjay01
Reopening the issue as faced in the draft PR mentioned.
Comment From: sonali-kumari1
Hi @Mrutyunjay01 - Could you please share a minimal reproducible code snippet to check if this issue still persists in the latest version of Keras(3.11.2)? Thanks!
Comment From: github-actions[bot]
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.