Skip to content

Commit

Permalink
Convert model classes to use keyword arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
oschwald committed Jan 23, 2025
1 parent c84b047 commit 4518919
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 108 deletions.
4 changes: 2 additions & 2 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ History
* BREAKING: The ``raw`` attribute on the model classes has been replaced
with a ``to_dict()`` method. This can be used to get a representation of
the object that is suitable for serialization.
* BREAKING: The record classes now require all arguments other than ``locales``
to be keyword arguments.
* BREAKING: The model and record classes now require all arguments other than
``locales`` to be keyword arguments.
* BREAKING: ``geoip2.mixins`` has been made internal. This normally would not
have been used by external code.
* IMPORTANT: Python 3.9 or greater is required. If you are using an older
Expand Down
11 changes: 4 additions & 7 deletions geoip2/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,9 @@ def _model_for(
ip_address: IPAddress,
) -> Union[Country, Enterprise, City]:
(record, prefix_len) = self._get(types, ip_address)
traits = record.setdefault("traits", {})
traits["ip_address"] = ip_address
traits["prefix_len"] = prefix_len
return model_class(record, locales=self._locales)
return model_class(
self._locales, ip_address=ip_address, prefix_len=prefix_len, **record
)

def _flat_model_for(
self,
Expand All @@ -266,9 +265,7 @@ def _flat_model_for(
ip_address: IPAddress,
) -> Union[ConnectionType, ISP, AnonymousIP, Domain, ASN]:
(record, prefix_len) = self._get(types, ip_address)
record["ip_address"] = ip_address
record["prefix_len"] = prefix_len
return model_class(record)
return model_class(ip_address=ip_address, prefix_len=prefix_len, **record)

def metadata(
self,
Expand Down
211 changes: 155 additions & 56 deletions geoip2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
"""

# pylint: disable=too-many-instance-attributes,too-few-public-methods
# pylint: disable=too-many-instance-attributes,too-few-public-methods,too-many-arguments
import ipaddress
from abc import ABCMeta
from typing import Any, cast, Dict, Optional, Sequence, Union
from typing import Dict, List, Optional, Sequence, Union

import geoip2.records
from geoip2._internal import Model
from geoip2.types import IPAddress


class Country(Model):
Expand Down Expand Up @@ -76,30 +77,44 @@ class Country(Model):
traits: geoip2.records.Traits

def __init__(
self, raw_response: Dict[str, Any], locales: Optional[Sequence[str]] = None
self,
locales: Optional[Sequence[str]],
*,
continent: Optional[Dict] = None,
country: Optional[Dict] = None,
ip_address: Optional[IPAddress] = None,
maxmind: Optional[Dict] = None,
prefix_len: Optional[int] = None,
registered_country: Optional[Dict] = None,
represented_country: Optional[Dict] = None,
traits: Optional[Dict] = None,
**_,
) -> None:
if locales is None:
locales = ["en"]
self._locales = locales
self.continent = geoip2.records.Continent(
locales, **raw_response.get("continent", {})
)
self.country = geoip2.records.Country(
locales, **raw_response.get("country", {})
)
self.continent = geoip2.records.Continent(locales, **(continent or {}))
self.country = geoip2.records.Country(locales, **(country or {}))
self.registered_country = geoip2.records.Country(
locales, **raw_response.get("registered_country", {})
locales, **(registered_country or {})
)
self.represented_country = geoip2.records.RepresentedCountry(
locales, **raw_response.get("represented_country", {})
locales, **(represented_country or {})
)

self.maxmind = geoip2.records.MaxMind(**raw_response.get("maxmind", {}))
self.maxmind = geoip2.records.MaxMind(**(maxmind or {}))

traits = traits or {}
if ip_address is not None:
traits["ip_address"] = ip_address
if prefix_len is not None:
traits["prefix_len"] = prefix_len

self.traits = geoip2.records.Traits(**raw_response.get("traits", {}))
self.traits = geoip2.records.Traits(**traits)

def __repr__(self) -> str:
return f"{self.__module__}.{self.__class__.__name__}({self.to_dict()}, {self._locales})"
return (
f"{self.__module__}.{self.__class__.__name__}({self._locales}, "
f"{', '.join(f'{k}={repr(v)}' for k, v in self.to_dict().items())})"
)


class City(Country):
Expand Down Expand Up @@ -179,15 +194,38 @@ class City(Country):
subdivisions: geoip2.records.Subdivisions

def __init__(
self, raw_response: Dict[str, Any], locales: Optional[Sequence[str]] = None
self,
locales: Optional[Sequence[str]],
*,
city: Optional[Dict] = None,
continent: Optional[Dict] = None,
country: Optional[Dict] = None,
location: Optional[Dict] = None,
ip_address: Optional[IPAddress] = None,
maxmind: Optional[Dict] = None,
postal: Optional[Dict] = None,
prefix_len: Optional[int] = None,
registered_country: Optional[Dict] = None,
represented_country: Optional[Dict] = None,
subdivisions: Optional[List[Dict]] = None,
traits: Optional[Dict] = None,
**_,
) -> None:
super().__init__(raw_response, locales)
self.city = geoip2.records.City(locales, **raw_response.get("city", {}))
self.location = geoip2.records.Location(**raw_response.get("location", {}))
self.postal = geoip2.records.Postal(**raw_response.get("postal", {}))
self.subdivisions = geoip2.records.Subdivisions(
locales, *raw_response.get("subdivisions", [])
super().__init__(
locales,
continent=continent,
country=country,
ip_address=ip_address,
maxmind=maxmind,
prefix_len=prefix_len,
registered_country=registered_country,
represented_country=represented_country,
traits=traits,
)
self.city = geoip2.records.City(locales, **(city or {}))
self.location = geoip2.records.Location(**(location or {}))
self.postal = geoip2.records.Postal(**(postal or {}))
self.subdivisions = geoip2.records.Subdivisions(locales, *(subdivisions or []))


class Insights(City):
Expand Down Expand Up @@ -325,20 +363,28 @@ class SimpleModel(Model, metaclass=ABCMeta):
_network: Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]
_prefix_len: int

def __init__(self, raw: Dict[str, Union[bool, str, int]]) -> None:
if network := raw.get("network"):
def __init__(
self,
ip_address: Optional[str],
network: Optional[str],
prefix_len: Optional[int],
) -> None:
if network:
self._network = ipaddress.ip_network(network, False)
self._prefix_len = self._network.prefixlen
else:
# This case is for MMDB lookups where performance is paramount.
# This is why we don't generate the network unless .network is
# used.
self._network = None
self._prefix_len = cast(int, raw.get("prefix_len"))
self.ip_address = cast(str, raw.get("ip_address"))
self._prefix_len = prefix_len
self.ip_address = ip_address

def __repr__(self) -> str:
return f"{self.__module__}.{self.__class__.__name__}({self.to_dict()})"
return (
f"{self.__module__}.{self.__class__.__name__}"
f"({', '.join(f'{k}={repr(v)}' for k, v in self.to_dict().items())})"
)

@property
def network(self) -> Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]:
Expand Down Expand Up @@ -427,14 +473,27 @@ class AnonymousIP(SimpleModel):
is_residential_proxy: bool
is_tor_exit_node: bool

def __init__(self, raw: Dict[str, bool]) -> None:
super().__init__(raw) # type: ignore
self.is_anonymous = raw.get("is_anonymous", False)
self.is_anonymous_vpn = raw.get("is_anonymous_vpn", False)
self.is_hosting_provider = raw.get("is_hosting_provider", False)
self.is_public_proxy = raw.get("is_public_proxy", False)
self.is_residential_proxy = raw.get("is_residential_proxy", False)
self.is_tor_exit_node = raw.get("is_tor_exit_node", False)
def __init__(
self,
*,
is_anonymous: bool = False,
is_anonymous_vpn: bool = False,
is_hosting_provider: bool = False,
is_public_proxy: bool = False,
is_residential_proxy: bool = False,
is_tor_exit_node: bool = False,
ip_address: Optional[str] = None,
network: Optional[str] = None,
prefix_len: Optional[int] = None,
**_,
) -> None:
super().__init__(ip_address, network, prefix_len)
self.is_anonymous = is_anonymous
self.is_anonymous_vpn = is_anonymous_vpn
self.is_hosting_provider = is_hosting_provider
self.is_public_proxy = is_public_proxy
self.is_residential_proxy = is_residential_proxy
self.is_tor_exit_node = is_tor_exit_node


class ASN(SimpleModel):
Expand Down Expand Up @@ -474,14 +533,19 @@ class ASN(SimpleModel):
autonomous_system_organization: Optional[str]

# pylint:disable=too-many-arguments,too-many-positional-arguments
def __init__(self, raw: Dict[str, Union[str, int]]) -> None:
super().__init__(raw)
self.autonomous_system_number = cast(
Optional[int], raw.get("autonomous_system_number")
)
self.autonomous_system_organization = cast(
Optional[str], raw.get("autonomous_system_organization")
)
def __init__(
self,
*,
autonomous_system_number: Optional[int] = None,
autonomous_system_organization: Optional[str] = None,
ip_address: Optional[str] = None,
network: Optional[str] = None,
prefix_len: Optional[int] = None,
**_,
) -> None:
super().__init__(ip_address, network, prefix_len)
self.autonomous_system_number = autonomous_system_number
self.autonomous_system_organization = autonomous_system_organization


class ConnectionType(SimpleModel):
Expand Down Expand Up @@ -520,9 +584,17 @@ class ConnectionType(SimpleModel):

connection_type: Optional[str]

def __init__(self, raw: Dict[str, Union[str, int]]) -> None:
super().__init__(raw)
self.connection_type = cast(Optional[str], raw.get("connection_type"))
def __init__(
self,
*,
connection_type: Optional[str] = None,
ip_address: Optional[str] = None,
network: Optional[str] = None,
prefix_len: Optional[int] = None,
**_,
) -> None:
super().__init__(ip_address, network, prefix_len)
self.connection_type = connection_type


class Domain(SimpleModel):
Expand Down Expand Up @@ -554,9 +626,17 @@ class Domain(SimpleModel):

domain: Optional[str]

def __init__(self, raw: Dict[str, Union[str, int]]) -> None:
super().__init__(raw)
self.domain = cast(Optional[str], raw.get("domain"))
def __init__(
self,
*,
domain: Optional[str] = None,
ip_address: Optional[str] = None,
network: Optional[str] = None,
prefix_len: Optional[int] = None,
**_,
) -> None:
super().__init__(ip_address, network, prefix_len)
self.domain = domain


class ISP(ASN):
Expand Down Expand Up @@ -626,9 +706,28 @@ class ISP(ASN):
organization: Optional[str]

# pylint:disable=too-many-arguments,too-many-positional-arguments
def __init__(self, raw: Dict[str, Union[str, int]]) -> None:
super().__init__(raw)
self.isp = cast(Optional[str], raw.get("isp"))
self.mobile_country_code = cast(Optional[str], raw.get("mobile_country_code"))
self.mobile_network_code = cast(Optional[str], raw.get("mobile_network_code"))
self.organization = cast(Optional[str], raw.get("organization"))
def __init__(
self,
*,
autonomous_system_number: Optional[int] = None,
autonomous_system_organization: Optional[str] = None,
isp: Optional[str] = None,
mobile_country_code: Optional[str] = None,
mobile_network_code: Optional[str] = None,
organization: Optional[str] = None,
ip_address: Optional[str] = None,
network: Optional[str] = None,
prefix_len: Optional[int] = None,
**_,
) -> None:
super().__init__(
autonomous_system_number=autonomous_system_number,
autonomous_system_organization=autonomous_system_organization,
ip_address=ip_address,
network=network,
prefix_len=prefix_len,
)
self.isp = isp
self.mobile_country_code = mobile_country_code
self.mobile_network_code = mobile_network_code
self.organization = organization
4 changes: 2 additions & 2 deletions geoip2/webservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ async def _response_for(
if status != 200:
raise self._exception_for_error(status, content_type, body, uri)
decoded_body = self._handle_success(body, uri)
return model_class(decoded_body, locales=self._locales)
return model_class(self._locales, **decoded_body)

async def close(self):
"""Close underlying session
Expand Down Expand Up @@ -499,7 +499,7 @@ def _response_for(
if status != 200:
raise self._exception_for_error(status, content_type, body, uri)
decoded_body = self._handle_success(body, uri)
return model_class(decoded_body, locales=self._locales)
return model_class(self._locales, **decoded_body)

def close(self):
"""Close underlying session
Expand Down
6 changes: 3 additions & 3 deletions tests/database_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_connection_type(self) -> None:

self.assertRegex(
str(record),
r"ConnectionType\(\{.*Cellular.*\}\)",
r"ConnectionType\(.*Cellular.*\)",
"ConnectionType str representation is reasonable",
)

Expand Down Expand Up @@ -197,7 +197,7 @@ def test_domain(self) -> None:

self.assertRegex(
str(record),
r"Domain\(\{.*maxmind.com.*\}\)",
r"Domain\(.*maxmind.com.*\)",
"Domain str representation is reasonable",
)

Expand Down Expand Up @@ -247,7 +247,7 @@ def test_isp(self) -> None:

self.assertRegex(
str(record),
r"ISP\(\{.*Telstra.*\}\)",
r"ISP\(.*Telstra.*\)",
"ISP str representation is reasonable",
)

Expand Down
Loading

0 comments on commit 4518919

Please sign in to comment.