Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: fix a couple more type checking errors flagged by pyright #209

Merged
merged 5 commits into from
Jan 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 89 additions & 50 deletions src/cmasher/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from itertools import chain
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING, NewType
from typing import TYPE_CHECKING, NewType, overload

import matplotlib as mpl
import numpy as np
Expand All @@ -33,7 +33,7 @@
import os
import sys
from collections.abc import Callable, Iterator
from typing import Literal, Protocol, TypeAlias
from typing import Literal, Protocol, TypeAlias, TypeVar

from matplotlib.artist import Artist
from numpy.typing import NDArray
Expand All @@ -43,6 +43,9 @@
else:
from typing_extensions import Self

T = TypeVar("T", int, float)
RGB: TypeAlias = tuple[T, T, T]

class SupportsDunderLT(Protocol):
def __lt__(self, other: Self, /) -> bool: ...

Expand All @@ -51,6 +54,7 @@ def __gt__(self, other: Self, /) -> bool: ...

SupportsOrdering: TypeAlias = SupportsDunderLT | SupportsDunderGT


_HAS_VISCM = find_spec("viscm") is not None

# All declaration
Expand Down Expand Up @@ -78,12 +82,6 @@ def __gt__(self, other: Self, /) -> bool: ...
Category = NewType("Category", str)
Name = NewType("Name", str)

# Type aliases
RED: TypeAlias = float
GREEN: TypeAlias = float
BLUE: TypeAlias = float
RGB: TypeAlias = list[tuple[RED, GREEN, BLUE]]


