Skip to content

Commit

Permalink
[python-package] add type hints on Dataset constructors (#5458)
Browse files Browse the repository at this point in the history
* [python-package] add type hints on Dataset constructors

* fix __init_from_list_np2d() hint

* add return type

* define a DatasetHandle type
  • Loading branch information
jameslamb authored Sep 3, 2022
1 parent d0ea321 commit c8712a9
Showing 1 changed file with 56 additions and 12 deletions.
68 changes: 56 additions & 12 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .compat import PANDAS_INSTALLED, concat, dt_DataTable, pd_CategoricalDtype, pd_DataFrame, pd_Series
from .libpath import find_lib_path

_DatasetHandle = ctypes.c_void_p
_LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]

Expand Down Expand Up @@ -1196,10 +1197,19 @@ def current_iteration(self) -> int:
class Dataset:
"""Dataset in LightGBM."""

def __init__(self, data, label=None, reference=None,
weight=None, group=None, init_score=None,
feature_name='auto', categorical_feature='auto', params=None,
free_raw_data=True):
def __init__(
self,
data,
label=None,
reference: Optional["Dataset"] = None,
weight=None,
group=None,
init_score=None,
feature_name='auto',
categorical_feature='auto',
params: Optional[Dict[str, Any]] = None,
free_raw_data: bool = True
):
"""Initialize Dataset.
Parameters
Expand Down Expand Up @@ -1488,9 +1498,19 @@ def _set_init_score_by_predictor(self, predictor, data, used_indices=None):
return self
self.set_init_score(init_score)

def _lazy_init(self, data, label=None, reference=None,
weight=None, group=None, init_score=None, predictor=None,
feature_name='auto', categorical_feature='auto', params=None):
def _lazy_init(
self,
data,
label=None,
reference: Optional["Dataset"] = None,
weight=None,
group=None,
init_score=None,
predictor=None,
feature_name='auto',
categorical_feature='auto',
params: Optional[Dict[str, Any]] = None
) -> "Dataset":
if data is None:
self.handle = None
return self
Expand Down Expand Up @@ -1635,7 +1655,11 @@ def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarr

return filtered, filtered_idx

def __init_from_seqs(self, seqs: List[Sequence], ref_dataset: Optional['Dataset'] = None):
def __init_from_seqs(
self,
seqs: List[Sequence],
ref_dataset: Optional["Dataset"] = None
) -> "Dataset":
"""
Initialize data from list of Sequence objects.
Expand Down Expand Up @@ -1664,7 +1688,12 @@ def __init_from_seqs(self, seqs: List[Sequence], ref_dataset: Optional['Dataset'
self._push_rows(seq[start:end])
return self

def __init_from_np2d(self, mat, params_str, ref_dataset):
def __init_from_np2d(
self,
mat: np.ndarray,
params_str: str,
ref_dataset: Optional[_DatasetHandle]
) -> "Dataset":
"""Initialize data from a 2-D numpy matrix."""
if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray must be 2 dimensional')
Expand All @@ -1687,7 +1716,12 @@ def __init_from_np2d(self, mat, params_str, ref_dataset):
ctypes.byref(self.handle)))
return self

def __init_from_list_np2d(self, mats, params_str, ref_dataset):
def __init_from_list_np2d(
self,
mats: List[np.ndarray],
params_str: str,
ref_dataset: Optional[_DatasetHandle]
) -> "Dataset":
"""Initialize data from a list of 2-D numpy matrices."""
ncol = mats[0].shape[1]
nrow = np.empty((len(mats),), np.int32)
Expand Down Expand Up @@ -1733,7 +1767,12 @@ def __init_from_list_np2d(self, mats, params_str, ref_dataset):
ctypes.byref(self.handle)))
return self

def __init_from_csr(self, csr, params_str, ref_dataset):
def __init_from_csr(
self,
csr: scipy.sparse.csr_matrix,
params_str: str,
ref_dataset: Optional[_DatasetHandle]
) -> "Dataset":
"""Initialize data from a CSR matrix."""
if len(csr.indices) != len(csr.data):
raise ValueError(f'Length mismatch: {len(csr.indices)} vs {len(csr.data)}')
Expand All @@ -1759,7 +1798,12 @@ def __init_from_csr(self, csr, params_str, ref_dataset):
ctypes.byref(self.handle)))
return self

def __init_from_csc(self, csc, params_str, ref_dataset):
def __init_from_csc(
self,
csc: scipy.sparse.csc_matrix,
params_str: str,
ref_dataset: Optional[_DatasetHandle]
) -> "Dataset":
"""Initialize data from a CSC matrix."""
if len(csc.indices) != len(csc.data):
raise ValueError(f'Length mismatch: {len(csc.indices)} vs {len(csc.data)}')
Expand Down

0 comments on commit c8712a9

Please sign in to comment.