From 09fe476a44a4cefb35d1f465e841e496907eabef Mon Sep 17 00:00:00 2001 From: saoussen <123117327+saoussenIris@users.noreply.github.com> Date: Sun, 11 Aug 2024 20:16:46 -0700 Subject: [PATCH] Added transient validation display function (#57) --- solid_utils/plotting.py | 197 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 196 insertions(+), 1 deletion(-) diff --git a/solid_utils/plotting.py b/solid_utils/plotting.py index 6c75b7e..b170d0a 100644 --- a/solid_utils/plotting.py +++ b/solid_utils/plotting.py @@ -1,5 +1,6 @@ # Author: Marin Govorcin # June, 2024 +# Transient validation display function added by Saoussen Belhadj-aissa. July, 2024 import pandas as pd import numpy as np @@ -161,4 +162,198 @@ def style_success_fail(value): .format(lambda x: f'{x*100:.0f}%', na_rep="none", precision=1, subset=['passed_pc']) .apply(bold_last_row, axis=1) .map(style_success_fail, subset=[validation_table.columns[-1]]) - ) \ No newline at end of file + ) + + + + + +def display_transient_validation(pair_distances: list, pair_differences: list, ifgs_dates: list, + site_name: str, distance_rqmt: list = [0.1, 50], + n_bins: int = 10, threshold: float = 0.683, sensor: str = 'Sentinel-1', + validation_data: str = 'GNSS'): + + + ''' + Parameters: + pair_distances : array - lis of 1d array of pair distances used in validation + pair_differences : array - list of 1d array 0f pair double differenced displacement residuals + site_name : str - name of the cal/val site + start_date : str - data record start date, eg. 20190101 + end_date : str - data record end date, eg. 20200101 + distance_rqmt : list - distance over requirement is tested, eg. length scales of 0.1-50 km + n_bins : int - number of bins + threshold : float - threshold represents percentile of Gaussian normal distribution + within residuals are expected to be to pass the test + e.g. 0.683 for 68.3% or 1-sigma limit + sensor : str - sensor used in validation, e.g Sentinel-1 or NISAR + validation_data : str - data used to validate against; GNSS or INSAR + + Return + validation_table : styled_df + validation_figure : fig + ''' + validation_type = 'Transient' + maxY=80 ## Y limit in the subplot + + n_ifgs = len(pair_distances) ## Number of interferograms to validate + + + # Data frame initialization + bins = np.linspace(*distance_rqmt, num=n_bins + 1) + bin_centers = (bins[:-1] + bins[1:]) / 2 + columns = [f'{bins[i]:.2f}-{bins[i + 1]:.2f}' for i in range(n_bins)] + ['Total'] + index = [f"{start.strftime('%Y%m%d')}-{end.strftime('%Y%m%d')}" for start, end in ifgs_dates] + n_all = np.zeros([n_ifgs, n_bins + 1], dtype=int) + n_pass = np.zeros([n_ifgs, n_bins + 1], dtype=int) + + ## requirements per interferogram and each interferogram per bin + for i in range(n_ifgs): + inds = np.digitize(pair_distances[i], bins) + for j in range(1, n_bins + 1): + mask = inds == j + rem = np.abs(pair_differences[i][mask]) + rqmt = 3 * (1 + np.sqrt(bins[j - 1])) + n_all[i, j - 1] = len(rem) + n_pass[i, j - 1] = np.sum(rem < rqmt) + n_all[i, -1] = np.sum(n_all[i, :-1]) + n_pass[i, -1] = np.sum(n_pass[i, :-1]) + + # Calculation of ratios and success/failure + ratio = n_pass / np.where(n_all > 0, n_all, 1) + success_or_fail = ratio > threshold + + # Creation of DataFrames for validation table + n_all_pd = pd.DataFrame(n_all, columns=columns, index=index) + n_pass_pd = pd.DataFrame(n_pass, columns=columns, index=index) + ratio_pd = pd.DataFrame(ratio, columns=columns, index=index) + success_or_fail_str = pd.DataFrame(success_or_fail.astype(str), columns=columns, index=index) + + # # Styling the DataFrame + def style_specific_cells(val): + color = '#e6ffe6' if val > threshold else '#ffe6e6' + return f'background-color: {color}' + + # Apply style to all cells, and bold for 'Total' column + styled_df = (ratio_pd.style.applymap(style_specific_cells) + .apply(lambda x: ['font-weight: bold' if x.name == 'Total' else '' for _ in x], axis=0)) + + + # Start subplot, each subplot represent validation test per interferogram + num_cols = 3 ## Can be changed to adjust subplot figure + num_rows = (n_ifgs + num_cols - 1) // num_cols # Calculate number of rows needed + + fig, axs = plt.subplots(num_rows, num_cols, figsize=(20, 4*num_rows)) + + axs = np.array(axs).reshape(num_rows, num_cols) + + for i in range(0,n_ifgs): + row = i // num_cols + col = i % num_cols + ax = axs[row, col] + ## Data frame for validation of each interferogram + df = pd.DataFrame(np.vstack([pair_distances[i], + pair_differences [i]]).T, + columns=['distance', 'double_diff']) + # remove nans, draw bins and group by distance + df_nonan = df.dropna(subset=['double_diff']) + bins = np.linspace(*distance_rqmt, num=n_bins+1) + bin_centers = (bins[:-1] + bins[1:]) / 2 + binned_df = df_nonan.groupby(pd.cut(df_nonan['distance'], bins), + observed=False)[['double_diff']] + validation = pd.DataFrame(data={ + 'total_count[#]': n_all[i], + 'passed_req.[#]': n_pass[i], + }) + validation['passed_pc'] = validation['passed_req.[#]'] / validation['total_count[#]'] + validation['success_fail'] = validation['passed_pc'] > threshold + validation.index.name = 'distance[km]' + validation.rename({validation.iloc[-1].name: 'Total'}, inplace=True) + # start scatter plot + ms = 8 if len(pair_differences[i]) < 1e4 else 0.3 + alpha = 0.6 if len(pair_differences[i]) < 1e4 else 0.2 + + ax.scatter(pair_distances[i], pair_differences[i], + color='black', s=ms, zorder=1, alpha=alpha, edgecolor='None') + + # Plot validation requirement log fit + dist_th = np.linspace(min(pair_distances[i]), max(pair_distances[i]), 100) + acpt_error = 3 * (1 + np.sqrt(dist_th)) + ax.plot(dist_th, acpt_error, 'r') + + # Vertical lines for bins + ax.vlines(bins, 0, maxY, linewidth=0.3, color='gray', zorder=1) + + # Bar plot for each bin + quantile_th = binned_df.quantile(q=threshold)['double_diff'].values + for bin_center, quantile, flag in zip(bin_centers, + quantile_th, + validation['success_fail']): + if flag: + color = '#227522' + else: + color = '#7c1b1b' + ax.bar(bin_center, quantile, align='center', width=np.diff(bins)[0], + color='None', edgecolor=color, linewidth=2, zorder=3) + # Add legend with data info + legend_kwargs = dict(transform=ax.transAxes, verticalalignment='top') + props = dict(boxstyle='square', facecolor='white', alpha=1, linewidth=0.4) + textstr = f'Sensor: {sensor} \n{validation_data}-InSAR point pairs\n' + textstr += f"Record: {index[i]}" + + # Place a text box in upper left in axes coords + ax.text(0.02, 0.95, textstr, fontsize=7, bbox=props, **legend_kwargs) + + # Add legend with validation info + textstr = f'{validation_type.capitalize()} Req \n' + textstr += f'Site: {site_name}\n' + if validation.loc['Total']['success_fail']: + validation_flag = 'PASSED' + validation_color = '#239d23' + else: + validation_flag ='FAILED' + validation_color = '#bc2e2e' + + props = {**props, **{'facecolor':'none', 'edgecolor':'none'}} + ax.text(0.818, 0.93, textstr, fontsize=8, bbox=props, **legend_kwargs) + ax.text(0.852, 0.82, f"{validation_flag}", + fontsize=10, weight='bold', + bbox=props, **legend_kwargs) + + # Add colored rectangle indicating validation status + rect = patches.Rectangle((0.8, 0.75), 0.19, 0.2, + linewidth=1, edgecolor='black', + facecolor=validation_color, + transform=ax.transAxes) + ax.add_patch(rect) + + # Title & labels + + ax.set_xlabel("Distance (km)", fontsize=8) + if validation_data == 'GNSS': + txt = "Double-Differenced \n Displacement Residual (mm)" + else: + txt = "Relative Velocity measurement (mm/yr)" + ax.set_ylabel(txt, fontsize=8) + ax.minorticks_on() + ax.tick_params(axis='x', which='minor', length=4, direction='in', top=False, width=1.5) + ax.tick_params(axis='both', labelsize=8) + ax.set_xticks(bin_centers, minor=True) + ax.set_xticks(np.arange(0, 55, 5)) + ax.set_ylim(0, maxY) + ax.set_xlim(*distance_rqmt) + ax.set_title(f"Residuals \n Date range {index[i]} \n Number of station pairs used: {len(pair_distances[i])} \n Cal/Val Site Los Angeles " ) + + # Hide unused subplots if there are any + for idx in range(n_ifgs, num_rows*num_cols): + axs.flat[idx].axis('off') + # figure title + fig.suptitle(f"{validation_type.capitalize()} requirement for site : {site_name} \n", fontsize=18, fontweight='bold') + plt.tight_layout() + + ## In case we want to save the figure + # plt.savefig(f'transient_validation_{index[i]}.png', bbox_inches='tight', transparent=True) + + plt.close() + return styled_df, fig +