import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["KERAS_NNX_ENABLED"] = "true"
from flax import nnx
from absl import app
import jax
import keras
import numpy as np
import tensorflow as tf
# Your NNX Modules (as before)
class SimpleMLP(nnx.Module):
def __init__(self, din, dhidden, dout, *, rngs):
self.linear1 = nnx.Linear(din, dhidden, rngs=rngs)
self.linear2 = nnx.Linear(dhidden, dout, rngs=rngs)
def __call__(self, x):
x = self.linear1(x)
x = nnx.relu(x)
x = self.linear2(x)
return x
def main(_):
input_dim = 10
# 1. Create a Linen-compatible class using the bridge
# ToLinen for StatefulMLP
LinenMLP = nnx.bridge.ToLinen(
SimpleMLP(input_dim, 32, 1, rngs=jax.random.PRNGKey(0)),
)
# 2. Use this Linen class with keras.layers.FlaxLayer
# FlaxLayer can take an instance or the class and kwargs
flax_layer = keras.layers.FlaxLayer(LinenMLP())
# 3. Build the Keras model
inputs = keras.Input(shape=(input_dim,), dtype=tf.float32)
# FlaxLayer's call method will pass the training argument
outputs = flax_layer(inputs)
model = keras.Model(inputs, outputs)
# Dummy Data
X_train = np.random.randn(100, input_dim).astype(np.float32)
y_train = np.random.randn(100, 1).astype(np.float32)
# Compile and Fit
model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mse")
print("Training with bridged NNX module...")
model.fit(X_train, y_train, epochs=3, batch_size=8, verbose=1)
print("Training finished.")
# Check variables
print("\nKeras Model Trainable Variables:")
for var in model.trainable_variables:
print(f"{var.name}: {var.shape}")
print("\nKeras Model Non-Trainable Variables (e.g., BatchNorm):")
for var in model.non_trainable_variables:
print(f"{var.name}: {var.shape}")
# Exporting
saved_model_path = "/tmp/nnx_bridged_saved_model"
print(f"\nSaving model to {saved_model_path}...")
model.save(saved_model_path)
print("Model saved.")
# Load and test
restored_model = tf.saved_model.load(saved_model_path)
print("Restored model predictions:", restored_model(X_train[:2]))
if __name__ == "__main__":
app.run(main)
Running this I get
AttributeError: 'jax.jaxlib._jax.ArrayImpl' object has no attribute 'params' It's this line self.linear1 = nnx.Linear(din, dhidden, rngs=rngs)