Skip to content

Commit

Permalink
Add some missing type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
jlaine committed Jun 21, 2024
1 parent 8ddcd24 commit a944982
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 43 deletions.
11 changes: 5 additions & 6 deletions src/aiortc/codecs/g711.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import fractions
from typing import List, Tuple
from typing import List, Tuple, cast

from av import AudioFrame, CodecContext
from av import AudioFrame, AudioResampler, CodecContext
from av.audio.codeccontext import AudioCodecContext
from av.audio.resampler import AudioResampler
from av.frame import Frame
from av.packet import Packet

Expand All @@ -19,7 +18,7 @@

class PcmDecoder(Decoder):
def __init__(self, codec_name: str) -> None:
self.codec: AudioCodecContext = CodecContext.create(codec_name, "r")
self.codec = cast(AudioCodecContext, CodecContext.create(codec_name, "r"))
self.codec.format = "s16"
self.codec.layout = "mono"
self.codec.sample_rate = SAMPLE_RATE
Expand All @@ -28,12 +27,12 @@ def decode(self, encoded_frame: JitterFrame) -> List[Frame]:
packet = Packet(encoded_frame.data)
packet.pts = encoded_frame.timestamp
packet.time_base = TIME_BASE
return self.codec.decode(packet)
return cast(List[Frame], self.codec.decode(packet))


class PcmEncoder(Encoder):
def __init__(self, codec_name: str) -> None:
self.codec: AudioCodecContext = CodecContext.create(codec_name, "w")
self.codec = cast(AudioCodecContext, CodecContext.create(codec_name, "w"))
self.codec.format = "s16"
self.codec.layout = "mono"
self.codec.sample_rate = SAMPLE_RATE
Expand Down
15 changes: 7 additions & 8 deletions src/aiortc/codecs/h264.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import math
from itertools import tee
from struct import pack, unpack_from
from typing import Iterator, List, Optional, Sequence, Tuple, Type, TypeVar
from typing import Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, cast

import av
from av.frame import Frame
from av.packet import Packet
from av.video.codeccontext import VideoCodecContext

from ..jitterbuffer import JitterFrame
from ..mediastreams import VIDEO_TIME_BASE, convert_timebase
Expand Down Expand Up @@ -104,27 +105,25 @@ def parse(cls: Type[DESCRIPTOR_T], data: bytes) -> Tuple[DESCRIPTOR_T, bytes]:

class H264Decoder(Decoder):
def __init__(self) -> None:
self.codec = av.CodecContext.create("h264", "r")
self.codec = cast(VideoCodecContext, av.CodecContext.create("h264", "r"))

def decode(self, encoded_frame: JitterFrame) -> List[Frame]:
try:
packet = av.Packet(encoded_frame.data)
packet.pts = encoded_frame.timestamp
packet.time_base = VIDEO_TIME_BASE
frames = self.codec.decode(packet)
return cast(List[Frame], self.codec.decode(packet))
except av.AVError as e:
logger.warning(
"H264Decoder() failed to decode, skipping package: " + str(e)
)
return []

return frames


def create_encoder_context(
codec_name: str, width: int, height: int, bitrate: int
) -> Tuple[av.CodecContext, bool]:
codec = av.CodecContext.create(codec_name, "w")
) -> Tuple[VideoCodecContext, bool]:
codec = cast(VideoCodecContext, av.CodecContext.create(codec_name, "w"))
codec.width = width
codec.height = height
codec.bit_rate = bitrate
Expand All @@ -144,7 +143,7 @@ class H264Encoder(Encoder):
def __init__(self) -> None:
self.buffer_data = b""
self.buffer_pts: Optional[int] = None
self.codec: Optional[av.CodecContext] = None
self.codec: Optional[VideoCodecContext] = None
self.codec_buffering = False
self.__target_bitrate = DEFAULT_BITRATE

