Skip to content

Commit

Permalink
more minor adjustments to response types
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Oct 22, 2024
1 parent 099a52f commit 8c38d16
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 15 deletions.
9 changes: 8 additions & 1 deletion elasticsearch_dsl/response/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@
from ..search_base import Request, SearchBase
from ..update_by_query_base import UpdateByQueryBase

__all__ = ["Response", "AggResponse", "UpdateByQueryResponse", "Hit", "HitMeta"]
__all__ = [
"Response",
"AggResponse",
"UpdateByQueryResponse",
"Hit",
"HitMeta",
"AggregateResponseType",
]


class Response(AttrDict[Any], Generic[_R]):
Expand Down
4 changes: 2 additions & 2 deletions elasticsearch_dsl/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from elastic_transport.client_utils import DEFAULT, DefaultType

from elasticsearch_dsl import Query, function
from elasticsearch_dsl import Query, function, index_base
from elasticsearch_dsl.document_base import InstrumentedField
from elasticsearch_dsl.utils import AttrDict

Expand Down Expand Up @@ -5100,7 +5100,7 @@ class Hit(AttrDict[Any]):
:arg sort:
"""

index: str
index: index_base.IndexBase
id: str
score: Union[float, None]
explanation: "Explanation"
Expand Down
7 changes: 4 additions & 3 deletions examples/async/composite_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

import asyncio
import os
from typing import Any, AsyncIterator, Dict, Mapping, Sequence
from typing import Any, AsyncIterator, Dict, Mapping, Sequence, cast

from elasticsearch.helpers import async_bulk

from elasticsearch_dsl import Agg, AsyncSearch, Response, aggs, async_connections
from elasticsearch_dsl.types import CompositeAggregate
from tests.test_integration.test_data import DATA, GIT_INDEX


Expand All @@ -30,7 +31,7 @@ async def scan_aggs(
source_aggs: Sequence[Mapping[str, Agg]],
inner_aggs: Dict[str, Agg] = {},
size: int = 10,
) -> AsyncIterator[Any]:
) -> AsyncIterator[CompositeAggregate]:
"""
Helper function used to iterate over all possible bucket combinations of
``source_aggs``, returning results of ``inner_aggs`` for each. Uses the
Expand All @@ -54,7 +55,7 @@ async def run_search(**kwargs: Any) -> Response:
response = await run_search()
while response.aggregations["comp"].buckets:
for b in response.aggregations["comp"].buckets:
yield b
yield cast(CompositeAggregate, b)
if "after_key" in response.aggregations["comp"]:
after = response.aggregations["comp"].after_key
else:
Expand Down
4 changes: 2 additions & 2 deletions examples/async/parent_child.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ async def add_answer(
# required make sure the answer is stored in the same shard
_routing=self.meta.id,
# since we don't have explicit index, ensure same index as self
_index=self.meta.index,
_index=cast(AsyncIndex, self.meta.index),
# set up the parent/child mapping
question_answer={"name": "answer", "parent": self.meta.id},
# pass in the field values
Expand Down Expand Up @@ -218,7 +218,7 @@ async def get_question(self) -> Optional[Question]:
# any attributes set on self would be interpreted as fields
if "question" not in self.meta:
self.meta.question = await Question.get(
id=self.question_answer.parent, index=self.meta.index
id=self.question_answer.parent, index=self.meta.index._name
)
return cast(Optional[Question], self.meta.question)

Expand Down
7 changes: 4 additions & 3 deletions examples/composite_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
# under the License.

import os
from typing import Any, Dict, Iterator, Mapping, Sequence
from typing import Any, Dict, Iterator, Mapping, Sequence, cast

from elasticsearch.helpers import bulk

from elasticsearch_dsl import Agg, Response, Search, aggs, connections
from elasticsearch_dsl.types import CompositeAggregate
from tests.test_integration.test_data import DATA, GIT_INDEX


Expand All @@ -29,7 +30,7 @@ def scan_aggs(
source_aggs: Sequence[Mapping[str, Agg]],
inner_aggs: Dict[str, Agg] = {},
size: int = 10,
) -> Iterator[Any]:
) -> Iterator[CompositeAggregate]:
"""
Helper function used to iterate over all possible bucket combinations of
``source_aggs``, returning results of ``inner_aggs`` for each. Uses the
Expand All @@ -53,7 +54,7 @@ def run_search(**kwargs: Any) -> Response:
response = run_search()
while response.aggregations["comp"].buckets:
for b in response.aggregations["comp"].buckets:
yield b
yield cast(CompositeAggregate, b)
if "after_key" in response.aggregations["comp"]:
after = response.aggregations["comp"].after_key
else:
Expand Down
4 changes: 2 additions & 2 deletions examples/parent_child.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def add_answer(
# required make sure the answer is stored in the same shard
_routing=self.meta.id,
# since we don't have explicit index, ensure same index as self
_index=self.meta.index,
_index=cast(Index, self.meta.index),
# set up the parent/child mapping
question_answer={"name": "answer", "parent": self.meta.id},
# pass in the field values
Expand Down Expand Up @@ -217,7 +217,7 @@ def get_question(self) -> Optional[Question]:
# any attributes set on self would be interpreted as fields
if "question" not in self.meta:
self.meta.question = Question.get(
id=self.question_answer.parent, index=self.meta.index
id=self.question_answer.parent, index=self.meta.index._name
)
return cast(Optional[Question], self.meta.question)

Expand Down
7 changes: 7 additions & 0 deletions utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,13 @@ def interface_to_python_class(
k, arg, for_types_py=for_types_py, for_response=for_response
)

if interface == "Hit" and arg["name"] == "index":
# Python DSL replaces the string typed index attribute
# with an Index or AsyncIndex instance. Here we use
# IndexBase, which is a base class for both Index and
# AsyncIndex.
k["args"][-1]["type"] = "index_base.IndexBase"

if "inherits" not in type_ or "type" not in type_["inherits"]:
break

Expand Down
2 changes: 1 addition & 1 deletion utils/templates/response.__init__.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ if TYPE_CHECKING:
from ..update_by_query_base import UpdateByQueryBase
from .. import types

__all__ = ["Response", "AggResponse", "UpdateByQueryResponse", "Hit", "HitMeta"]
__all__ = ["Response", "AggResponse", "UpdateByQueryResponse", "Hit", "HitMeta", "AggregateResponseType"]


class Response(AttrDict[Any], Generic[_R]):
Expand Down
2 changes: 1 addition & 1 deletion utils/templates/types.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ from typing import Any, Dict, Literal, Mapping, Sequence, Union
from elastic_transport.client_utils import DEFAULT, DefaultType

from elasticsearch_dsl.document_base import InstrumentedField
from elasticsearch_dsl import function, Query
from elasticsearch_dsl import function, index_base, Query
from elasticsearch_dsl.utils import AttrDict

PipeSeparatedFlags = str
Expand Down

0 comments on commit 8c38d16

Please sign in to comment.