I define an EnsembleModel class that is constructed from a list of other Keras models.

class EnsembleModel(keras.Model):

  def __init__(
    self,
    models: Iterable[keras.Model],
    reduce_fn: Callable = keras.ops.mean,
    **kwargs):

    super(EnsembleModel, self).__init__(**kwargs)

    self.models = models
    # self.model0 = models[0]
    # self.model1 = models[1]
    self.reduce_fn = reduce_fn

  @tf.function(input_signature=[input_signature])
  def call(
    self,
    input: Dict[Text, Any]) -> Any:

    all_outputs = [keras.ops.reshape(model(input), newshape=(-1,)) for model in self.models]
    output = self.reduce_fn(all_outputs, axis=0)

    return output

averaging_model = EnsembleModel(models=[model0, model1])

I then wish to export the ensemble model:

averaging_model.export("export/1/", input_signature=[input_signature])

But I get an error on the export:

AssertionError: Tried to export a function which references an 'untracked' resource. TensorFlow objects (e.g. 
tf.Variable) captured by functions must be 'tracked' by assigning them to an attribute of a tracked object or 
assigned to an attribute of the main object directly. See the information below:
        Function name = b'__inference_signature_wrapper___call___10899653'
        Captured Tensor = <ResourceHandle(name="10671455", device="/job:localhost/replica:0/task:0/device:CPU:0", 
container="localhost", type="tensorflow::lookup::LookupInterface", dtype and shapes : "[  ]")>
        Trackable referencing this tensor = <tensorflow.python.ops.lookup_ops.StaticHashTable object at 
0x7fd62d126990>
        Internal Tensor = Tensor("10899255:0", shape=(), dtype=resource)

If I explicitly assign the models to variables in the constructor:

    self.model0 = models[0]
    self.model1 = models[1]

It works fine (even if I don't reference those variables anywhere else). But I want an instance of the EnsembleModel class to support an arbitrary list of models. How can I ensure the models are "tracked" so that I don't get an error on export?

Comment From: rlcauvin

Adding this kludge to the EnsembleModel constructor seems to have worked to "track" the models in the list and avoid export errors:

    # Register each model so that it is "tracked" for export.
    for i, model in enumerate(self.models):
      self.__setattr__(f"model_{i}", model)

Using __setattr__ dynamically assigns each model in the collection to an attribute of the EnsembleModel instance so that it is tracked.

Is there a better way?

Comment From: rlcauvin

Closing this issue, as I am satisfied with the kludge mentioned in my last comment. But I welcome any comments on the kludge or any alternate suggestions.

Comment From: google-ml-butler[bot]

Are you satisfied with the resolution of your issue? Yes No

Comment From: edwardyehuang

@fchollet @james77777778

Is there a more effective approach to making a list of sublayers trackable in Keras 3? This issue does not occur in Keras 2 and may underlie numerous other weight-related issues in Keras 3.

@mehtamansi29 @rlcauvin Could you please reopen this issue temporarily?

Adjust ENABLE_SET_ATTR below to see the effect.

import tensorflow as tf
import keras

ENABLE_SET_ATTR = True

class SampleModel(keras.Model):
    def __init__(self, num_blocks):

        super().__init__()

        self.num_blocks = num_blocks


    def build(self, input_shape):

        self.blocks = []

        self.first_block = keras.layers.Dense(10, activation='relu', name="first.block")

        for i in range(self.num_blocks):

            block = keras.layers.Dense(10, activation='relu', name=f"block.{i}.dense")
            self.blocks.append(block)

            if ENABLE_SET_ATTR:
                setattr(self, f"block_{i}", block)

        self.last_block = keras.layers.Dense(10, activation='relu', name="last.block")

        super().build(input_shape)


    def call(self, inputs):

        x = inputs

        x = self.first_block(x)

        for block in self.blocks:
            x = block(x)

        x = self.last_block(x)

        return x


if __name__ == "__main__":

    model = SampleModel(num_blocks=3)
    model(tf.zeros((1, 224, 224, 3)))

    sub_items = tf.train.TrackableView(model).children(model).items()

    for item in sub_items:
        print(item)