-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathpersistence.go
123 lines (101 loc) · 2.4 KB
/
persistence.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package rnn
import (
"bytes"
"encoding/gob"
"fmt"
"io/ioutil"
"log"
)
// GobEncode implements GobEncoder. This is necessary because RNN contains several unexported fields.
// It would be easier to simply export them by changing to uppercase, but for comparison purposes,
// I wanted to keep the field names the same between Go and the original Python code.
func (r *RNN) GobEncode() ([]byte, error) {
var b bytes.Buffer
encoder := gob.NewEncoder(&b)
var err error
encode := func(data interface{}) {
// no-op if we've already seen an err
if err == nil {
err = encoder.Encode(data)
}
}
encode(r.Wxh)
encode(r.Whh)
encode(r.Why)
encode(r.bh)
encode(r.by)
encode(&r.hprev)
encode(&r.mWxh)
encode(&r.mWhh)
encode(&r.mWhy)
encode(&r.mbh)
encode(&r.mby)
encode(r.data)
encode(r.charToIndex)
encode(r.indexToChar)
encode(r.VocabSize)
encode(r.n)
encode(r.loss)
encode(r.smooth_loss)
return b.Bytes(), err
}
// GobDecode implements GoDecoder.
func (r *RNN) GobDecode(data []byte) error {
b := bytes.NewBuffer(data)
decoder := gob.NewDecoder(b)
var err error
decode := func(data interface{}) {
// no-op if we've already seen an err
if err == nil {
err = decoder.Decode(data)
}
}
decode(&r.Wxh)
decode(&r.Whh)
decode(&r.Why)
decode(&r.bh)
decode(&r.by)
decode(&r.hprev)
decode(&r.mWxh)
decode(&r.mWhh)
decode(&r.mWhy)
decode(&r.mbh)
decode(&r.mby)
decode(&r.data)
decode(&r.charToIndex)
decode(&r.indexToChar)
decode(&r.VocabSize)
decode(&r.n)
decode(&r.loss)
decode(&r.smooth_loss)
return err
}
func (r *RNN) SaveTo(filePath string) error {
log.Printf("Saving RNN to %s...", filePath)
buf := new(bytes.Buffer)
encoder := gob.NewEncoder(buf)
err := encoder.Encode(r)
if err != nil {
return fmt.Errorf("error encoding network: %s", err)
}
err = ioutil.WriteFile(filePath, buf.Bytes(), 0644)
if err != nil {
return fmt.Errorf("error writing RNN to file: %s: %s", filePath, err)
}
return nil
}
func LoadFrom(filePath string) (*RNN, error) {
log.Printf("Loading RNN from %s...", filePath)
b, err := ioutil.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("error reading RNN checkpoint file: %s", err)
}
decoder := gob.NewDecoder(bytes.NewBuffer(b))
var result RNN
err = decoder.Decode(&result)
if err != nil {
return nil, fmt.Errorf("error decoding RNN checkpoint file: %s", err)
}
result.checkpointFile = filePath
return &result, nil
}