Description

Hi, I am trying to build a custom keras model subclassing the keras.Model. It seems that if I add custom kwargs to the init function, this breaks the mechanism of automtically calling the .get_config() of internal layers (both keras official and custom layers).

EDIT

going deeper I found the code block that skips the 'get_config()' from Functional class /opt/conda/lib/python3.10/site-packages/keras/src/models/functional.py: functional_like_constructor

Image

At this point, reading the comment in the 'return', I am asking my self if I am trying something that dosn't agree with keras principles... Anyway: I would like to preserve the 'Functional' approach, while allowing to pass custom kwargs params to the init of my sublcassed model. I am not creating a model different from the 'Pure Functional': just having custom kwargs in the init. Is there any way to preserve the automatic 'get_config' logic while subclassing with custom kwargs or a clean way to follow ?

Distruption

This breaks, for example, model saving / loading as the params needed for the re-creation of custom layers are not preserved and passed during model loading.

How to reproduce

If we leave the original init signature we obtain layers get_config()

class OriginSignature(keras.Model):

    def __init__(
        self,
        **kwargs
        ):
        inputs = keras.Input((1,))
        outputs = keras.layers.Dense(1)(inputs)
        d = dict(
            inputs=inputs, outputs=outputs
        )
        super().__init__(**d, **kwargs)

origin_signature_model = OriginSignature()
pprint(origin_signature_model.get_config())
{'input_layers': [['input_layer_94', 0, 0]],
 'layers': [{'class_name': 'InputLayer',
             'config': {'batch_shape': (None, 1),
                        'dtype': 'float32',
                        'name': 'input_layer_94',
                        'ragged': False,
                        'sparse': False},
             'inbound_nodes': [],
             'module': 'keras.layers',
             'name': 'input_layer_94',
             'registered_name': None},
            {'build_config': {'input_shape': (None, 1)},
             'class_name': 'Dense',
             'config': {'activation': 'linear',
                        'bias_constraint': None,
                        'bias_initializer': {'class_name': 'Zeros',
                                             'config': {},
                                             'module': 'keras.initializers',
                                             'registered_name': None},
                        'bias_regularizer': None,
                        'dtype': {'class_name': 'DTypePolicy',
                                  'config': {'name': 'float32'},
                                  'module': 'keras',
                                  'registered_name': None},
                        'kernel_constraint': None,
                        'kernel_initializer': {'class_name': 'GlorotUniform',
                                               'config': {'seed': None},
                                               'module': 'keras.initializers',
                                               'registered_name': None},
                        'kernel_regularizer': None,
                        'name': 'dense_105',
                        'trainable': True,
                        'units': 1,
                        'use_bias': True},
             'inbound_nodes': [{'args': ({'class_name': '__keras_tensor__',
                                          'config': {'dtype': 'float32',
                                                     'keras_history': ['input_layer_94',
                                                                       0,
                                                                       0],
                                                     'shape': (None, 1)}},),
                                'kwargs': {}}],
             'module': 'keras.layers',
             'name': 'dense_105',
             'registered_name': None}],
 'name': 'origin_signature_6',
 'output_layers': [['dense_105', 0, 0]],
 'trainable': True}
{'dtype': {'class_name': 'DTypePolicy',
           'config': {'name': 'float32'},
           'module': 'keras',
           'registered_name': None},
 'name': 'modified_signature_5',
 'trainable': True}

while if we add a 'this_breaks_get_config_mechanism' custom param to the init functions we lose this mechanism

class ModifiedSignature(keras.Model):
    def __init__(
        self,
        this_breaks_get_config_mechanism=1,
        **kwargs
        ):
        self.this_breaks_get_config_mechanism=1
        inputs = keras.Input((1,))
        outputs = keras.layers.Dense(1)(inputs)
        d = dict(
            inputs=inputs, outputs=outputs
        )
        super().__init__(**d, **kwargs)
modified_init_signature_model = ModifiedSignature()
pprint(modified_init_signature_model.get_config())
{'dtype': {'class_name': 'DTypePolicy',
           'config': {'name': 'float32'},
           'module': 'keras',
           'registered_name': None},
 'name': 'modified_signature_6',
 'trainable': True}

System info

keras 3.9.2

Comment From: sonali-kumari1

Hi @smartArancina - Thanks for reporting this! I have tested the code with latest version of keras(3.11.1) in this gist. This issue likely arises from the functional_like_constructor check, which compares your model's init arguments with functional model(init_args == functional_init_args). Since you have added this_breaks_get_config_mechanism=1 in your model's init, functional_like_constructor returns False, causing the behavior you reported. Please refer to this guide on Customizing Saving and Serialization.