Skip to content

Commit

Permalink
apply weight recalculation when using L1/L2-SVM dual solver
Browse files Browse the repository at this point in the history
  • Loading branch information
maclin726 committed Jan 7, 2025
1 parent 6d955f2 commit 450cfea
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions libmultilabel/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import scipy.sparse as sparse
from liblinear.liblinearutil import train
from liblinear.liblinearutil import train, problem, parameter
from tqdm import tqdm

__all__ = [
Expand Down Expand Up @@ -333,8 +333,11 @@ def _do_train(y: np.ndarray, x: sparse.csr_matrix, options: str) -> np.matrix:
if y.shape[0] == 0:
return np.matrix(np.zeros((x.shape[1], 1)))

prob = problem(y, x)
param = parameter(options)
param.w_recalc = True # only works for solving L1/L2-SVM dual
with silent_stderr():
model = train(y, x, options)
model = train(prob, param)

w = np.ctypeslib.as_array(model.w, (x.shape[1], 1))
w = np.asmatrix(w)
Expand Down Expand Up @@ -592,8 +595,11 @@ def train_binary_and_multiclass(
Invalid dataset. Only multi-class dataset is allowed."""
y = np.squeeze(nonzero_label_ids)

prob = problem(2 * y - 1, x)
param = parameter(options)
param.w_recalc = False
with silent_stderr():
model = train(y, x, options)
model = train(prob, param)

# Labels appeared in training set; length may be smaller than num_labels
train_labels = np.array(model.get_labels(), dtype="int")
Expand Down

0 comments on commit 450cfea

Please sign in to comment.