From 28fc04764aca36778431905ae6abfedf6cc5337f Mon Sep 17 00:00:00 2001 From: bw4sz Date: Thu, 9 Jan 2025 10:13:13 -0800 Subject: [PATCH 1/2] Improve point plotting behavior, making it the same as box or polygon if no scores are present. --- src/deepforest/visualize.py | 6 +++++- tests/test_visualize.py | 20 ++++++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/deepforest/visualize.py b/src/deepforest/visualize.py index 51097d15..1ea216ba 100644 --- a/src/deepforest/visualize.py +++ b/src/deepforest/visualize.py @@ -347,7 +347,11 @@ def convert_to_sv_format(df, width=None, height=None): labels = df['label'].map(label_mapping).values.astype(int) # Extract scores as a numpy array - scores = np.array(df['score'].tolist()) + try: + scores = np.array(df['score'].tolist()) + except KeyError: + scores = np.ones(len(labels)) + scores = np.expand_dims(np.stack(scores), 1) # Create a reverse mapping from integer to string labels diff --git a/tests/test_visualize.py b/tests/test_visualize.py index f4f6941a..bd494ef7 100644 --- a/tests/test_visualize.py +++ b/tests/test_visualize.py @@ -165,8 +165,7 @@ def test_plot_results_point(m, tmpdir): 'y': [15, 25], 'label': ['Tree', 'Tree'], 'image_path': [get_data("OSBS_029.tif"), get_data("OSBS_029.tif")], - 'score': [0.9, 0.8], - 'label': ['Tree', 'Tree'] + 'score': [0.9, 0.8] } df = pd.DataFrame(data) gdf = read_file(df, root_dir=os.path.dirname(get_data("OSBS_029.tif"))) @@ -178,6 +177,23 @@ def test_plot_results_point(m, tmpdir): # Assertions assert os.path.exists(os.path.join(tmpdir, "OSBS_029.png")) +def test_plot_results_point_no_label(m, tmpdir): + # Create a mock DataFrame with point annotations + data = { + 'x': [15, 25], + 'y': [15, 25], + 'label': ['Tree', 'Tree'], + 'image_path': [get_data("OSBS_029.tif"), get_data("OSBS_029.tif")], + } + df = pd.DataFrame(data) + gdf = read_file(df, root_dir=os.path.dirname(get_data("OSBS_029.tif"))) + gdf.root_dir = os.path.dirname(get_data("OSBS_029.tif")) + + # Call the function + visualize.plot_results(gdf, savedir=tmpdir) + + # Assertions + assert os.path.exists(os.path.join(tmpdir, "OSBS_029.png")) def test_plot_results_polygon(m, tmpdir): # Create a mock DataFrame with polygon annotations From 9e57da670086ff745d5993c5901ad2f4408c5324 Mon Sep 17 00:00:00 2001 From: bw4sz Date: Mon, 13 Jan 2025 09:32:32 -0800 Subject: [PATCH 2/2] style change --- src/deepforest/visualize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/deepforest/visualize.py b/src/deepforest/visualize.py index 1ea216ba..1fd836cd 100644 --- a/src/deepforest/visualize.py +++ b/src/deepforest/visualize.py @@ -351,7 +351,7 @@ def convert_to_sv_format(df, width=None, height=None): scores = np.array(df['score'].tolist()) except KeyError: scores = np.ones(len(labels)) - + scores = np.expand_dims(np.stack(scores), 1) # Create a reverse mapping from integer to string labels