Expand Down
3 changes: 1 addition & 2 deletions src/aiortc/codecs/opus.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import fractions
from typing import List, Tuple

from av import AudioFrame
from av.audio.resampler import AudioResampler
from av import AudioFrame, AudioResampler
from av.frame import Frame
from av.packet import Packet

Expand Down
40 changes: 21 additions & 19 deletions src/aiortc/contrib/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
]


async def blackhole_consume(track):
async def blackhole_consume(track: MediaStreamTrack) -> None:
while True:
try:
await track.recv()
Expand All @@ -53,8 +53,8 @@ class MediaBlackhole:
A media sink that consumes and discards all media.
"""

def __init__(self):
self.__tracks = {}
def __init__(self) -> None:
self.__tracks: Dict[MediaStreamTrack, asyncio.Future] = {}

def addTrack(self, track):
"""
Expand All @@ -65,15 +65,15 @@ def addTrack(self, track):
if track not in self.__tracks:
self.__tracks[track] = None

async def start(self):
async def start(self) -> None:
"""
Start discarding media.
"""
for track, task in self.__tracks.items():
if task is None:
self.__tracks[track] = asyncio.ensure_future(blackhole_consume(track))

async def stop(self):
async def stop(self) -> None:
"""
Stop discarding media.
"""
Expand Down Expand Up @@ -219,12 +219,12 @@ def player_worker_demux(


class PlayerStreamTrack(MediaStreamTrack):
def __init__(self, player, kind):
def __init__(self, player: "MediaPlayer", kind: str) -> None:
super().__init__()
self.kind = kind
self._player = player
self._queue = asyncio.Queue()
self._start = None
self._queue: asyncio.Queue[Union[Frame, Packet]] = asyncio.Queue()
self._start: Optional[float] = None

async def recv(self) -> Union[Frame, Packet]:
if self.readyState != "live":
Expand Down Expand Up @@ -254,7 +254,7 @@ async def recv(self) -> Union[Frame, Packet]:

return data

def stop(self):
def stop(self) -> None:
super().stop()
if self._player is not None:
self._player._stop(self)
Expand Down Expand Up @@ -301,7 +301,7 @@ class MediaPlayer:

def __init__(
self, file, format=None, options=None, timeout=None, loop=False, decode=True
):
) -> None:
self.__container = av.open(
file=file, format=format, mode="r", options=options, timeout=timeout
)
Expand Down Expand Up @@ -393,7 +393,7 @@ def __log_debug(self, msg: str, *args) -> None:


class MediaRecorderContext:
def __init__(self, stream):
def __init__(self, stream) -> None:
self.started = False
self.stream = stream
self.task = None
Expand Down Expand Up @@ -422,7 +422,7 @@ def __init__(self, file, format=None, options=None):
self.__container = av.open(file=file, format=format, mode="w", options=options)
self.__tracks = {}

