Skip to content

Commit

Permalink
base.SourceCloner. NB ExampleGridSearchCV fails.
Browse files Browse the repository at this point in the history
  • Loading branch information
pa-m committed Dec 4, 2019
1 parent 26b55ef commit 469516e
Show file tree
Hide file tree
Showing 15 changed files with 143 additions and 58 deletions.
35 changes: 24 additions & 11 deletions base/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,17 @@ package base
import (
"sync"

"golang.org/x/exp/rand"

"github.com/pa-m/randomkit"
"golang.org/x/exp/rand"
)

// A Source represents a source of uniformly-distributed
// pseudo-random int64 values in the range [0, 1<<64).
type Source interface {
Uint64() uint64
Seed(seed uint64)
}
type Source = rand.Source

// SourceCloner is an "golang.org/x/exp/rand".Source with a Clone method
type SourceCloner interface {
Clone() rand.Source
SourceClone() Source
}

// RandomState represents a bit more than random_state pythonic attribute. it's not only a seed but a source with a state as it's name states
Expand All @@ -34,7 +30,14 @@ func NewSource(seed uint64) *randomkit.RKState {
// It is just a standard Source with its operations protected by a sync.Mutex.
type LockedSource struct {
lk sync.Mutex
src Source
src *randomkit.RKState
}

// WithLock executes f while s is locked
func (s *LockedSource) WithLock(f func(Source)) {
s.lk.Lock()
f(s.src)
s.lk.Unlock()
}

// Uint64 ...
Expand All @@ -52,9 +55,9 @@ func (s *LockedSource) Seed(seed uint64) {
s.lk.Unlock()
}

// Clone ...
func (s *LockedSource) Clone() rand.Source {
return &LockedSource{src: s.src.(SourceCloner).Clone()}
// SourceClone ...
func (s *LockedSource) SourceClone() Source {
return &LockedSource{src: s.src.SourceClone().(*randomkit.RKState)}
}

// NewLockedSource returns a rand.Source safe for concurrent access
Expand All @@ -78,3 +81,13 @@ type NormFloat64er interface {
type Intner interface {
Intn(int) int
}

// Permer is implemented by a random source having a method Perm(int) []int
type Permer interface {
Perm(int) []int
}

// Shuffler is implemented by a random source having a method Shuffle(int,func(int,int))
type Shuffler interface {
Shuffler(int, func(int, int))
}
2 changes: 1 addition & 1 deletion base/source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ var (

func TestSource(t *testing.T) {
s := NewSource(7)
s2 := s.Clone()
s2 := s.SourceClone()
var a [5]float64
for i := range a {
a[i] = s.Float64()
Expand Down
2 changes: 1 addition & 1 deletion gaussian_process/gpr.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (m *Regressor) IsClassifier() bool { return false }
func (m *Regressor) PredicterClone() base.Predicter {
clone := *m
if cloner, ok := m.RandomState.(base.SourceCloner); ok {
clone.RandomState = cloner.Clone()
clone.RandomState = cloner.SourceClone()
}
if m.Xtrain != nil {
clone.Xtrain = mat.DenseCopyOf(m.Xtrain)
Expand Down
6 changes: 2 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ require (
github.com/jung-kurt/gofpdf v1.10.1 // indirect
github.com/kr/pretty v0.1.0 // indirect
github.com/pa-m/optimize v0.0.0-20190612075243-15ee852a6d9a
github.com/pa-m/randomkit v0.0.0-20190612075210-f24d270692b4
github.com/pa-m/randomkit v0.0.0-20191001073902-db4fd80633df
github.com/pkg/errors v0.8.1
golang.org/x/exp v0.0.0-20190829153037-c13cbed26979
golang.org/x/image v0.0.0-20190902063713-cb417be4ba39 // indirect
golang.org/x/tools v0.0.0-20190905235650-93dcc2f048f5 // indirect
golang.org/x/exp v0.0.0-20191129062945-2f5052295587
gonum.org/v1/gonum v0.0.0-20190929233944-b20cf7805fc4
gonum.org/v1/plot v0.0.0-20190615073203-9aa86143727f
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
Expand Down
13 changes: 13 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802 h1:1BDTz0u9nC3//pOCMdNH+CiXJVYJh5UQNCOBG7jbELc=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af h1:wVe6/Ea46ZMeNkQjjBW6xcqyQA/j5e0D6GytH95g0gQ=
Expand All @@ -18,6 +19,7 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/fogleman/gg v1.3.0 h1:/7zJX8F6AaYQc57WQCyN9cAIz+4bCJGO9B+dyW29am8=
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/gogo/protobuf v1.3.0 h1:G8O7TerXerS4F6sx9OV7/nRfJdnXgHZu/S/7F2SN+UE=
github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
Expand All @@ -41,6 +43,8 @@ github.com/pa-m/optimize v0.0.0-20190612075243-15ee852a6d9a h1:cgsB0XsJwsMq0JifJ
github.com/pa-m/optimize v0.0.0-20190612075243-15ee852a6d9a/go.mod h1:gHioqOgOl5Wa4lmyUg/ojarU7Dfdkh/OnTnGA/WexsY=
github.com/pa-m/randomkit v0.0.0-20190612075210-f24d270692b4 h1:+LyPTCDcQRARqza7LfS0w7v03e7VYceqQNTE8eRcGA4=
github.com/pa-m/randomkit v0.0.0-20190612075210-f24d270692b4/go.mod h1:2Ix1Kyeujyr6FhU2SPX4iyiEpEBjHHcRV/Mki06ACcE=
github.com/pa-m/randomkit v0.0.0-20191001073902-db4fd80633df h1:waQf2YvgkQdOEK4IvtzwNIuFAo2FZd34JtAb/wrLbbc=
github.com/pa-m/randomkit v0.0.0-20191001073902-db4fd80633df/go.mod h1:rEyYBR/jbMkj6lX7VpWTAPPrjDIi/aNhAXmFuLMZS4o=
github.com/phpdave11/gofpdi v1.0.7/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
Expand All @@ -61,6 +65,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de h1:xSjD6HQTqT0H/k60N5yYBtnN1OEkVy7WIo/DYyxKRO0=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f h1:9kQ594xxPWRNKfTOnPjPcgrIJ19zM3ic57aI7PbMyAA=
Expand All @@ -71,8 +76,11 @@ golang.org/x/exp v0.0.0-20190312203227-4b39c73a6495 h1:I6A9Ag9FpEKOjcKrRNjQkPHaw
golang.org/x/exp v0.0.0-20190312203227-4b39c73a6495/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522 h1:OeRHuibLsmZkFj773W4LcfAGsSxJgfPONhr8cmO+eLA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
golang.org/x/exp v0.0.0-20190731235908-ec7cb31e5a56/go.mod h1:JhuoJpWY28nO4Vef9tZUw9qufEGTyX1+7lmHxV5q5G4=
golang.org/x/exp v0.0.0-20190829153037-c13cbed26979 h1:Agxu5KLo8o7Bb634SVDnhIfpTvxmzUwhbYAzBvXt6h4=
golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek=
golang.org/x/exp v0.0.0-20191129062945-2f5052295587 h1:5Uz0rkjCFu9BC9gCRN7EkwVvhNyQgGWb8KNJrPwBoHY=
golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81 h1:00VmoueYNlNz/aHIilyyQz/MHSqGoWJzpFv/HW8xpzI=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067 h1:KYGJGHOQy8oSi1fDlSpcZF0+juKwk/hEMv5SiwHogR0=
Expand All @@ -87,17 +95,20 @@ golang.org/x/image v0.0.0-20190902063713-cb417be4ba39/go.mod h1:FeLwcggjj3mMvU+o
golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
golang.org/x/mobile v0.0.0-20190607214518-6fa95d984e88/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
golang.org/x/mobile v0.0.0-20190830201351-c6da95954960/go.mod h1:mJOp/i0LXPxJZ9weeIadcPqKVfS05Ai7m6/t9z1Hs/Y=
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190611141213-3f473d35a33a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313 h1:pczuHS43Cp2ktBEEmLwScxgjWsBSzdaQiKzUyf3DTTc=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190610200419-93c9922d18ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
Expand All @@ -112,6 +123,8 @@ golang.org/x/tools v0.0.0-20190611222205-d73e1c7e250b/go.mod h1:/rFqwRUd4F7ZHNgw
golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20190905235650-93dcc2f048f5 h1:xU4gBaA7ny56EkBSp9Uw2MVovJDupIfONnEOZ+FChTY=
golang.org/x/tools v0.0.0-20190905235650-93dcc2f048f5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a h1:TwMENskLwU2NnWBzrJGEWHqSiGUkO/B4rfyhwqDxDYQ=
golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
gonum.org/v1/gonum v0.0.0-20190331200053-3d26580ed485 h1:OB/uP/Puiu5vS5QMRPrXCDWUPb+kt8f1KW8oQzFejQw=
Expand Down
2 changes: 1 addition & 1 deletion linear_model/logistic.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func NewLogisticRegression() *LogisticRegression {
func (m *LogisticRegression) PredicterClone() base.Predicter {
clone := *m
if sc, ok := m.RandomState.(base.SourceCloner); ok {
clone.RandomState = sc.Clone()
clone.RandomState = sc.SourceClone()
}
return &clone
}
Expand Down
2 changes: 1 addition & 1 deletion model_selection/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (gscv *GridSearchCV) PredicterClone() base.Predicter {
}
clone := *gscv
if sourceCloner, ok := clone.RandomState.(base.SourceCloner); ok && sourceCloner != base.SourceCloner(nil) {
clone.RandomState = sourceCloner.Clone()
clone.RandomState = sourceCloner.SourceClone()
}
return &clone
}
Expand Down
10 changes: 6 additions & 4 deletions model_selection/search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ func chkRandomState(rs rand.Source) {
panic(fmt.Errorf("wrong random state\nexpected:%s\n%s\ngot :%s\n%s", expected, "", got, ""))
}
}

func ExampleGridSearchCV() {
RandomState := base.NewSource(7)
RandomState := base.NewLockedSource(7)
ds := datasets.LoadBoston()
X, Y := preprocessing.NewStandardScaler().FitTransform(ds.X, ds.Y)

Expand All @@ -99,14 +100,15 @@ func ExampleGridSearchCV() {
mlp.BatchSize = 20
mlp.LearningRateInit = .005
mlp.MaxIter = 100

scorer := func(Y, Ypred mat.Matrix) float64 {
return metrics.MeanSquaredError(Y, Ypred, nil, "").At(0, 0)
}
gscv := &GridSearchCV{
Estimator: mlp,
ParamGrid: map[string][]interface{}{
"Alpha": {2e-4, 5e-4, 1e-3},
"WeightDecay": {.0002, .0001, 0},
"Alpha": {1e-4, 2e-4, 5e-4, 1e-3},
"WeightDecay": {1e-4, 1e-5, 1e-6,1e-7,1e-8, 0},
},
Scorer: scorer,
LowerScoreIsBetter: true,
Expand All @@ -120,7 +122,7 @@ func ExampleGridSearchCV() {

// Output:
// Alpha 0.0002
// WeightDecay 0
// WeightDecay 1e-06

}

Expand Down
32 changes: 17 additions & 15 deletions model_selection/split.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package modelselection

import (
"github.com/pa-m/sklearn/base"
"math"
"sort"

"github.com/pa-m/sklearn/base"
"golang.org/x/exp/rand"

"gonum.org/v1/gonum/mat"
)

Expand Down Expand Up @@ -41,7 +39,7 @@ func (splitter *KFold) SplitterClone() Splitter {
}
clone := *splitter
if sourceCloner, ok := clone.RandomState.(base.SourceCloner); ok && sourceCloner != base.SourceCloner(nil) {
clone.RandomState = sourceCloner.Clone()
clone.RandomState = sourceCloner.SourceClone()
}
return &clone
}
Expand Down Expand Up @@ -114,29 +112,33 @@ func (splitter *KFold) GetNSplits(X, Y *mat.Dense) int {

// TrainTestSplit splits X and Y into test set and train set
// testsize must be between 0 and 1
// it does'nt yet produce same sets than scikit-learn du to a different shuffle method
// it produce same sets than scikit-learn
func TrainTestSplit(X, Y mat.Matrix, testsize float64, randomstate uint64) (Xtrain, Xtest, ytrain, ytest *mat.Dense) {
NSamples, NFeatures := X.Dims()
_, NOutputs := Y.Dims()
var testlen int
if testsize > 1 {
testlen = int(math.Round(math.Min(float64(NSamples), testsize)))
testlen = int(math.Ceil(math.Min(float64(NSamples), testsize)))
} else {
testlen = int(math.Round(float64(NSamples) * testsize))
testlen = int(math.Ceil(float64(NSamples) * testsize))
}
Xtest = mat.NewDense(testlen, NFeatures, nil)
ytest = mat.NewDense(testlen, NOutputs, nil)
Xtrain = mat.NewDense(NSamples-testlen, NFeatures, nil)
ytrain = mat.NewDense(NSamples-testlen, NOutputs, nil)
src := base.NewLockedSource(randomstate)
shuffler := rand.New(src)
ind := make([]int, NSamples)
for i := range ind {
ind[i] = i
}
//shuffle ind
slice := sort.IntSlice(ind)
shuffler.Shuffle(slice.Len(), slice.Swap)

var ind []int
src.WithLock(func(src base.Source) {
permer, ok := src.(base.Permer)
if !ok {
panic("Source does not implement Perm")
}
{
ind = permer.Perm(NSamples)
}

})
for i := 0; i < NSamples; i++ {
j := ind[i]
if i < testlen {
Expand Down
79 changes: 67 additions & 12 deletions model_selection/split_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package modelselection

import (
"fmt"
"github.com/pa-m/randomkit"
"testing"

"github.com/pa-m/sklearn/datasets"

"github.com/pa-m/sklearn/base"
"golang.org/x/exp/rand"
Expand Down Expand Up @@ -50,15 +50,70 @@ func perm(r base.Intner, n int) []int {
return m

}
func TestTrainTestSplit(t *testing.T) {
rs := randomkit.NewRandomkitSource(42)
NSamples := 178
ind := make([]int, NSamples)
for i := range ind {
ind[i] = i
}
permer := rand.New(rs)
ind = permer.Perm(178)
fmt.Println(ind)

func _ExampleTrainTestSplit() {

features, target := datasets.LoadWine().GetXY()
RandomState := uint64(42)
_, _, Ytrain, Ytest := TrainTestSplit(features, target, .30, RandomState)
Ntrain, _ := Ytrain.Dims()
ytrain := make([]float64, Ntrain)
mat.Col(ytrain, 0, Ytrain)
fmt.Println(ytrain[:8])
Ntest, _ := Ytest.Dims()
ytest := make([]float64, Ntest)
mat.Col(ytest, 0, Ytest)
fmt.Println(ytest[:8])

// Output:
//[2 1 1 0 1 0 2 1]
//[0 0 2 0 1 0 1 2]
}

func ExampleTrainTestSplit() {
/*
>>> import numpy as np
>>> from sklearn.model_selection import train_test_split
>>> X, y = np.arange(10).reshape((5, 2)), range(5)
>>> X_train, X_test, y_train, y_test = train_test_split(
... X, y, test_size=0.33, random_state=42)
...
>>> X_train
array([[4, 5],
[0, 1],
[6, 7]])
>>> y_train
[2, 0, 3]
>>> X_test
array([[2, 3],
[8, 9]])
>>> y_test
[1, 4]
*/
X := mat.NewDense(5, 2, []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
Y := mat.NewDense(5, 1, []float64{0, 1, 2, 3, 4})
RandomState := uint64(42)
Xtrain, Xtest, Ytrain, Ytest := TrainTestSplit(X, Y, .33, RandomState)
fmt.Printf("X_train:\n%g\n", mat.Formatted(Xtrain))
fmt.Printf("Y_train:\n%g\n", mat.Formatted(Ytrain))
fmt.Printf("X_test:\n%g\n", mat.Formatted(Xtest))
fmt.Printf("Y_test:\n%g\n", mat.Formatted(Ytest))

// Output:
//X_train:
//⎡4 5⎤
//⎢0 1⎥
//⎣6 7⎦
//Y_train:
//⎡2⎤
//⎢0⎥
//⎣3⎦
//X_test:
//⎡2 3⎤
//⎣8 9⎦
//Y_test:
//⎡1⎤
//⎣4⎦

}
Loading

0 comments on commit 469516e

Please sign in to comment.