Skip to content

mcabbott/TensorGrad.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TensorGrad.jl

Build Status

This package adds gradient definitions for Zygote.jl to most calculations using TensorOperations.jl, and some using Einsum.jl. It exports a macro @grad which rewrites an expression like

@grad @tensor A[i,k] := B[i,j] * C[j,k] * D[l,l]

into something equivalent to this:

fun(b,c,d) = @tensor a[i,k] := b[i,j] * c[j,k] * d[l,l]  # define a function

@adjoint function fun(b,c,d)
    fwd = @tensor a[i,k] := b[i,j] * c[j,k] * d[l,l]     # forward pass
    function back(Δa)
        @tensor Δb[i,j] := Δa[i,k] * c[j,k] * d[l,l]     # reverse pass
        @tensor Δc[j,k] := b[i,j] * Δa[i,k] * d[l,l]
        δ = Diagonal(ones(size(d,1)))
        @tensor Δd[l,l′] := b[i,j] * c[j,k] * Δa[i,k] * δ[l,l′]
        return (Δb, Δc, Δd)
    end
    return (fwd, back)
end

A = fun(B,C,D)                                           # apply this to B, C, D

You may also write @grad B C @tensor A[i,k] := B[i,j] * C[j,k] * D[l,l] to specify that only sensitivities for B and C are needed, this will remove the calculation of Δd above.

To see what is being defined, call TensorGrad.verbose(true) before the macro (rather than using @macroexpand1).

If Tracker.jl is loaded, then it will now define the same gradients for B::TrackedArray etc.

Note that this is a fairly crude experiment, probably not something to rely on.

Limitations:

  1. The expression must be one term, and scalar factors are not handled yet.
  2. It makes no attempt to cache intermediate contractions for re-use, and thus if there are many tensors it will do the same work several times (like b[i,j] * c[j,k] above, done twice).
  3. Requires you to add @grad everywhere, so won't work in other people's code.

I can solve 1. But 2 seems hard to solve with this design.

It now understands other macros like @einsum which share the same syntax. This allows it to treat non-Einstein contractions, such as batched matrix multiplication:

@grad x @einsum z[i,k,b] := x[i,j,b] * y[j,k,b]

Those are also handled by @ein from OMEinsum.jl, which may be pointless as that has its own gradients built-in. Probably you should use that instead!

An earlier attempt is now TensorTrack.jl, which works at the level of functions contract! etc, and thus gets some re-use, 4. But is completely limited by 2, being deeply plugged into TensorOperations.

Finally, note also that TensorCast.jl should be almost fully differentiable (although focused on operations other than contractions).

--- Michael Abbott, August 2019

Update:

Essentially the same code has been bolted onto Tullio.jl originally in PR#6, and moved to @tensor in PR#92. It has the same limitations as above. (But it avoids eval by attaching gradients to a callable struct Eval always, not to the newly defined functions.)

The package TensorRules.jl has a macro @∇ which performs manipulations of @tensor expressions, acting on whole functions containing them.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages