I can't seem to get a NNX module to work with FlaxLayer. Normal flax works fine, but when I create a nnx.Module and called nnx.bridge.ToLinen I get the error at the botttom.
class MyFlax(nn.Module): @nn.compact def call(self, x): return nn.Dense(features=1)(x) l = MyFlax() f = keras.layers.FlaxLayer(l) f(jnp.ones(1))
class MyNnx(nnx.Module): def init(self, *, rngs: nnx.Rngs): self.l = nnx.Linear(1, 1, rngs=rngs)
def call(self, **kwargs): print( kwargs) if 'inputs' in kwargs: return self.l(kwargs['inputs'])
MyNnx(rngs=nnx.Rngs(0))(inputs=jnp.ones(1))
l = nnx.bridge.ToLinen(MyNnx(rngs=nnx.Rngs(0))) f = keras.layers.FlaxLayer(l)
f(inputs=jnp.ones(1))
Cell In[101], line 1 ----> 1 f(inputs=jnp.ones(1))
File ~/projects/joe/.venv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.keras.config.disable_traceback_filtering()
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
[... skipping hidden 9 frame]
File ~/projects/joe/.venv/lib/python3.11/site-packages/flax/nnx/bridge/wrappers.py:263, in ToLinen.call(self, args, kwargs)
260 # TODO: add lazy_init here in case there's an ToNNX
submodule under module
.
261 # update linen variables before call module to save initial state
262 self._update_variables(module)
--> 263 out = module(args, **kwargs)
264 return out
266 # create state
TypeError: 'NoneType' object is not callable