import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["KERAS_NNX_ENABLED"] = "true"
from flax import nnx
from absl import app
import jax
import keras
import numpy as np
import tensorflow as tf


# Your NNX Modules (as before)
class SimpleMLP(nnx.Module):

  def __init__(self, din, dhidden, dout, *, rngs):
    self.linear1 = nnx.Linear(din, dhidden, rngs=rngs)
    self.linear2 = nnx.Linear(dhidden, dout, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    x = nnx.relu(x)
    x = self.linear2(x)
    return x



def main(_):
  input_dim = 10
  # 1. Create a Linen-compatible class using the bridge
  # ToLinen for StatefulMLP
  LinenMLP = nnx.bridge.ToLinen(
      SimpleMLP(input_dim, 32, 1, rngs=jax.random.PRNGKey(0)),
  )

  # 2. Use this Linen class with keras.layers.FlaxLayer
  # FlaxLayer can take an instance or the class and kwargs
  flax_layer = keras.layers.FlaxLayer(LinenMLP())

  # 3. Build the Keras model
  inputs = keras.Input(shape=(input_dim,), dtype=tf.float32)
  # FlaxLayer's call method will pass the training argument
  outputs = flax_layer(inputs)
  model = keras.Model(inputs, outputs)

  # Dummy Data
  X_train = np.random.randn(100, input_dim).astype(np.float32)
  y_train = np.random.randn(100, 1).astype(np.float32)

  # Compile and Fit
  model.compile(optimizer=keras.optimizers.Adam(0.01), loss="mse")
  print("Training with bridged NNX module...")
  model.fit(X_train, y_train, epochs=3, batch_size=8, verbose=1)
  print("Training finished.")

  # Check variables
  print("\nKeras Model Trainable Variables:")
  for var in model.trainable_variables:
    print(f"{var.name}: {var.shape}")

  print("\nKeras Model Non-Trainable Variables (e.g., BatchNorm):")
  for var in model.non_trainable_variables:
    print(f"{var.name}: {var.shape}")

  # Exporting
  saved_model_path = "/tmp/nnx_bridged_saved_model"
  print(f"\nSaving model to {saved_model_path}...")
  model.save(saved_model_path)
  print("Model saved.")

  # Load and test
  restored_model = tf.saved_model.load(saved_model_path)
  print("Restored model predictions:", restored_model(X_train[:2]))


if __name__ == "__main__":
  app.run(main)

Running this I get

AttributeError: 'jax.jaxlib._jax.ArrayImpl' object has no attribute 'params' It's this line self.linear1 = nnx.Linear(din, dhidden, rngs=rngs)