Skip to content

Commit

Permalink
- improved coverage
Browse files Browse the repository at this point in the history
- mlp: add optional batch normalization
- add FitTransform to Transformer
- added ExampleBaseMultilayerPerceptron32_Fit_mnist
  • Loading branch information
pa-m committed Apr 15, 2019
1 parent c8994e9 commit 4af17ee
Show file tree
Hide file tree
Showing 33 changed files with 715 additions and 252 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ before_install:
- go get -t -v ./... && go build -v ./...

script:
- go test -race -coverprofile=coverage.txt -covermode=atomic ./...
- go test -short -parallel 4 -race -coverprofile=coverage.txt -covermode=atomic ./...

after_success:
- bash <(curl -s https://codecov.io/bash)
13 changes: 3 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ Partial port of scikit-learn to [go](http://golang.org)
[![Go Report Card](https://goreportcard.com/badge/github.com/pa-m/sklearn)](https://goreportcard.com/report/github.com/pa-m/sklearn)
[![GoDoc](https://godoc.org/github.com/pa-m/sklearn?status.svg)](https://godoc.org/github.com/pa-m/sklearn)


## Examples
### cluster
[DBSCAN](https://godoc.org/github.com/pa-m/sklearn/cluster#example-DBSCAN) [KMeans](https://godoc.org/github.com/pa-m/sklearn/cluster#example-KMeans)
Expand All @@ -24,7 +23,7 @@ Partial port of scikit-learn to [go](http://golang.org)
### neighbors
[KNeighborsClassifier](https://godoc.org/github.com/pa-m/sklearn/neighbors#example-KNeighborsClassifier) [MinkowskiDistance](https://godoc.org/github.com/pa-m/sklearn/neighbors#example-MinkowskiDistance) [EuclideanDistance](https://godoc.org/github.com/pa-m/sklearn/neighbors#example-EuclideanDistance) [KDTree](https://godoc.org/github.com/pa-m/sklearn/neighbors#example-KDTree) [NearestCentroid](https://godoc.org/github.com/pa-m/sklearn/neighbors#example-NearestCentroid) [KNeighborsRegressor](https://godoc.org/github.com/pa-m/sklearn/neighbors#example-KNeighborsRegressor) [NearestNeighbors](https://godoc.org/github.com/pa-m/sklearn/neighbors#example-NearestNeighbors) [NearestNeighbors.KNeighborsGraph](https://godoc.org/github.com/pa-m/sklearn/neighbors#example-NearestNeighbors-KNeighborsGraph) [NearestNeighbors.Tree](https://godoc.org/github.com/pa-m/sklearn/neighbors#example-NearestNeighbors-Tree)
### neural_network
[MLPClassifier](https://godoc.org/github.com/pa-m/sklearn/neural_network#example-MLPClassifier) [MLPRegressor](https://godoc.org/github.com/pa-m/sklearn/neural_network#example-MLPRegressor)
[MLPClassifier.Unmarshal](https://godoc.org/github.com/pa-m/sklearn/neural_network#example-MLPClassifier-Unmarshal) [MLPClassifier.Fit.mnist](https://godoc.org/github.com/pa-m/sklearn/neural_network#example-MLPClassifier-Fit-mnist) [MLPClassifier.Predict.mnist](https://godoc.org/github.com/pa-m/sklearn/neural_network#example-MLPClassifier-Predict-mnist) [MLPClassifier.Fit.breast.cancer](https://godoc.org/github.com/pa-m/sklearn/neural_network#example-MLPClassifier-Fit-breast-cancer) [MLPRegressor.Fit.boston](https://godoc.org/github.com/pa-m/sklearn/neural_network#example-MLPRegressor-Fit-boston)
### pipeline
[Pipeline](https://godoc.org/github.com/pa-m/sklearn/pipeline#example-Pipeline)
### preprocessing
Expand All @@ -34,20 +33,14 @@ Partial port of scikit-learn to [go](http://golang.org)



This is

- a personal project to get a deeper understanding of how all of this magic works

- a recent work still in progress, subject to refactoring, so interfaces may change, especially args to NewXXX
This is a personal project to get a deeper understanding of how all of this magic works

- linted with ~~gofmt, golint, go vet~~ [revive](https://github.com/mgechev/revive)

- unit tested but coverage should reach 90%

- underdocumented but [scikit-learn doc](http://scikit-learn.org/stable/documentation.html) is your friend


Many thanks to gonum and scikit-learn authors and contributors

PRs are welcome

PRs are welcome
1 change: 1 addition & 0 deletions base/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ type Predicter interface {
type Transformer interface {
Fiter
Transform(X, Y mat.Matrix) (Xout, Yout *mat.Dense)
FitTransform(X, Y mat.Matrix) (Xout, Yout *mat.Dense)
TransformerClone() Transformer
}
2 changes: 0 additions & 2 deletions base/matrix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
"io"
"log"
"math"

"golang.org/x/exp/rand"
Expand Down Expand Up @@ -421,7 +420,6 @@ func ToDense(m mat.Matrix) *mat.Dense {
// FromDense fills dst (mat.Mutable) with src (mat.Dense)
func FromDense(dst mat.Mutable, dense *mat.Dense) *mat.Dense {
if dst == mat.Mutable(nil) {
log.Println("warning dst is nil")
return dense
}
src := dense.RawMatrix()
Expand Down
16 changes: 14 additions & 2 deletions base/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package base
import (
"sync"

"golang.org/x/exp/rand"

"github.com/pa-m/randomkit"
)

Expand All @@ -13,11 +15,16 @@ type Source interface {
Seed(seed uint64)
}

// SourceCloner is an "golang.org/x/exp/rand".Source with a Clone method
type SourceCloner interface {
Clone() rand.Source
}

// RandomState represents a bit more than random_state pythonic attribute. it's not only a seed but a source with a state as it's name states
type RandomState = Source

// NewSource returns a new pseudo-random Source seeded with the given value.
func NewSource(seed uint64) Source {
func NewSource(seed uint64) *randomkit.RKState {
var rng randomkit.RKState
rng.Seed(seed)
return &rng
Expand Down Expand Up @@ -45,8 +52,13 @@ func (s *LockedSource) Seed(seed uint64) {
s.lk.Unlock()
}

// Clone ...
func (s *LockedSource) Clone() rand.Source {
return &LockedSource{src: s.src.(SourceCloner).Clone()}
}

// NewLockedSource returns a rand.Source safe for concurrent access
func NewLockedSource(seed uint64) Source {
func NewLockedSource(seed uint64) *LockedSource {
var s LockedSource
s.src = NewSource(seed)
return &s
Expand Down
14 changes: 14 additions & 0 deletions base/source_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package base

import (
"golang.org/x/exp/rand"

"github.com/pa-m/randomkit"
)

var (
_ = []rand.Source{&randomkit.RKState{}, &LockedSource{}}
_ = []Float64er{&randomkit.RKState{}}
_ = []NormFloat64er{&randomkit.RKState{}}
_ = []SourceCloner{&LockedSource{}}
)
21 changes: 21 additions & 0 deletions cluster/dbscan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"image/color"
"os"
"os/exec"
"testing"
"time"

"github.com/pa-m/sklearn/datasets"
Expand All @@ -19,6 +20,26 @@ import (

var visualDebug = flag.Bool("visual", false, "output images for benchmarks and test data")

func TestDBSCAN_PredicterClone(t *testing.T) {
m := NewDBSCAN(&DBSCANConfig{})
clone := m.PredicterClone()
if fmt.Sprintf("+%v", clone) != fmt.Sprintf("+%v", m) {
t.Errorf("cloning failed \n%+v, \n%+v", m, clone)
}
}

func TestDBSCAN_IsClassifier(t *testing.T) {
if !NewDBSCAN(&DBSCANConfig{}).IsClassifier() {
t.Fail()
}
}

func TestDBSCAN_Predict(t *testing.T) {
s := fmt.Sprintf("%#v", (&DBSCAN{Labels: []int{1, 2, 3}}).Predict(mat.NewDense(3, 1, nil), nil).RawMatrix().Data)
if s != "[]float64{1, 2, 3}" {
t.Fail()
}
}
func ExampleDBSCAN() {
// adapted from http://scikit-learn.org/stable/_downloads/plot_dbscan.ipynb
// Generate sample data
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module github.com/pa-m/sklearn

require (
github.com/chewxy/math32 v1.0.0
github.com/pa-m/randomkit v0.0.0-20190402202301-70c3c46153e1
github.com/pa-m/randomkit v0.0.0-20190414101838-b61cec1ec1e3
golang.org/x/exp v0.0.0-20190321205749-f0864edee7f3
golang.org/x/tools v0.0.0-20190401205534-4c644d7e323d // indirect
gonum.org/v1/gonum v0.0.0-20190331200053-3d26580ed485
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ github.com/llgcode/draw2d v0.0.0-20180817132918-587a55234ca2 h1:3xDkT1Tbsw2yDtKW
github.com/llgcode/draw2d v0.0.0-20180817132918-587a55234ca2/go.mod h1:mVa0dA29Db2S4LVqDYLlsePDzRJLDfdhVZiI15uY0FA=
github.com/llgcode/ps v0.0.0-20150911083025-f1443b32eedb h1:61ndUreYSlWFeCY44JxDDkngVoI7/1MVhEl98Nm0KOk=
github.com/llgcode/ps v0.0.0-20150911083025-f1443b32eedb/go.mod h1:1l8ky+Ew27CMX29uG+a2hNOKpeNYEQjjtiALiBlFQbY=
github.com/pa-m/randomkit v0.0.0-20190402202301-70c3c46153e1 h1:K+j2m1I9BV3KrGtpZG/vnzp6q/rFeJJaEXUZe69g+/E=
github.com/pa-m/randomkit v0.0.0-20190402202301-70c3c46153e1/go.mod h1:iloaywGVzk8xNSb2ZSf40GA6MwM7OXPEzSXHvf8zHdc=
github.com/pa-m/randomkit v0.0.0-20190414101838-b61cec1ec1e3 h1:/JmnvF6yaoqMtHE3aZuFnkvlfhyO1b6PyUsO5T+cQhg=
github.com/pa-m/randomkit v0.0.0-20190414101838-b61cec1ec1e3/go.mod h1:iloaywGVzk8xNSb2ZSf40GA6MwM7OXPEzSXHvf8zHdc=
github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446/go.mod h1:uYEyJGbgTkfkS4+E/PavXkNJcbFIpEtjt2B0KDQ5+9M=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
Expand Down
18 changes: 0 additions & 18 deletions linear_model/Base.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,6 @@ func (regr *LinearRegression) Predict(X mat.Matrix, Ymutable mat.Mutable) *mat.D
return base.FromDense(Ymutable, Y)
}

// FitTransform is for Pipeline
func (regr *LinearRegression) FitTransform(X, Y *mat.Dense) (Xout, Yout *mat.Dense) {
r, c := Y.Dims()
Xout, Yout = X, mat.NewDense(r, c, nil)
regr.Fit(X, Y)
regr.Predict(X, Yout)
return
}

// SGDRegressor base struct
// should be named GonumOptimizeRegressor
// implemented as a per-output optimization of (possibly regularized) square-loss with gonum/optimize methods
Expand Down Expand Up @@ -296,15 +287,6 @@ func (regr *SGDRegressor) Predict(X mat.Matrix, Ymutable mat.Mutable) *mat.Dense
return base.FromDense(Ymutable, Y)
}

// FitTransform is for Pipeline
func (regr *SGDRegressor) FitTransform(X, Y *mat.Dense) (Xout, Yout *mat.Dense) {
r, c := Y.Dims()
Xout, Yout = X, mat.NewDense(r, c, nil)
regr.Fit(X, Y)
regr.Predict(X, Yout)
return
}

func unused(...interface{}) {}

// LinFitOptions are options for LinFit
Expand Down
9 changes: 0 additions & 9 deletions linear_model/bayes.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,3 @@ func (regr *BayesianRidge) Predict2(X, Y, yStd *mat.Dense) {
return sigmasSquaredData + 1./regr.Alpha
}, sigmasSquaredData)
}

// FitTransform is for Pipeline
func (regr *BayesianRidge) FitTransform(X, Y *mat.Dense) (Xout, Yout *mat.Dense) {
r, c := Y.Dims()
Xout, Yout = X, mat.NewDense(r, c, nil)
regr.Fit(X, Y)
regr.Predict(X, Yout)
return
}
96 changes: 31 additions & 65 deletions model_selection/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"strings"

"github.com/pa-m/sklearn/base"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/mat"
)
Expand Down Expand Up @@ -46,12 +47,13 @@ func ParameterGrid(paramGrid map[string][]interface{}) (out []map[string]interfa
type GridSearchCV struct {
Estimator base.Predicter
ParamGrid map[string][]interface{}
Scorer func(Ytrue, Ypred *mat.Dense) float64
Scorer func(Ytrue, Ypred mat.Matrix) float64
CV Splitter
Verbose bool
NJobs int
LowerScoreIsBetter bool
UseChannels bool
RandomState rand.Source

CVResults map[string][]interface{}
BestEstimator base.Predicter
Expand All @@ -63,7 +65,13 @@ type GridSearchCV struct {

// PredicterClone ...
func (gscv *GridSearchCV) PredicterClone() base.Predicter {
if gscv == nil {
return nil
}
clone := *gscv
if sourceCloner, ok := clone.RandomState.(base.SourceCloner); ok && sourceCloner != base.SourceCloner(nil) {
clone.RandomState = sourceCloner.Clone()
}
return &clone
}

Expand Down Expand Up @@ -98,97 +106,55 @@ func (gscv *GridSearchCV) Fit(Xmatrix, Ymatrix mat.Matrix) base.Fiter {
estCloner := gscv.Estimator
// get seed for all estimator clone

type ClonableRandomState interface {
Clone() base.Source
paramArray := ParameterGrid(gscv.ParamGrid)

if gscv.RandomState == rand.Source(nil) {
gscv.RandomState = base.NewSource(0)
}
var clonableRandomState ClonableRandomState
if rs, ok := getParam(gscv.Estimator, "RandomState"); ok {
if rs1, ok := rs.(ClonableRandomState); ok {
clonableRandomState = rs1
}
if gscv.CV == Splitter(nil) {
gscv.CV = &KFold{NSplits: 3, Shuffle: true, RandomState: gscv.RandomState}
}

paramArray := ParameterGrid(gscv.ParamGrid)
gscv.CVResults = make(map[string][]interface{})
for k := range gscv.ParamGrid {
gscv.CVResults[k] = make([]interface{}, len(paramArray))
}
gscv.CVResults["score"] = make([]interface{}, len(paramArray))

type structIn struct {
cvindex int
index int
params map[string]interface{}
estimator base.Predicter
cv Splitter
score float64
}
dowork := func(sin structIn) structIn {
sin.estimator = estCloner.PredicterClone()

if clonableRandomState != ClonableRandomState(nil) {

//setParam(sin.estimator, "RandomState", rand.New(base.NewLockedSource(clonesSeed)))
setParam(sin.estimator, "RandomState", clonableRandomState)
}

for k, v := range sin.params {
setParam(sin.estimator, k, v)
}
CV := gscv.CV.SplitterClone()
cvres := CrossValidate(sin.estimator, X, Y, nil, gscv.Scorer, CV, gscv.NJobs)
dowork := func(sin *structIn) {
cvres := CrossValidate(sin.estimator, X, Y, nil, gscv.Scorer, sin.cv, gscv.NJobs)
sin.score = floats.Sum(cvres.TestScore) / float64(len(cvres.TestScore))
bestFold := bestIdx(cvres.TestScore)
sin.estimator = cvres.Estimator[bestFold]
return sin
}
gscv.BestIndex = -1

/*if gscv.UseChannels { // use channels
chin := make(chan structIn)
chout := make(chan structIn)
worker := func(j int) {
for sin := range chin {
chout <- dowork(sin)
}
}
if gscv.NJobs <= 0 || gscv.NJobs > runtime.NumCPU() {
gscv.NJobs = runtime.NumCPU()
}
for j := 0; j < gscv.NJobs; j++ {
go worker(j)
}
for cvindex, params := range paramArray {
chin <- structIn{cvindex: cvindex, params: params}
}
close(chin)
for range paramArray {
sout := <-chout
for k, v := range sout.params {
gscv.CVResults[k][sout.cvindex] = v
}
gscv.CVResults["score"][sout.cvindex] = sout.score
if gscv.BestIndex == -1 || isBetter(sout.score, gscv.CVResults["score"][gscv.BestIndex].(float64)) {
gscv.BestIndex = sout.cvindex
gscv.BestEstimator = sout.estimator
gscv.BestParams = sout.params
gscv.BestScore = sout.score
}
}
close(chout)
} else*/{ // use sync.workGroup
{
sin := make([]structIn, len(paramArray))
for i, params := range paramArray {
sin[i] = structIn{cvindex: i, params: params, estimator: gscv.Estimator}
sin[i] = structIn{index: i, params: params, estimator: estCloner.PredicterClone(), cv: gscv.CV.SplitterClone()}
for k, v := range sin[i].params {
setParam(sin[i].estimator, k, v)
}
}
base.Parallelize(gscv.NJobs, len(paramArray), func(th, start, end int) {
for i := start; i < end; i++ {
sin[i] = dowork(sin[i])
gscv.CVResults["score"][sin[i].cvindex] = sin[i].score
dowork(&sin[i])
for k, v := range paramArray[i] {
gscv.CVResults[k][i] = v
}
gscv.CVResults["score"][i] = sin[i].score
}
})
for _, sout := range sin {
for i, sout := range sin {
if gscv.BestIndex == -1 || isBetter(sout.score, gscv.CVResults["score"][gscv.BestIndex].(float64)) {
gscv.BestIndex = sout.cvindex
gscv.BestIndex = i
gscv.BestEstimator = sout.estimator
gscv.BestParams = sout.params
gscv.BestScore = sout.score
Expand Down
Loading

0 comments on commit 4af17ee

Please sign in to comment.