Skip to content

Commit

Permalink
Improve type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
oschwald committed Jan 28, 2025
1 parent d0da163 commit eda1f07
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
4 changes: 2 additions & 2 deletions geoip2/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class Model(metaclass=ABCMeta):
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.to_dict() == other.to_dict()

def __ne__(self, other):
def __ne__(self, other) -> bool:
return not self.__eq__(other)

# pylint: disable=too-many-branches
def to_dict(self):
def to_dict(self) -> dict:
"""Returns a dict of the object suitable for serialization."""
result = {}
for key, value in self.__dict__.items():
Expand Down
8 changes: 3 additions & 5 deletions geoip2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from abc import ABCMeta
from collections.abc import Sequence
from typing import Optional, Union
from ipaddress import IPv4Address, IPv6Address

import geoip2.records
from geoip2._internal import Model
Expand Down Expand Up @@ -395,12 +396,9 @@ def __repr__(self) -> str:
)

@property
def ip_address(self):
def ip_address(self) -> Union[IPv4Address, IPv6Address]:
"""The IP address for the record."""
if not isinstance(
self._ip_address,
(ipaddress.IPv4Address, ipaddress.IPv6Address),
):
if not isinstance(self._ip_address, (IPv4Address, IPv6Address)):
self._ip_address = ipaddress.ip_address(self._ip_address)
return self._ip_address

Expand Down
13 changes: 7 additions & 6 deletions geoip2/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# pylint:disable=R0903
from abc import ABCMeta
from collections.abc import Sequence
from ipaddress import IPv4Address, IPv6Address
from typing import Optional, Union

from geoip2._internal import Model
from geoip2.types import IPAddress


class Record(Model, metaclass=ABCMeta):
Expand Down Expand Up @@ -841,7 +843,7 @@ class Traits(Record):
autonomous_system_organization: Optional[str]
connection_type: Optional[str]
domain: Optional[str]
_ip_address: Optional[str]
_ip_address: IPAddress
is_anonymous: bool
is_anonymous_proxy: bool
is_anonymous_vpn: bool
Expand Down Expand Up @@ -912,6 +914,8 @@ def __init__(
self.static_ip_score = static_ip_score
self.user_type = user_type
self.user_count = user_count
if ip_address is None:
raise TypeError("ip_address must be defined")
self._ip_address = ip_address
if network is None:
self._network = None
Expand All @@ -923,12 +927,9 @@ def __init__(
self._prefix_len = prefix_len

@property
def ip_address(self):
def ip_address(self) -> Union[IPv4Address, IPv6Address]:
"""The IP address for the record."""
if not isinstance(
self._ip_address,
(ipaddress.IPv4Address, ipaddress.IPv6Address),
):
if not isinstance(self._ip_address, (IPv4Address, IPv6Address)):
self._ip_address = ipaddress.ip_address(self._ip_address)
return self._ip_address

Expand Down

0 comments on commit eda1f07

Please sign in to comment.