Description
Hi,
I am trying to build a custom keras model subclassing the keras.Model
.
It seems that if I add custom kwargs to the init function, this breaks the mechanism of automtically calling the .get_config()
of internal layers (both keras official and custom layers).
EDIT
going deeper I found the code block that skips the 'get_config()' from Functional class
/opt/conda/lib/python3.10/site-packages/keras/src/models/functional.py: functional_like_constructor
At this point, reading the comment in the 'return', I am asking my self if I am trying something that dosn't agree with keras principles... Anyway: I would like to preserve the 'Functional' approach, while allowing to pass custom kwargs params to the init of my sublcassed model. I am not creating a model different from the 'Pure Functional': just having custom kwargs in the init. Is there any way to preserve the automatic 'get_config' logic while subclassing with custom kwargs or a clean way to follow ?
Distruption
This breaks, for example, model saving / loading as the params needed for the re-creation of custom layers are not preserved and passed during model loading.
How to reproduce
If we leave the original init signature we obtain layers get_config()
class OriginSignature(keras.Model):
def __init__(
self,
**kwargs
):
inputs = keras.Input((1,))
outputs = keras.layers.Dense(1)(inputs)
d = dict(
inputs=inputs, outputs=outputs
)
super().__init__(**d, **kwargs)
origin_signature_model = OriginSignature()
pprint(origin_signature_model.get_config())
{'input_layers': [['input_layer_94', 0, 0]],
'layers': [{'class_name': 'InputLayer',
'config': {'batch_shape': (None, 1),
'dtype': 'float32',
'name': 'input_layer_94',
'ragged': False,
'sparse': False},
'inbound_nodes': [],
'module': 'keras.layers',
'name': 'input_layer_94',
'registered_name': None},
{'build_config': {'input_shape': (None, 1)},
'class_name': 'Dense',
'config': {'activation': 'linear',
'bias_constraint': None,
'bias_initializer': {'class_name': 'Zeros',
'config': {},
'module': 'keras.initializers',
'registered_name': None},
'bias_regularizer': None,
'dtype': {'class_name': 'DTypePolicy',
'config': {'name': 'float32'},
'module': 'keras',
'registered_name': None},
'kernel_constraint': None,
'kernel_initializer': {'class_name': 'GlorotUniform',
'config': {'seed': None},
'module': 'keras.initializers',
'registered_name': None},
'kernel_regularizer': None,
'name': 'dense_105',
'trainable': True,
'units': 1,
'use_bias': True},
'inbound_nodes': [{'args': ({'class_name': '__keras_tensor__',
'config': {'dtype': 'float32',
'keras_history': ['input_layer_94',
0,
0],
'shape': (None, 1)}},),
'kwargs': {}}],
'module': 'keras.layers',
'name': 'dense_105',
'registered_name': None}],
'name': 'origin_signature_6',
'output_layers': [['dense_105', 0, 0]],
'trainable': True}
{'dtype': {'class_name': 'DTypePolicy',
'config': {'name': 'float32'},
'module': 'keras',
'registered_name': None},
'name': 'modified_signature_5',
'trainable': True}
while if we add a 'this_breaks_get_config_mechanism' custom param to the init functions we lose this mechanism
class ModifiedSignature(keras.Model):
def __init__(
self,
this_breaks_get_config_mechanism=1,
**kwargs
):
self.this_breaks_get_config_mechanism=1
inputs = keras.Input((1,))
outputs = keras.layers.Dense(1)(inputs)
d = dict(
inputs=inputs, outputs=outputs
)
super().__init__(**d, **kwargs)
modified_init_signature_model = ModifiedSignature()
pprint(modified_init_signature_model.get_config())
{'dtype': {'class_name': 'DTypePolicy',
'config': {'name': 'float32'},
'module': 'keras',
'registered_name': None},
'name': 'modified_signature_6',
'trainable': True}
System info
keras 3.9.2
Comment From: sonali-kumari1
Hi @smartArancina -
Thanks for reporting this!
I have tested the code with latest version of keras(3.11.1) in this gist. This issue likely arises from the functional_like_constructor check, which compares your model's init
arguments with functional model(init_args == functional_init_args
). Since you have added this_breaks_get_config_mechanism=1
in your model's init
, functional_like_constructor
returns False, causing the behavior you reported. Please refer to this guide on Customizing Saving and Serialization.