diff --git a/src/klib/describe.py b/src/klib/describe.py index 5b1560f..8d95d6b 100644 --- a/src/klib/describe.py +++ b/src/klib/describe.py @@ -92,15 +92,16 @@ def cat_plot( lim_top, lim_bot = top, bottom if n_unique < top + bottom: - lim_top = int(n_unique // 2) - lim_bot = int(n_unique // 2) + 1 - - if n_unique <= 2: - lim_top = lim_bot = int(n_unique // 2) - + if bottom > top: + lim_top = int(n_unique // 2) if int(n_unique // 2) < top else top + lim_bot = n_unique - lim_top + else: + lim_bot = int(n_unique // 2) if int(n_unique // 2) < bottom else bottom + lim_top = n_unique - lim_bot + value_counts_top = value_counts[:lim_top] value_counts_idx_top = value_counts_top.index.tolist() - value_counts_bot = value_counts[-lim_bot:] + value_counts_bot = value_counts[-lim_bot:] if lim_bot > 0 else pd.DataFrame() value_counts_idx_bot = value_counts_bot.index.tolist() if top == 0: