Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update datatable usage #4123

Merged
merged 1 commit into from
Feb 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ class DataFrame(object):

# dt
try:
from datatable import DataTable
import datatable
if hasattr(datatable, "Frame"):
DataTable = datatable.Frame
else:
DataTable = datatable.DataTable
DT_INSTALLED = True
except ImportError:

Expand Down
32 changes: 18 additions & 14 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,10 @@ def _maybe_dt_data(data, feature_names, feature_types):
return data, feature_names, feature_types

data_types_names = tuple(lt.name for lt in data.ltypes)
if not all(type_name in DT_TYPE_MAPPER for type_name in data_types_names):
bad_fields = [data.names[i] for i, type_name in
enumerate(data_types_names) if type_name not in DT_TYPE_MAPPER]

bad_fields = [data.names[i]
for i, type_name in enumerate(data_types_names)
if type_name not in DT_TYPE_MAPPER]
if bad_fields:
msg = """DataFrame.types for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
Expand All @@ -317,7 +317,7 @@ def _maybe_dt_array(array):

# below requires new dt version
# extract first column
array = array.tonumpy()[:, 0].astype('float')
array = array.to_numpy()[:, 0].astype('float')

return array

Expand All @@ -340,7 +340,7 @@ def __init__(self, data, label=None, missing=None,
"""
Parameters
----------
data : string/numpy array/scipy.sparse/pd.DataFrame/DataTable
data : string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame
Data source of DMatrix.
When data is string type, it represents the path libsvm format txt file,
or binary file that xgboost can read from.
Expand Down Expand Up @@ -497,16 +497,20 @@ def _init_from_npy2d(self, mat, missing, nthread):

def _init_from_dt(self, data, nthread):
"""
Initialize data from a DataTable
Initialize data from a datatable Frame.
"""
cols = []
ptrs = (ctypes.c_void_p * data.ncols)()
for icol in range(data.ncols):
col = data.internal.column(icol)
cols.append(col)
# int64_t (void*)
ptr = col.data_pointer
ptrs[icol] = ctypes.c_void_p(ptr)
if hasattr(data, "internal") and hasattr(data.internal, "column"):
# datatable>0.8.0
for icol in range(data.ncols):
col = data.internal.column(icol)
ptr = col.data_pointer
ptrs[icol] = ctypes.c_void_p(ptr)
else:
# datatable<=0.8.0
from datatable.internal import frame_column_data_r
for icol in range(data.ncols):
ptrs[icol] = frame_column_data_r(data, icol)

# always return stypes for dt ingestion
feature_type_strings = (ctypes.c_char_p * data.ncols)()
Expand Down