Skip to content

Commit

Permalink
MISC renamed n_iterations to n_iter in all other places.
Browse files Browse the repository at this point in the history
  • Loading branch information
amueller committed Oct 10, 2012
1 parent 0a81827 commit 9bd0aad
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 33 deletions.
18 changes: 9 additions & 9 deletions benchmarks/bench_plot_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sklearn.datasets.samples_generator import make_low_rank_matrix


def compute_bench(samples_range, features_range, n_iterations=3, rank=50):
def compute_bench(samples_range, features_range, n_iter=3, rank=50):

it = 0

Expand All @@ -36,19 +36,19 @@ def compute_bench(samples_range, features_range, n_iterations=3, rank=50):
results['scipy svd'].append(time() - tstart)

gc.collect()
print "benching scikit-learn randomized_svd: n_iterations=0"
print "benching scikit-learn randomized_svd: n_iter=0"
tstart = time()
randomized_svd(X, rank, n_iterations=0)
results['scikit-learn randomized_svd (n_iterations=0)'].append(
randomized_svd(X, rank, n_iter=0)
results['scikit-learn randomized_svd (n_iter=0)'].append(
time() - tstart)

gc.collect()
print ("benching scikit-learn randomized_svd: n_iterations=%d "
% n_iterations)
print ("benching scikit-learn randomized_svd: n_iter=%d "
% n_iter)
tstart = time()
randomized_svd(X, rank, n_iterations=n_iterations)
results['scikit-learn randomized_svd (n_iterations=%d)'
% n_iterations].append(time() - tstart)
randomized_svd(X, rank, n_iter=n_iter)
results['scikit-learn randomized_svd (n_iter=%d)'
% n_iter].append(time() - tstart)

return results

Expand Down
2 changes: 1 addition & 1 deletion examples/applications/wikipedia_principal_eigenvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def get_adjacency_matrix(redirects_filename, page_links_filename, limit=None):

print "Computing the principal singular vectors using randomized_svd"
t0 = time()
U, s, V = randomized_svd(X, 5, n_iterations=3)
U, s, V = randomized_svd(X, 5, n_iter=3)
print "done in %0.3fs" % (time() - t0)

# print the names of the wikipedia related strongest compenents of the the
Expand Down
2 changes: 1 addition & 1 deletion examples/svm/plot_svm_scale_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
# reduce the variance
grid = GridSearchCV(clf, refit=False, param_grid=param_grid,
cv=ShuffleSplit(n=n_samples, train_size=train_size,
n_iterations=250, random_state=1))
n_iter=250, random_state=1))
grid.fit(X, y)
scores = [x[1] for x in grid.grid_scores_]

