From 450cfeaea742677dfee484ed331a7ed68a83a80e Mon Sep 17 00:00:00 2001 From: maclin726 Date: Tue, 7 Jan 2025 11:27:43 +0000 Subject: [PATCH] apply weight recalculation when using L1/L2-SVM dual solver --- libmultilabel/linear/linear.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/libmultilabel/linear/linear.py b/libmultilabel/linear/linear.py index 4bf839c6..6d1ff897 100644 --- a/libmultilabel/linear/linear.py +++ b/libmultilabel/linear/linear.py @@ -5,7 +5,7 @@ import numpy as np import scipy.sparse as sparse -from liblinear.liblinearutil import train +from liblinear.liblinearutil import train, problem, parameter from tqdm import tqdm __all__ = [ @@ -333,8 +333,11 @@ def _do_train(y: np.ndarray, x: sparse.csr_matrix, options: str) -> np.matrix: if y.shape[0] == 0: return np.matrix(np.zeros((x.shape[1], 1))) + prob = problem(y, x) + param = parameter(options) + param.w_recalc = True # only works for solving L1/L2-SVM dual with silent_stderr(): - model = train(y, x, options) + model = train(prob, param) w = np.ctypeslib.as_array(model.w, (x.shape[1], 1)) w = np.asmatrix(w) @@ -592,8 +595,11 @@ def train_binary_and_multiclass( Invalid dataset. Only multi-class dataset is allowed.""" y = np.squeeze(nonzero_label_ids) + prob = problem(2 * y - 1, x) + param = parameter(options) + param.w_recalc = False with silent_stderr(): - model = train(y, x, options) + model = train(prob, param) # Labels appeared in training set; length may be smaller than num_labels train_labels = np.array(model.get_labels(), dtype="int")