diff --git a/pyproject.toml b/pyproject.toml index a040241..211dabf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ ignore = [ "PGH003", "PLR", "SIM102", + "SIM108", "TRY003", ] diff --git a/src/xlviews/dataframes/heat_frame.py b/src/xlviews/dataframes/heat_frame.py index e68a782..501313d 100644 --- a/src/xlviews/dataframes/heat_frame.py +++ b/src/xlviews/dataframes/heat_frame.py @@ -9,12 +9,15 @@ from xlviews.config import rcParams from xlviews.decorators import turn_off_screen_updating -from xlviews.range.style import set_alignment, set_font +from xlviews.range.formula import aggregate +from xlviews.range.style import set_alignment, set_border, set_color_scale, set_font +from xlviews.utils import rgb from .sheet_frame import SheetFrame +from .style import set_heat_frame_style if TYPE_CHECKING: - from xlwings import Range + from xlwings import Range, Sheet class HeatFrame(SheetFrame): @@ -26,11 +29,91 @@ def __init__( x: str, y: str, value: str, + vmin: float | None = None, + vmax: float | None = None, + sheet: Sheet | None = None, style: bool = True, autofit: bool = True, + font_size: int | None = None, **kwargs, ) -> None: df = data.pivot_table(value, y, x, aggfunc=lambda x: x) df.index.name = None - super().__init__(*args, data=df, index=True, style=False, **kwargs) + super().__init__(*args, data=df, index=True, sheet=sheet, style=False) + + if style: + set_heat_frame_style(self, autofit=autofit, font_size=font_size, **kwargs) + + self.set_adjacent_column_width(1, offset=-1) + + self.set_extrema(vmin, vmax) + self.set_colorbar() + set_color_scale(self.range(index=False), self.vmin, self.vmax) + + self.set_label(value) + + if autofit: + self.label.columns.autofit() + + @property + def vmin(self) -> Range: + return self.cell.offset(len(self), len(self.columns) + 1) + + @property + def vmax(self) -> Range: + return self.cell.offset(1, len(self.columns) + 1) + + @property + def label(self) -> Range: + return self.cell.offset(0, len(self.columns) + 1) + + def set_extrema( + self, + vmin: float | str | None = None, + vmax: float | str | None = None, + ) -> None: + rng = self.range(index=False) + + if vmin is None: + vmin = aggregate("min", rng, formula=True) + + if vmax is None: + vmax = aggregate("max", rng, formula=True) + + self.vmin.value = vmin + self.vmax.value = vmax + + def set_colorbar(self) -> None: + vmin = self.vmin.get_address() + vmax = self.vmax.get_address() + + col = self.vmax.column + start = self.vmax.row + end = self.vmin.row + n = end - start - 1 + for i in range(n): + value = f"={vmax}+{i + 1}*({vmin}-{vmax})/{n + 1}" + self.sheet.range(i + start + 1, col).value = value + + rng = self.sheet.range((start, col), (end, col)) + set_color_scale(rng, self.vmin, self.vmax) + set_font(rng, color=rgb("white"), size=rcParams["frame.font.size"]) + set_alignment(rng, horizontal_alignment="center") + ec = rcParams["frame.gray.border.color"] + set_border(rng, edge_weight=2, edge_color=ec, inside_weight=0) + + if n > 0: + rng = self.sheet.range((start + 1, col), (end - 1, col)) + set_font(rng, size=4) + + def set_label(self, label: str) -> None: + rng = self.label + rng.value = label + set_font(rng, bold=True, size=rcParams["frame.font.size"]) + set_alignment(rng, horizontal_alignment="center") + + def set_adjacent_column_width(self, width: float, offset: int = 1) -> None: + """Set the width of the adjacent empty column.""" + column = self.label.column + offset + self.sheet.range(1, column).column_width = width diff --git a/src/xlviews/dataframes/sheet_frame.py b/src/xlviews/dataframes/sheet_frame.py index cd8db83..536263f 100644 --- a/src/xlviews/dataframes/sheet_frame.py +++ b/src/xlviews/dataframes/sheet_frame.py @@ -173,9 +173,6 @@ def set_data( if style: self.set_style(gray=gray, autofit=autofit, font_size=font_size, **kwargs) - if self.head is None: - self.set_adjacent_column_width(1) - if self.name: book = self.sheet.book refers_to = "=" + self.cell.get_address(include_sheetname=True) @@ -1043,6 +1040,8 @@ def delete(self, direction: str = "up", *, entire: bool = False) -> None: def dist_frame(self, *args, **kwargs) -> DistFrame: from .dist_frame import DistFrame + self.set_adjacent_column_width(1) + self.dist = DistFrame(self, *args, **kwargs) return self.dist diff --git a/src/xlviews/dataframes/style.py b/src/xlviews/dataframes/style.py index 6276596..a270b07 100644 --- a/src/xlviews/dataframes/style.py +++ b/src/xlviews/dataframes/style.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: from xlwings import Range + from .heat_frame import HeatFrame from .sheet_frame import SheetFrame from .table import Table @@ -215,3 +216,65 @@ def set_table_style( style.TableStyleElements(even_type).Interior.Color = even_color table.api.TableStyle = style + + +@turn_off_screen_updating +def set_heat_frame_style( + sf: HeatFrame, + *, + autofit: bool = False, + alignment: str | None = "center", + border: bool = True, + font: bool = True, + fill: bool = True, + font_size: int | None = None, +) -> None: + """Set style of SheetFrame. + + Args: + sf: The SheetFrame object. + autofit: Whether to autofit the frame. + alignment: The alignment of the frame. + border: Whether to draw the border. + font: Whether to specify the font. + fill: Whether to fill the frame. + font_size: The font size to specify directly. + """ + cell = sf.cell + sheet = sf.sheet + + set_style = partial( + _set_style, + border=border, + font=font, + fill=fill, + gray=False, + font_size=font_size, + ) + + index_level = sf.index_level + columns_level = sf.columns_level + length = len(sf) + + if index_level > 0: + start = cell.offset(columns_level, 0) + end = cell.offset(columns_level + length - 1, index_level - 1) + set_style(start, end, "index") + + width = len(sf.value_columns) + + start = cell.offset(columns_level - 1, index_level) + end = cell.offset(columns_level - 1, index_level + width - 1) + set_style(start, end, "index") + + start = cell.offset(columns_level, index_level) + end = cell.offset(columns_level + length - 1, index_level + width - 1) + set_style(start, end, "values") + + rng = sheet.range(cell, end) + + if autofit: + rng.columns.autofit() + + if alignment: + set_alignment(rng, alignment) diff --git a/src/xlviews/range/style.py b/src/xlviews/range/style.py index 2d69d80..583b57d 100644 --- a/src/xlviews/range/style.py +++ b/src/xlviews/range/style.py @@ -4,8 +4,14 @@ from typing import TYPE_CHECKING +import xlwings as xw from xlwings import Range, Sheet -from xlwings.constants import BordersIndex, FormatConditionType, LineStyle +from xlwings.constants import ( + BordersIndex, + ConditionValueTypes, + FormatConditionType, + LineStyle, +) from xlviews.config import rcParams from xlviews.utils import constant, rgb, set_font_api @@ -83,10 +89,14 @@ def set_fill(rng: Range | RangeCollection, color: int | str | None = None) -> No def set_font( rng: Range | RangeCollection, name: str | None = None, - **kwargs, + *, + size: float | None = None, + bold: bool | None = None, + italic: bool | None = None, + color: int | str | None = None, ) -> None: name = name or rcParams["frame.font.name"] - set_font_api(rng.api, name, **kwargs) + set_font_api(rng.api, name, size=size, bold=bold, italic=italic, color=color) def set_alignment( @@ -177,5 +187,34 @@ def address(r: Range) -> str: condition.Font.Italic = True -def hide_gridlines(sheet: Sheet) -> None: +def hide_gridlines(sheet: Sheet | None = None) -> None: + sheet = sheet or xw.sheets.active sheet.book.app.api.ActiveWindow.DisplayGridlines = False + + +def set_color_condition(rng: Range, values: list[str], colors: list[int]) -> None: + condition = rng.api.FormatConditions.AddColorScale(len(values)) + condition.SetFirstPriority() + + for k, (value, color) in enumerate(zip(values, colors, strict=True)): + criteria = condition.ColorScaleCriteria(k + 1) + criteria.Type = ConditionValueTypes.xlConditionValueNumber + criteria.Value = value + criteria.FormatColor.Color = color + + +def set_color_scale( + rng: Range, + vmin: float | str | Range, + vmax: float | str | Range, +) -> None: + if isinstance(vmin, Range): + vmin = vmin.get_address() + + if isinstance(vmax, Range): + vmax = vmax.get_address() + + values = [f"={vmin}", f"=({vmin} + {vmax}) / 2", f"={vmax}"] + colors = [rgb(130, 130, 255), rgb(80, 185, 80), rgb(255, 130, 130)] + + set_color_condition(rng, values, colors) diff --git a/tests/chart/axes/test_position.py b/tests/chart/axes/test_position.py index d789168..14829cd 100644 --- a/tests/chart/axes/test_position.py +++ b/tests/chart/axes/test_position.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize( ("pos", "left", "top"), - [("right", 269.5, 18), ("inside", 134, 66), ("bottom", 52, 90)], + [("right", 312, 18), ("inside", 134, 66), ("bottom", 52, 90)], ) def test_set_first_position(sheet: Sheet, pos: str, left: float, top: float): from xlviews.chart.axes import ( diff --git a/tests/dataframes/heat_frame/__init__.py b/tests/dataframes/heat_frame/__init__.py deleted file mode 100644 index 3acffb8..0000000 --- a/tests/dataframes/heat_frame/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -if __name__ == "__main__": - from itertools import product - - import pandas as pd - import xlwings as xw - - from xlviews.dataframes.heat_frame import HeatFrame - from xlviews.dataframes.sheet_frame import SheetFrame - - for app in xw.apps: - app.quit() - - book = xw.Book() - sheet = book.sheets.add() - - values = list(product(range(1, 5), range(1, 4))) - df = pd.DataFrame(values, columns=["x", "y"]) - df["v"] = list(range(len(df))) - df = df.set_index(["x", "y"]) - sf = SheetFrame(2, 2, data=df, index=True) - - data = sf.get_address(["v"], formula=True) - hf = HeatFrame(2, 6, data=data, x="x", y="y", value="v") - - hf.range() diff --git a/tests/dataframes/test_heat_frame.py b/tests/dataframes/test_heat_frame.py new file mode 100644 index 0000000..e7d6c30 --- /dev/null +++ b/tests/dataframes/test_heat_frame.py @@ -0,0 +1,112 @@ +from itertools import product + +import numpy as np +import pytest +from pandas import DataFrame +from xlwings import Sheet + +from xlviews.dataframes.heat_frame import HeatFrame +from xlviews.dataframes.sheet_frame import SheetFrame +from xlviews.utils import is_excel_installed + +pytestmark = pytest.mark.skipif(not is_excel_installed(), reason="Excel not installed") + + +@pytest.fixture(scope="module") +def sf(sheet_module: Sheet): + values = list(product(range(1, 5), range(1, 7))) + df = DataFrame(values, columns=["x", "y"]) + df["v"] = list(range(len(df))) + df = df[(df["x"] + df["y"]) % 4 != 0] + df = df.set_index(["x", "y"]) + + sf = SheetFrame(2, 2, data=df, index=True, sheet=sheet_module) + data = sf.get_address(["v"], formula=True) + + return HeatFrame(2, 6, data=data, x="x", y="y", value="v", sheet=sheet_module) + + +def test_index(sf: HeatFrame): + assert sf.sheet.range("F3:F8").value == [1, 2, 3, 4, 5, 6] + + +def test_columns(sf: HeatFrame): + assert sf.sheet.range("G2:J2").value == [1, 2, 3, 4] + + +@pytest.mark.parametrize( + ("i", "value"), + [ + (3, [0, 6, None, 18]), + (4, [1, None, 13, 19]), + (5, [None, 8, 14, 20]), + (6, [3, 9, 15, None]), + (7, [4, 10, None, 22]), + (8, [5, None, 17, 23]), + ], +) +def test_values(sf: HeatFrame, i: int, value: int): + assert sf.sheet.range(f"G{i}:J{i}").value == value + + +def test_vmin(sf: HeatFrame): + assert sf.vmin.get_address() == "$L$8" + + +def test_vmax(sf: HeatFrame): + assert sf.vmax.get_address() == "$L$3" + + +def test_label(sf: HeatFrame): + assert sf.label.get_address() == "$L$2" + + +def test_label_value(sf: HeatFrame): + assert sf.label.value == "v" + + +@pytest.mark.parametrize( + ("i", "value"), + [ + (3, 23), + (4, 23 * 4 / 5), + (5, 23 * 3 / 5), + (6, 23 * 2 / 5), + (7, 23 / 5), + (8, 0), + ], +) +def test_colorbar(sf: HeatFrame, i: int, value: int): + v = sf.sheet.range(f"L{i}").value + assert isinstance(v, float) + np.testing.assert_allclose(v, value) + + +if __name__ == "__main__": + from itertools import product + + import xlwings as xw + from pandas import DataFrame + + from xlviews.dataframes.heat_frame import HeatFrame + from xlviews.dataframes.sheet_frame import SheetFrame + from xlviews.range.style import hide_gridlines + + for app in xw.apps: + app.quit() + + book = xw.Book() + sheet = book.sheets.add() + hide_gridlines() + + values = list(product(range(1, 5), range(1, 7))) + df = DataFrame(values, columns=["x", "y"]) + df["v"] = list(range(len(df))) + df = df[(df["x"] + df["y"]) % 4 != 0] + df = df.set_index(["x", "y"]) + sf = SheetFrame(2, 2, data=df, index=True) # type: ignore + + data = sf.get_address(["v"], formula=True) + hf = HeatFrame(2, 6, data=data, x="x", y="y", value="v") + + hf.range(index=False)