Skip to content

Commit

Permalink
changed []ts.Tensor -> []*ts.Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
sugarme committed Jul 5, 2023
1 parent e29288f commit f45f0a7
Show file tree
Hide file tree
Showing 25 changed files with 18,892 additions and 15,212 deletions.
40 changes: 10 additions & 30 deletions example/char-rnn/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,18 @@ func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.De

state := lstm.Step(input, inState)

// 1. Delete inState tensors (from C land memory)
inState.(*nn.LSTMState).Tensor1.MustDrop()
inState.(*nn.LSTMState).Tensor2.MustDrop()
// 2. Then update with current state
// Update with current state
inState = state
// 3. Delete intermediate tensors
input.MustDrop()
inputView.MustDrop()

forwardTs := linear.Forward(state.(*nn.LSTMState).H()).MustSqueezeDim(0, true).MustSoftmax(-1, gotch.Float, true)
sampledY := forwardTs.MustMultinomial(1, false, true)
lastLabel = sampledY.Int64Values()[0]
sampledY.MustDrop()
char := data.LabelForChar(lastLabel)

runes = append(runes, char)
}

// Delete the last state
inState.(*nn.LSTMState).Tensor1.MustDrop()
inState.(*nn.LSTMState).Tensor2.MustDrop()
ts.CleanUp(100)
}

return string(runes)
}
Expand Down Expand Up @@ -93,42 +84,31 @@ func main() {

batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false)
xsOnehot := batchNarrow.Onehot(labels).MustTo(device, true) // [256, 180, 65]
batchNarrow.MustDrop()

ys := batchTs.MustNarrow(1, 1, SeqLen, true).MustTotype(gotch.Int64, true).MustTo(device, true).MustView([]int64{BatchSize * SeqLen}, true)

lstmOut, outState := lstm.Seq(xsOnehot)
// NOTE. Although outState will not be used. There a hidden memory usage
// on C land memory that is needed to free up. Don't use `_`
outState.(*nn.LSTMState).Tensor1.MustDrop()
outState.(*nn.LSTMState).Tensor2.MustDrop()
xsOnehot.MustDrop()
lstmOut, _ := lstm.Seq(xsOnehot)

logits := linear.Forward(lstmOut)
lstmOut.MustDrop()
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)

loss := lossView.CrossEntropyForLogits(ys)
ys.MustDrop()
lossView.MustDrop()

opt.BackwardStepClip(loss, 0.5)
sumLoss += loss.Float64Values()[0]
cntLoss += 1.0
loss.MustDrop()

batchCount++
if batchCount%500 == 0 {
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
fmt.Printf("\nEpoch %v - Batch %v \n", epoch, batchCount)
}
fmt.Printf("dataIter: progress: %v\n", dataIter.Progress())
// fmt.Printf("dataIter: progress: %v\n", dataIter.Progress())
fmt.Print(".")

ts.CleanUp(100)
} // infinite for-loop

sampleStr := sample(data, lstm, linear, device)
fmt.Printf("Epoch %v - Loss: %v \n", epoch, sumLoss/cntLoss)
fmt.Printf("\nEpoch %v - Loss: %v \n", epoch, sumLoss/cntLoss)
fmt.Println(sampleStr)

dataIter.Data.MustDrop()
dataIter.Indexes.MustDrop()
}
}
2 changes: 1 addition & 1 deletion nn/jit.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (m *TrainableCModule) Save(file string) error {
// ForwardT implements ModuleT for TrainableCModule.
// NOTE: train parameter will not be used.
func (m *TrainableCModule) ForwardT(x *ts.Tensor, train bool) *ts.Tensor {
retVal, err := m.Inner.ForwardTs([]ts.Tensor{*x})
retVal, err := m.Inner.ForwardTs([]*ts.Tensor{x})
if err != nil {
log.Fatal(err)
}
Expand Down
12 changes: 6 additions & 6 deletions nn/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,12 @@ func WithErrorIfNonFinite(v bool) ClipOpt {
}
}

/// Clips gradient L2 norm over all trainable parameters.
// / Clips gradient L2 norm over all trainable parameters.
//
// The norm is computed over all gradients together, as if they were
// concatenated into a single vector.
//
/// Args:
// / Args:
// - max: max norm of the gradient
// - o.NormType. Type of the used p-norm, can be "inf" for infinity norm. Default= 2.0
// - o.ErrorIfNonFinite bool. If true, throw error if total norm of the gradients from paramters is "nan", "inf" or "-inf". Default=false
Expand All @@ -413,15 +413,15 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
}

var (
norms []ts.Tensor
norms []*ts.Tensor
totalNorm *ts.Tensor
)

