Skip to content

OneHot op #559

@JoshPattman

Description

@JoshPattman

I am writing some code that requires an op that takes a vector of indexes (ints), and returns a matrix of onehot vectors. I didn't make this a pull request as I am not 100% sure this has enough error-checking / correct implementations for some functions. I also might just be missing the fact that this is already implemented.

I think, if its not already there, this would be a good addition to make gorgonia a bit easier to use.

package main

import (
	"fmt"
	"hash"

	"github.com/chewxy/hm"
	"gorgonia.org/gorgonia"
	"gorgonia.org/tensor"
)

func OneHot(x *gorgonia.Node, numClasses int, dType tensor.Dtype) (*gorgonia.Node, error) {
	op := &oneHotOp{numClasses, dType}

	return gorgonia.ApplyOp(op, x)
}

var _ gorgonia.Op = &oneHotOp{}
var _ gorgonia.SDOp = &oneHotOp{}

type oneHotOp struct {
	numClasses int
	dType      tensor.Dtype
}

// DiffWRT implements gorgonia.SDOp.
func (*oneHotOp) DiffWRT(inputs int) []bool {
	// I'm pretty sure you cant, nor would ever want to, take the derivative of this op.
	return make([]bool, inputs)
}

// SymDiff implements gorgonia.SDOp.
func (*oneHotOp) SymDiff(inputs gorgonia.Nodes, output *gorgonia.Node, grad *gorgonia.Node) (retVal gorgonia.Nodes, err error) {
	panic("unimplemented (tho tbf this should never be called)")
}

// Arity implements gorgonia.Op.
func (*oneHotOp) Arity() int {
	return 1 // we expect just a vector of indices
}

// CallsExtern implements gorgonia.Op.
func (*oneHotOp) CallsExtern() bool {
	return false
}

// Do implements gorgonia.Op.
func (op *oneHotOp) Do(inp ...gorgonia.Value) (gorgonia.Value, error) {
	batchSize := inp[0].Shape()[0]
	tens := tensor.New(tensor.WithShape(batchSize, op.numClasses), tensor.Of(op.dType))
	for i := 0; i < batchSize; i++ {
		index := inp[0].Data().([]int)[i]
		switch op.dType {
		case tensor.Int:
			tens.SetAt(int(1), i, index)
		case tensor.Float64:
			tens.SetAt(float64(1), i, index)
		case tensor.Float32:
			tens.SetAt(float32(1), i, index)
		case tensor.Bool:
			tens.SetAt(true, i, index)
		}
	}
	return tens, nil
}

// InferShape implements gorgonia.Op.
func (op *oneHotOp) InferShape(inputs ...gorgonia.DimSizer) (tensor.Shape, error) {
	s := inputs[0].(tensor.Shape).Clone()
	s = append(s, op.numClasses)
	return s, nil
}

// OverwritesInput implements gorgonia.Op.
func (*oneHotOp) OverwritesInput() int {
	return -1
}

// ReturnsPtr implements gorgonia.Op.
func (*oneHotOp) ReturnsPtr() bool {
	return false
}

// String implements gorgonia.Op.
func (*oneHotOp) String() string {
	return "OneHotOp"
}

// Type implements gorgonia.Op.
func (*oneHotOp) Type() hm.Type {
	ohTypeInput := gorgonia.TensorType{
		Dims: 1,
		Of:   tensor.Int,
	}
	ohTypeOutput := gorgonia.TensorType{
		Dims: 2,
		Of:   tensor.Float64,
	}
	return hm.NewFnType(ohTypeInput, ohTypeOutput)
}

// I dont actually know what this is for (i just copied this code from another op)
func (op *oneHotOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, op.String()) }

// Hashcode implements gorgonia.Op.
func (*oneHotOp) Hashcode() uint32 {
	// I dont actually know what this is for
	panic("unimplementedb")
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions