From 3d5220a9428dced3ce7f69149e4027c6444a3b7a Mon Sep 17 00:00:00 2001 From: Martin Fleischmann Date: Fri, 12 Nov 2021 10:57:12 +0000 Subject: [PATCH] BUG: cuML non-weighted plot fix --- clustergram/clustergram.py | 2 +- clustergram/test_clustergram.py | 104 ++++++++++++++------------------ 2 files changed, 47 insertions(+), 59 deletions(-) diff --git a/clustergram/clustergram.py b/clustergram/clustergram.py index 31816a3..cee97d8 100644 --- a/clustergram/clustergram.py +++ b/clustergram/clustergram.py @@ -756,7 +756,7 @@ def _compute_means_cuml(self): for n in self.k_range: means = self.cluster_centers[n].mean(axis=1) - if isinstance(means, (cp.core.core.ndarray, np.ndarray)): + if isinstance(means, (cp.ndarray, np.ndarray)): self.plot_data[n] = means.take(self.labels[n].values) self.link[n] = dict(zip(means.tolist(), range(n))) else: diff --git a/clustergram/test_clustergram.py b/clustergram/test_clustergram.py index 3b49467..a61b850 100644 --- a/clustergram/test_clustergram.py +++ b/clustergram/test_clustergram.py @@ -225,13 +225,13 @@ def test_cuml_kmeans(): assert clustergram.labels.notna().all().all() expected = [ - 3.7674055099487305, - 2.7064273357391357, - 3.451129913330078, - 4.223802089691162, - 4.125243663787842, - 2.953890800476074, - 3.4818685054779053, + 0.9148379012942314, + 1.0465015769004822, + 0.9405179619789124, + 0.8763175010681152, + 1.5546628013253212, + 1.2617384965221086, + 0.7542384501014437, ] assert expected == [ pytest.approx(float(clustergram.cluster_centers[x].mean().mean()), rel=1e-6) @@ -247,10 +247,10 @@ def test_cuml_kmeans(): ax.get_geometry() == (1, 1, 1) assert clustergram.plot_data_pca.mean().mean() == pytest.approx( - 1.1016593594032404, rel=1e-10 + 1.344412697695078, rel=1e-10 ) assert clustergram.plot_data.mean().mean() == pytest.approx( - 3.7674053507191796, rel=1e-10 + 0.9148379244974681, rel=1e-10 ) # cupy array @@ -265,13 +265,13 @@ def test_cuml_kmeans(): assert clustergram.labels.notna().all().all() expected = [ - 3.7674055099487305, - 2.7064273357391357, - 3.451129913330078, - 4.223802089691162, - 4.125243663787842, - 2.953890800476074, - 3.4818685054779053, + 0.9148379012942314, + 1.0465015769004822, + 0.9405179619789124, + 0.8763175010681152, + 1.5546628013253212, + 1.2617384965221086, + 0.7542384501014437, ] assert expected == [ pytest.approx(float(cp.mean(clustergram.cluster_centers[x])), rel=1e-6) @@ -287,10 +287,10 @@ def test_cuml_kmeans(): ax.get_geometry() == (1, 1, 1) assert clustergram.plot_data_pca.mean().mean() == pytest.approx( - 1.1016593081610544, rel=1e-6 + 1.344412697695078, rel=1e-6 ) assert clustergram.plot_data.mean().mean() == pytest.approx( - 3.7674053737095425, rel=1e-6 + 0.9148379244974681, rel=1e-6 ) @@ -431,17 +431,11 @@ def test_silhouette_score_cuml(): pd.testing.assert_series_equal( clustergram.silhouette_score(), pd.Series( - [ - 0.7494349479675293, - 0.9806153178215027, - 0.6721830368041992, - 0.39418715238571167, - 0.44574037194252014, - 0.08033210784196854, - ], + [0.5359467, 0.5933514, 0.7809184, 0.8807362, 0.68701756, 0.4919311], index=list(range(2, 8)), name="silhouette_score", ), + check_dtype=False, ) clustergram = Clustergram(range(1, 8), backend="cuML", random_state=random_state) @@ -450,17 +444,11 @@ def test_silhouette_score_cuml(): pd.testing.assert_series_equal( clustergram.silhouette_score(), pd.Series( - [ - 0.7494349479675293, - 0.9806153178215027, - 0.6721830368041992, - 0.39418715238571167, - 0.44574037194252014, - 0.08033210784196854, - ], + [0.5359467, 0.5933514, 0.7809184, 0.8807362, 0.68701756, 0.4919311], index=list(range(2, 8)), name="silhouette_score", ), + check_dtype=False, ) @@ -529,12 +517,12 @@ def test_calinski_harabasz_score_cuml(): clustergram.calinski_harabasz_score(), pd.Series( [ - 25.619150510634366, - 15374.042816067375, - 10813.16845006968, - 8818.1163716754, - 8070.657293970755, - 7259.89764652579, + 14.884236661408588, + 18.993060869559063, + 25.53897801880369, + 10495.855575243557, + 10895.935616041483, + 10449.035861758717, ], index=list(range(2, 8)), name="calinski_harabasz_score", @@ -548,12 +536,12 @@ def test_calinski_harabasz_score_cuml(): clustergram.calinski_harabasz_score(), pd.Series( [ - 25.619150510634366, - 15374.042816067375, - 10813.16845006968, - 8818.1163716754, - 8070.657293970755, - 7259.89764652579, + 14.884236661408588, + 18.993060869559063, + 25.53897801880369, + 10495.855575243557, + 10895.935616041483, + 10449.035861758717, ], index=list(range(2, 8)), name="calinski_harabasz_score", @@ -626,12 +614,12 @@ def test_davies_bouldin_score_cuml(): clustergram.davies_bouldin_score(), pd.Series( [ - 0.3107512701086121, - 0.02263161666570639, - 0.2261582258142144, - 0.3839688146565784, - 0.13388392354928222, - 0.279734367840293, + 0.67477383902307, + 0.7673811855139047, + 0.4520342597085474, + 0.02258593626130912, + 0.01451002792630246, + 0.00967011650130667, ], index=list(range(2, 8)), name="davies_bouldin_score", @@ -645,12 +633,12 @@ def test_davies_bouldin_score_cuml(): clustergram.davies_bouldin_score(), pd.Series( [ - 0.3107512701086121, - 0.02263161666570639, - 0.2261582258142144, - 0.3839688146565784, - 0.13388392354928222, - 0.279734367840293, + 0.67477383902307, + 0.7673811855139047, + 0.4520342597085474, + 0.02258593626130912, + 0.01451002792630246, + 0.00967011650130667, ], index=list(range(2, 8)), name="davies_bouldin_score",