Skip to content

Commit

Permalink
Merge branch 'fix-free' of https://github.com/braindatalab/BSI-Zoo in…
Browse files Browse the repository at this point in the history
…to fix-free
  • Loading branch information
anujanegi committed Jan 2, 2025
2 parents ade971e + d7ed170 commit 7e5c7d4
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions bsi_zoo/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,8 @@ def gprime(w):

alpha = alpha * alpha_max

# eigen_fields, sing, eigen_leads = _safe_svd(L, full_matrices=False)

# y->M
# L->gain
x = _solve_reweighted_lasso(
Expand All @@ -618,6 +620,7 @@ def iterative_L2(L, y, alpha=0.2, n_orient=1, max_iter=1000, max_iter_reweightin
for solving the following problem:
x^(k+1) <-- argmin_x ||y - Lx||^2_Fro + alpha * sum_i w_i^(k)|x_i|
Parameters
----------
L : array, shape (n_sensors, n_sources)
Expand Down Expand Up @@ -652,7 +655,12 @@ def gprime(w):
grp_norm2 = groups_norm2(w.copy(), n_orient)
return np.repeat(grp_norm2, n_orient).ravel() + eps

alpha_max = abs(L.T.dot(y)).max() / len(L)
if n_orient == 1:
alpha_max = abs(L.T.dot(y)).max() / len(L)
else:
n_dip_per_pos = 3
alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos)

alpha = alpha * alpha_max

x = _solve_reweighted_lasso(
Expand Down Expand Up @@ -711,7 +719,12 @@ def g(w):
def gprime(w):
return 2.0 * np.repeat(g(w), n_orient).ravel()

alpha_max = abs(L.T.dot(y)).max() / len(L)
if n_orient == 1:
alpha_max = abs(L.T.dot(y)).max() / len(L)
else:
n_dip_per_pos = 3
alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos)

alpha = alpha * alpha_max

x = _solve_reweighted_lasso(
Expand Down Expand Up @@ -796,7 +809,12 @@ def iterative_L1_typeII(
n_sensors, n_sources = L.shape
weights = np.ones(n_sources)

alpha_max = abs(L.T.dot(y)).max() / len(L)
if n_orient == 1:
alpha_max = abs(L.T.dot(y)).max() / len(L)
else:
n_dip_per_pos = 3
alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos)

alpha = alpha * alpha_max

if isinstance(cov, float):
Expand Down Expand Up @@ -895,7 +913,12 @@ def iterative_L2_typeII(
n_sensors, n_sources = L.shape
weights = np.ones(n_sources)

alpha_max = abs(L.T.dot(y)).max() / len(L)
if n_orient == 1:
alpha_max = abs(L.T.dot(y)).max() / len(L)
else:
n_dip_per_pos = 3
alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos)

alpha = alpha * alpha_max

if isinstance(cov, float):
Expand Down

0 comments on commit 7e5c7d4

Please sign in to comment.