From 4b86c868ded5eed49a657764a69bd6a491305c01 Mon Sep 17 00:00:00 2001 From: Francesco Verdoja Date: Fri, 3 May 2024 14:36:49 +0300 Subject: [PATCH] finalized docstring --- bff/utils.py | 102 +++++++++++++++++++++++++++++++++++----- mod/curation.py | 76 +++++++++++++++++------------- mod/grid.py | 47 ++++++++++++++++++ mod/models.py | 78 ++++++++++++++++++++++++++++-- mod/occupancy.py | 21 +++++++++ mod/utils.py | 47 ++++++++++++++++++ mod/visualisation.py | 11 +++-- plot_quivers_atc.py | 13 ++--- plot_quivers_kth.py | 13 ++--- tests/bff/utils_test.py | 21 +-------- training_atc.py | 2 +- 11 files changed, 338 insertions(+), 93 deletions(-) diff --git a/bff/utils.py b/bff/utils.py index a5802dd..7be38f4 100644 --- a/bff/utils.py +++ b/bff/utils.py @@ -31,6 +31,20 @@ def plot_dir( dpi: int = 300, cmap: str = "inferno", ) -> None: + """Plots the specied directional probabilities extracted from a dynamic + map. + + Args: + occupancy (OccupancyMap): The occupancy map to overlay on top of the + dynamics. + dynamics (np.ndarray): An array containing the 8-directional dynamics. + dir (Direction): The direction for which the dynamics should be + plotted. + dpi (int, optional): The dots-per-inch (pixel per inch) for the created + plot. Default is 300. + cmap (str, optional): The colormap used to map normalized data values + to RGB colors. Default is "inferno". + """ binary_map = occupancy.binary_map plt.figure(dpi=dpi) plt.title(f"Direction: {dir.name}") @@ -58,6 +72,23 @@ def plot_quivers( normalize: bool = True, dpi: int = 300, ) -> None: + """Plots the eight-direction people dynamics as quivers/arrow plots over + the occupancy map. + + Args: + occupancy (np.ndarray): The occupancy grid as a 2D numpy array. + dynamics (np.ndarray): An array containing the 8-directional dynamics. + scale (int, optional): Scaling factor used for reducing arrow density. + Default is 1. + window_size (int, optional): The size of the window to be plotted. If + provided, a center must also be provided. + center (RowColumnPair, optional): The center of the window to be + plotted. If provided, a window size must also be provided. + normalize (bool, optional): If True, the arrow scales are normalized. + Default is True. + dpi (int, optional): The dots-per-inch (pixel per inch) for the plot. + Default is 300. + """ sz_occ = occupancy.shape sz_dyn = dynamics.shape assert sz_occ[0] // scale == sz_dyn[0] and sz_occ[1] // scale == sz_dyn[1] @@ -113,22 +144,12 @@ def plot_quivers( def scale_quivers(d: np.ndarray) -> np.ndarray: + """Scales the quivers by normalizing the arrow lengths.""" max = np.amax(d, axis=1) ret = np.expand_dims(np.where(max != 0, max, 1), axis=1) return d / ret -def random_input( - size: int = 32, - p_occupied: float = 0.5, - device: torch.device = torch.device("cpu"), -) -> torch.Tensor: - w = Window(size) - a = torch.rand((1, 1, size, size), device=device) < p_occupied - a[0, 0, w.center[0], w.center[1]] = 0 # ensure center is empty - return a.type(torch.float) - - def estimate_dynamics( net: PeopleFlow, occupancy: Union[OccupancyMap, np.ndarray], @@ -138,6 +159,32 @@ def estimate_dynamics( batch_size: int = 4, device: Optional[torch.device] = None, ) -> np.ndarray: + """Estimates the dynamics of people flow in a given occupancy. + The function applies the network on windows of the provided occupancy to + estimate the people flow dynamics in the region. It handles both + OccupancyMap objects and numpy arrays representing occupancy. By default, + if a scale factor greater than 1 is provided and the occupancy is an + OccupancyMap, it will rescale the occupancy binary map accordingly before + applying the network. + + Args: + net (PeopleFlow): PeopleFlow network used for estimating the dynamics. + occupancy (Union[OccupancyMap, np.ndarray]): Occupancy space to + estimate the dynamics upon. + scale (int, optional): Scale down factor for the occupancy map. This + parameter is ignored when occupancy is a numpy array. Default is 1. + net_scale (int, optional): Scale factor for the network window size. + Default is 1. + batch_size (int, optional): Number of patch samples to take from each + batch. Default is 4. + device (torch.device, optional): The device to move the network model + and data to for computation. If not given, automatically set to GPU if + available, else CPU. + + Returns: + np.ndarray: An numpy array containing the estimated dynamics of people + flow. + """ window = Window(net.window_size * net_scale) if not device: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -346,24 +393,45 @@ def load(self, path: str) -> None: class Window: + """A class representing a window on a 2-dimensional grid.""" + def __init__(self, size: int) -> None: + """Init `Window` class + + Args: + size (int): The size of the square window, represented as the + length of one side. + """ self.size = size @property def half_size(self) -> int: + """Returns the integer division of the window size by 2""" return self.size // 2 @property def center(self) -> RowColumnPair: + """Returns the coordinate of the window center""" return (self.half_size, self.half_size) @property def pad_amount(self) -> Sequence[int]: + """Returns the number of rows and columns to pad around the window""" return (self.half_size, self.half_size + self.size % 2 - 1) def corners( self, center: RowColumnPair, bounds: Optional[Sequence[int]] = None ) -> Sequence[int]: + """Returns the corners of a window centered on the center. + + Args: + center (RowColumnPair): The center coordinates of the window. + bounds (Sequence[int], optional): If given, the provided + corners will not exceed these bounds. Bounds should be given as + (min_row, max_row, min_col, max_col). Defaults to None. + Returns: + Sequence[int]: The corner coordinates as (left, top, right, bottom) + """ left, top, right, bottom = ( center[1] - self.half_size, # left center[0] - self.half_size, # top @@ -382,6 +450,18 @@ def corners( def indeces( self, center: RowColumnPair, bounds: Optional[Sequence[int]] = None ) -> Set[RowColumnPair]: + """Generates the indices encompassed by the window centered in center. + + Args: + center (RowColumnPair): The center coordinates of the window. + bounds (Sequence[int], optional): If given, the provided indeces + will not exceed these bounds. Bounds should be given as + (min_row, max_row, min_col, max_col). Defaults to None. + + Returns: + Set[RowColumnPair]: A set of row and column pairs representing each + cell position within the window. + """ left, top, right, bottom = self.corners(center, bounds) indeces = { (row, col) diff --git a/mod/curation.py b/mod/curation.py index aa40435..6e9101f 100644 --- a/mod/curation.py +++ b/mod/curation.py @@ -17,18 +17,15 @@ def mm2m(val: float) -> float: class DataCurator: - """ - Stump of class handing data curation. In current version handles only - subsampling. - """ + """Stump of class handing data curation. In current version handles only + subsampling.""" - # TODO Update description and add logging - def __init__(self, params): - """ - Constructor for data curator class. - :param params: Parameters guiding the curation (see json schema - @config/curator_shema.json) - :type params: dict + def __init__(self, params: dict): + """Init `DataCurator` class. + + Args: + params (dict): Parameters guiding the curation (see json schema + @config/curator_shema.json) """ # load params self.params = params @@ -39,16 +36,12 @@ def __init__(self, params): logger.info("Parameters for data curation {}".format(self.params)) def curate(self): - """ - Function running main curation loop - """ + """Function running main curation loop""" if self.params["chunk_size"] != 0: self.__chunk_processing() def __chunk_processing(self): - """ - Function for curating large csv files in chunks. - """ + """Function for curating large csv files in chunks.""" header = True logger.info( "Processing in chunks of size {}".format(self.params["chunk_size"]) @@ -87,41 +80,50 @@ def __chunk_processing(self): header = False - def __subsample(self, chunk): - """ - Function for subsampling data. Depending on the flag different + def __subsample(self, chunk: pd.DataFrame) -> pd.DataFrame: + """Function for subsampling data. Depending on the flag different subsampling policies can be used. Currently only available one is to keep very n-th line. - :param chunk: Data to be subsample :type chunk: dataframe :return: - Subsampled data :rtype: dataframe + Args: + chunk (pd.Dataframe): Data to be subsampled + + Returns: + pd.Dataframe: Subsampled data """ result = [] if self.params["subsample"]["method"] == "line_keep": result = chunk.iloc[:: self.params["subsample"]["parameter"], :] logger.info( "Was {} rows, remained {} rows".format( - len(chunk.index), len(result.index) + len(chunk.index), len(result.index) # type: ignore ) ) return result - def __drop_columns(self, chunk): - """ - Function dropping redundant columns from the original data + def __drop_columns(self, chunk: pd.DataFrame) -> pd.DataFrame: + """Function dropping redundant columns from the original data + + Args: + chunk (pd.Dataframe): Data to be edited - :param chunk: Data to be edited - :type chunk: dataframe - :return: edited chunk - :rtype: dataframe + Returns: + pd.Dataframe: Edited data """ for c in self.params["drop_columns"]: chunk = chunk.drop(c, axis=1) logger.info("Dropping column: {}".format(c)) return chunk - def __process_column(self, chunk): - # TODO Description and logging + def __process_column(self, chunk: pd.DataFrame) -> pd.DataFrame: + """Function processing columns in the original data + + Args: + chunk (pd.Dataframe): Data to be edited + + Returns: + pd.Dataframe: Edited data + """ for function_name, column_name in zip( self.params["process_columns"]["method"], self.params["process_columns"]["columns"], @@ -133,7 +135,15 @@ def __process_column(self, chunk): chunk[column_name] = chunk[column_name].apply(mm2m) return chunk - def __replace_header(self, chunk): + def __replace_header(self, chunk: pd.DataFrame) -> pd.DataFrame: + """Function replacing the header row from the original data + + Args: + chunk (pd.Dataframe): Data to be edited + + Returns: + pd.Dataframe: Edited data + """ for old_name, new_name in zip( self.params["replace_header"]["input"], self.params["replace_header"]["replacement"], diff --git a/mod/grid.py b/mod/grid.py index cd4d1f4..23f68cc 100644 --- a/mod/grid.py +++ b/mod/grid.py @@ -10,6 +10,19 @@ class Grid(BaseModel): + """Represents a 2D Map of Dynamics as a grid. Each cell of the grid is an + instance of models.Cell or its subclasses. + + Attributes: + resolution (PositiveFloat): The size of the sides of the grid's + squares. + origin (XYCoords): The coordinate reference for the origin point. + model (type[Cell]): The type of cell used to fill the grid. + cells (dict[RCCoords, Cell]): A mapping of grid coordinates to cell + instances. + total_count (int): The total number of data items added to grid. + """ + resolution: PositiveFloat origin: XYCoords model: type[Cell] = Cell @@ -18,6 +31,12 @@ class Grid(BaseModel): @property def dimensions(self) -> RCCoords: + """Calculate the extent of the grid in rows and columns. + + Returns: + RCCoords: The maximum row and column values currently within the + grid. + """ max_r = 0 max_c = 0 for c in self.cells: @@ -28,6 +47,13 @@ def dimensions(self) -> RCCoords: return RCCoords(max_r, max_c) def add_data(self, data: pd.DataFrame) -> None: + """Add positional data to grid points by calculating which cell in the + grid each dataset entry belongs to. + + Args: + data (pd.DataFrame): Data with 'x' and 'y' positional entries + """ + data_with_row_col = data.assign( row=((data.y - self.origin.y) // self.resolution).astype(int), col=((data.x - self.origin.x) // self.resolution).astype(int), @@ -48,6 +74,9 @@ def add_data(self, data: pd.DataFrame) -> None: self.total_count += len(cell_data.index) def update_model(self) -> None: + """Update all Cell models within the grid. Used after all data has been + added to the grid to finalize model parameters.""" + for cell in self.cells.values(): cell.update_model(self.total_count) @@ -55,6 +84,14 @@ def update_model(self) -> None: def assign_prior_to_grid( grid: Grid, prior: list[Probability], alpha: float ) -> None: + """Assigns a single prior probability value to all cells in a grid. + + Args: + grid (Grid): The grid to which the prior will be assigned. + prior (list[Probability]): The prior probabilities to be assigned to + each cell. + alpha (float): Concentration hyperparameter for the Dirichlet prior. + """ for cell in grid.cells.values(): assert isinstance(cell, BayesianDiscreteDirectional) cell.update_prior(prior, alpha) @@ -66,6 +103,16 @@ def assign_cell_priors_to_grid( alpha: float, add_missing_cells: bool = False, ) -> None: + """Assigns individual cell priors to each cell in the grid. + + Args: + grid (Grid): The grid to which the prior will be assigned. + priors (dict[RCCoords, list[Probability]]): Mapping of cell coordinates + to their corresponding prior probabilities. + alpha (float): Concentration hyperparameter for the Dirichlet prior. + add_missing_cells (bool, optional): Set to True to add cells that are + missing in grid but present in priors. Defaults to False. + """ for cell_id in priors: cell = grid.cells.get(cell_id) if cell: diff --git a/mod/models.py b/mod/models.py index 393b4e2..2e50f88 100644 --- a/mod/models.py +++ b/mod/models.py @@ -19,7 +19,18 @@ class Cell(BaseModel): - # TODO previously flipped (x = coords[1], y = coords[0]) + """Defines a single cell in a 2D grid representing a Map of Dynamics. + + Attributes: + coords (XYCoords): The origin coordinates of the cell in 2D space. + index (RCCoords): The row and column indices of the cell within the + grid. + resolution (PositiveFloat): The size of the sides of the cell's square. + probability (Probability): Transition probabilities associated with the + cell. Default is 0. + data (pd.DataFrame): Stored data points that fall within the cell. + """ + coords: XYCoords = Field(frozen=True) index: RCCoords = Field(frozen=True) resolution: PositiveFloat = Field(frozen=True) @@ -31,19 +42,28 @@ class Cell(BaseModel): @property def observation_count(self) -> int: + """The number of data points within the cell.""" return len(self.data.index) @cached_property def center(self) -> XYCoords: + """The center coordinates of the cell.""" return XYCoords( x=self.coords.x + self.resolution / 2, y=self.coords.y + self.resolution / 2, ) def add_data(self, data: pd.DataFrame) -> None: + """Adds the data points in `data` to the cell.""" self.data = pd.concat([self.data, data]) def compute_cell_probability(self, total_observations: int) -> None: + """Compute the probability associated with the cell. + + Args: + total_observations (int): Total observations in the grid to which + the cell belongs. + """ if total_observations: self.probability = self.observation_count / total_observations elif self.observation_count: @@ -53,12 +73,25 @@ def compute_cell_probability(self, total_observations: int) -> None: ) def update_model(self, total_observations: int) -> None: + """Update the cell model by computing the cell probability. + + Args: + total_observations (int): Total observations in the grid to which + the cell belongs. + """ self.compute_cell_probability(total_observations) class DiscreteDirectional(Cell): - """ - Floor Field + """A type of cell that divides the stored data into directional bins, + allowing for modeling of data with directional components. Corresponds to + the Floor Field model. + + Attributes: + bin_count (int): Number of directions, or bins, into which data is + divided. + bins (list[Probability]): Precomputed probabilities for each + directional bin. """ bin_count: int = Field(frozen=True, default=8) @@ -69,6 +102,7 @@ class DiscreteDirectional(Cell): def default_bins( cls, v: list[Probability], values: ValidationInfo ) -> list[Probability]: + """Validate the provided bins or create default even bins.""" if v is not None and len(v) != 0: if len(v) != values.data["bin_count"] or not np.isclose( np.sum(v), 1 @@ -83,13 +117,21 @@ def default_bins( @cached_property def half_split(self) -> float: + """Half of the angular coverage of each bin.""" return np.pi / self.bin_count @cached_property def directions(self) -> np.ndarray: + """An array representing the central directions for each bin.""" return np.arange(0, _2PI, _2PI / self.bin_count) def add_data(self, data: pd.DataFrame) -> None: + """Add data points to the cell, assigning them to a bin based on their + motion angle. + + Args: + data (pd.DataFrame): The data to add. + """ self.data = pd.concat( [ self.data, @@ -100,6 +142,7 @@ def add_data(self, data: pd.DataFrame) -> None: ) def bin_from_angle(self, rad: float) -> int: + """Calculates the bin index for a given angle.""" a = rad % _2PI for i, d in enumerate(self.directions): s = (d - self.half_split) % _2PI @@ -109,6 +152,7 @@ def bin_from_angle(self, rad: float) -> int: raise ValueError(f"{rad} does not represent an angle") def update_bin_probabilities(self) -> None: + """Update the probabilities of each bin.""" if not self.data.empty: for i in range(self.bin_count): self.bins[i] = ( @@ -119,13 +163,27 @@ def update_bin_probabilities(self) -> None: ), f"Bin probability sum equal to {sum(self.bins)}." def update_model(self, total_observations: int) -> None: + """Updates both the cell probabilities and bin probabilities based on + the current data added to the cell. + + Args: + total_observations (int): Total observations in the grid to which + the cell belongs. + """ self.compute_cell_probability(total_observations) self.update_bin_probabilities() class BayesianDiscreteDirectional(DiscreteDirectional): - """ - Bayesian Floor Field + """A subclass of `DiscreteDirectional` that extends the bin probability + calculation with a Bayesian approach. We refer to this as Bayesian Floor + Field. + + Attributes: + priors (list[Probability]): Prior probabilities for each directional + bin. + alpha (NonNegativeFloat): Concentration hyperparameter to be applied + during the Bayesian update. """ priors: list[Probability] = Field(default=[], validate_default=True) @@ -136,6 +194,7 @@ class BayesianDiscreteDirectional(DiscreteDirectional): def default_priors( cls, v: list[Probability], values: ValidationInfo ) -> list[Probability]: + """Validate the provided priors or create default uniform priors.""" if v is not None and len(v) != 0: if len(v) != values.data["bin_count"] or not np.isclose( np.sum(v), 1 @@ -149,11 +208,20 @@ def default_priors( ] def update_prior(self, priors: list[Probability], alpha: float) -> None: + """Update the prior probabilities and the alpha hyperparameter. + + Args: + priors (list[Probability]): The new prior probabilities for each + bin. + alpha (float): The new alpha hyperparameter. + """ self.priors = priors self.alpha = alpha self.update_bin_probabilities() def update_bin_probabilities(self) -> None: + """Updates the bin probabilities according to the current priors, + alpha, and data using a Dirichlet conjugate prior.""" if not self.data.empty: for i in range(self.bin_count): posterior = self.priors[i] * self.alpha + len( diff --git a/mod/occupancy.py b/mod/occupancy.py index 9060b66..07d29fa 100644 --- a/mod/occupancy.py +++ b/mod/occupancy.py @@ -11,6 +11,8 @@ class OccupancyMap: + """A class representing an occupancy map.""" + def __init__( self, image_file: Union[str, Path], @@ -20,6 +22,20 @@ def __init__( occupied_thresh: float, free_thresh: float, ): + """Init `OccupancyMap` class. + + Args: + image_file (Union[str, Path]): path to the occupancy map image. + resolution (float): resultion of the map in meters per pixel. + origin (Union[XYCoords, Sequence[float]]): the 2D pose of the + lower-left pixel in the map, as (x, y), in world coordinates. + negate (bool): whether the white/black free/occupied semantics + should be reversed (interpretation of thresholds is unaffected). + occupied_thresh (float): pixels with occupancy probability greater + than this threshold are considered completely occupied. + free_thresh (float): pixels with occupancy probability less than + this threshold are considered completely free. + """ self.resolution = resolution self.origin = XYCoords(x=origin[0], y=origin[1]) self.negate = negate @@ -41,6 +57,7 @@ def __init__( @property def binary_map(self) -> Image.Image: + """The binarized occupancy map""" if self.negate: return self.map.point(lambda p: p > self.occupied_thresh and 255) else: @@ -49,6 +66,7 @@ def binary_map(self) -> Image.Image: ) def pixel_from_XY(self, coords: XYCoords) -> TDRCCoords: + """Returns the pixel index corresponding to `coods`""" w, h = self.map.size row, column = TDRC_from_XY(coords, self.origin, self.resolution, h) if row < 0 or row >= h or column < 0 or column >= w: @@ -56,6 +74,7 @@ def pixel_from_XY(self, coords: XYCoords) -> TDRCCoords: return TDRCCoords(row, column) def XY_from_pixel(self, pixel: TDRCCoords) -> XYCoords: + """Returns the x-y coords of `pixel`""" w, h = self.map.size if ( pixel.row < 0 @@ -68,6 +87,7 @@ def XY_from_pixel(self, pixel: TDRCCoords) -> XYCoords: @classmethod def from_metadata(cls, metadata: dict) -> "OccupancyMap": + """Construct `OccupancyMap` from the given metadata""" return cls( image_file=metadata["image"], resolution=metadata["resolution"], @@ -79,6 +99,7 @@ def from_metadata(cls, metadata: dict) -> "OccupancyMap": @classmethod def from_yaml(cls, yaml_path: Union[str, Path]) -> "OccupancyMap": + """Construct `OccupancyMap` from a ROS map_server yaml metadata file""" with open(yaml_path, "r") as stream: meta = yaml.safe_load(stream) if not Path(meta["image"]).is_absolute(): diff --git a/mod/utils.py b/mod/utils.py index 41a3836..67934b0 100644 --- a/mod/utils.py +++ b/mod/utils.py @@ -38,6 +38,7 @@ class TDRCCoords(NamedTuple): def RC_from_TDRC(rc: TDRCCoords, num_rows: int) -> RCCoords: + """Converts a top-down row column pair into bottom-up from `num_rows`""" if num_rows <= rc.row: raise ValueError( f"Row index greater than available rows ({rc.row=}, {num_rows=})" @@ -46,6 +47,7 @@ def RC_from_TDRC(rc: TDRCCoords, num_rows: int) -> RCCoords: def TDRC_from_RC(rc: RCCoords, num_rows: int) -> TDRCCoords: + """Converts a bottom-up row column pair into top-down from `num_rows`""" if num_rows <= rc.row: raise ValueError( f"Row index greater than available rows ({rc.row=}, {num_rows=})" @@ -54,6 +56,7 @@ def TDRC_from_RC(rc: RCCoords, num_rows: int) -> TDRCCoords: def RC_from_XY(xy: XYCoords, origin: XYCoords, resolution: float) -> RCCoords: + """Converts an x-y coordinate pair into bottom-up row column coordinates""" row = int((xy.y - origin.y) // resolution) col = int((xy.x - origin.x) // resolution) if row < 0 or col < 0: @@ -62,6 +65,7 @@ def RC_from_XY(xy: XYCoords, origin: XYCoords, resolution: float) -> RCCoords: def XY_from_RC(rc: RCCoords, origin: XYCoords, resolution: float) -> XYCoords: + """Converts a bottom-up row column pair into x-y coordinates""" x = rc.column * resolution + origin.x y = rc.row * resolution + origin.y return XYCoords(x, y) @@ -70,16 +74,20 @@ def XY_from_RC(rc: RCCoords, origin: XYCoords, resolution: float) -> XYCoords: def TDRC_from_XY( xy: XYCoords, origin: XYCoords, resolution: float, num_rows: int ) -> TDRCCoords: + """Converts an x-y coordinate pair into top-down row column coordinates""" return TDRC_from_RC(RC_from_XY(xy, origin, resolution), num_rows) def XY_from_TDRC( rc: TDRCCoords, origin: XYCoords, resolution: float, num_rows: int ) -> XYCoords: + """Converts a top-down row column pair into x-y coordinates""" return XY_from_RC(RC_from_TDRC(rc, num_rows), origin, resolution) class Direction(IntEnum): + """A class representing a cardinal direction as an enumerated integer.""" + E = 0 NE = 1 N = 2 @@ -91,26 +99,32 @@ class Direction(IntEnum): @property def rad(self) -> float: + """The Direction in radians""" return self.value * _2PI / 8 @property def range(self) -> tuple[float, float]: + """Lower and upper bounds of the range of the Direction in radians""" a = self.rad return ((a - np.pi / 8) % _2PI, a + np.pi / 8) @property def u(self) -> float: + """the u component of the Direction""" return np.cos(self.rad) @property def v(self) -> float: + """the v component of the Direction""" return np.sin(self.rad) @property def uv(self) -> tuple[float, float]: + """the (u, v) components of the Direction""" return (self.u, self.v) def contains(self, rad: float) -> bool: + """Returns whether the range of this Direction contains angle `rad`""" a = rad % _2PI s, e = self.range return bool( @@ -120,6 +134,7 @@ def contains(self, rad: float) -> bool: @classmethod def from_rad(cls, rad: float) -> "Direction": + """Returns the Direction corresponding to the angle `rad`""" for dir in Direction: if dir.contains(rad): return dir @@ -129,6 +144,7 @@ def from_rad(cls, rad: float) -> "Direction": def from_points( cls, p1: tuple[float, float], p2: tuple[float, float] ) -> "Direction": + """Returns the Direction of the angle between `p1` and `p2`""" assert p1 != p2 rad = np.arctan2(p2[1] - p1[1], p2[0] - p1[0]) return cls.from_rad(rad) @@ -141,6 +157,21 @@ def extended_validator( tuple[Literal[False], exceptions.ValidationError], tuple[Literal[False], exceptions.SchemaError], ]: + """Validates a JSON file against a schema extended to handle default + properties. + + Args: + json_path (str): The path of the JSON file to be validated. + schema_path (str): The path of the JSON Schema file. + + Returns: + Union[tuple[Literal[True], dict], tuple[Literal[False], + exceptions.ValidationError], tuple[Literal[False], + exceptions.SchemaError]]: Returns a tuple where the first element is + True if validation is successful and False otherwise. The second + element is either the content of the JSON file as a dictionary if + validation is successful, or else an error message. + """ schema_file = open(schema_path, "r") my_schema = json.load(schema_file) @@ -184,4 +215,20 @@ def get_local_settings( tuple[Literal[False], exceptions.ValidationError], tuple[Literal[False], exceptions.SchemaError], ]: + """Fetches and validates local settings from a JSON file. + + Args: + json_path (str, optional): The path of the local settings JSON file. + Defaults to "config/local_settings.json". + schema_path (str, optional): The path of the local settings JSON Schema + file. Defaults to "config/local_settings_schema.json". + + Returns: + Union[tuple[Literal[True], dict], tuple[Literal[False], + exceptions.ValidationError], tuple[Literal[False], + exceptions.SchemaError]]: Returns a tuple where the first element is + True if validation is successful and False otherwise. The second + element is either the local settings data as a dictionary if validation + is successful, or else an error message. + """ return extended_validator(json_path, schema_path) diff --git a/mod/visualisation.py b/mod/visualisation.py index 7a2804f..d2bf647 100644 --- a/mod/visualisation.py +++ b/mod/visualisation.py @@ -9,6 +9,7 @@ def polar2cart(theta: float, r: float) -> tuple[float, float]: + """Converts polar coordinates to cartesian coordinates""" z = r * np.exp(1j * theta) return np.real(z), np.imag(z) @@ -16,19 +17,20 @@ def polar2cart(theta: float, r: float) -> tuple[float, float]: def show_all( grid: Grid, occ: Optional[OccupancyMap] = None, - occ_overlay: bool = False, dpi: int = 100, ) -> None: + """Calls `show_raw` and `show_discrete_directional` in sequence""" show_raw(grid, dpi) if ( grid.model == DiscreteDirectional or grid.model == BayesianDiscreteDirectional ): - show_discrete_directional(grid, occ, occ_overlay, dpi) + show_discrete_directional(grid, occ, dpi) plt.show() def show_raw(grid: Grid, dpi: int = 100) -> None: + """Plots all the data in the given Grid""" plt.figure(dpi=dpi) for cell in grid.cells.values(): @@ -38,10 +40,10 @@ def show_raw(grid: Grid, dpi: int = 100) -> None: def show_discrete_directional( grid: Grid, occ: Optional[OccupancyMap] = None, - occ_overlay: bool = False, dpi: int = 100, save_name: Optional[str] = None, ) -> None: + """Plots the dynamics in `Grid`, optionally overlaid on the map `occ`.""" plt.figure(dpi=dpi) X = [] Y = [] @@ -64,7 +66,7 @@ def show_discrete_directional( for i in range(8) ] ) - if occ and occ_overlay: + if occ: show_occupancy(occ) plt.quiver(X, Y, U, V, scale_units="xy", scale=2, minshaft=2, minlength=0) if save_name: @@ -72,6 +74,7 @@ def show_discrete_directional( def show_occupancy(occ: OccupancyMap) -> None: + """Plots the given occupancy map""" r = occ.resolution o = occ.origin sz = occ.map.size diff --git a/plot_quivers_atc.py b/plot_quivers_atc.py index 428ef0b..c4a748b 100644 --- a/plot_quivers_atc.py +++ b/plot_quivers_atc.py @@ -75,15 +75,12 @@ grid.assign_cell_priors_to_grid( grid=grid_bayes, priors=net_prior, alpha=ALPHA, add_missing_cells=True ) -show_discrete_directional( - grid_bayes, occupancy, occ_overlay=True, dpi=3000, save_name="atc_0" -) +show_discrete_directional(grid_bayes, occupancy, dpi=3000, save_name="atc_0") grid_bayes = pickle.load(open(GRID_BAYES_DATA[10000], "rb")) show_discrete_directional( grid_bayes, occupancy, - occ_overlay=True, dpi=3000, save_name="atc_10k_noprior", ) @@ -91,11 +88,7 @@ grid.assign_cell_priors_to_grid( grid=grid_bayes, priors=net_prior, alpha=ALPHA, add_missing_cells=True ) -show_discrete_directional( - grid_bayes, occupancy, occ_overlay=True, dpi=3000, save_name="atc_10k" -) +show_discrete_directional(grid_bayes, occupancy, dpi=3000, save_name="atc_10k") grid_bayes = pickle.load(open(GRID_TEST_DATA, "rb")) -show_discrete_directional( - grid_bayes, occupancy, occ_overlay=True, dpi=3000, save_name="atc_gt" -) +show_discrete_directional(grid_bayes, occupancy, dpi=3000, save_name="atc_gt") diff --git a/plot_quivers_kth.py b/plot_quivers_kth.py index c217169..dbf3d30 100644 --- a/plot_quivers_kth.py +++ b/plot_quivers_kth.py @@ -66,15 +66,12 @@ grid.assign_cell_priors_to_grid( grid=grid_bayes, priors=net_prior, alpha=ALPHA, add_missing_cells=True ) -show_discrete_directional( - grid_bayes, occupancy, occ_overlay=True, dpi=3000, save_name="kth_0" -) +show_discrete_directional(grid_bayes, occupancy, dpi=3000, save_name="kth_0") grid_bayes = pickle.load(open(GRID_BAYES_DATA[10000], "rb")) show_discrete_directional( grid_bayes, occupancy, - occ_overlay=True, dpi=3000, save_name="kth_10k_noprior", ) @@ -82,11 +79,7 @@ grid.assign_cell_priors_to_grid( grid=grid_bayes, priors=net_prior, alpha=ALPHA, add_missing_cells=True ) -show_discrete_directional( - grid_bayes, occupancy, occ_overlay=True, dpi=3000, save_name="kth_10k" -) +show_discrete_directional(grid_bayes, occupancy, dpi=3000, save_name="kth_10k") grid_bayes = pickle.load(open(GRID_TEST_DATA, "rb")) -show_discrete_directional( - grid_bayes, occupancy, occ_overlay=True, dpi=3000, save_name="kth_gt" -) +show_discrete_directional(grid_bayes, occupancy, dpi=3000, save_name="kth_gt") diff --git a/tests/bff/utils_test.py b/tests/bff/utils_test.py index 025daf7..19a7fb8 100644 --- a/tests/bff/utils_test.py +++ b/tests/bff/utils_test.py @@ -2,16 +2,10 @@ import numpy as np import pytest -from torch import device, float32 +from torch import device from bff.nets import DiscreteDirectional -from bff.utils import ( - Trainer, - Window, - estimate_dynamics, - random_input, - scale_quivers, -) +from bff.utils import Trainer, Window, estimate_dynamics, scale_quivers from mod.occupancy import OccupancyMap @@ -26,17 +20,6 @@ def test_scale_quivers() -> None: assert (scaled[2, :] == [1 / 2, 1, 1 / 4, 1 / 4, 0, 0, 0, 0]).all() -@pytest.mark.parametrize( - ["p_occupied", "expected"], - [(0, np.zeros((1, 1, 16, 16))), (1, np.ones((1, 1, 16, 16)))], -) -def test_random_input(p_occupied: float, expected: np.ndarray) -> None: - expected[0, 0, 8, 8] = 0 # center should always be zero - tensor = random_input(size=16, p_occupied=p_occupied) - assert tensor.dtype == float32 - assert (tensor.numpy() == expected).all() - - def test_window_size() -> None: w = Window(4) assert w.size == 4 diff --git a/training_atc.py b/training_atc.py index 4941f5f..2e4c01b 100644 --- a/training_atc.py +++ b/training_atc.py @@ -66,7 +66,7 @@ grid_train: grid.Grid = pickle.load(open(GRID_TRAIN_DATA, "rb")) grid_test: grid.Grid = pickle.load(open(GRID_TEST_DATA, "rb")) -show_all(grid_train, occupancy, occ_overlay=True, dpi=PLOT_DPI) +show_all(grid_train, occupancy, dpi=PLOT_DPI) # transform = None transform = transforms.Compose(