Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions differentiation.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func forwardDiffAnalysis(outputs, sortedNodes Nodes) (retVal NodeSet, err error)
// diffSet := outputs.Set()
diffSet := outputs.mapSet()

symdiffLogf("Diff Set: %d", diffSet)
symdiffLogf("Diff Set: %v", diffSet)
symdiffLogf("%d", sortedNodes)
// for i := len(sortedNodes) - 1; i ⩾ 0; i-- {
// n := sortedNodes[i]
Expand Down Expand Up @@ -216,7 +216,7 @@ func Backpropagate(outputs, gradOutputs, wrt Nodes) (retVal Nodes, err error) {
// "pullback" function to backpropagate derivatives
activeNodes := affectsOutput.Intersect(affectedByOutput)

symdiffLogf("Active: %d", activeNodes)
symdiffLogf("Active: %v", activeNodes)

symdiffLogf("Sorted: %d", sortedNodes)
symdiffLogf("nodeGradMap: %+#d", FmtNodeMap(nodeGradMap))
Expand Down
9 changes: 7 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ github.com/cznic/strutil v0.0.0-20181122101858-275e90344537/go.mod h1:AHHPPPXTw0
github.com/cznic/xc v0.0.0-20181122101856-45b06973881e/go.mod h1:3oFoiOvCDBYH+swwf5+k/woVmWy7h1Fcyu8Qig/jjX0=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/delaneyj/cogent v0.0.0-20180619184653-2fcea326194c h1:UliKg7JACWAXDW7yFdms6lLwOLK7H3uId3NG5z4f378=
github.com/delaneyj/cogent v0.0.0-20180619184653-2fcea326194c/go.mod h1:hL/k6TDIq37bqQ6sySYVYw+Idnv0JkVmKsmedD5AduQ=
github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys=
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
Expand Down Expand Up @@ -50,15 +52,20 @@ github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-runewidth v0.0.4 h1:2BvfKmzob6Bmd4YsL0zygOqfdFnK7GR4QL06Do4/p7Y=
github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/mattn/gorgonia-cfd5423e2acc2f8c2b86 v0.0.0-20200313070349-288c2a647837 h1:/7GLXOx1Cd15DDfNpIZguExr6Ui5e2vKVbCf8x52ls0=
github.com/mattn/gorgonia-cfd5423e2acc2f8c2b86 v0.0.0-20200313070349-288c2a647837/go.mod h1:MGXCds9oIEtiTo7SSDV2qlEYxIFO0LdSOf4BlNJYr34=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/seehuhn/mt19937 v0.0.0-20191220121156-d07252b9f9df h1:rhEzo7J+sDOLI5NulkwtescnyYMSt4J5mkxDMgQRjN4=
github.com/seehuhn/mt19937 v0.0.0-20191220121156-d07252b9f9df/go.mod h1:w+IAy13Luqfsp+plFpT1RiqauADylJKmpkrWFwpjbsc=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
Expand Down Expand Up @@ -114,9 +121,7 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorgonia.org/cu v0.9.0-beta h1:s4WQ6fiAGoErwIiXWHRB6Y9ydkx1vTTPwhWzoEZVePc=
gorgonia.org/cu v0.9.0-beta/go.mod h1:RPEPIfaxxqUmeRe7T1T8a0NER+KxBI2McoLEXhP1Vd8=
gorgonia.org/cu v0.9.3 h1:IkxE4NWXuZHqr8AnmgoB8WNQPZeD6u0EJNxYjDC0YgY=
gorgonia.org/cu v0.9.3/go.mod h1:LgyAYDkN7HWhh8orGnCY2R8pP9PYbO44ivEbLMatkVU=
gorgonia.org/dawson v1.1.0 h1:o7+eJ3SKi9sheH19lpOat//tDbg0Y+M9iY/lH79VHqY=
gorgonia.org/dawson v1.1.0/go.mod h1:Px1mcziba8YUBIDsbzGwbKJ11uIblv/zkln4jNrZ9Ws=
gorgonia.org/dawson v1.2.0 h1:hJ/aofhfkReSnJdSMDzypRZ/oWDL1TmeYOauBnXKdFw=
gorgonia.org/dawson v1.2.0/go.mod h1:Px1mcziba8YUBIDsbzGwbKJ11uIblv/zkln4jNrZ9Ws=
Expand Down
10 changes: 6 additions & 4 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,11 @@ func WithShape(shp ...int) NodeConsOpt {
s := tensor.Shape(tensor.BorrowInts(len(shp)))
copy(s, shp)
f := func(n *Node) {
if n.t == nil && n.shape == nil {
n.shape = s
return
}
nd := n.Dims()
// if nd == 1 && s.IsVector() {
// goto safe
// }
isVec := s.IsColVec() || s.IsRowVec()
acceptVec := (isVec && (nd == 1))
sameDims := nd == s.Dims()
Expand All @@ -209,7 +210,6 @@ func WithShape(shp ...int) NodeConsOpt {
if !acceptVec && !sameDims && !acceptScalar {
panic(fmt.Sprintf("Node %v, has %d dimensions(Shape: %v). Input shape is %v, which has %d dimensions", n, n.Dims(), n.shape, s, s.Dims()))
}
// safe:
n.shape = s
}
return f
Expand Down Expand Up @@ -258,6 +258,8 @@ func newNode(opts ...NodeConsOpt) *Node {
n := borrowNode()
n.dataOn = CPU
n.id = -1
n.t = nil
n.shape = nil

for _, opt := range opts {
opt(n)
Expand Down
21 changes: 8 additions & 13 deletions op_reduction.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (

func reductionType(d int, along []int) hm.Type {
a := hm.TypeVariable('a')
t := makeTensorType(d, a)
t := makeTensorType(d-len(along), a)

axes := make(map[int]bool)
for _, axis := range along {
Expand Down Expand Up @@ -52,24 +52,19 @@ func reductionInferShape(along []int, in tensor.Shape) (tensor.Shape, error) {
if d >= shape.Dims() {
return nil, fmt.Errorf("shape error, along %d is not a valid axis for shape %v", d, in)
}
shape[d] = 1
shape[d] = 0
}
// special cases: if all dimensions are 1 -> ScalarShape, if exactly one dimension is != 1 -> vector
vecD := 0
numNot1 := 0

var dims []int
for _, d := range shape {
if d != 1 {
vecD = d
numNot1++
if numNot1 > 1 {
return shape, nil
}
if d != 0 {
dims = append(dims, d)
}
}
if numNot1 == 0 {
if len(dims) == 0 {
return tensor.ScalarShape(), nil
}
return tensor.Shape{vecD}, nil
return tensor.Shape(dims), nil
}

func reductionDo(op Op, s string, f func(*tensor.Dense, ...int) (*tensor.Dense, error), along []int, inputs ...Value) (retVal Value, err error) {
Expand Down
26 changes: 13 additions & 13 deletions op_reduction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func TestMaxOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Max,
along: []int{0},
wantShape: []int{1, 2, 2, 2},
wantShape: []int{2, 2, 2},
wantData: []float32{9, 10, 11, 12, 13, 14, 15, 16},
},
{
Expand All @@ -263,7 +263,7 @@ func TestMaxOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Max,
along: []int{1},
wantShape: []int{2, 1, 2, 2},
wantShape: []int{2, 2, 2},
wantData: []float32{5, 6, 7, 8, 13, 14, 15, 16},
},
{
Expand All @@ -272,7 +272,7 @@ func TestMaxOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Max,
along: []int{2},
wantShape: []int{2, 2, 1, 2},
wantShape: []int{2, 2, 2},
wantData: []float32{3, 4, 7, 8, 11, 12, 15, 16},
},
{
Expand All @@ -281,7 +281,7 @@ func TestMaxOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Max,
along: []int{3},
wantShape: []int{2, 2, 2, 1},
wantShape: []int{2, 2, 2},
wantData: []float32{2, 4, 6, 8, 10, 12, 14, 16},
},
{
Expand All @@ -290,7 +290,7 @@ func TestMaxOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Max,
along: []int{1, 3},
wantShape: []int{2, 1, 2, 1},
wantShape: []int{2, 2},
wantData: []float32{6, 8, 14, 16},
},
{
Expand Down Expand Up @@ -342,7 +342,7 @@ func TestSumOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Sum,
along: []int{0},
wantShape: []int{1, 2, 2, 2},
wantShape: []int{2, 2, 2},
wantData: []float32{10, 12, 14, 16, 18, 20, 22, 24},
},
{
Expand All @@ -351,7 +351,7 @@ func TestSumOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Sum,
along: []int{1},
wantShape: []int{2, 1, 2, 2},
wantShape: []int{2, 2, 2},
wantData: []float32{6, 8, 10, 12, 22, 24, 26, 28},
},
{
Expand All @@ -360,7 +360,7 @@ func TestSumOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Sum,
along: []int{2},
wantShape: []int{2, 2, 1, 2},
wantShape: []int{2, 2, 2},
wantData: []float32{4, 6, 12, 14, 20, 22, 28, 30},
},
{
Expand All @@ -369,7 +369,7 @@ func TestSumOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Sum,
along: []int{3},
wantShape: []int{2, 2, 2, 1},
wantShape: []int{2, 2, 2},
wantData: []float32{3, 7, 11, 15, 19, 23, 27, 31},
},
{
Expand All @@ -378,7 +378,7 @@ func TestSumOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Sum,
along: []int{1, 3},
wantShape: []int{2, 1, 2, 1},
wantShape: []int{2, 2},
wantData: []float32{14, 22, 46, 54},
},
{
Expand Down Expand Up @@ -562,12 +562,12 @@ func TestFollowupOp(t *testing.T) {
Xn := NewTensor(g, tensor.Float64, 4, WithShape(2, 2, 2, 2), WithInit(RangedFrom(1)))
mx := Must(Max(Xn, 1, 2))
sx := Must(Sum(Xn, 1, 2))
y := NewTensor(g, tensor.Float64, 4, WithShape(2, 1, 1, 2), WithInit(RangedFrom(1)))
y := NewTensor(g, tensor.Float64, 2, WithShape(2, 2), WithInit(RangedFrom(1)))

amx := Must(Add(mx, y))
asx := Must(Add(sx, y))
assert.Equal(t, amx.Shape(), tensor.Shape{2, 1, 1, 2})
assert.Equal(t, asx.Shape(), tensor.Shape{2, 1, 1, 2})
assert.Equal(t, amx.Shape(), tensor.Shape{2, 2})
assert.Equal(t, asx.Shape(), tensor.Shape{2, 2})
vm := NewTapeMachine(g)
defer vm.Close()
err := vm.RunAll()
Expand Down
11 changes: 1 addition & 10 deletions operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,16 +323,7 @@ func Sum(a *Node, along ...int) (retVal *Node, err error) {

dims := a.Dims()
if len(along) == 0 {
switch {
case a.IsRowVec():
along = []int{1}
dims = 1
case a.IsColVec(), a.IsVector():
along = []int{0}
dims = 1
default:
along = intRange(0, dims)
}
along = intRange(0, dims)
}

op := newSumOp(along, a.shape, dims)
Expand Down