Skip to content

Commit

Permalink
polish,docs; move pairwise() to sklearn_seco.util
Browse files Browse the repository at this point in the history
  • Loading branch information
azrdev committed Aug 8, 2019
1 parent a8b068a commit 2717572
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 13 deletions.
3 changes: 2 additions & 1 deletion evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def plot_speedup(results_log_file: str,
import seaborn
seaborn.set(**seaborn_style)
import matplotlib.pyplot as plt
fig: plt.Figure = plt.figure()
fig: plt.Figure = plt.figure(figsize=(9, 4))
ax: plt.Axes = fig.gca(xscale='log', yscale='log')

for other in _OTHER_ALGO:
Expand All @@ -625,6 +625,7 @@ def plot_speedup(results_log_file: str,
t4[('runtime_cv', 'sklearn_seco.Ripper')].values,
t4['speedup'].values)

fig.tight_layout()
if outfile_pattern is not None:
fig.savefig(outfile_pattern)
return fig
Expand Down
5 changes: 5 additions & 0 deletions seco_runtime_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def plot_timings(timings, title=None, figure=None,
axes = figure.gca(xlabel='n_features', ylabel='time[s]')
if title is not None:
axes.set_title(title)

timings = np.asarray(timings)
# sort by n_features so plot lines make sense
timings = timings[np.argsort(timings[:, 1])]

n_samples = timings.T[0]
n_features = timings.T[1]
tm_min = timings.T[2]
Expand Down
1 change: 1 addition & 0 deletions sklearn_seco/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def wrapper_ordering_classes_by_size(estimator):
self.base_estimator_ = OneVsOneClassifier(self.base_estimator_,
n_jobs=self.n_jobs)
elif self.multi_class_ == "direct":
# TODO: if self.multi_class=='direct' (not `None` auto-detect), only assertion prevents binary-only learner to silently learn on multiclass training data
self.base_estimator_ = wrapper_ordering_classes_by_size(
self.base_estimator_)
else:
Expand Down
3 changes: 2 additions & 1 deletion sklearn_seco/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,10 @@ class RuleContext:
Methods & Properties provided are:
- `PN`
- `X` and `y`
- `match_rule`
- `pn`
- `match_rule`
- `evaluate_rule`
- `sort_key`
Fields
-----
Expand Down
11 changes: 1 addition & 10 deletions sklearn_seco/concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,7 @@
AbstractSecoImplementation, RuleContext, TheoryContext
from sklearn_seco.ripper_mdl import \
data_description_length, relative_description_length
from sklearn_seco.util import log2


def pairwise(iterable):
"""s -> (s0,s1), (s1,s2), (s2, s3), ..."""
# copied from itertools docs
from itertools import tee
a, b = tee(iterable)
next(b, None)
return zip(a, b)
from sklearn_seco.util import log2, pairwise


def grow_prune_split(y,
Expand Down
1 change: 0 additions & 1 deletion sklearn_seco/tests/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import Union

import numpy as np
import pytest
from sklearn.datasets import make_blobs, make_classification, make_moons
from sklearn.externals import _arff as arff
from sklearn.utils import check_random_state, Bunch
Expand Down
9 changes: 9 additions & 0 deletions sklearn_seco/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ def log2(x: float) -> float:
return math.log2(x) if x > 0 else 0


def pairwise(iterable):
"""s -> (s0,s1), (s1,s2), (s2, s3), ..."""
# copied from itertools docs
from itertools import tee
a, b = tee(iterable)
next(b, None)
return zip(a, b)


def build_categorical_mask(which_features, n_features: int
) -> np.ndarray or None:
""":return: A mask array of length `n_features` based on `which_features`.
Expand Down

0 comments on commit 2717572

Please sign in to comment.