Do to some recent changes to the torch backend (I think we enabled compilation by default?), a simple torch fit()
call with dict inputs will fail.
https://colab.research.google.com/gist/mattdangerw/5f966e9a35667396819efe13d0fa6980/torch-error.ipynb
The direct error is Could not infer dtype of NoneType
, but the stack trace makes it look like the dm-tree
library is involved.
Comment From: haifeng-jin
@kiukchung Would you mind taking a look? Thanks!
Comment From: kiukchung
Yes, we’ve recently enabled dynamo compilation for the torch backend (previously we’d always run in eager mode). When jit_compile=True the pytorch backend will attempt to compile your model. To run in eager mode set jit_compile=False or “auto”.
Comment From: mattdangerw
Thanks! I still think this would qualify as a bug though.
This code snippet is a basic use case, feeds dictionary input into a model.fit
call, and works on other backends when compiled without issues.
I suspect this is to do with dm-tree issues you were mentioning at some point, but we should figure out a fix of some sort.
(to be clear the issue is not the compilation is enabled, the issue is how it is failing in the colab above)
Comment From: kiukchung
I think the basic use case should be jit_compile=“auto” since it would default to the most natural run-mode for each framework (jit compiled mode for tf and jax and eager for torch)
The dm_tree issue isn’t exactly what is causing the exception, rather it is causing graph-breaks (does not manifest as python errors) and therefore leads to poor performance as graph breaks runs in python-land and therefore loses out on compiler optimization opportunities.
Comment From: mattdangerw
Do you understand what is causing the error? At a high-level this looks like it should be valid Keras code, do we expect this to not compile correctly on torch?
(I can dig in more too, just trying to grok if this is "works as intended" or not)
Comment From: kiukchung
I’m on PTO so would have to take a deeper look when I get back next Monday but from the stack trace this looks like a tracing error. Given that this is a simple flagship usage we should probably fix it and add this as a unittest. But in general it’s near impossible to guarantee that all keras usages will be dynamo compatible (or that one will see perf improvements using dynamo out of the box).
With dynamo, tracing can be a bit contextual so depending on the combination of keras APIs called (in specific order) and the type, dtype, and shape of the inputs and intermediaries, it could lead to tracing errors, guard failures (and therefore graph breaks) or legitimate graph breaks. There isn’t a great way to absolutely, 100% always guarantee that all keras usages be dynamo compatible. That’s why with jit_compile=“auto” (the default for keras trainers) we run with eager for the pytorch backend and only run with dynamo if the user wants to see if their particular usage will see speed ups with dynamo.
This is in part due to the way dynamo is designed. PyTorch’s intention is basically follow this path:
- run with eager for maximum compatibility (basically everything works pretty well on eager with ok performance)
- try dynamo. If it works out of the box, great.
- Otherwise you put in the effort to make your code compatible with dynamo. Presumably the more effort you put to remove graph breaks and such, the better perf you’ll get. But ultimately it’s up to you to decide how much effort you want to put into it.
This is fundamentally different from how JAX works since it’s basically all or nothing with JAX. If it runs then you can leave it up to XLA to give you the “best perf” but you need to make it compatible for even the first run to execute.
Comment From: mattdangerw
Thanks! This is really helpful to get the lay of the land.
Comment From: divyashreepathihalli
/gemini help