Whenever we try to save a keras-hub model after quantization, we are unable to load the quantized model. I've tried from_preset()
method for that model, and also keras.models.load_model
nothing works.
I've attached notebook https://colab.research.google.com/gist/pctablet505/b5ef8ab36dceb58527e992b571aefb70/keras-quantized-model-not-loading.ipynb
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-5-3241078411>](https://localhost:8080/#) in <cell line: 0>()
8 import keras_hub
9
---> 10 gemma_lm_quantized = keras_hub.models.Gemma3CausalLM.from_preset("stored_gemma_int8")
11 gemma_lm_quantized.generate("hello, what is your name?")
4 frames
[/content/keras_hub_repo/keras_hub/src/models/task.py](https://localhost:8080/#) in from_preset(cls, preset, load_weights, **kwargs)
196 # images, audio).
197 load_task_weights = "num_classes" not in kwargs
--> 198 return loader.load_task(cls, load_weights, load_task_weights, **kwargs)
199
200 def load_task_weights(self, filepath):
[/content/keras_hub_repo/keras_hub/src/utils/preset_utils.py](https://localhost:8080/#) in load_task(self, cls, load_weights, load_task_weights, **kwargs)
701 else:
702 jax_memory_cleanup(task.backbone)
--> 703 self._load_backbone_weights(task.backbone)
704 return task
705
[/content/keras_hub_repo/keras_hub/src/utils/preset_utils.py](https://localhost:8080/#) in _load_backbone_weights(self, backbone)
754 # Download the sharded weights.
755 _ = get_file(self.preset, sharded_filename)
--> 756 backbone.load_weights(filepath)
757
758
[/content/keras_repo/keras/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
120 # To get the full stack trace, call:
121 # `keras.config.disable_traceback_filtering()`
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
[/content/keras_repo/keras/src/saving/saving_lib.py](https://localhost:8080/#) in _raise_loading_failure(error_msgs, warn_only)
648 warnings.warn(msg)
649 else:
--> 650 raise ValueError(msg)
651
652
ValueError: A total of 183 objects could not be loaded. Example error message for object <ReversibleEmbedding name=token_embedding, built=True>:
Layer 'token_embedding' expected 1 variables, but received 0 variables during loading. Expected: ['embeddings']
List of objects that could not be loaded:
[<ReversibleEmbedding name=token_embedding, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gating_2, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <EinsumDense name=ffw_linear, built=True>, <EinsumDense name=ffw_gating, built=True>, <EinsumDense name=ffw_gat...
I get similar error for other models like LLama, or others. I've tested on gemma2, gemma3, llama3.1, lllama3.2, llama3 and more.
Comment From: pctablet505
There is a script in keras-hub/tools /quantize_checkpoints.py for quantizing the checkpoints, it works for text only model, but still fails for multimodal like gemma3.
Comment From: pctablet505
@mattdangerw I tried the save and load for quantized model created using keras directly. it works perfectly.
import keras
import keras_hub
from keras import layers
inputs=layers.Input(shape=(None,))
x=layers.Embedding(input_dim=100000,output_dim=1000)(inputs)
x=layers.Dense(100)(x)
x=layers.Dense(100)(x)
x=layers.Dense(100)(x)
model=keras.Model(inputs=inputs,outputs=x)
model.quantize(mode='int8')
model.save("quantized_model.keras")
quantized_model=keras.models.load_model('quantized_model.keras')
quantized_model.summary()
This indicates that the problem exists in the way we are handling saving and loading for keras_hub
models