# %% HELPER FUNCTIONS
# Define function for obtaining the sorting order for lightness ranking
Expand Down Expand Up @@ -131,17 +129,17 @@ def _get_cmap_lightness_rank(

# Get lightness values of colormap
lab = cspace_converter("sRGB1", "CAM02-UCS")(rgb)
L = lab[:, 0]
lightness = lab[:, 0]

# If cyclic colormap, add first L at the end
# If cyclic colormap, add first lightness at the end
if cm_type == "cyclic":
L = np.r_[L, [L[0]]]
lightness = np.r_[lightness, [lightness[0]]]

# Determine number of values that will be in deltas
N_deltas = len(L) - 1
N_deltas = len(lightness) - 1

# Determine the deltas of the lightness profile
deltas = np.diff(L)
deltas = np.diff(lightness)
derivs = N_deltas * deltas

# Set lightness profile type to 0
Expand All @@ -153,7 +151,7 @@ def _get_cmap_lightness_rank(
L_rmse = np.around(np.std(derivs), 0)

# Calculate starting lightness value
L_start = np.around(L[0], 0)
L_start = np.around(lightness[0], 0)

# Determine type of lightness profile
L_type += (not np.allclose(rgb[0], [0, 0, 0])) * 2
Expand All @@ -173,13 +171,13 @@ def _get_cmap_lightness_rank(
)

# Calculate central lightness value
L_start = np.around(np.average(L[central_i]), 0)
L_start = np.around(np.average(lightness[central_i]), 0)

# Determine lightness range
L_rng = np.around(np.max(L) - np.min(L), 0)
L_rng = np.around(np.max(lightness) - np.min(lightness), 0)

# Determine if cmap goes from dark to light or the opposite
L_slope = (L_start > L[-1]) * 2 - 1
L_slope = (L_start > lightness[-1]) * 2 - 1

# For qualitative/misc colormaps, set all lightness values to zero
else:
Expand Down Expand Up @@ -251,7 +249,7 @@ def _get_cmap_perceptual_rank(
# This function combines multiple colormaps at given nodes
def combine_cmaps(
*cmaps: Colormap | str,
nodes: list[float] | np.ndarray | None = None,
nodes: list[float] | NDArray[np.floating] | None = None,
n_rgb_levels: int = 256,
combined_cmap_name: str = "combined_cmap",
) -> LinearSegmentedColormap:
Expand Down Expand Up @@ -763,7 +761,7 @@ def get_name(x: Colormap) -> str:
if ncols == 1:
ax0 = ax
else:
ax0, ax1 = ax
ax0 = ax[0]
pos0 = ax0.get_position()

# Obtain the colormap type
Expand All @@ -782,25 +780,25 @@ def get_name(x: Colormap) -> str:
if show_grayscale:
# Get lightness values of colormap
lab = cspace_convert(rgb)
L = lab[:, 0]
lightness = lab[:, 0]

# Normalize lightness values
L /= 99.99871678
lightness /= 99.99871678

# Get RGB values for lightness values using neutral
rgb_L = cmrcm.neutral(L)[:, :3]
rgb_L = cmrcm.neutral(lightness)[:, :3]

# Add gray-scale colormap subplot
ax1.imshow(rgb_L[np.newaxis, ...], aspect="auto")
ax[1].imshow(rgb_L[np.newaxis, ...], aspect="auto")

# Check if the lightness profile was requested
if plot_profile and (cm_type != "qualitative"):
# Determine the points that need to be plotted
plot_L = -(L - 0.5)
plot_L = -(lightness - 0.5)
points = np.stack([x, plot_L], axis=1)

# Determine the colors that each point must have
# Use black for L >= 0.5 and white for L <= 0.5.
# Use black for lightness >= 0.5 and white for lightness <= 0.5.
colors = np.zeros_like(plot_L, dtype=int)
colors[plot_L >= 0] = 1

Expand Down Expand Up @@ -838,7 +836,7 @@ def get_name(x: Colormap) -> str:
lc.set_array(np.array(s_colors))

# Add line-collection to this subplot
ax1.add_collection(lc)
ax[1].add_collection(lc)

# Determine positions of colormap name
x_text = pos0.x0 - spacing
Expand Down Expand Up @@ -940,7 +938,21 @@ def get_bibtex() -> None:


# This function returns a list of all colormaps available in CMasher
def get_cmap_list(cmap_type: str = "all") -> list[str]:
def get_cmap_list(
cmap_type: Literal[
"a",
"all",
"s",
"seq",
"sequential",
"d",
"div",
"diverging",
"c",
"cyc",
"cyclic",
] = "all",
) -> list[str]:
"""
Returns a list with the names of all colormaps available in *CMasher* of
the given `cmap_type`.
Expand All @@ -960,9 +972,6 @@ def get_cmap_list(cmap_type: str = "all") -> list[str]:
List containing the names of all colormaps available in *CMasher*.

"""
# Convert cmap_type to lowercase
cmap_type = cmap_type.lower()

# Obtain proper list
if cmap_type in ("a", "all"):
cmaps = list(cmrcm.cmap_d)
Expand All @@ -972,6 +981,8 @@ def get_cmap_list(cmap_type: str = "all") -> list[str]:
cmaps = list(cmrcm.cmap_cd["diverging"])
elif cmap_type in ("c", "cyc", "cyclic"):
cmaps = list(cmrcm.cmap_cd["cyclic"])
else:
raise ValueError(cmap_type)

# Return cmaps
return cmaps
Expand Down Expand Up @@ -1009,14 +1020,14 @@ def get_cmap_type(cmap: Colormap | Name) -> str:

# Get lightness values of colormap
lab = cspace_converter("sRGB1", "CAM02-UCS")(rgb)
L = lab[:, 0]
diff_L = np.diff(L)
lightness = lab[:, 0]
diff_L = np.diff(lightness)

# Obtain central values of lightness
N = cmap.N - 1
central_i = [int(np.floor(N / 2)), int(np.ceil(N / 2))]
diff_L0 = np.diff(L[: central_i[0] + 1])
diff_L1 = np.diff(L[central_i[1] :])
diff_L0 = np.diff(lightness[: central_i[0] + 1])
diff_L1 = np.diff(lightness[central_i[1] :])

# Obtain perceptual differences of last two and first two values
lab_red = lab[[-2, -1, 0, 1]]
Expand Down Expand Up @@ -1423,13 +1434,43 @@ def set_cmap_legend_entry(artist: Artist, label: str) -> None:


# Function to take N equally spaced colors from a colormap
@overload
def take_cmap_colors(
cmap: Colormap | Name,
N: int | None,
*,
cmap_range: tuple[float, float] = (0, 1),
return_fmt: str = "float",
) -> RGB:
return_fmt: Literal["float", "norm"] = "float",
) -> RGB[float]: ...


@overload
def take_cmap_colors(
cmap: Colormap | Name,
N: int | None,
*,
cmap_range: tuple[float, float] = (0, 1),
return_fmt: Literal["int", "8bit"],
) -> RGB[int]: ...


@overload
def take_cmap_colors(
cmap: Colormap | Name,
N: int | None,
*,
cmap_range: tuple[float, float] = (0, 1),
return_fmt: Literal["str", "hex"],
) -> list[str]: ...


def take_cmap_colors(
cmap: Colormap | Name,
N: int | None,
*,
cmap_range: tuple[float, float] = (0, 1),
return_fmt: Literal["float", "norm", "int", "8bit", "str", "hex"] = "float",
) -> RGB[float] | RGB[int] | list[str]:
"""
Takes `N` equally spaced colors from the provided colormap `cmap` and
returns them.
Expand Down Expand Up @@ -1501,9 +1542,6 @@ def take_cmap_colors(
that describe the same property, but have a different initial state.

"""
# Convert provided fmt to lowercase
return_fmt = return_fmt.lower()

# Obtain the colormap
if isinstance(cmap, str):
cmap = mpl.colormaps[cmap]
Expand All @@ -1519,24 +1557,25 @@ def take_cmap_colors(
stop = int(np.ceil(cmap_range[1] * cmap.N)) - 1

# Pick colors
index: NDArray
index: NDArray[np.int64]
if N is None:
index = np.arange(start, stop + 1, dtype=int)
index = np.arange(start, stop + 1, dtype="int64")
else:
index = np.array(np.rint(np.linspace(start, stop, num=N)), dtype=int)
index = np.array(np.rint(np.linspace(start, stop, num=N)), dtype="int64")
colors = cmap(index)

# Convert colors to proper format
if return_fmt in ("float", "norm", "int", "8bit"):
colors = np.apply_along_axis(to_rgb, 1, colors) # type: ignore [call-overload]
if return_fmt in ("int", "8bit"):
colors = np.array(np.rint(colors * 255), dtype=int)
colors = list(map(tuple, colors))
return [(int(c[0]), int(c[1]), int(c[2])) for c in colors] # type: ignore [misc]
else:
return [(float(c[0]), float(c[1]), float(c[2])) for c in colors] # type: ignore [misc]
elif return_fmt in ("str", "hex"):
return [to_hex(x).upper() for x in colors]
else:
colors = [to_hex(x).upper() for x in colors]

# Return colors
return colors
raise ValueError(return_fmt)


# Function to view what a colormap looks like
Expand Down Expand Up @@ -1578,9 +1617,9 @@ def view_cmap(
if show_grayscale:
# If so, create a colormap of cmap in grayscale
rgb = cmap(np.arange(cmap.N))[:, :3]
L = cspace_convert(rgb)[:, 0]
L /= 99.99871678
rgb_L = cmrcm.neutral(L)[:, :3]
lightness = cspace_convert(rgb)[:, 0]
lightness /= 99.99871678
rgb_L = cmrcm.neutral(lightness)[:, :3]
cmap_L = LC(rgb_L)

# Set that there are two plots to create
Expand Down
Loading