Skip to content

Commit

Permalink
Add search_entry_counts().
Browse files Browse the repository at this point in the history
For #185.
  • Loading branch information
lemon24 committed Nov 27, 2020
1 parent 08be2d4 commit df39dde
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 8 deletions.
84 changes: 84 additions & 0 deletions src/reader/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
from ._storage import apply_random
from ._storage import apply_recent
from ._types import EntryFilterOptions
from ._utils import exactly_one
from ._utils import join_paginated_iter
from .exceptions import InvalidSearchQueryError
from .exceptions import SearchError
from .exceptions import SearchNotEnabledError
from .types import EntrySearchCounts
from .types import EntrySearchResult
from .types import HighlightedString
from .types import SearchSortOrder
Expand Down Expand Up @@ -110,6 +112,8 @@ def strip_html(text: SQLiteType, features: Optional[str] = None) -> SQLiteType:
# have a look at the lessons here first:
# https://github.com/lemon24/reader/issues/175#issuecomment-657495233

# When adding a new method, add a new test_search.py::test_errors_locked test.


class Search:

Expand Down Expand Up @@ -859,6 +863,35 @@ def value_factory(t: Tuple[Any, ...]) -> EntrySearchResult:

raise

@wrap_exceptions(SearchError)
def search_entry_counts(
self,
query: str,
filter_options: EntryFilterOptions = EntryFilterOptions(), # noqa: B008
) -> EntrySearchCounts:
sql_query, query_context = make_search_entry_counts_query(filter_options)
context = dict(query=query, **query_context)

try:
row = exactly_one(self.db.execute(str(sql_query), context))
except sqlite3.OperationalError as e:
# TODO: dedupe with search_entries_page
msg_lower = str(e).lower()

if 'no such table' in msg_lower:
raise SearchNotEnabledError()

is_query_error = any(
fragment in msg_lower
for fragment in self._query_error_message_fragments
)
if is_query_error:
raise InvalidSearchQueryError(message=str(e))

raise

return EntrySearchCounts(*row)


