I can't seem to get a NNX module to work with FlaxLayer. Normal flax works fine, but when I create a nnx.Module and called nnx.bridge.ToLinen I get the error at the botttom.
class MyFlax(nn.Module): @nn.compact def call(self, x): return nn.Dense(features=1)(x) l = MyFlax() f = keras.layers.FlaxLayer(l) f(jnp.ones(1))
class MyNnx(nnx.Module): def init(self, *, rngs: nnx.Rngs): self.l = nnx.Linear(1, 1, rngs=rngs)
def call(self, **kwargs): print( kwargs) if 'inputs' in kwargs: return self.l(kwargs['inputs'])
MyNnx(rngs=nnx.Rngs(0))(inputs=jnp.ones(1))
l = nnx.bridge.ToLinen(MyNnx(rngs=nnx.Rngs(0))) f = keras.layers.FlaxLayer(l)
f(inputs=jnp.ones(1))
Cell In[101], line 1 ----> 1 f(inputs=jnp.ones(1))
File ~/projects/joe/.venv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback.keras.config.disable_traceback_filtering()
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
[... skipping hidden 9 frame]
File ~/projects/joe/.venv/lib/python3.11/site-packages/flax/nnx/bridge/wrappers.py:263, in ToLinen.call(self, args, kwargs)
260 # TODO: add lazy_init here in case there's an ToNNX
submodule under module
.
261 # update linen variables before call module to save initial state
262 self._update_variables(module)
--> 263 out = module(args, **kwargs)
264 return out
266 # create state
TypeError: 'NoneType' object is not callable
Comment From: sonali-kumari1
Hi @joetoth -
Thanks for reporting this issue. I have tested your code in this gist and encountered the same TypeError: 'NoneType' object is not callable
. We will look into this and update you.
Comment From: hertschuh
@joetoth ,
The issue is that you're passing an nnx.Module
instance to ToLinen
; instead, you should be passing a class. Here is how I fixed it:
class MyFlax(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Dense(features=1)(x)
l = MyFlax()
f = keras.layers.FlaxLayer(l)
f(jnp.ones(1))
class MyNnx(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.l = nnx.Linear(1, 1, rngs=rngs)
def __call__(self, x):
return self.l(x)
MyNnx(rngs=nnx.Rngs(0))(jnp.ones(1))
l = nnx.bridge.ToLinen(MyNnx)
f = keras.layers.FlaxLayer(l)
f(inputs=jnp.ones(1))
Note that we're working on a native integration of NNX in Keras. Once that's ready, mixing NNX modules in Keras will be a lot easier.