device := opt.varstore.device
if o.NormType == math.Inf(1) {
for _, v := range opt.varstore.vars {
n := v.Tensor.MustGrad(false).MustDetach(true).MustAbs(true).MustMax(true).MustTo(device, true)
norms = append(norms, *n)
norms = append(norms, n)
}
// total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
totalNorm = ts.MustStack(norms, 0).MustMax(true)
Expand All @@ -432,7 +432,7 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
// NOTE. tensor.Norm() is going to be deprecated. So use linalg_norm
// Ref. https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
norms = append(norms, *x)
norms = append(norms, x)
}
}

Expand Down Expand Up @@ -556,7 +556,7 @@ func (opt *Optimizer) ParamGroupNum() int {
return int(ngroup)
}

func (opt *Optimizer) AddParamGroup(tensors []ts.Tensor) {
func (opt *Optimizer) AddParamGroup(tensors []*ts.Tensor) {
err := opt.opt.AddParamGroup(tensors)
if err != nil {
log.Fatalf("Optimizer - ParamGroupNum method call error: %v\n", err)
Expand Down
18 changes: 9 additions & 9 deletions nn/rnn.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func DefaultRNNConfig() *RNNConfig {
//
// https://en.wikipedia.org/wiki/Long_short-term_memory
type LSTM struct {
flatWeights []ts.Tensor
flatWeights []*ts.Tensor
hiddenDim int64
config *RNNConfig
device gotch.Device
Expand All @@ -89,7 +89,7 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
}

gateDim := 4 * hiddenDim
flatWeights := make([]ts.Tensor, 0)
flatWeights := make([]*ts.Tensor, 0)

for i := 0; i < int(cfg.NumLayers); i++ {
if i != 0 {
Expand All @@ -102,22 +102,22 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
bIh := vs.MustZeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
bHh := vs.MustZeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})

flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh)

case 2: // bi-directional
// forward
wIh := vs.MustKaimingUniform(fmt.Sprintf("weight_ih_l%d", i), []int64{gateDim, inDim})
wHh := vs.MustKaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
bIh := vs.MustZeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
bHh := vs.MustZeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh)

// reverse
wIhR := vs.MustKaimingUniform(fmt.Sprintf("weight_ih_l%d_reverse", i), []int64{gateDim, inDim})
wHhR := vs.MustKaimingUniform(fmt.Sprintf("weight_hh_l%d_reverse", i), []int64{gateDim, hiddenDim})
bIhR := vs.MustZeros(fmt.Sprintf("bias_ih_l%d_reverse", i), []int64{gateDim})
bHhR := vs.MustZeros(fmt.Sprintf("bias_hh_l%d_reverse", i), []int64{gateDim})
flatWeights = append(flatWeights, *wIhR, *wHhR, *bIhR, *bHhR)
flatWeights = append(flatWeights, wIhR, wHhR, bIhR, bHhR)
}
}