def addTrack(self, track):
def addTrack(self, track: MediaStreamTrack) -> None:
"""
Add a track to be recorded.
Expand All @@ -445,15 +445,15 @@ def addTrack(self, track):
stream.pix_fmt = "yuv420p"
self.__tracks[track] = MediaRecorderContext(stream)

async def start(self):
async def start(self) -> None:
"""
Start recording.
"""
for track, context in self.__tracks.items():
if context.task is None:
context.task = asyncio.ensure_future(self.__run_track(track, context))

async def stop(self):
async def stop(self) -> None:
"""
Stop recording.
"""
Expand All @@ -470,7 +470,9 @@ async def stop(self):
self.__container.close()
self.__container = None

async def __run_track(self, track: MediaStreamTrack, context: MediaRecorderContext):
async def __run_track(
self, track: MediaStreamTrack, context: MediaRecorderContext
) -> None:
while True:
try:
frame = await track.recv()
Expand All @@ -496,16 +498,16 @@ def __init__(self, relay, source: MediaStreamTrack, buffered: bool) -> None:
self._source: Optional[MediaStreamTrack] = source
self._buffered = buffered

self._frame: Optional[Frame] = None
self._queue: Optional[asyncio.Queue[Optional[Frame]]] = None
self._frame: Union[Frame, Packet, None] = None
self._queue: Optional[asyncio.Queue[Union[Frame, Packet, None]]] = None
self._new_frame_event: Optional[asyncio.Event] = None

if self._buffered:
self._queue = asyncio.Queue()
else:
self._new_frame_event = asyncio.Event()

async def recv(self):
async def recv(self) -> Union[Frame, Packet]:
if self.readyState != "live":
raise MediaStreamError

Expand All @@ -522,7 +524,7 @@ async def recv(self):
raise MediaStreamError
return self._frame

def stop(self):
def stop(self) -> None:
super().stop()
if self._relay is not None:
self._relay._stop(self)
Expand Down
2 changes: 1 addition & 1 deletion src/aiortc/rtcdtlstransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _create_ssl_context(
) -> SSL.Context:
ctx = SSL.Context(SSL.DTLS_METHOD)
ctx.set_verify(
SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, lambda *args: 1
SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT, lambda *args: True
)
ctx.use_certificate(self._cert)
ctx.use_privatekey(self._key)
Expand Down
2 changes: 1 addition & 1 deletion src/aiortc/rtcpeerconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def addTransceiver(
direction=direction, kind=kind, sender_track=track
)

async def close(self):
async def close(self) -> None:
"""
Terminate the ICE agent, ending ICE processing and streams.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/aiortc/rtcrtpsender.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self, trackOrKind: Union[MediaStreamTrack, str], transport) -> None
self.__rtp_timestamp = 0
self.__octet_count = 0
self.__packet_count = 0
self.__rtt = None
self.__rtt: Optional[float] = None

# logging
self.__log_debug: Callable[..., None] = lambda *args: None
Expand Down Expand Up @@ -206,7 +206,7 @@ async def send(self, parameters: RTCRtpSendParameters) -> None:
self.__rtcp_task = asyncio.ensure_future(self._run_rtcp())
self.__started = True

async def stop(self):
async def stop(self) -> None:
"""
Irreversibly stop the sender.
"""
Expand All @@ -219,7 +219,7 @@ async def stop(self):
self.__rtcp_task.cancel()
await asyncio.gather(self.__rtp_exited.wait(), self.__rtcp_exited.wait())

async def _handle_rtcp_packet(self, packet):
async def _handle_rtcp_packet(self, packet) -> None:
if isinstance(packet, (RtcpRrPacket, RtcpSrPacket)):
for report in filter(lambda x: x.ssrc == self._ssrc, packet.reports):
# estimate round-trip time
Expand Down
2 changes: 1 addition & 1 deletion src/aiortc/rtcrtptransceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def setCodecPreferences(self, codecs: List[RTCRtpCodecCapability]) -> None:
unique.insert(0, codec)
self._preferred_codecs = unique

async def stop(self):
async def stop(self) -> None:
"""
Permanently stops the :class:`RTCRtpTransceiver`.
"""
Expand Down
4 changes: 2 additions & 2 deletions src/aiortc/rtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,14 @@ def pack_header_extensions(extensions: List[Tuple[int, bytes]]) -> Tuple[int, by
return extension_profile, extension_value


def compute_audio_level_dbov(frame: AudioFrame):
def compute_audio_level_dbov(frame: AudioFrame) -> int:
"""
Compute the energy level as spelled out in RFC 6465, Appendix A.
"""
MAX_SAMPLE_VALUE = 32767
MAX_AUDIO_LEVEL = 0
MIN_AUDIO_LEVEL = -127
rms = 0
rms = 0.0
buf = bytes(frame.planes[0])
s = struct.Struct("h")
for unpacked in s.iter_unpack(buf):
Expand Down

0 comments on commit a944982

Please sign in to comment.