Skip to content

Commit

Permalink
Elastic search search_after pagination (#4473)
Browse files Browse the repository at this point in the history
  • Loading branch information
hamza-56 authored Dec 26, 2024
1 parent 1dde756 commit c3261d4
Show file tree
Hide file tree
Showing 20 changed files with 458 additions and 0 deletions.
1 change: 1 addition & 0 deletions course_discovery/apps/api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@

urlpatterns = [
path('v1/', include('course_discovery.apps.api.v1.urls')),
path('v2/', include('course_discovery.apps.api.v2.urls')),
]
Empty file.
133 changes: 133 additions & 0 deletions course_discovery/apps/api/v2/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
""" Serializers for api/v2/search/all """

from course_discovery.apps.course_metadata.search_indexes import documents
from course_discovery.apps.course_metadata.search_indexes.constants import SEARCH_INDEX_ADDITIONAL_FIELDS_V2
from course_discovery.apps.course_metadata.search_indexes.documents import (
CourseDocument, CourseRunDocument, LearnerPathwayDocument, PersonDocument, ProgramDocument
)
from course_discovery.apps.course_metadata.search_indexes.serializers.aggregation import (
AggregateSearchListSerializer, AggregateSearchSerializer
)
from course_discovery.apps.course_metadata.search_indexes.serializers.common import SortFieldMixin
from course_discovery.apps.course_metadata.search_indexes.serializers.course import CourseSearchDocumentSerializer
from course_discovery.apps.course_metadata.search_indexes.serializers.course_run import (
CourseRunSearchDocumentSerializer
)
from course_discovery.apps.course_metadata.search_indexes.serializers.learner_pathway import (
LearnerPathwaySearchDocumentSerializer
)
from course_discovery.apps.course_metadata.search_indexes.serializers.person import PersonSearchDocumentSerializer
from course_discovery.apps.course_metadata.search_indexes.serializers.program import ProgramSearchDocumentSerializer
from course_discovery.apps.edx_elasticsearch_dsl_extensions.serializers import DummyDocument


class CourseRunSearchDocumentSerializerV2(SortFieldMixin, CourseRunSearchDocumentSerializer):
"""
Serializer for Course Run documents, extending the base `CourseRunSearchDocumentSerializer`
to include additional fields for enhanced search functionality, as well as a `sort` field
to provide sorting information from the Elasticsearch response.
This serializer expands the `fields` attribute in the `Meta` class to include additional
fields specified in `SEARCH_INDEX_ADDITIONAL_FIELDS_V2`.
"""

class Meta(CourseRunSearchDocumentSerializer.Meta):
document = CourseRunDocument
fields = CourseRunSearchDocumentSerializer.Meta.fields + SEARCH_INDEX_ADDITIONAL_FIELDS_V2


class CourseSearchDocumentSerializerV2(SortFieldMixin, CourseSearchDocumentSerializer):
"""
Serializer for Course documents, extending the base `CourseSearchDocumentSerializer`
to include additional fields for enhanced search functionality, as well as a `sort` field
to provide sorting information from the Elasticsearch response.
This serializer expands the `fields` attribute in the `Meta` class to include additional
fields specified in `SEARCH_INDEX_ADDITIONAL_FIELDS_V2`.
"""

class Meta(CourseSearchDocumentSerializer.Meta):
document = CourseDocument
fields = CourseSearchDocumentSerializer.Meta.fields + SEARCH_INDEX_ADDITIONAL_FIELDS_V2


class ProgramSearchDocumentSerializerV2(SortFieldMixin, ProgramSearchDocumentSerializer):
"""
Serializer for Program documents, extending the base `ProgramSearchDocumentSerializer`
to include additional fields for enhanced search functionality, as well as a `sort` field
to provide sorting information from the Elasticsearch response.
This serializer expands the `fields` attribute in the `Meta` class to include additional
fields specified in `SEARCH_INDEX_ADDITIONAL_FIELDS_V2`.
"""

class Meta(ProgramSearchDocumentSerializer.Meta):
document = ProgramDocument
fields = ProgramSearchDocumentSerializer.Meta.fields + SEARCH_INDEX_ADDITIONAL_FIELDS_V2


class LearnerPathwaySearchDocumentSerializerV2(SortFieldMixin, LearnerPathwaySearchDocumentSerializer):
"""
Serializer for Learner Pathway documents, extending the base `LearnerPathwaySearchDocumentSerializer`
to include additional fields for enhanced search functionality, as well as a `sort` field
to provide sorting information from the Elasticsearch response.
This serializer expands the `fields` attribute in the `Meta` class to include additional
fields specified in `SEARCH_INDEX_ADDITIONAL_FIELDS_V2`.
"""

class Meta(LearnerPathwaySearchDocumentSerializer.Meta):
document = LearnerPathwayDocument
fields = LearnerPathwaySearchDocumentSerializer.Meta.fields + SEARCH_INDEX_ADDITIONAL_FIELDS_V2


class PersonSearchDocumentSerializerV2(SortFieldMixin, PersonSearchDocumentSerializer):
"""
Serializer for Person documents, extending the base `PersonSearchDocumentSerializer`
to include additional fields for enhanced search functionality, as well as a `sort` field
to provide sorting information from the Elasticsearch response.
This serializer expands the `fields` attribute in the `Meta` class to include additional
fields specified in `SEARCH_INDEX_ADDITIONAL_FIELDS_V2`.
"""

class Meta(PersonSearchDocumentSerializer.Meta):
document = PersonDocument
fields = PersonSearchDocumentSerializer.Meta.fields + SEARCH_INDEX_ADDITIONAL_FIELDS_V2


# pylint: disable=abstract-method
class AggregateSearchListSerializerV2(AggregateSearchListSerializer):
"""
Extended version of the AggregateSearchListSerializer with updated serializers that support search_after pagination.
This subclass allows for the use of newer serializer versions for the same document types,
which include additional search index fields specifically for version 2.
"""

class Meta(AggregateSearchListSerializer.Meta):
"""
Meta options.
"""

serializers = {
documents.CourseRunDocument: CourseRunSearchDocumentSerializerV2,
documents.CourseDocument: CourseSearchDocumentSerializerV2,
documents.ProgramDocument: ProgramSearchDocumentSerializerV2,
documents.LearnerPathwayDocument: LearnerPathwaySearchDocumentSerializerV2,
documents.PersonDocument: PersonSearchDocumentSerializerV2,
}


class AggregateSearchSerializerV2(AggregateSearchSerializer):
"""
Serializer for aggregated elasticsearch documents.
"""

class Meta(AggregateSearchSerializer.Meta):
"""
Meta options.
"""

list_serializer_class = AggregateSearchListSerializerV2
document = DummyDocument
Empty file.
Empty file.
125 changes: 125 additions & 0 deletions course_discovery/apps/api/v2/tests/test_views/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
""" Test cases for api/v2/search/all """

import json

import ddt
from django.urls import reverse

from course_discovery.apps.api.v1.tests.test_views import mixins
from course_discovery.apps.core.tests.mixins import ElasticsearchTestMixin
from course_discovery.apps.course_metadata.choices import CourseRunStatus
from course_discovery.apps.course_metadata.tests.factories import CourseFactory, CourseRunFactory, PersonFactory
from course_discovery.apps.learner_pathway.models import LearnerPathway
from course_discovery.apps.learner_pathway.tests.factories import LearnerPathwayStepFactory


@ddt.ddt
class AggregateSearchViewSetV2Tests(mixins.LoginMixin, ElasticsearchTestMixin, mixins.APITestCase):
list_path = reverse("api:v2:search-all-list")

def fetch_page_data(self, page_size, search_after=None):
query_params = {"page_size": page_size}
if search_after:
query_params["search_after"] = search_after
response = self.client.get(self.list_path, data=query_params)
assert response.status_code == 200
return response.json()

def validate_page_data(self, page_data, expected_size):
assert all("sort" in obj for obj in page_data["results"]), "Not all objects have a 'sort' field"
assert all(
"aggregation_uuid" in obj for obj in page_data["results"]
), "Not all objects have an 'aggregation_uuid' field"
assert (
len(page_data["results"]) == expected_size
), f"Page does not have the expected number of results ({expected_size})"

def test_results_include_aggregation_uuid_and_sort_fields(self):
"""
Test that search results include 'aggregation_uuid' and 'sort' fields
and that the total result count matches the expected value.
"""
PersonFactory.create_batch(5, partner=self.partner)
courses = CourseFactory.create_batch(5, partner=self.partner)

for course in courses:
CourseRunFactory(
course__partner=self.partner,
course=course,
type__is_marketable=True,
status=CourseRunStatus.Published,
)
response = self.client.get(self.list_path)
response_data = response.json()
assert response.status_code == 200
assert response_data["count"] == 15
self.validate_page_data(response_data, 15)

@ddt.data((True, 10), (False, 0))
@ddt.unpack
def test_learner_pathway_feature_flag(self, include_learner_pathways, expected_result_count):
"""
Test the inclusion of learner pathways in search results based on a feature flag.
"""
LearnerPathwayStepFactory.create_batch(10, pathway__partner=self.partner)
pathways = LearnerPathway.objects.all()
assert pathways.count() == 10
query = {
"include_learner_pathways": include_learner_pathways,
}

response = self.client.get(self.list_path, data=query)
assert response.status_code == 200
response_data = response.json()

assert response_data["count"] == expected_result_count

def test_search_after_pagination(self):
"""
Test paginated fetching of search results using 'search_after' param.
"""
PersonFactory.create_batch(25, partner=self.partner)
courses = CourseFactory.create_batch(25, partner=self.partner)

for course in courses:
CourseRunFactory(
course__partner=self.partner,
course=course,
type__is_marketable=True,
status=CourseRunStatus.Published,
)

page_size = 10
response_data = self.fetch_page_data(page_size)

assert response_data["count"] == 75 # Total objects: 25 Persons + 25 Courses + 25 CourseRuns
self.validate_page_data(response_data, page_size)

all_results = response_data["results"]
next_token = response_data.get("next")

while next_token:
response_data = self.fetch_page_data(page_size, search_after=json.dumps(next_token))

expected_size = min(page_size, 75 - len(all_results))
self.validate_page_data(response_data, expected_size)

all_results.extend(response_data["results"])
next_token = response_data.get("next")

if next_token:
last_sort_value = response_data["results"][-1]["sort"]
assert last_sort_value == next_token

assert len(all_results) == 75, "The total number of results does not match the expected count"

single_page_response = self.client.get(self.list_path, data={"page_size": 75})
assert single_page_response.status_code == 200
single_page_data = single_page_response.json()

assert (
len(single_page_data["results"]) == 75
), "The total number of results in the single request does not match the expected count"
assert (
single_page_data["results"] == all_results
), "Combined pagination results do not match single request results"
11 changes: 11 additions & 0 deletions course_discovery/apps/api/v2/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""API v2 URLs."""

from rest_framework import routers

from course_discovery.apps.api.v2.views import search as search_views

app_name = 'v2'

router = routers.SimpleRouter()
router.register(r'search/all', search_views.AggregateSearchViewSet, basename='search-all')
urlpatterns = router.urls
Empty file.
25 changes: 25 additions & 0 deletions course_discovery/apps/api/v2/views/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""API v2 search module for efficient Elasticsearch document searching with search_after pagination."""

from course_discovery.apps.api.v1.views.search import AggregateSearchViewSet as AggregateSearchViewSetV1
from course_discovery.apps.api.v2.serializers import AggregateSearchSerializerV2
from course_discovery.apps.edx_elasticsearch_dsl_extensions.search import SearchAfterSearch
from course_discovery.apps.edx_elasticsearch_dsl_extensions.viewsets import SearchAfterPagination


class AggregateSearchViewSet(AggregateSearchViewSetV1):
"""
Viewset for searching Elasticsearch documents using search_after pagination.
This viewset extends the functionality of the original AggregateSearchViewSet
by implementing search_after pagination, which allows for efficient pagination
through large datasets in Elasticsearch.
"""

serializer_class = AggregateSearchSerializerV2
pagination_class = SearchAfterPagination
ordering_fields = {"start": "start", "aggregation_uuid": "aggregation_uuid"}
ordering = ("-start", "aggregation_uuid")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.search = SearchAfterSearch(using=self.client, index=self.index, doc_type=self.document._doc_type.name)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
BASE_SEARCH_INDEX_FIELDS = ('aggregation_key', 'content_type', 'text')
SEARCH_INDEX_ADDITIONAL_FIELDS_V2 = ('aggregation_uuid', 'sort')

BASE_PROGRAM_FIELDS = (
'card_image_url',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(self, *args, **kwargs):
self._object = None

aggregation_key = fields.KeywordField()
aggregation_uuid = fields.KeywordField()
partner = fields.TextField(
analyzer=html_strip,
fields={'raw': fields.KeywordField(), 'lower': fields.TextField(analyzer=case_insensitive_keyword)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class CourseDocument(BaseCourseDocument):
def prepare_aggregation_key(self, obj):
return 'course:{}'.format(obj.key)

def prepare_aggregation_uuid(self, obj):
return 'course:{}'.format(obj.uuid)

def prepare_availability(self, obj):
return [str(course_run.availability) for course_run in filter_visible_runs(obj.course_runs)]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def prepare_aggregation_key(self, obj):
# Aggregate CourseRuns by Course key since that is how we plan to dedup CourseRuns on the marketing site.
return 'courserun:{}'.format(obj.course.key)

def prepare_aggregation_uuid(self, obj):
return 'courserun:{}'.format(obj.uuid)

def prepare_course_key(self, obj):
return obj.course.key

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class LearnerPathwayDocument(BaseDocument, OrganizationsMixin):
def prepare_aggregation_key(self, obj):
return 'learnerpathway:{}'.format(obj.uuid)

def prepare_aggregation_uuid(self, obj):
return 'learnerpathway:{}'.format(obj.uuid)

def prepare_published(self, obj):
return obj.status == PathwayStatus.Active

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class PersonDocument(BaseDocument):
def prepare_aggregation_key(self, obj):
return 'person:{}'.format(obj.uuid)

def prepare_aggregation_uuid(self, obj):
return 'person:{}'.format(obj.uuid)

def prepare_bio_language(self, obj):
if obj.bio_language:
return obj.bio_language.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class ProgramDocument(BaseDocument, OrganizationsMixin):
def prepare_aggregation_key(self, obj):
return 'program:{}'.format(obj.uuid)

def prepare_aggregation_uuid(self, obj):
return 'program:{}'.format(obj.uuid)

def prepare_credit_backing_organizations(self, obj):
return self._prepare_organizations(obj.credit_backing_organizations.all())

Expand Down
Loading

0 comments on commit c3261d4

Please sign in to comment.