Expand Down Expand Up @@ -188,7 +188,7 @@ func (l *LSTM) Seq(input *ts.Tensor) (*ts.Tensor, State) {

func (l *LSTM) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) {

output, h, c := input.MustLstm([]ts.Tensor{*inState.(*LSTMState).Tensor1, *inState.(*LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
output, h, c := input.MustLstm([]*ts.Tensor{inState.(*LSTMState).Tensor1, inState.(*LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)

return output, &LSTMState{
Tensor1: h,
Expand All @@ -209,7 +209,7 @@ func (gs *GRUState) Value() *ts.Tensor {
//
// https://en.wikipedia.org/wiki/Gated_recurrent_unit
type GRU struct {
flatWeights []ts.Tensor
flatWeights []*ts.Tensor
hiddenDim int64
config *RNNConfig
device gotch.Device
Expand All @@ -223,7 +223,7 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) {
}

gateDim := 3 * hiddenDim
flatWeights := make([]ts.Tensor, 0)
flatWeights := make([]*ts.Tensor, 0)

for i := 0; i < int(cfg.NumLayers); i++ {
for n := 0; n < int(numDirections); n++ {
Expand All @@ -239,7 +239,7 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) {
bIh := vs.MustZeros("b_ih", []int64{gateDim})
bHh := vs.MustZeros("b_hh", []int64{gateDim})

flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh)
}
}

Expand Down
30 changes: 15 additions & 15 deletions nn/sequential.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ func (s *Sequential) AddFn(fn ts.Module) {
}

// ForwardAll applies the forward pass and returns the output for each layer.
func (s *Sequential) ForwardAll(xs *ts.Tensor, opts ...uint8) (retVal []ts.Tensor) {
func (s *Sequential) ForwardAll(xs *ts.Tensor, opts ...uint8) (retVal []*ts.Tensor) {

var n uint8 = uint8(len(s.layers))
if len(opts) > 0 {
n = opts[0]
}

if s.IsEmpty() {
return []ts.Tensor{*xs.MustShallowClone()}
return []*ts.Tensor{xs.MustShallowClone()}
}

for i := 0; i < int(n); i++ {
retVal = append(retVal, *s.layers[i].Forward(xs))
retVal = append(retVal, s.layers[i].Forward(xs))
}

return retVal
Expand All @@ -85,15 +85,15 @@ func (s *Sequential) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
}

// forward sequentially
outs := make([]ts.Tensor, len(s.layers))
outs := make([]*ts.Tensor, len(s.layers))
for i := 0; i < len(s.layers); i++ {
if i == 0 {
outs[0] = *s.layers[i].Forward(xs)
outs[0] = s.layers[i].Forward(xs)
defer outs[0].MustDrop()
} else if i == len(s.layers)-1 {
return s.layers[i].Forward(&outs[i-1])
return s.layers[i].Forward(outs[i-1])
} else {
outs[i] = *s.layers[i].Forward(&outs[i-1])
outs[i] = s.layers[i].Forward(outs[i-1])
defer outs[i].MustDrop()
}
}
Expand All @@ -106,7 +106,7 @@ type SequentialT struct {
layers []ts.ModuleT
}

/// SeqT creates a new empty sequential layer.
// / SeqT creates a new empty sequential layer.
func SeqT() *SequentialT {
return &SequentialT{
layers: make([]ts.ModuleT, 0),
Expand Down Expand Up @@ -139,15 +139,15 @@ func (s *SequentialT) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
}

// forward sequentially
outs := make([]ts.Tensor, len(s.layers))
outs := make([]*ts.Tensor, len(s.layers))
for i := 0; i < len(s.layers); i++ {
if i == 0 {
outs[0] = *s.layers[i].ForwardT(xs, train)
outs[0] = s.layers[i].ForwardT(xs, train)
defer outs[0].MustDrop()
} else if i == len(s.layers)-1 {
return s.layers[i].ForwardT(&outs[i-1], train)
return s.layers[i].ForwardT(outs[i-1], train)
} else {
outs[i] = *s.layers[i].ForwardT(&outs[i-1], train)
outs[i] = s.layers[i].ForwardT(outs[i-1], train)
defer outs[i].MustDrop()
}
}
Expand Down Expand Up @@ -179,21 +179,21 @@ func (s *SequentialT) AddFnT(fn ts.ModuleT) {
}

// ForwardAll applies the forward pass and returns the output for each layer.
func (s *SequentialT) ForwardAllT(xs *ts.Tensor, train bool, opts ...uint8) (retVal []ts.Tensor) {
func (s *SequentialT) ForwardAllT(xs *ts.Tensor, train bool, opts ...uint8) (retVal []*ts.Tensor) {

var n uint8 = uint8(len(s.layers))
if len(opts) > 0 {
n = opts[0]
}

if s.IsEmpty() {
return []ts.Tensor{*xs.MustShallowClone()}
return []*ts.Tensor{xs.MustShallowClone()}
}

currTs := xs
for i := 0; i < int(n); i++ {
res := s.layers[i].ForwardT(currTs, train)
retVal = append(retVal, *res)
retVal = append(retVal, res)
currTs = res
}

Expand Down
6 changes: 3 additions & 3 deletions nn/varstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ func (vs *VarStore) IsEmpty() bool {
}

// TrainableVariabless returns reference to all trainable variables kept in VarStore.
func (vs *VarStore) TrainableVariables() []ts.Tensor {
func (vs *VarStore) TrainableVariables() []*ts.Tensor {
vs.Lock()
defer vs.Unlock()

var trainables []ts.Tensor
var trainables []*ts.Tensor
for _, v := range vs.vars {
x := v.Tensor
if x.MustRequiresGrad() {
trainables = append(trainables, *x)
trainables = append(trainables, x)
}
}

Expand Down
5 changes: 3 additions & 2 deletions ts/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ func (tdi *TextDataIter) Progress() float32 {
progress := float32(startIndex) / float32(availableIndices)
return progress
}

// Labels returns the number of different `character` (rune) used by the dataset.
func (td *TextData) Labels() (retVal int64) {
return int64(len(td.CharForLabel))
Expand Down Expand Up @@ -281,12 +282,12 @@ func (tdi *TextDataIter) Next() (*Tensor, bool) {
indexes := indexesTs.Int64Values()
indexesTs.MustDrop()

var batch []Tensor
var batch []*Tensor

for _, idx := range indexes {
narrowIdx := NewNarrow(idx, idx+tdi.SeqLen)
idxTs := tdi.Data.Idx(narrowIdx)
batch = append(batch, *idxTs)
batch = append(batch, idxTs)
}

retVal := MustStack(batch, 0)
Expand Down
1 change: 0 additions & 1 deletion ts/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,4 @@ func TestTextDataIter(t *testing.T) {
vals := sum.Int64Values()
t.Logf("sum: %v\n", vals)
}

}
Loading

0 comments on commit f45f0a7

Please sign in to comment.