Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make search results more ergonomic #498

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions tagstudio/src/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from datetime import datetime, UTC
import shutil
from os import makedirs
Expand Down Expand Up @@ -92,6 +93,33 @@ def get_default_tags() -> tuple[Tag, ...]:
return archive_tag, favorite_tag


@dataclass(frozen=True)
class SearchResult:
"""Wrapper for search results.

:param total_count: total number of items for given query, might be different than len(items)
:param items: items for current page (size matches filter.page_size)
"""

total_count: int
items: list[Entry]

def __bool__(self) -> bool:
"""Boolean evaluation for the wrapper.

:return: True if there are items in the result.
"""
return self.total_count > 0

def __len__(self) -> int:
"""Return the total number of items in the result."""
return len(self.items)

def __getitem__(self, index: int) -> Entry:
"""Allow to access items via index directly on the wrapper."""
return self.items[index]


class Library:
"""Class for the Library object, and all CRUD operations made upon it."""

Expand Down Expand Up @@ -338,7 +366,7 @@ def has_path_entry(self, path: Path) -> bool:
def search_library(
self,
search: FilterState,
) -> tuple[int, list[Entry]]:
) -> SearchResult:
"""Filter library by search query.

:return: number of entries matching the query and one page of results.
Expand Down Expand Up @@ -414,11 +442,14 @@ def search_library(
),
)

entries_ = list(session.scalars(statement).unique())
res = SearchResult(
total_count=count_all,
items=list(session.scalars(statement).unique()),
)

session.expunge_all()

return count_all, entries_
return res

def search_tags(
self,
Expand Down
6 changes: 3 additions & 3 deletions tagstudio/src/core/utils/dupe_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def refresh_dupe_files(self, results_filepath: str | Path):
# The file is not in the library directory
continue

_, entries = self.library.search_library(
results = self.library.search_library(
FilterState(path=path_relative),
)

if not entries:
if not results:
# file not in library
continue

files.append(entries[0])
files.append(results[0])

if not len(files) > 1:
# only one file in the group, nothing to do
Expand Down
12 changes: 6 additions & 6 deletions tagstudio/src/qt/ts_qt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,26 +1009,26 @@ def filter_items(self, filter: FilterState | None = None) -> None:
self.main_window.statusbar.repaint()
start_time = time.time()

query_count, page_items = self.lib.search_library(self.filter)
results = self.lib.search_library(self.filter)

logger.info("items to render", count=len(page_items))
logger.info("items to render", count=len(results))

end_time = time.time()
if self.filter.summary:
self.main_window.statusbar.showMessage(
f'{query_count} Results Found for "{self.filter.summary}" ({format_timespan(end_time - start_time)})'
f'{results.total_count} Results Found for "{self.filter.summary}" ({format_timespan(end_time - start_time)})'
)
else:
self.main_window.statusbar.showMessage(
f"{query_count} Results ({format_timespan(end_time - start_time)})"
f"{results.total_count} Results ({format_timespan(end_time - start_time)})"
)

# update page content
self.frame_content = list(page_items)
self.frame_content = results.items
self.update_thumbs()

# update pagination
self.pages_count = math.ceil(query_count / self.filter.page_size)
self.pages_count = math.ceil(results.total_count / self.filter.page_size)
self.main_window.pagination.update_buttons(
self.pages_count, self.filter.page_index, emit=False
)
Expand Down
2 changes: 1 addition & 1 deletion tagstudio/src/qt/widgets/item_thumb.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def on_badge_check(self, badge_type: BadgeType):
# update the entry
self.driver.frame_content[idx] = self.lib.search_library(
FilterState(id=entry.id)
)[1][0]
).items[0]

self.driver.update_badges(update_items)

Expand Down
17 changes: 10 additions & 7 deletions tagstudio/src/qt/widgets/preview_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def update_selected_entry(driver: "QtDriver"):
for grid_idx in driver.selected:
entry = driver.frame_content[grid_idx]
# reload entry
_, entries = driver.lib.search_library(FilterState(id=entry.id))
results = driver.lib.search_library(FilterState(id=entry.id))
logger.info(
"found item", entries=entries, grid_idx=grid_idx, lookup_id=entry.id
"found item", entries=len(results), grid_idx=grid_idx, lookup_id=entry.id
)
assert entries, f"Entry not found: {entry.id}"
driver.frame_content[grid_idx] = entries[0]
assert results, f"Entry not found: {entry.id}"
driver.frame_content[grid_idx] = next(results)


class PreviewPanel(QWidget):
Expand Down Expand Up @@ -499,11 +499,14 @@ def update_widgets(self) -> bool:
# TODO - Entry reload is maybe not necessary
for grid_idx in self.driver.selected:
entry = self.driver.frame_content[grid_idx]
_, entries = self.lib.search_library(FilterState(id=entry.id))
results = self.lib.search_library(FilterState(id=entry.id))
logger.info(
"found item", entries=entries, grid_idx=grid_idx, lookup_id=entry.id
"found item",
entries=len(results.items),
grid_idx=grid_idx,
lookup_id=entry.id,
)
self.driver.frame_content[grid_idx] = entries[0]
self.driver.frame_content[grid_idx] = results[0]

if len(self.driver.selected) == 1:
# 1 Selected Entry
Expand Down
4 changes: 2 additions & 2 deletions tagstudio/tests/macros/test_missing_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ def test_refresh_missing_files(library: Library):
assert list(registry.fix_missing_files()) == [1, 2]

# `bar.md` should be relinked to new correct path
_, entries = library.search_library(FilterState(path="bar.md"))
assert entries[0].path == pathlib.Path("bar.md")
results = library.search_library(FilterState(path="bar.md"))
assert results[0].path == pathlib.Path("bar.md")
66 changes: 31 additions & 35 deletions tagstudio/tests/test_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ def test_library_search(library, generate_tag, entry_full):
assert library.entries_count == 2
tag = list(entry_full.tags)[0]

query_count, items = library.search_library(
results = library.search_library(
FilterState(
tag=tag.name,
),
)

assert query_count == 1
assert len(items) == 1
assert results.total_count == 1
assert len(results) == 1

entry = items[0]
entry = results[0]
assert {x.name for x in entry.tags} == {
"foo",
}
Expand All @@ -92,9 +92,9 @@ def test_tag_search(library):

def test_get_entry(library, entry_min):
assert entry_min.id
cnt, entries = library.search_library(FilterState(id=entry_min.id))
assert len(entries) == cnt == 1
assert entries[0].tags
results = library.search_library(FilterState(id=entry_min.id))
assert len(results) == results.total_count == 1
assert results[0].tags


def test_entries_count(library):
Expand All @@ -103,14 +103,14 @@ def test_entries_count(library):
for x in range(10)
]
library.add_entries(entries)
matches, page = library.search_library(
results = library.search_library(
FilterState(
page_size=5,
)
)

assert matches == 12
assert len(page) == 5
assert results.total_count == 12
assert len(results) == 5


def test_add_field_to_entry(library):
Expand Down Expand Up @@ -144,8 +144,8 @@ def test_add_field_tag(library, entry_full, generate_tag):
library.add_field_tag(entry_full, tag, tag_field.type_key)

# Then
_, entries = library.search_library(FilterState(id=entry_full.id))
tag_field = entries[0].tag_box_fields[0]
results = library.search_library(FilterState(id=entry_full.id))
tag_field = results[0].tag_box_fields[0]
assert [x.name for x in tag_field.tags if x.name == tag_name]


Expand Down Expand Up @@ -177,15 +177,15 @@ def test_search_filter_extensions(library, is_exclude):
library.set_prefs(LibraryPrefs.EXTENSION_LIST, ["md"])

# When
query_count, items = library.search_library(
results = library.search_library(
FilterState(),
)

# Then
assert query_count == 1
assert len(items) == 1
assert results.total_count == 1
assert len(results) == 1

entry = items[0]
entry = results[0]
assert (entry.path.suffix == ".txt") == is_exclude


Expand All @@ -198,15 +198,15 @@ def test_search_library_case_insensitive(library):
tag = list(entry.tags)[0]

# When
query_count, items = library.search_library(
results = library.search_library(
FilterState(tag=tag.name.upper()),
)

# Then
assert query_count == 1
assert len(items) == 1
assert results.total_count == 1
assert len(results) == 1

assert items[0].id == entry.id
assert results[0].id == entry.id


def test_preferences(library):
Expand All @@ -229,11 +229,11 @@ def test_save_windows_path(library, generate_tag):
# library.add_tag(tag)
library.add_field_tag(entry, tag, create_field=True)

_, found = library.search_library(FilterState(tag=tag_name))
assert found
results = library.search_library(FilterState(tag=tag_name))
assert results

# path should be saved in posix format
assert str(found[0].path) == "foo/bar.txt"
assert str(results[0].path) == "foo/bar.txt"


def test_remove_entry_field(library, entry_full):
Expand Down Expand Up @@ -310,13 +310,13 @@ def test_mirror_entry_fields(library, entry_full):

entry_id = library.add_entries([target_entry])[0]

_, entries = library.search_library(FilterState(id=entry_id))
new_entry = entries[0]
results = library.search_library(FilterState(id=entry_id))
new_entry = results[0]

library.mirror_entry_fields(new_entry, entry_full)

_, entries = library.search_library(FilterState(id=entry_id))
entry = entries[0]
results = library.search_library(FilterState(id=entry_id))
entry = results[0]

assert len(entry.fields) == 4
assert {x.type_key for x in entry.fields} == {
Expand Down Expand Up @@ -348,13 +348,11 @@ def test_remove_tag_from_field(library, entry_full):
],
)
def test_search_file_name(library, query_name, has_result):
res_count, items = library.search_library(
results = library.search_library(
FilterState(name=query_name),
)

assert (
res_count == has_result
), f"mismatch with query: {query_name}, result: {res_count}"
assert results.total_count == has_result


@pytest.mark.parametrize(
Expand All @@ -367,13 +365,11 @@ def test_search_file_name(library, query_name, has_result):
],
)
def test_search_entry_id(library, query_name, has_result):
res_count, items = library.search_library(
results = library.search_library(
FilterState(id=query_name),
)

assert (
res_count == has_result
), f"mismatch with query: {query_name}, result: {res_count}"
assert results.total_count == has_result


def test_update_field_order(library, entry_full):
Expand Down