def make_search_entries_query(
filter_options: EntryFilterOptions, sort: SearchSortOrder
Expand Down Expand Up @@ -930,3 +963,54 @@ def make_search_entries_query(
log.debug("_search_entries query\n%s\n", query)

return query, context


def make_search_entry_counts_query(
filter_options: EntryFilterOptions,
) -> Tuple[Query, Dict[str, Any]]:
# FIXME: dedupe with make_search_entries_query
search = (
Query()
.SELECT("_id, _feed")
.FROM("entries_search")
.JOIN("entries ON (entries.id, entries.feed) = (_id, _feed)")
.WHERE("entries_search MATCH :query")
# https://www.mail-archive.com/[email protected]/msg115821.html
# rule 14 https://www.sqlite.org/optoverview.html#subquery_flattening
.LIMIT("-1 OFFSET 0")
)

context = apply_entry_filter_options(search, filter_options)

search_grouped = (
Query()
.SELECT("_id, _feed")
.FROM("search")
.GROUP_BY("search._id", "search._feed")
)

query = (
Query()
.WITH(("search", str(search)), ("search_grouped", str(search_grouped)),)
.SELECT(
'count(*)',
'coalesce(sum(read == 1), 0)',
'coalesce(sum(important == 1), 0)',
"""
coalesce(
sum(
NOT (
json_array_length(entries.enclosures) IS NULL OR json_array_length(entries.enclosures) = 0
)
), 0
)
""",
)
.FROM("entries")
# FIXME: is this better as a WHERE (entries.id, entries.feed) in ...?
.JOIN("search_grouped ON (id, feed) = (_id, _feed)")
)

log.debug("_search_entry_counts query\n%s\n", query)

return query, context
49 changes: 49 additions & 0 deletions src/reader/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .types import Entry
from .types import EntryCounts
from .types import EntryInput
from .types import EntrySearchCounts
from .types import EntrySearchResult
from .types import EntrySortOrder
from .types import Feed
Expand Down Expand Up @@ -1125,6 +1126,54 @@ def search_entries(
now = self._now()
return self._search.search_entries(query, now, filter_options, sort)

def search_entry_counts(
self,
query: str,
*,
feed: Optional[FeedInput] = None,
entry: Optional[EntryInput] = None,
read: Optional[bool] = None,
important: Optional[bool] = None,
has_enclosures: Optional[bool] = None,
feed_tags: TagFilterInput = None,
) -> EntrySearchCounts:
"""Count entries matching a full-text search query.
See :meth:`~Reader.search_entries()` for details on how
the query syntax and filtering work.
Search must be enabled to call this method.
Args:
query (str): The search query.
feed (str or Feed or None): Only count the entries for this feed.
entry (tuple(str, str) or Entry or None):
Only count the entry with this (feed URL, entry id) tuple.
read (bool or None): Only count (un)read entries.
important (bool or None): Only count (un)important entries.
has_enclosures (bool or None): Only count entries that (don't)
have enclosures.
feed_tags (None or bool or list(str or bool or list(str or bool))):
Only count the entries from feeds matching these tags.
Returns:
:class:`EntrySearchCounts`:
Raises:
SearchNotEnabledError
InvalidSearchQueryError
SearchError
StorageError
.. versionadded:: 1.11
"""

filter_options = EntryFilterOptions.from_args(
feed, entry, read, important, has_enclosures, feed_tags
)
return self._search.search_entry_counts(query, filter_options)

def add_feed_tag(self, feed: FeedInput, tag: str) -> None:
"""Add a tag to a feed.
Expand Down
35 changes: 32 additions & 3 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from reader import Entry
from reader import EntryCounts
from reader import EntryNotFoundError
from reader import EntrySearchCounts
from reader import Feed
from reader import FeedCounts
from reader import FeedExistsError
Expand Down Expand Up @@ -2284,6 +2285,23 @@ def test_feed_counts(reader, kwargs, expected):
assert reader.get_feed_counts(**kwargs) == expected


def get_entry_counts(reader, **kwargs):
return reader.get_entry_counts(**kwargs)


def search_entry_counts(reader, **kwargs):
return reader.search_entry_counts('entry', **kwargs)


with_call_entry_counts_method = pytest.mark.parametrize(
'pre_stuff, call_method, rv_type',
[
(lambda _: None, get_entry_counts, EntryCounts),
(enable_and_update_search, search_entry_counts, EntrySearchCounts),
],
)


@pytest.mark.parametrize(
'kwargs, expected',
[
Expand Down Expand Up @@ -2320,15 +2338,22 @@ def test_feed_counts(reader, kwargs, expected):
),
],
)
def test_entry_counts(reader, kwargs, expected):
@with_call_entry_counts_method
def test_entry_counts(reader, kwargs, expected, pre_stuff, call_method, rv_type):
# TODO: fuzz get_entries() == get_entry_counts()

reader._parser = parser = Parser()

one = parser.feed(1, datetime(2010, 1, 3))
two = parser.feed(2, datetime(2010, 1, 3))
three = parser.feed(3, datetime(2010, 1, 1))
one_entries = [parser.entry(1, 1, datetime(2010, 1, 3))]
one_entry = parser.entry(
1,
1,
datetime(2010, 1, 3),
summary='summary',
content=(Content('value3', 'type', 'en'), Content('value2')),
)
two_entries = [
parser.entry(2, i, datetime(2010, 1, 3), enclosures=[]) for i in range(1, 1 + 8)
]
Expand All @@ -2343,10 +2368,14 @@ def test_entry_counts(reader, kwargs, expected):
reader.add_feed(feed)

reader.update_feeds()
pre_stuff(reader)

for entry in two_entries[:2]:
reader.mark_as_read(entry)
for entry in two_entries[:4]:
reader.mark_as_important(entry)

assert reader.get_entry_counts(**kwargs) == expected
rv = call_method(reader, **kwargs)
assert type(rv) is rv_type
# this isn't gonna work as well if the return types get different attributes
assert rv._asdict() == expected._asdict()
13 changes: 13 additions & 0 deletions tests/test_reader_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from reader import Content
from reader import Enclosure
from reader import EntrySearchCounts
from reader import EntrySearchResult
from reader import FeedNotFoundError
from reader import HighlightedString
Expand Down Expand Up @@ -444,6 +445,14 @@ def test_search_entries_fails_if_not_enabled(reader, sort):
assert excinfo.value.message


@rename_argument('reader', 'reader_without_and_with_entries')
def test_search_entry_counts_fails_if_not_enabled(reader):
with pytest.raises(SearchNotEnabledError) as excinfo:
list(reader.search_entry_counts('one'))
assert excinfo.value.__cause__ is None
assert excinfo.value.message


@with_sort
def test_search_entries_basic(reader, sort):
parser = Parser()
Expand Down Expand Up @@ -476,10 +485,12 @@ def test_search_entries_basic(reader, sort):
reader.update_search()

search = lambda *a, **kw: reader.search_entries(*a, sort=sort, **kw)
search_counts = lambda *a, **kw: reader.search_entry_counts(*a, **kw)

# TODO: the asserts below look parametrizable

assert list(search('zero')) == []
assert search_counts('zero') == EntrySearchCounts(0, 0, 0, 0)
assert list(search('one')) == [
EntrySearchResult(
feed.url,
Expand All @@ -490,6 +501,7 @@ def test_search_entries_basic(reader, sort):
},
)
]
assert search_counts('one') == EntrySearchCounts(1, 0, 0, 0)
assert list(search('two')) == [
EntrySearchResult(
feed.url,
Expand Down Expand Up @@ -562,6 +574,7 @@ def test_search_entries_basic(reader, sort):
),
]
}
assert search_counts('summary') == EntrySearchCounts(3, 0, 0, 0)


# search_entries() filtering is tested in test_reader.py::test_entries_filtering{,_error}
Expand Down
34 changes: 29 additions & 5 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def search_entries_chunk_size_1(storage, _, __):
)


def search_entry_counts(storage, _, __):
Search(storage).search_entry_counts('entry')


@pytest.mark.slow
@pytest.mark.parametrize(
'pre_stuff, do_stuff',
Expand All @@ -141,6 +145,7 @@ def search_entries_chunk_size_1(storage, _, __):
(enable_search, update_search),
(enable_search, search_entries_chunk_size_0),
(enable_search, search_entries_chunk_size_1),
(enable_search, search_entry_counts),
],
)
def test_errors_locked(db_path, pre_stuff, do_stuff):
Expand Down Expand Up @@ -204,28 +209,47 @@ def test_iter_locked(db_path, iter_stuff):
check_iter_locked(db_path, enable_and_update_search, iter_stuff)


class ActuallyOK(Exception):
pass


def call_search_entries(search, query):
try:
next(search.search_entries(query, datetime(2010, 1, 1)))
except StopIteration:
raise ActuallyOK


def call_search_entry_counts(search, query):
search.search_entry_counts(query)
raise ActuallyOK


@pytest.mark.parametrize(
'query, exc_type',
[
('\x00', InvalidSearchQueryError),
('"', InvalidSearchQueryError),
# For some reason, on CPython * works when the filtering is inside
# the CTE (it didn't when it was outside), hence the StopIteration.
# the CTE (it didn't when it was outside), hence the ActuallyOK.
# On PyPy 7.3.1 we still get a InvalidSearchQueryError.
# We're fine as long as we don't get another exception.
('*', (StopIteration, InvalidSearchQueryError)),
('*', (ActuallyOK, InvalidSearchQueryError)),
('O:', InvalidSearchQueryError),
('*p', InvalidSearchQueryError),
],
)
def test_invalid_search_query_error(storage, query, exc_type):
@pytest.mark.parametrize(
'call_method', [call_search_entries, call_search_entry_counts,],
)
def test_invalid_search_query_error(storage, query, exc_type, call_method):
# We're not testing this in test_reader_search.py because
# the invalid query strings are search-provider-dependent.
search = Search(storage)
search.enable()
with pytest.raises(exc_type) as excinfo:
next(search.search_entries(query, datetime(2010, 1, 1)))
if isinstance(exc_type, tuple) and StopIteration in exc_type:
call_method(search, query)
if isinstance(exc_type, tuple) and ActuallyOK in exc_type:
return
assert excinfo.value.message
assert excinfo.value.__cause__ is None
Expand Down

0 comments on commit df39dde

Please sign in to comment.