Conversation
|
I'm a bit skeptical about the implementation. For one thing I don't think it really needs a registry of operation-to-operation mappings. At the moment it's essentially a layer of checking for "is this jaxpr correct". This can be done without operation substitution, by simply analysing the jaxpr and raising an error if e.g. two disparate units are added together. For another I don't think this will scale well: if I call (+various nits: |
Not really! It's actually also doing unit conversions in the cases when that's necessary. It would be possible to provide a more minimal API that just checks, but that's not what I need!
Totally agreed - I'd love to hear suggestions for approaches to this! |
This builds on conversations on Twitter to sketch an interface for transforming a raw JAX function (using the
jax.numpyinterface directly) into one that supports units. For example:or
should both work.
This is so far (very!) incomplete. Some things to do / think about:
add_unitsormake_quantityprimitive that decorates the literal with units that we can use when transforming the jaxpr. This might also be useful for implementing correct derivative rules without overloadinggrad. Other ideas?/cc @shoyer @mattjj @sschoenholz @patrick-kidger