Skip to content

Commit

Permalink
params: validate param option against all available options
Browse files Browse the repository at this point in the history
  • Loading branch information
zzacharo committed Oct 19, 2023
1 parent c2de64f commit 1ebe3d1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
21 changes: 20 additions & 1 deletion invenio_records_resources/services/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,14 @@ def customize(cls, opts):
attrs = {}
if opts.facets:
attrs["facets"] = opts.facets
# these are the selected sort options
if opts.sort_options:
attrs["sort_options"] = opts.sort_options
attrs["sort_default"] = opts.sort_default
attrs["sort_default_no_query"] = opts.sort_default_no_query
# store all available sort options
if opts.available_sort_options:
attrs["available_sort_options"] = opts.available_sort_options
if opts.query_parser_cls:
attrs["query_parser_cls"] = opts.query_parser_cls
return _make_cls(cls, attrs) if attrs else cls
Expand Down Expand Up @@ -128,20 +132,30 @@ def __init__(self, available_options, selected_options):
# Ensure all selected options are availabe.
for o in selected_options:
assert o in available_options, f"Selected option '{o}' is undefined."
self.iterate_all_options = False

self.available_options = available_options
self.selected_options = selected_options

def __iter__(self):
"""Iterate over options to produce RSK options."""
for o in self.selected_options:
for o in (
self.selected_options
if not self.iterate_all_options
else self.available_options
):
yield self.map_option(o, self.available_options[o])

def map_option(self, key, option):
"""Map an option."""
# This interface is used in Invenio-App-RDM.
return (key, option)

def __call__(self):
"""Control if we iterate through all or selected sort options"""
self.iterate_all_options = True
return self


class SortOptionsSelector(OptionsSelector):
"""Sort options for the search configuration."""
Expand Down Expand Up @@ -194,6 +208,11 @@ def sort_options(self):
"""Get sort options for search."""
return {k: v for (k, v) in self._sort}

@property
def available_sort_options(self):
"""Get sort options for search."""
return {k: v for (k, v) in self._sort()}

@property
def sort_default(self):
"""Get default sort method for search."""
Expand Down
1 change: 1 addition & 0 deletions invenio_records_resources/services/records/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class SearchOptions:
fields=["created"],
),
}
available_sort_options = sort_options
facets = {}
pagination_options = {"default_results_per_page": 25, "default_max_results": 10000}
params_interpreters_cls = [QueryStrParam, PaginationParam, SortParam, FacetsParam]
Expand Down
2 changes: 1 addition & 1 deletion invenio_records_resources/services/records/params/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def apply(self, identity, search, params):

def _compute_sort_fields(self, params):
"""Compute sort fields."""
options = deepcopy(self.config.sort_options)
options = deepcopy(self.config.available_sort_options)
if "sort" not in params:
params["sort"] = self._default_sort(params, options)

Expand Down

0 comments on commit 1ebe3d1

Please sign in to comment.