Skip to content

Commit

Permalink
Fix unexpected behavior of .set_index() since pandas 0.21.0 (#1723)
Browse files Browse the repository at this point in the history
* fix set_index behavior using pandas 0.21.0

* review comments
  • Loading branch information
benbovy authored and Joe Hamman committed Nov 17, 2017
1 parent 8267fdb commit 1a01208
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 10 deletions.
13 changes: 13 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@ What's New
.. _whats-new.0.10.0:

v0.10.0 (unreleased)
--------------------

Bug fixes
~~~~~~~~~

- Fixed unexpected behavior in ``Dataset.set_index()`` and
``DataArray.set_index()`` introduced by Pandas 0.21.0. Setting a new
index with a single variable resulted in 1-level
``pandas.MultiIndex`` instead of a simple ``pandas.Index``
(:issue:`1722`). By `Benoit Bovy <https://github.com/benbovy>`_.


v0.10.0 rc2 (13 November 2017)
------------------------------

Expand Down
27 changes: 17 additions & 10 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ def merge_indexes(
names, labels, levels = [], [], []
current_index_variable = variables.get(dim)

for n in var_names:
var = variables[n]
if (current_index_variable is not None and
var.dims != current_index_variable.dims):
raise ValueError(
"dimension mismatch between %r %s and %r %s"
% (dim, current_index_variable.dims, n, var.dims))

if current_index_variable is not None and append:
current_index = current_index_variable.to_index()
if isinstance(current_index, pd.MultiIndex):
Expand All @@ -148,20 +156,19 @@ def merge_indexes(
labels.append(cat.codes)
levels.append(cat.categories)

for n in var_names:
names.append(n)
var = variables[n]
if ((current_index_variable is not None) and
(var.dims != current_index_variable.dims)):
raise ValueError(
"dimension mismatch between %r %s and %r %s"
% (dim, current_index_variable.dims, n, var.dims))
else:
if not len(names) and len(var_names) == 1:
idx = pd.Index(variables[var_names[0]].values)

else:
for n in var_names:
names.append(n)
var = variables[n]
cat = pd.Categorical(var.values, ordered=True)
labels.append(cat.codes)
levels.append(cat.categories)

idx = pd.MultiIndex(labels=labels, levels=levels, names=names)
idx = pd.MultiIndex(labels=labels, levels=levels, names=names)

vars_to_replace[dim] = IndexVariable(dim, idx)
vars_to_remove.extend(var_names)

Expand Down
6 changes: 6 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2007,6 +2007,12 @@ def test_set_index(self):
ds.set_index(x=mindex.names, inplace=True)
self.assertDatasetIdentical(ds, expected)

# ensure set_index with no existing index and a single data var given
# doesn't return multi-index
ds = Dataset(data_vars={'x_var': ('x', [0, 1, 2])})
expected = Dataset(coords={'x': [0, 1, 2]})
self.assertDataArrayIdentical(ds.set_index(x='x_var'), expected)

def test_reset_index(self):
ds = create_test_multiindex()
mindex = ds['x'].to_index()
Expand Down

0 comments on commit 1a01208

Please sign in to comment.