I define an EnsembleModel
class that is constructed from a list of other Keras models.
class EnsembleModel(keras.Model):
def __init__(
self,
models: Iterable[keras.Model],
reduce_fn: Callable = keras.ops.mean,
**kwargs):
super(EnsembleModel, self).__init__(**kwargs)
self.models = models
# self.model0 = models[0]
# self.model1 = models[1]
self.reduce_fn = reduce_fn
@tf.function(input_signature=[input_signature])
def call(
self,
input: Dict[Text, Any]) -> Any:
all_outputs = [keras.ops.reshape(model(input), newshape=(-1,)) for model in self.models]
output = self.reduce_fn(all_outputs, axis=0)
return output
averaging_model = EnsembleModel(models=[model0, model1])
I then wish to export the ensemble model:
averaging_model.export("export/1/", input_signature=[input_signature])
But I get an error on the export:
AssertionError: Tried to export a function which references an 'untracked' resource. TensorFlow objects (e.g.
tf.Variable) captured by functions must be 'tracked' by assigning them to an attribute of a tracked object or
assigned to an attribute of the main object directly. See the information below:
Function name = b'__inference_signature_wrapper___call___10899653'
Captured Tensor = <ResourceHandle(name="10671455", device="/job:localhost/replica:0/task:0/device:CPU:0",
container="localhost", type="tensorflow::lookup::LookupInterface", dtype and shapes : "[ ]")>
Trackable referencing this tensor = <tensorflow.python.ops.lookup_ops.StaticHashTable object at
0x7fd62d126990>
Internal Tensor = Tensor("10899255:0", shape=(), dtype=resource)
If I explicitly assign the models to variables in the constructor:
self.model0 = models[0]
self.model1 = models[1]
It works fine (even if I don't reference those variables anywhere else). But I want an instance of the EnsembleModel
class to support an arbitrary list of models. How can I ensure the models are "tracked" so that I don't get an error on export?
Comment From: rlcauvin
Adding this kludge to the EnsembleModel
constructor seems to have worked to "track" the models in the list and avoid export errors:
# Register each model so that it is "tracked" for export.
for i, model in enumerate(self.models):
self.__setattr__(f"model_{i}", model)
Using __setattr__
dynamically assigns each model in the collection to an attribute of the EnsembleModel
instance so that it is tracked.
Is there a better way?
Comment From: rlcauvin
Closing this issue, as I am satisfied with the kludge mentioned in my last comment. But I welcome any comments on the kludge or any alternate suggestions.
Comment From: google-ml-butler[bot]
Are you satisfied with the resolution of your issue? Yes No
Comment From: edwardyehuang
@fchollet @james77777778
Is there a more effective approach to making a list of sublayers trackable in Keras 3? This issue does not occur in Keras 2 and may underlie numerous other weight-related issues in Keras 3.
@mehtamansi29 @rlcauvin Could you please reopen this issue temporarily?
Adjust ENABLE_SET_ATTR
below to see the effect.
import tensorflow as tf
import keras
ENABLE_SET_ATTR = True
class SampleModel(keras.Model):
def __init__(self, num_blocks):
super().__init__()
self.num_blocks = num_blocks
def build(self, input_shape):
self.blocks = []
self.first_block = keras.layers.Dense(10, activation='relu', name="first.block")
for i in range(self.num_blocks):
block = keras.layers.Dense(10, activation='relu', name=f"block.{i}.dense")
self.blocks.append(block)
if ENABLE_SET_ATTR:
setattr(self, f"block_{i}", block)
self.last_block = keras.layers.Dense(10, activation='relu', name="last.block")
super().build(input_shape)
def call(self, inputs):
x = inputs
x = self.first_block(x)
for block in self.blocks:
x = block(x)
x = self.last_block(x)
return x
if __name__ == "__main__":
model = SampleModel(num_blocks=3)
model(tf.zeros((1, 224, 224, 3)))
sub_items = tf.train.TrackableView(model).children(model).items()
for item in sub_items:
print(item)