Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate DNS logic into their own staticmethods #203

Merged
merged 18 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 6 additions & 25 deletions mcstatus/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from typing import NamedTuple, Optional, TYPE_CHECKING, Tuple, Union
from urllib.parse import urlparse

import dns.asyncresolver
import dns.resolver
from dns.rdatatype import RdataType

import mcstatus.dns

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -126,11 +126,7 @@ def resolve_ip(self, lifetime: Optional[float] = None) -> Union[ipaddress.IPv4Ad
# ValueError is raised if the given address wasn't valid
# this means it's a hostname and we should try to resolve
# the A record
answers = dns.resolver.resolve(self.host, RdataType.A, lifetime=lifetime)
# There should only be one answer here, though in case the server
# does actually point to multiple IPs, we just pick the first one
answer = answers[0]
ip_addr = str(answer).rstrip(".")
ip_addr = mcstatus.dns.resolve_a_record(self.host, lifetime=lifetime)
ip = ipaddress.ip_address(ip_addr)

self._cached_ip = ip
Expand All @@ -151,11 +147,7 @@ async def async_resolve_ip(self, lifetime: Optional[float] = None) -> Union[ipad
# ValueError is raised if the given address wasn't valid
# this means it's a hostname and we should try to resolve
# the A record
answers = await dns.asyncresolver.resolve(self.host, RdataType.A, lifetime=lifetime)
# There should only be one answer here, though in case the server
# does actually point to multiple IPs, we just pick the first one
answer = answers[0]
ip_addr = str(answer).rstrip(".")
ip_addr = await mcstatus.dns.async_resolve_a_record(self.host, lifetime=lifetime)
ip = ipaddress.ip_address(ip_addr)

self._cached_ip = ip
Expand Down Expand Up @@ -195,16 +187,14 @@ def minecraft_srv_address_lookup(
# port which we should use. If there's no such record, fall back
# to the default_port (if it's defined).
try:
answers = dns.resolver.resolve("_minecraft._tcp." + host, RdataType.SRV, lifetime=lifetime)
host, port = mcstatus.dns.resolve_mc_srv(host, lifetime=lifetime)
except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
if default_port is None:
raise ValueError(
f"Given address '{address}' doesn't contain port, doesn't have an SRV record pointing to a port,"
" and default_port wasn't specified, can't parse."
)
port = default_port
else:
return _parse_first_found_record(answers)

return Address(host, port)

Expand All @@ -229,22 +219,13 @@ async def async_minecraft_srv_address_lookup(
# port which we should use. If there's no such record, fall back
# to the default_port (if it's defined).
try:
answers = await dns.asyncresolver.resolve("_minecraft._tcp." + host, RdataType.SRV, lifetime=lifetime)
host, port = await mcstatus.dns.async_resolve_mc_srv(host, lifetime=lifetime)
except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
if default_port is None:
raise ValueError(
f"Given address '{address}' doesn't contain port, doesn't have an SRV record pointing to a port,"
" and default_port wasn't specified, can't parse."
)
port = default_port
else:
return _parse_first_found_record(answers)

return Address(host, port)


def _parse_first_found_record(answers) -> Address:
answer = answers[0]
host = str(answer.target).rstrip(".")
port = int(answer.port)
return Address(host, port)
90 changes: 90 additions & 0 deletions mcstatus/dns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Optional, Tuple

import dns.asyncresolver
import dns.resolver
from dns.rdatatype import RdataType


def resolve_a_record(hostname: str, lifetime: Optional[float] = None) -> str:
"""Perform a DNS resolution for an A record to given hostname

:param str hostname: The address to resolve for.
:return: The resolved IP address from the A record
:raises dns.exception.DNSException:
One of the exceptions possibly raised by dns.resolver.resolve
Most notably this will be `dns.exception.Timeout`, `dns.resolver.NXDOMAIN` and `dns.resolver.NoAnswer`
"""
answers = dns.resolver.resolve(hostname, RdataType.A, lifetime=lifetime)
# There should only be one answer here, though in case the server
# does actually point to multiple IPs, we just pick the first one
answer = answers[0]
ip = str(answer).rstrip(".")
return ip


async def async_resolve_a_record(hostname: str, lifetime: Optional[float] = None) -> str:
"""Asynchronous alternative to resolve_a_record.

For more details, check the docstring of resolve_a_record function.
"""
answers = await dns.asyncresolver.resolve(hostname, RdataType.A, lifetime=lifetime)
# There should only be one answer here, though in case the server
# does actually point to multiple IPs, we just pick the first one
answer = answers[0]
ip = str(answer).rstrip(".")
return ip


def resolve_srv_record(query_name: str, lifetime: Optional[float] = None) -> Tuple[str, int]:
"""Perform a DNS resolution for SRV record pointing to the Java Server.

:param str address: The address to resolve for.
:return: A tuple of host string and port number
:raises dns.exception.DNSException:
One of the exceptions possibly raised by dns.resolver.resolve
Most notably this will be `dns.exception.Timeout`, `dns.resolver.NXDOMAIN` and `dns.resolver.NoAnswer`
"""
answers = dns.resolver.resolve(query_name, RdataType.SRV, lifetime=lifetime)
# There should only be one answer here, though in case the server
# does actually point to multiple IPs, we just pick the first one
answer = answers[0]
host = str(answer.target).rstrip(".")
port = int(answer.port)
return host, port


async def async_resolve_srv_record(query_name: str, lifetime: Optional[float] = None) -> Tuple[str, int]:
"""Asynchronous alternative to resolve_srv_record.

For more details, check the docstring of resolve_srv_record function.
"""
answers = await dns.asyncresolver.resolve(query_name, RdataType.SRV, lifetime=lifetime)
# There should only be one answer here, though in case the server
# does actually point to multiple IPs, we just pick the first one
answer = answers[0]
host = str(answer.target).rstrip(".")
port = int(answer.port)
return host, port


def resolve_mc_srv(hostname: str, lifetime: Optional[float] = None) -> Tuple[str, int]:
"""Resolve SRV record for a minecraft server on given hostname.

:param str address: The address, without port, on which an SRV record is present.
:return: Obtained target and port from the SRV record, on which the server should live on.
:raises dns.exception.DNSException:
One of the exceptions possibly raised by dns.resolver.resolve
Most notably this will be `dns.exception.Timeout`, `dns.resolver.NXDOMAIN` and `dns.resolver.NoAnswer`

Returns obtained target and port from the SRV record, on which
the minecraft server should live on.
"""
return resolve_srv_record("_minecraft._tcp." + hostname, lifetime=lifetime)


async def async_resolve_mc_srv(hostname: str, lifetime: Optional[float] = None) -> Tuple[str, int]:
"""Asynchronous alternative to resolve_mc_srv.

For more details, check the docstring of resolve_mc_srv function.
"""
return await async_resolve_srv_record("_minecraft._tcp." + hostname, lifetime=lifetime)