Expand Down
14 changes: 7 additions & 7 deletions sklearn/cluster/k_means_.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ def _mini_batch_step(X, x_squared_norms, centers, counts,
return inertia, squared_diff


def _mini_batch_convergence(model, iteration_idx, n_iterations, tol,
def _mini_batch_convergence(model, iteration_idx, n_iter, tol,
n_samples, centers_squared_diff, batch_inertia,
context, verbose=0):
"""Helper function to encapsulte the early stopping logic"""
Expand Down Expand Up @@ -926,7 +926,7 @@ def _mini_batch_convergence(model, iteration_idx, n_iterations, tol,
progress_msg = (
'Minibatch iteration %d/%d:'
'mean batch inertia: %f, ewa inertia: %f ' % (
iteration_idx + 1, n_iterations, batch_inertia,
iteration_idx + 1, n_iter, batch_inertia,
ewa_inertia))
print progress_msg

Expand All @@ -935,7 +935,7 @@ def _mini_batch_convergence(model, iteration_idx, n_iterations, tol,
if tol > 0.0 and ewa_diff < tol:
if verbose:
print 'Converged (small centers change) at iteration %d/%d' % (
iteration_idx + 1, n_iterations)
iteration_idx + 1, n_iter)
return True

# Early stopping heuristic due to lack of improvement on smoothed inertia
Expand All @@ -952,7 +952,7 @@ def _mini_batch_convergence(model, iteration_idx, n_iterations, tol,
if verbose:
print ('Converged (lack of improvement in inertia)'
' at iteration %d/%d' % (
iteration_idx + 1, n_iterations))
iteration_idx + 1, n_iter))
return True

# update the convergence context to maintain state across sucessive calls:
Expand Down Expand Up @@ -1102,7 +1102,7 @@ def fit(self, X, y=None):

distances = np.zeros(self.batch_size, dtype=np.float64)
n_batches = int(np.ceil(float(n_samples) / self.batch_size))
n_iterations = int(self.max_iter * n_batches)
n_iter = int(self.max_iter * n_batches)

init_size = self.init_size
if init_size is None:
Expand Down Expand Up @@ -1158,7 +1158,7 @@ def fit(self, X, y=None):

# Perform the iterative optimization untill the final convergence
# criterion
for iteration_idx in xrange(n_iterations):
for iteration_idx in xrange(n_iter):

# Sample the minibatch from the full dataset
minibatch_indices = self.random_state.random_integers(
Expand All @@ -1172,7 +1172,7 @@ def fit(self, X, y=None):

# Monitor the convergence and do early stopping if necessary
if _mini_batch_convergence(
self, iteration_idx, n_iterations, tol, n_samples,
self, iteration_idx, n_iter, tol, n_samples,
centers_squared_diff, batch_inertia, convergence_context,
verbose=self.verbose):
break
Expand Down
2 changes: 1 addition & 1 deletion sklearn/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def fit(self, X, y=None):
n_components = self.n_components

U, S, V = randomized_svd(X, n_components,
n_iterations=self.iterated_power,
n_iter=self.iterated_power,
random_state=self.random_state)

self.explained_variance_ = exp_var = (S ** 2) / n_samples
Expand Down
25 changes: 18 additions & 7 deletions sklearn/utils/extmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Authors: G. Varoquaux, A. Gramfort, A. Passos, O. Grisel
# License: BSD

import warnings
import numpy as np
from scipy import linalg

Expand Down Expand Up @@ -78,7 +79,8 @@ def safe_sparse_dot(a, b, dense_output=False):
return np.dot(a, b)


def randomized_range_finder(A, size, n_iterations, random_state=None):
def randomized_range_finder(A, size, n_iter, random_state=None,
n_iterations=None):
"""Computes an orthonormal matrix whose range approximates the range of A.
Parameters
Expand All @@ -87,7 +89,7 @@ def randomized_range_finder(A, size, n_iterations, random_state=None):
The input data matrix
size: integer
Size of the return array
n_iterations: integer
n_iter: integer
Number of power iterations used to stabilize the result
random_state: RandomState or an int seed (0 by default)
A random number generator instance
Expand All @@ -106,6 +108,10 @@ def randomized_range_finder(A, size, n_iterations, random_state=None):
approximate matrix decompositions
Halko, et al., 2009 (arXiv:909) http://arxiv.org/pdf/0909.4061
"""
if n_iterations is not None:
warnings.warn("n_iterations was renamed to n_iter for consistency "
"and will be removed in 0.16.", DeprecationWarning)
n_iter = n_iterations
random_state = check_random_state(random_state)

# generating random gaussian vectors r with shape: (A.shape[1], size)
Expand All @@ -117,16 +123,16 @@ def randomized_range_finder(A, size, n_iterations, random_state=None):

# perform power iterations with Y to further 'imprint' the top
# singular vectors of A in Y
for i in xrange(n_iterations):
for i in xrange(n_iter):
Y = safe_sparse_dot(A, safe_sparse_dot(A.T, Y))

# extracting an orthonormal basis of the A range samples
Q, R = qr_economic(Y)
return Q


def randomized_svd(M, n_components, n_oversamples=10, n_iterations=0,
transpose='auto', random_state=0):
def randomized_svd(M, n_components, n_oversamples=10, n_iter=0,
transpose='auto', random_state=0, n_iterations=None):
"""Computes a truncated randomized SVD
Parameters
Expand All @@ -142,7 +148,7 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iterations=0,
to ensure proper conditioning. The total number of random vectors
used to find the range of M is n_components + n_oversamples.
n_iterations: int (default is 0)
n_iter: int (default is 0)
Number of power iterations (can be used to deal with very noisy
problems).
Expand Down Expand Up @@ -172,6 +178,11 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iterations=0,
* A randomized algorithm for the decomposition of matrices
Per-Gunnar Martinsson, Vladimir Rokhlin and Mark Tygert
"""
if n_iterations is not None:
warnings.warn("n_iterations was renamed to n_iter for consistency "
"and will be removed in 0.16.", DeprecationWarning)
n_iter = n_iterations

random_state = check_random_state(random_state)
n_random = n_components + n_oversamples
n_samples, n_features = M.shape
Expand All @@ -182,7 +193,7 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iterations=0,
# this implementation is a bit faster with smaller shape[1]
M = M.T

Q = randomized_range_finder(M, n_random, n_iterations, random_state)
Q = randomized_range_finder(M, n_random, n_iter, random_state)

# project M to the (k + p) dimensional space using the basis vectors
B = safe_sparse_dot(Q.T, M)
Expand Down
14 changes: 7 additions & 7 deletions sklearn/utils/tests/test_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ def test_randomized_svd_low_rank_with_noise():

# compute the singular values of X using the fast approximate method
# without the iterated power method
_, sa, _ = randomized_svd(X, k, n_iterations=0)
_, sa, _ = randomized_svd(X, k, n_iter=0)

# the approximation does not tolerate the noise:
assert_greater(np.abs(s[:k] - sa).max(), 0.05)

# compute the singular values of X using the fast approximate method with
# iterated power method
_, sap, _ = randomized_svd(X, k, n_iterations=5)
_, sap, _ = randomized_svd(X, k, n_iter=5)

# the iterated power method is helping getting rid of the noise:
assert_almost_equal(s[:k], sap, decimal=3)
Expand All @@ -100,14 +100,14 @@ def test_randomized_svd_infinite_rank():

# compute the singular values of X using the fast approximate method
# without the iterated power method
_, sa, _ = randomized_svd(X, k, n_iterations=0)
_, sa, _ = randomized_svd(X, k, n_iter=0)

# the approximation does not tolerate the noise:
assert_greater(np.abs(s[:k] - sa).max(), 0.1)

# compute the singular values of X using the fast approximate method with
# iterated power method
_, sap, _ = randomized_svd(X, k, n_iterations=5)
_, sap, _ = randomized_svd(X, k, n_iter=5)

# the iterated power method is still managing to get most of the structure
# at the requested rank
Expand All @@ -125,11 +125,11 @@ def test_randomized_svd_transpose_consistency():
effective_rank=rank, tail_strength=0.5, random_state=0)
assert_equal(X.shape, (n_samples, n_features))

U1, s1, V1 = randomized_svd(X, k, n_iterations=3, transpose=False,
U1, s1, V1 = randomized_svd(X, k, n_iter=3, transpose=False,
random_state=0)
U2, s2, V2 = randomized_svd(X, k, n_iterations=3, transpose=True,
U2, s2, V2 = randomized_svd(X, k, n_iter=3, transpose=True,
random_state=0)
U3, s3, V3 = randomized_svd(X, k, n_iterations=3, transpose='auto',
U3, s3, V3 = randomized_svd(X, k, n_iter=3, transpose='auto',
random_state=0)
U4, s4, V4 = linalg.svd(X, full_matrices=False)

Expand Down

0 comments on commit 9bd0aad

Please sign in to comment.