Whenever we try to save a keras-hub model after quantization, we are unable to load the quantized model. I've tried from_preset() method for that model, and also keras.models.load_model nothing works.

I've attached notebook https://colab.research.google.com/gist/pctablet505/b5ef8ab36dceb58527e992b571aefb70/keras-quantized-model-not-loading.ipynb

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-5-3241078411>](https://localhost:8080/#) in <cell line: 0>()
      8 import keras_hub
      9 
---> 10 gemma_lm_quantized = keras_hub.models.Gemma3CausalLM.from_preset("stored_gemma_int8")
     11 gemma_lm_quantized.generate("hello, what is your name?")

4 frames
[/content/keras_hub_repo/keras_hub/src/models/task.py](https://localhost:8080/#) in from_preset(cls, preset, load_weights, **kwargs)
    196         # images, audio).
    197         load_task_weights = "num_classes" not in kwargs
--> 198         return loader.load_task(cls, load_weights, load_task_weights, **kwargs)
    199 
    200     def load_task_weights(self, filepath):

[/content/keras_hub_repo/keras_hub/src/utils/preset_utils.py](https://localhost:8080/#) in load_task(self, cls, load_weights, load_task_weights, **kwargs)
    701             else:
    702                 jax_memory_cleanup(task.backbone)
--> 703             self._load_backbone_weights(task.backbone)
    704         return task
    705 

[/content/keras_hub_repo/keras_hub/src/utils/preset_utils.py](https://localhost:8080/#) in _load_backbone_weights(self, backbone)
    754                 # Download the sharded weights.
    755                 _ = get_file(self.preset, sharded_filename)
--> 756         backbone.load_weights(filepath)
    757 
    758 

[/content/keras_repo/keras/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
    120             # To get the full stack trace, call:
    121             # `keras.config.disable_traceback_filtering()`
--> 122             raise e.with_traceback(filtered_tb) from None
    123         finally:
    124             del filtered_tb

[/content/keras_repo/keras/src/saving/saving_lib.py](https://localhost:8080/#) in _raise_loading_failure(error_msgs, warn_only)
    648         warnings.warn(msg)
    649     else:
--> 650         raise ValueError(msg)
    651 
    652 

ValueError: A total of 183 objects could not be loaded. Example error message for object <ReversibleEmbedding name=token_embedding, built=True>:

Layer 'token_embedding' expected 1 variables, but received 0 variables during loading. Expected: ['embeddings']

List of objects that could not be loaded:
[<ReversibleEmbedding name=token_embedding, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gat...

I get similar error for other models like LLama, or others. I've tested on gemma2, gemma3, llama3.1, lllama3.2, llama3 and more.

Comment From: pctablet505

There is a script in keras-hub/tools /quantize_checkpoints.py for quantizing the checkpoints, it works for text only model, but still fails for multimodal like gemma3.

Comment From: pctablet505

@mattdangerw I tried the save and load for quantized model created using keras directly. it works perfectly.

import keras
import keras_hub
from keras import layers

inputs=layers.Input(shape=(None,))
x=layers.Embedding(input_dim=100000,output_dim=1000)(inputs)
x=layers.Dense(100)(x)
x=layers.Dense(100)(x)
x=layers.Dense(100)(x)
model=keras.Model(inputs=inputs,outputs=x)

model.quantize(mode='int8')
model.save("quantized_model.keras")

quantized_model=keras.models.load_model('quantized_model.keras')

quantized_model.summary()

This indicates that the problem exists in the way we are handling saving and loading for keras_hub models