From ff19c642cb22a5b6a073d611e593baa836e5ebe4 Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Fri, 28 Feb 2020 21:28:32 -0800 Subject: [PATCH] Barycentric clipping in the renderer and flat shading Summary: Updates to the Renderer to enable barycentric clipping. This is important when there is blurring in the rasterization step. Also added support for flat shading. Reviewed By: jcjohnson Differential Revision: D19934259 fbshipit-source-id: 036e48636cd80d28a04405d7a29fcc71a2982904 --- pytorch3d/renderer/blending.py | 13 +-- pytorch3d/renderer/mesh/__init__.py | 11 +- pytorch3d/renderer/mesh/renderer.py | 31 ++++++ pytorch3d/renderer/mesh/shader.py | 3 +- pytorch3d/renderer/mesh/shading.py | 12 ++- pytorch3d/renderer/mesh/texturing.py | 82 ++------------ pytorch3d/renderer/mesh/utils.py | 100 ++++++++++++++++++ tests/data/test_blurry_textured_rendering.png | Bin 0 -> 46164 bytes tests/data/test_simple_sphere_light_flat.png | Bin 0 -> 26694 bytes ...mple_sphere_light_flat_elevated_camera.png | Bin 0 -> 18420 bytes tests/test_mesh_rendering_utils.py | 24 +++++ tests/test_rendering_meshes.py | 67 ++++++++++-- ...{test_utils.py => test_rendering_utils.py} | 0 tests/test_texturing.py | 19 ++-- 14 files changed, 254 insertions(+), 108 deletions(-) create mode 100644 pytorch3d/renderer/mesh/utils.py create mode 100644 tests/data/test_blurry_textured_rendering.png create mode 100644 tests/data/test_simple_sphere_light_flat.png create mode 100644 tests/data/test_simple_sphere_light_flat_elevated_camera.png create mode 100644 tests/test_mesh_rendering_utils.py rename tests/{test_utils.py => test_rendering_utils.py} (100%) diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index 0865547a6..8d50c52f1 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -90,7 +90,9 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: return torch.flip(pixel_colors, [1]) -def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: +def softmax_rgb_blend( + colors, fragments, blend_params, znear: float = 1.0, zfar: float = 100 +) -> torch.Tensor: """ RGB and alpha channel blending to return an RGBA image based on the method proposed in [0] @@ -118,6 +120,8 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: exponential function used to control the opacity of the color. - background_color: (3) element list/tuple/torch.Tensor specifying the RGB values for the background color. + znear: float, near clipping plane in the z direction + zfar: float, far clipping plane in the z direction Returns: RGBA pixel_colors: (N, H, W, 4) @@ -125,6 +129,7 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: [0] Shichen Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based 3D Reasoning' """ + N, H, W, K = fragments.pix_to_face.shape device = fragments.pix_to_face.device pix_colors = torch.ones( @@ -140,11 +145,6 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: delta = np.exp(1e-10 / blend_params.gamma) * 1e-10 delta = torch.tensor(delta, device=device) - # Near and far clipping planes. - # TODO: add zfar/znear as input params. - zfar = 100.0 - znear = 1.0 - # Mask for padded pixels. mask = fragments.pix_to_face >= 0 @@ -164,6 +164,7 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: # Weights for each face. Adjust the exponential by the max z to prevent # overflow. zbuf shape (N, H, W, K), find max over K. # TODO: there may still be some instability in the exponent calculation. + z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask z_inv_max = torch.max(z_inv, dim=-1).values[..., None] weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma) diff --git a/pytorch3d/renderer/mesh/__init__.py b/pytorch3d/renderer/mesh/__init__.py index 32e431e5f..3ac0e00a2 100644 --- a/pytorch3d/renderer/mesh/__init__.py +++ b/pytorch3d/renderer/mesh/__init__.py @@ -1,5 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from .texturing import ( # isort:skip + interpolate_texture_map, + interpolate_vertex_colors, +) from .rasterize_meshes import rasterize_meshes from .rasterizer import MeshRasterizer, RasterizationSettings from .renderer import MeshRenderer @@ -13,10 +18,6 @@ TexturedSoftPhongShader, ) from .shading import gouraud_shading, phong_shading -from .texturing import ( # isort: skip - interpolate_face_attributes, - interpolate_texture_map, - interpolate_vertex_colors, -) +from .utils import interpolate_face_attributes __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/renderer/mesh/renderer.py b/pytorch3d/renderer/mesh/renderer.py index aeac515d7..18da3d34a 100644 --- a/pytorch3d/renderer/mesh/renderer.py +++ b/pytorch3d/renderer/mesh/renderer.py @@ -5,6 +5,9 @@ import torch import torch.nn as nn +from .rasterizer import Fragments +from .utils import _clip_barycentric_coordinates, _interpolate_zbuf + # A renderer class should be initialized with a # function for rasterization and a function for shading. # The rasterizer should: @@ -34,6 +37,34 @@ def __init__(self, rasterizer, shader): self.shader = shader def forward(self, meshes_world, **kwargs) -> torch.Tensor: + """ + Render a batch of images from a batch of meshes by rasterizing and then shading. + + NOTE: If the blur radius for rasterization is > 0.0, some pixels can have one or + more barycentric coordinates lying outside the range [0, 1]. For a pixel with + out of bounds barycentric coordinates with respect to a face f, clipping is required + before interpolating the texture uv coordinates and z buffer so that the colors and + depths are limited to the range for the corresponding face. + """ fragments = self.rasterizer(meshes_world, **kwargs) + raster_settings = kwargs.get( + "raster_settings", self.rasterizer.raster_settings + ) + if raster_settings.blur_radius > 0.0: + # TODO: potentially move barycentric clipping to the rasterizer + # if no downstream functions requires unclipped values. + # This will avoid unnecssary re-interpolation of the z buffer. + clipped_bary_coords = _clip_barycentric_coordinates( + fragments.bary_coords + ) + clipped_zbuf = _interpolate_zbuf( + fragments.pix_to_face, clipped_bary_coords, meshes_world + ) + fragments = Fragments( + bary_coords=clipped_bary_coords, + zbuf=clipped_zbuf, + dists=fragments.dists, + pix_to_face=fragments.pix_to_face, + ) images = self.shader(fragments, meshes_world, **kwargs) return images diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index 1fbae6a63..efeb4792c 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -270,6 +270,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) + blend_params = kwargs.get("blend_params", self.blend_params) colors = phong_shading( meshes=meshes, fragments=fragments, @@ -278,7 +279,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras=cameras, materials=materials, ) - images = softmax_rgb_blend(colors, fragments, self.blend_params) + images = softmax_rgb_blend(colors, fragments, blend_params) return images diff --git a/pytorch3d/renderer/mesh/shading.py b/pytorch3d/renderer/mesh/shading.py index 69d1e305a..1b9effebc 100644 --- a/pytorch3d/renderer/mesh/shading.py +++ b/pytorch3d/renderer/mesh/shading.py @@ -70,8 +70,12 @@ def phong_shading( vertex_normals = meshes.verts_normals_packed() # (V, 3) faces_verts = verts[faces] faces_normals = vertex_normals[faces] - pixel_coords = interpolate_face_attributes(fragments, faces_verts) - pixel_normals = interpolate_face_attributes(fragments, faces_normals) + pixel_coords = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_verts + ) + pixel_normals = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_normals + ) ambient, diffuse, specular = _apply_lighting( pixel_coords, pixel_normals, lights, cameras, materials ) @@ -122,7 +126,9 @@ def gouraud_shading( ) verts_colors_shaded = vertex_colors * (ambient + diffuse) + specular face_colors = verts_colors_shaded[faces] - colors = interpolate_face_attributes(fragments, face_colors) + colors = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, face_colors + ) return colors diff --git a/pytorch3d/renderer/mesh/texturing.py b/pytorch3d/renderer/mesh/texturing.py index b79dbe3e7..86f3a4424 100644 --- a/pytorch3d/renderer/mesh/texturing.py +++ b/pytorch3d/renderer/mesh/texturing.py @@ -7,75 +7,7 @@ from pytorch3d.structures.textures import Textures - -def _clip_barycentric_coordinates(bary) -> torch.Tensor: - """ - Args: - bary: barycentric coordinates of shape (...., 3) where `...` represents - an arbitrary number of dimensions - - Returns: - bary: All barycentric coordinate values clipped to the range [0, 1] - and renormalized. The output is the same shape as the input. - """ - if bary.shape[-1] != 3: - msg = "Expected barycentric coords to have last dim = 3; got %r" - raise ValueError(msg % bary.shape) - clipped = bary.clamp(min=0, max=1) - clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5) - clipped = clipped / clipped_sum - return clipped - - -def interpolate_face_attributes( - fragments, face_attributes: torch.Tensor, bary_clip: bool = False -) -> torch.Tensor: - """ - Interpolate arbitrary face attributes using the barycentric coordinates - for each pixel in the rasterized output. - - Args: - fragments: - The outputs of rasterization. From this we use - - - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices - of the faces (in the packed representation) which - overlap each pixel in the image. - - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying - the barycentric coordianates of each pixel - relative to the faces (in the packed - representation) which overlap the pixel. - face_attributes: packed attributes of shape (total_faces, 3, D), - specifying the value of the attribute for each - vertex in the face. - bary_clip: Bool to indicate if barycentric_coords should be clipped - before being used for interpolation. - - Returns: - pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated - value of the face attribute for each pixel. - """ - pix_to_face = fragments.pix_to_face - barycentric_coords = fragments.bary_coords - F, FV, D = face_attributes.shape - if FV != 3: - raise ValueError("Faces can only have three vertices; got %r" % FV) - N, H, W, K, _ = barycentric_coords.shape - if pix_to_face.shape != (N, H, W, K): - msg = "pix_to_face must have shape (batch_size, H, W, K); got %r" - raise ValueError(msg % pix_to_face.shape) - if bary_clip: - barycentric_coords = _clip_barycentric_coordinates(barycentric_coords) - - # Replace empty pixels in pix_to_face with 0 in order to interpolate. - mask = pix_to_face == -1 - pix_to_face = pix_to_face.clone() - pix_to_face[mask] = 0 - idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) - pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D) - pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2) - pixel_vals[mask] = 0 # Replace masked values in output. - return pixel_vals +from .utils import interpolate_face_attributes def interpolate_texture_map(fragments, meshes) -> torch.Tensor: @@ -97,8 +29,8 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor: relative to the faces (in the packed representation) which overlap the pixel. meshes: Meshes representing a batch of meshes. It is expected that - meshes has a textures attribute which is an instance of the - Textures class. + meshes has a textures attribute which is an instance of the + Textures class. Returns: texels: tensor of shape (N, H, W, K, C) giving the interpolated @@ -114,7 +46,9 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor: texture_maps = meshes.textures.maps_padded() # pixel_uvs: (N, H, W, K, 2) - pixel_uvs = interpolate_face_attributes(fragments, faces_verts_uvs) + pixel_uvs = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs + ) N, H_out, W_out, K = fragments.pix_to_face.shape N, H_in, W_in, C = texture_maps.shape # 3 for RGB @@ -178,5 +112,7 @@ def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor: vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :] faces_packed = meshes.faces_packed() faces_textures = vertex_textures[faces_packed] # (F, 3, C) - texels = interpolate_face_attributes(fragments, faces_textures) + texels = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_textures + ) return texels diff --git a/pytorch3d/renderer/mesh/utils.py b/pytorch3d/renderer/mesh/utils.py new file mode 100644 index 000000000..a82f10ca3 --- /dev/null +++ b/pytorch3d/renderer/mesh/utils.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +import torch + + +def _clip_barycentric_coordinates(bary) -> torch.Tensor: + """ + Args: + bary: barycentric coordinates of shape (...., 3) where `...` represents + an arbitrary number of dimensions + + Returns: + bary: Barycentric coordinates clipped (i.e any values < 0 are set to 0) + and renormalized. We only clip the negative values. Values > 1 will fall + into the [0, 1] range after renormalization. + The output is the same shape as the input. + """ + if bary.shape[-1] != 3: + msg = "Expected barycentric coords to have last dim = 3; got %r" + raise ValueError(msg % bary.shape) + clipped = bary.clamp(min=0.0) + clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5) + clipped = clipped / clipped_sum + return clipped + + +def interpolate_face_attributes( + pix_to_face: torch.Tensor, + barycentric_coords: torch.Tensor, + face_attributes: torch.Tensor, +) -> torch.Tensor: + """ + Interpolate arbitrary face attributes using the barycentric coordinates + for each pixel in the rasterized output. + + Args: + pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices + of the faces (in the packed representation) which + overlap each pixel in the image. + barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying + the barycentric coordianates of each pixel + relative to the faces (in the packed + representation) which overlap the pixel. + face_attributes: packed attributes of shape (total_faces, 3, D), + specifying the value of the attribute for each + vertex in the face. + + Returns: + pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated + value of the face attribute for each pixel. + """ + F, FV, D = face_attributes.shape + if FV != 3: + raise ValueError("Faces can only have three vertices; got %r" % FV) + N, H, W, K, _ = barycentric_coords.shape + if pix_to_face.shape != (N, H, W, K): + msg = "pix_to_face must have shape (batch_size, H, W, K); got %r" + raise ValueError(msg % pix_to_face.shape) + + # Replace empty pixels in pix_to_face with 0 in order to interpolate. + mask = pix_to_face == -1 + pix_to_face = pix_to_face.clone() + pix_to_face[mask] = 0 + idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) + pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D) + pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2) + pixel_vals[mask] = 0 # Replace masked values in output. + return pixel_vals + + +def _interpolate_zbuf( + pix_to_face: torch.Tensor, barycentric_coords: torch.Tensor, meshes +) -> torch.Tensor: + """ + A helper function to calculate the z buffer for each pixel in the + rasterized output. + + Args: + pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices + of the faces (in the packed representation) which + overlap each pixel in the image. + barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying + the barycentric coordianates of each pixel + relative to the faces (in the packed + representation) which overlap the pixel. + meshes: Meshes object representing a batch of meshes. + + Returns: + zbuffer: (N, H, W, K) FloatTensor + """ + verts = meshes.verts_packed() + faces = meshes.faces_packed() + faces_verts_z = verts[faces][..., 2][..., None] # (F, 3, 1) + return interpolate_face_attributes( + pix_to_face, barycentric_coords, faces_verts_z + )[ + ..., 0 + ] # (1, H, W, K) diff --git a/tests/data/test_blurry_textured_rendering.png b/tests/data/test_blurry_textured_rendering.png new file mode 100644 index 0000000000000000000000000000000000000000..5ab0e4b6379e2a1658ba5499a68f13602daa1a67 GIT binary patch literal 46164 zcmeEu^;cV6uyz6zLUET;LMRT!OL2F1EA9n~OYz|DP~5e+yg+e^YbeEvYjFtfa?^YN zi*J4V{E)1bz9qm1VGAlfMQ409bOelIj2e65=fq0QB$0hg_l>0Qm2X zoTQlM*X*MwpA>gJ7ZOzI{?*-OAFgg|zgaEArxaXNbJ2Vf{bpR^Z@8Eksu%23O_=zg zx*<%A;6>R#qV2%I!3LRT(hz(0V6^z5F)MbvBRwvo=0sp{LX0Wo-B^Ro2X&Fk!!ti) z;ginYqfYbZh^GiYSZl_eR_aV@#>0)J&} zJk@d`GcRZ#RDLfX5!fVV70-{F8~78jUNkd zIFU^`4mXkON&|JDOaQqrFMYWAt$|^OfjaM0Q1=4` zhN5->#By{*c>O8yIY~GWE_4GAWTX3m2vgD+CQ#OW=$(9Ev1x?zvqsS1>|>eM|o|+~F+y7n7RwdlrhS^XY%tR`ta$o5=oF;yY`%(QDIq zY6cA{$sPY4R~o39`r^+PkrnJ9lh>gC+vmI63FFIE0F~CoSxF{AxV1C#friMjLg2q2 zHa1IRLw1Gjj@v*7Oi^zbzORpiob~Qe+@&VaxI3zU^{aG(frKvYP^8x!0A9GR_L*<6VL@Opc4Ms6hc>M$d zsGg(eQ(f1&8JTQtzAokSLBp&i{(BBf(F(CZEyAl^1{8k*V7CSyq5s~B%J#zW69r-R zfJ8&?wn5BrsV4e&0C6f5d3u%oAci)K;v-ZmVNq5lf#8<8ybs(Ra{tVNsX7OaYK>29 zf1O_Nx)|~vgt*&(Tup?s4|pB49v6q*+kTx8z<4xfJ$yE_d@r}Z!5sYteN3zPF!TsD z1*hi-8C#1CakIC-TmV=)1CGyyH`IlYD)AeXoE}5jZx=ro-&?a&L2`h4E-HuC2Vk8{5umVXE<0(ZYrbV<_AmO24!bcY8&xAJs7wa-qU; zPJ@Mu?A6O8CrX={%L<}?muJt#biVyw56vPhd~;++x2qJIZv@;%0hldN9?XjQC+JdX zj>Hf@f9w!KKy6>@=_kZSsFN87znF1TeI?2`8_VlV zL`w#ZdY-%kWDhHFb9dzY^Vh#7JFBWv>7r5rGv#}PSD`7vc3OfGOlp`+f!W8;{r-#~lR^jw{(-V~YYA_m zd(KgfbOYLM$%eic%k)k3f4zuR%=xCNqFliS#W5Xpzn<2aH=S{?39h0XQW*V=p4rwM z+An=X#gIh=&L;sT)AXol1R$c>8wmtnq)|CFoNH$F%FVf{>$a{V%YV0VL;f7TfK}W! zvWe`*Nwq~*MK6)op>ra{JXA5=I*_m8y2#j)*vir>I=ss-T^J{0+Tk0Va6rsb(+1S}q6^sF#{v43dXnBi#|eX#AXWtn{2odH2i+9aHN}#cVHHY7jq5N2$X-97>Mf zGu#9E904dxMi%{6L#{aaS02|cvOO%Ug<$b@{sw-zY{v|2Ua2utLPqN}xdF-9{CzQj zgd>we;CO+GB@rlk$*WlkAU&<%)iR<|;{B99JABk#cv~5*gw+=+{96HdwQyjrdp9xo z!rgKF&*9O|5PKAXq*ny_LtQ}Wi2+{m1IWm9=dd0X z51}~z3lB;TW9q)X@8^?=fsF$bKdshz**R4(k~NS;w|(27=`}iW>1(of`OVZeUq9u` ziE&f(|2u>}3Ot89Is-Ah?x<1v88U6&>aXeVLgYCi-D9|$!MFFpH@-ig(pDYsSFi0j zQR|JGDM6KdAy3ou*D6L{j|49aO3WIZuHIh?)UkC18V`EURnxhVMU52%O~ZO9Hg9hK zGr)tYG;$!j5`@^?W*1ch$XUg{+&eo^&+ZwkR(bp9VN9~k@F}p=_F1w{aRWGh+de%` z-btiG`ryppfWH{u^?$wZhfn^DKk~~NRw@lR*`Pz~|KTfF z#+OGN6T|qpOR=98(TyjH<@gbSYEHO08JzS4%%uyt_N>;5o^knn$G|2PS)86ry94jW z#q_I0L-0+Oz{-?i@qhz+DCejh=ei**>fR^*UH~Tu{B*LMBP8@PBXZr&r5PB|netO& zTSZ<*28^3$QSq9(mWWNRRjzZ;h=OD7*J-Ob;#iPCu7ANY^>#}|8V@jbGA}$d4k$+I zEKFld94xhEtX!AZylg(olp0GEPq49S4O5DGWgciIM&ce#zdXSM-vkCjBvlu}B}%(J zzN+sh`-!q@MNM_ee<^-jxA@5!<;&Y&r`P|Yb(;BVPF5N+?~+NgvdKWfAr2K4Cpk@U zb06XSIGsvzVo-UC{l>;^dAWP05T(i2k z9|^+g+gt~nIy+G?wP@0}4`n*_0}oh^KL4vgsfAMi*IJ{VnR+=G8|=tPWs?|6%1{)G zUw=HGdjsY!?YOmnJTF&kVWCM&qf-B-|F)WJ!AZ@@Mcr%L6f_gH?4F~0w9@Cfc^~KV zcWeA@yuf6arveQ}DUyJbQ z1pvODT?AmW1AOtg0v;V-<^}EfZpGh2_M%|uuli}7A+$Ld&VFa0rn~vc%ls_FNk&HC zBrk|DRn8C~zepz}!)x+1t)5R~GjRC^+2SMot7pKYi8~z6T_{P6Q?{4{pEh3yEsD$t z1wa$Eh&4_jUOwGzUv0UQoUCCtZtqQobn5fGU_7yU4DJ@p6&AkxBO?EuTXRAFWrpann}JvFd1lYuY+0Q=x5k0*RJQ4smCqpD729J+F_5&g6hUWo=L zoy?6XQ|Z07c)4u^VFH9tX)=V0d|AKmeYXufKdUGAU-)&(C-cvp3Y%N^2cqa17ff0F z0Mth&{4~geH77z1GM>j%LXE3#GM*%g0gKL%PohNFpUpc$w$arqOl>$jvP> zo>RIseOX)V*joz(JB0^a%*=k*K4}Y8J>ip`1skrxh#Ztv+=NH!rTdhRpQy?#KpbVd zy8Q6zgaX&ONo=}c{~m2}>FRC6TvdFz@!$OfefdK6?GZH6LQ+qz`KkZpB~xyVIF!?n zIZ9Kes&Eq@{q5q7z>~SfvH=CzO{gmV%HZ91T6!opoi|vH;D7Fe$9=&y0jNx=T3H76 zKi>-D-#zif__*#uJeZHXkLI8@ap3gX!As^O+dTY-cXO>PzM?RiC<_!Q9tSa!)HFMBv=T>Q!QcUu`O8DFbGGrvxG{+&t^ z30f&@4i2O$aA#xpgDQTjDMedRd;7d`=@myhc019v9e0l5m-53YU2el}{aWo#93T)^ zPX3`xTa67lY3dsVlsenYZZ@R2pq#K@u1N_)K?1}CR&`?pVHVc++f6&3yimN{4n12( z5XpuYhhalb@>)}UcnpxaIevr%-+t-dzL?9IGIQbwJKbCa2qRs3f~m+P#d|l7iUr!< z44N+IZXSTX;K*7Wt)SsVV&YfN+tvSBMT?6Py~jxy!+F~5M*~Lv*MR+4f3KPNWsKU* zYoKR5P~yL*?>;PS#JnMO0#^?!&u7-Xym2bV&erkC_iwh1xRk+GP+Y0QD0~WfA>O#~ zXmY3Db#*tuM4lVU=jGbl#gErMcc*#JH6O2khK2o0sCgrB{P!Q?bv&kkg7>|(NAV{B z02#7u4{!tp{=_Q_-0dS-m6NrRln-Jo_1(fAoEoSg{ZaUj7CeVVvSEZXT!6wc^ z6WUW!Hio7m+aPu-oau>e8bBHL^!3`}>v@!bA7=BrrEq7E11T&Q>H;YE8nk#>20!<9 zL68SOFSFWly4KTtXFuI$*|HXF#1V?I`vh&lrnvDoV{vTX~8lUJ7gPS6h%BkyA3r@i?cf-)~TeyX?&GmpO9PK?!1 zXAmUBStK{_-qiE?d!DQhammy$(p`@U%c52bpRo7o9;}GuMA&=EqqA7zN46QC+6+l% z1I2>G@Hh33B~R-fEx}Qo`ZXqB9P&-$5Xi{nd%M)7CGxa)^D-NYcL-LZNnEI!A53`M zkYqNt8XQyo8e@RV!V9I2%rM?re4aG=PW}s6#d`%Fb0Wx zs0PWTPEZjWGewF(bV8t+l!z_cBBMpioCl3$gb2>fh1f%H@ zcRarn!Bvn7lw6p(wDWPhxefJnJP|f{`!=hF-n=elis`q!0JLrK^^Au>wsc2Laf;eE z=7)kiM71mG3TVOBX?Pe$nsn+fC|3m5>Ao=h(zX}*)&?H86fc~iaB-axhfH#t z=itN?<*Qb(9O{b^UJ9t z5+1o-n|pwpwIAHn6Ry^>=ys^o{OtbfaQS)C?egFY-O_G{r{I&@?>=6m`!@~_`3Mz# z3W0mMR1$kw<9H34nvxSbN*&U*@|NRz_>$W10?q zCf$c_+9x>3-W`yXmvi!cZ>u*gIG59@QjY2J?!{dw+XHM5?JbvltU#!W?v-=j+0e^r zw~?%qq)Ar*{3Jv}R838Kk3S1}rmV~w2fjw|U=;(LY6>4e-i;c!xY~^gsH%$``WtyS zC?2di$VMl-MOFWOdCwHA8p`P&QtOKTxuQp`8cYpRBqU+^j3kNGbI zy?xjJS)?W4Z`sj7xk%e|b^P$PO$^5zX_E5i89aaCX*cQVejWfPD=&1L=(mNXg$WXZ znSc}v%;+5!FBBGRurIz|9~4&%;31Qd$*MUVT~e>b4`Q7Sxjfz<*>}aE>@O#QN_G$c zlMfoKxZI9v`vo-CYwbNBGjf?&i#uwzJh4i6t1-*x5`RhM0Wt;|eN+WD9}`QYL?ds4 z{VBvd186@j51ljB}>HpXq5#BA#HzPUw)>A}J?g9DA?U1pu92(owm zRQ1ke-s)JGS)@7x5WZJ2a9t$p9^(GR`$f$e)~4u9?{mB-Ij;3+ky&01~sd z%?Mim0UbbcubQKoYf$!Q{Bmb`h}vYP$baR0DQMrL4aeMfHDU+5ea)?|AqI~xUoU_y zWU)wXfw#jg0TDpT<#uLZ&Hi}IQBS`hi7id|+?>Zk&EHn04?Y^yJ}98TwR1~}YNA=A_Y}uVa5R~gxJ~&^UuHPuMSG{E6Bz<=@O(-q! z8;d3;1VvT2O!hxMfDc=W5}0Dh&kBzBeF&!%y1~T{NR%!DI2|o< zTF`0jWL#Q9lar@48~en4@;6@^2a(wUrp=9Rg{ z`5uSZ;|G|RjO^ARhXja12BVT%%lS)Sgg6L=kU_K@%U++HgBss8;ZrOi*HucG)I5ZF zG)Y#CCH^-p$znH?m3op>`gc9W2{-j^%NpuA3U`W^$61uMXA_yuRVoIT+JgKa#tTG9+d><4s zM-CPQ*lQo+A|Z#bXGEhhJXjY~CBNoXr7BIIGl{^@jvSCNVQfq!O8XkiUl8cbL8IbE z57iGdbXsX;VbDj=;qunOmL*UMpUcg{5C9UsnyfGYSgt6XzPjv2K8`ra7>JXs1Z`@r zyT=^Nv}yT%0IOEmV5ZZz$>@SpO})mbHu@XPK<*?7>fr-2Upu(PAFBN~RMb`;u$UVb>l7WdT%imZ39062n|{Z07^ zG8HmcqJ>9bq(sBxZ%PUfqiPQn+S^9pN^WBwVQ_l&8F{D1YL->!_d9Yl@?;Hxx9ml- zjP)!U=W;vWx&NX%3nL-nwc?m?Yf$+^A@Q%q3%$joy`&_ zp91>r6bPN{c+`^^Bt3XA@~b0VJ>uYwo`8Z#F$v-#_BwA?!+Rpu{UQw2BmutuU8xjR z+pmzg@XeT%XMw`!fwaK6wQX^&jtM}+!NAfZaq5;;Km}IqOO;o(QI(~VMK3`rArT5Yc4Lfc*FO39&baChs zvdMO$$C!Mx<~t)9Epk=3FT!EZOk18>;2C0FOdR@sNQ({#HH4sTR0{m zyYzk&J=ru3V+$~q@bg$H>el(Y@a3J`R5a#2VX4}YLY1AL-hAQlG1_OlTtzN zo5&BL#(GB87$lX^F8g2HQ z@8eP&gH`i(b}?^ODvV=c$DTeCGwwXMx|rnOlX{1?<@0D-sk{j?B*x~1>aQ!92wj*( zm0!H!4>i%2YL*4Szs4gw+By?zyVOy_5mzMiKNC^$=_z#Q&PLFR)z`mHGknfOeNCEV zR?cr=htfG=L-^}TTI;Q@Y-In36l|K`#0IAeI!2bzn#M(t1*i0<7cvnf=lGDp1dg3O zo5t0u_8NlFsjkBaOaSHjY zOqOT!DGoYHT?0e9B7-#4X#wu%q>R&&QRzjBE{SHTJKgr;6(f~$JMUf~ABgv~3#tJ% zANk*CbRAylliaJ209t2gQ9Dl0t&QJnya7u2f z%++D`X6cw1HwdTw>G^Dn{-J2Yx82X_W%X>zibBQ?lRXe9-pkOTJ^Mfrr--Lj$%2dq zy*NHOOQw(L$K2wz_iX1Sv80%UO~NU_OEuG5;EUDhP!$3YjiEuv%U|LW{ufOTbX)g? zNuEaa7^u7|Tg7LRb^~p5K*mTWD8Nh~WjILkDVnGb*QdK!Qm(BM86e)1Kmu%K6KJU$ z=pz|93HubYgV>r2Y1yvQ+w!}rFT8aLG_AiWTt^BPZ8$z=u6;uy;QYlAk_%sHKg>@4 zvr4+vvF6~3S+{79KwCsYI3y#p#cn@$2-C+7tZs^k$PS^eP|uUdR;?fspXx>r%(=<) zy8CNNTHO>F?3z2BF9}F^wpemc^%g&X#ChTs!H_M?KA!|tcxOw2iSj;~ci@b($8C8x zMH!{bQ}dL8qA1ij<|3Ezd6@W$>>n;#Lms1X)Z??3MfG`!j9!mnNlQw?SQ*guPB(yj zSRctIfClEiH+lE6X(7+drQNSzFjlg$ddF%z)Q@G~19v>~>n?Tu3hkj8A?%e`2ao@j zAp6T!y#LYm(b4T5C~~f{39<;k@mq7Bqx`~C*|p`;OiLi3m6!!F1cmp&ye1#d$MS@P z1TG-#&?ZBVApCxUjlHKj@VGaT96gcq?%CCE|18ee_-B>D(q+(W$8X;bSzw|!mo6#D zqt{^9Pey-}t}wY}BiGVtRVuNouv6Z;!(?&Is0z2`Nong@cJa<*rxgBKL%tFD`V9eR9Zn4nL5& z-mI?Z82Q9vCUO9PiQ?7qM!Om3+dw@-CFWq>?slX9*zGckKD{ zZ%gD;w`+XDzq2O=ufqvH6;58^u*Ow_V584Y|de@*u`F_IbLR zwGEX_Lc3mtUjxR&7{0!Ys00i?IQe#*!NMo_S*h4ZxmXsT8(1Q)u7C2ifWvf*yk@-u z+6e6F9tDj2?;5xEeAh)+6UcG|`1m-&PF6doJ#$oWhpdKaUNf-0FJ7eg>My@T zFQ1{;U3#|dadBTA?^uHKWVz!AZI8FQxZNpVTqdjXfwV(qu1Z_8Wt!y}jpZ|-*)_wR zy=D(Ty+~@aW}m+Nuta195@fMKqJ#Y4ljYWjW!LzysT{=0{z$?$SkL<3*W@O138g3Qo4SFOQBM zecM-#Jqp2SZQS(lJiIt6`RKf$^X!-^D$3~@5!Zdth*dPO_}b0k)%`8(tuv;z_3qNj z$^oH1%Dnje)I(E)IDz5Lphxx@9fTv1L#n??IBLzL`3R<^z0p={*-3_`1)@ z!~eNwCC#71PI3N5-kw^^v-Et7@MMTV@3obgcRoYyh6R;k2a@79jiXl@s;ji$%shsvzuiZzFmN$d-hwW3bOUEa^KK?E)ZEJ}9 zf~voazAS5kAHh1HS?!bJifFLXzbNvjx%4Y)5s-O~_lMuj1P9JE^Mi@MpG8Q6V2 zkT|9b)@C%<1FnQI&BlPHmG~Bc*q8TPQOVQ{EVV+T-&1FemUU-q3o(pVaS9S&TF4Pt z??{yo;a_uzkLNjiRqhW%j2r@Tnx^ajV~S)lc$q$gE*v%&jy_-hSUxc@3kf=LXec^jOAs_vGey%%{7A} zS{QM-7>pY}yS-9tLt~eHqXT*{t()_Jh2hFfIZ+-OHMt!A0RPVVtlLR$kawRa<`aBB zNgZ8nW}B9!eWEQcGrRG|0WTBz$7;GZA#1AJu8#2Cca`1v{hoktl13(dB{DpWwIjG5 zyB)5uS5`0Ea-bswX5BO&W`o1kQZjN}%}XV+J>&BP4Qq9J--Y&YR(59Z9gPo+ho&p0 z3L?kD5J=PYe1FIq0`uTN!lCHE5xf1iF*|Q`wQ8qBS7b7*I`2NG58@I6!RQ(@`%`(T zSqTE^y@}FYtJ?Te$f(%Sz}c=_d$8subHPzxeeLt1Q9A|H^rD?^%BLE%SFW5d7np8m zIBlv;zZ3%{Qm?da(4k9V-%z>+Z&zLx!_Up%%I?f5$7;5MXW#A2i6m~c+8BC;*cb7V zH1|W7!`E{VdB*P2@xsC#dmQXx#bY-#p|T^Aq=h#X+XsYu)z;cB_05dclQh6zK>zycg!pJNhh0`0b^x zDn?z62#Wo<>mt=?VcS7tWo0@Z=1dW0N;-7;_j>CGc+iyMM#7uD?PqM?Uf2}Y7DInx z8d4d}Qy$@y)I^&;bvcQGxR@10(wbl_k_qU52MgYR#l8F*QIt%GNpTGA%u0Si_P^)K z6TTfOybE4fPff~iGvOf1v`9)v4xhT42aE8nok;IK%)so0%H(=OJSx-zBW+TZs*)vt+%qy}Bhb5k?9Z(0OB^=d&3(pS3g>zl`i1 zk4xm?)qg%=42PxBUO=}F<0RD2BKz}+CrP7ZqAt<}F+Zn1KK^3Q40{v*4^;-8kJ=pC zm#TFtd-0Hy8GvO?&w={!9BQwG#QYWuG=^m$l8NiI7 zB(4VG&-22U;Afw;R0&uDXLJuKL*#lUx?fkqzrpqRa(wjR&mFuzHs-W&6zj64lW{Uhs&Uce|igHL2@y?keH zvQ*EWwD~X}yrYjQwF;NBfMB}a>qjl&85H8~Sz%0~l7>{GFFP$(wHEt9E)-D#tio@+ zCIc>2o|YW9Rq8hQi};{+IC4`+Jq|6#wcadICJF6|Y%^YbGhxNvi4A(_fqWngHnMdD z|K)0JZH+KobNdP+$iY`TcXN_QJ;~x}hYhPE{ zaYt{6rjK9E zhO&$`Y=FLJCMrTA}1h6w?s`<9Xh|mk(<@cO+z;h;*p+-vCKU(j@f0JtKfvj*% zT`X?!OG?c7svt{%bCf5W3CWB{S4~;#b`3!yQ}kc~kkIO-VLl@sNDVk@aJ4F1kcq~? zlK~0Um`KV;wX>J+Nc8o>aw-Jr5>Bh}eN=J^yMyron_EA_%joIN56+d)#HhK$`0b2+ zhXk}MyGWAlzqkQXl=P^z#Q%HgMG#6Vlam7$k}3#m=$@)^vO3~W!Q!2HAI+pDI8rkBjYatoiraD1$DY6uK}d>JcVqb~ z!KdM(HVC=VRGm<@rfhRtuYqeCqg?iyP>p)n^o4srp*L{wWm);3)y^imZ*E=&L@rj& zMILzh?Ud+3UGS+wm{9X%UNS*yZ5-thq}LE$un%-dErqHBR2d2C_h@f~<>fgJ(0<)@ z`N1^zi`?ht7SWRI*fM16{BfXTIKNPnweVu0$rXT?6^O{icv18uD8OG5hWS3m_4$(` zoZ6#h>EwnrkAMp)e(T2?PGwix=c^=jp23uTmI|2~%I4G)j7k3)hEgyQ7|q>eVmt!I zM;#0S;3K1;vZ1GP#?!8%)bCItHAaRBd7|xJ0|Am$m@c#|=OnmGSDFw81dZQi%ZR)R z5BL2xalfXI*rMF!sohx;PNIU;F2UP{oDkXFW(}IM11M(6k#o z=0@;Fe)0CtzzF%eu4BDKDGnBoH?0=q$|t3%70r(w%Ne`1Eq59%P# zAJ4S!C88#b8S+znMen{(zh^-r+9yM^+I`DOiw4a=8vQ6zJgPG@VLjcO$NOFwc8~Rw z99fiel`7=u2BLLxP-+I)dE)Je>~0Fn3Z4Ub5S;!qCjB4@W+Iy<|51mCfRH7sT{VYh zTPJ5|HU~-9`2N5VKpJrj{YkI;>ek>leY|lF-D~k%dmR{ow3B=Yx+6{)1;OMlEp)+k zDz&-~9%=%Hie$28E7Az->ccQMSbzzv^zA@N?l2ZIiLi)!*dh$+!hzJYuWkwM`S66` zeSJ^G8fJU7_;x)WRvbid6Cu;?Z%gnFXfX_!sSRF5R*o2&$k}~ADcI6Ns)CZ=#Zv-265d;TIV!Q9cOR8lET#6)z1K^;L8(eQ;JhB!&_{uQ?z> zyOLP(E2ob~k>kAlQ~033PEY4IsVfs&srvlnoEOGATNbZCyPez(apx zfRpf3h$?7n&!hbIUSFpjU3|0(QTU0r>?{ODf)Y4sI?In_oHI-~NjYd3k5>stsub zz2yMvy1RMyH%ZveL+R<;vrsIuH>uVVG9I{`&r#AR)MO?p%hh5Wa#Ooo5MVQcgs{gH z-QW*fP2`S?K}nS)+~6lG?Vj-dgUR%$fqNG21HDfip>_$FWWga1PgF^TMXsx@>q2D} znH?ClOo?qZTXmw_ieKxZM-jaed*yZ}^pdLVK+7L-uog}qcm{G^#h-XR2SaP3?VxN| zm>1q+G&3^EJld*H>n`lCkCA66kgL(Pr~(ZM?gwjpnS$HQy)+{NTvo8-onHVHh+w8H zlMxcW@c`5<*vn@*<~+QbFcdX~eq(XD1VLWA4GNp=qvDRGzFsm9c0KBKi9(|c~s6dPjS^65+xf)6nO6E*||%)1Zjf8?Ko$(b(f%}d=Z1S^-G zXbJ;V1wx98vA&ad^4!rGnh%%MOudqwwHuyKyQ2hGcSk>K53^!RlIkKG82O6+_dMOc z?9`6v5WYd}2-w{_8(EmLdgmgPGWKj&`YHT{7N_FC$<~7gj~!j_9-lPZ^}!Ugb(kjJ zNo_e6#ChX)gK!&{8@8Up_e``O$ulC8_;ywwjy>qr3EGyZvH~>2=?W*joAj$CyY)wQ ziVclz@X9WXHvbGqy)uhyN-vP7X9pxe;KBHR1J~c`XZaeh3{`e2&$g)LNP+@!13^V= zJy04pm$t?S(|i}hbPOO$Kyl~`)|XdKeW&gIYhVBr%`+a6g3^Y?IIjQof_cAUXU0LI z5~nt=PTllF-=r;Q1BWfsV-D!k6DZ1twZT?95eU7eoo(fgq&=0 zEYQ;5tmcuPTEZf>e$ZpCM*B&CQt05}xzI-73M{krMX`noF%Eb*TLI^{i{ZY^&C3(m z6+d#CC^J6lf^ddaeaO1HhvjV1B;ab0@a<{Ow+>^x_d*>v%BNALlR#=X`!GmqXXRr` zqamlZO{+*VrS1(iZzLUh1Ts6SJxV!DCVQcmqP#?a`q;PEDjeUfzriPe<@PQR2fpIx z?eTn{T*fBU7Em6{_yr?*K&;S8-Pn#vKn8`=qS~s33Ef#ei%~f_!1&^4Jz5Dl8 z2^^Nr>zHWb%+m>-Hz6~g&BCeYp9guaN(%B3C`08x4}U-goq8eKX1z}W^=bRPKi4uS znM9l#fTHTVEpu*u-A1+Xb6FEsjr!HF$rczN<1D@SyX@ExhSm>VpQ`Wun-F@B&-FRk zv(Xq-G`g}iLr-EqY7O?J7*e!-XX39Q?r-k~6 zp!`?%Q2|0h+A&6vpyyUhOO_DU6=Fc%U-F+wgYCHc6D}8Z*n0R>ER2H-@8HfDP`7k^ zF7`5<_XxB8@h2<`G5h5}*Y)@230R+ln0zV3GepeCcr~E8 zolQ+y#Y(oXusRcpO3p26tvPB}PI$Zi7isKyLeYH@>RHt$9B0kwu6iz;xn@7tT_9Zqbzu1? zec;QO=i{eyUgE)&1r6ubFTYnZG|dip7$q&#Lxeq@t&@`9O3z=dSarheF$eCL2g%K= z>j&kpcKgC{q#oW!awIgov{II4im>Z7)3|^qCtD7$+TahS6xB^RS6y&G4>>{f#%+~7{ASx>8^WM^wccu|O1lYPZho5Y3UF3~Mw>)c)h-MwglMZ9RKI3^rF6D0 z(7tDRquKVwD{o2+McU$ZAo43KgR9D(?9hsqCBqLhr3D4jh(!t$pn(rwI&M=?>0E$W zJI7P=whp7RASLJNPWdiw{}l(c|M||U`Lz4rIH7}JI&eWiRY)h@I!Cc3Cn*R&fdHR? zR4w@%!!iLL1bXv4Kw<=t+(-I0s>0aKLFaOY16~d|X$|*`kDuQJ-Wyp!50uw^U!zrb zyYatUN61$sGntc{gVpy{x)d(3OHd0={jf#(ufBP_nEY}-J2ka~AQx^KoHm z?+Ps}EkyxVUBioA2zq@&AUEK@gZl}q_j#N=&s=IKZr=T``m+4**JC*jXI}$WFHSs< zJs;+c0{jvrn1?l`?T(|_I<`)7wytd5-Q3)K-TpJG-FzgY5pA*-5_EPZQ_ZdSK-Pk5 z=VZrzN_#yUok2=huyw9@=l!Y&V!^}MvOaQZCT+D|qy%1p4XA0olZxXLv0-W33zHr96Fuizjd$+j&1>nE*sw-{BsOiY>WP<;R?1_2U!p+!sWWG+n{&Y0j}7=ti7IM>NL9n8^r!WvO9~uc z2Ei}WB$V|9Km^4(0&2goiDfUvk&+9X-I}3J*$JT47GwiNO%WZ2n2;j9>BmJ&!VT-e zIaxynfIGOMWQDV~SEoOpCZ?pNiB&QSD#Q#l7DZ^4x5m$Z)1@w@aXEGwRj*$Cj?QU9 zP#M{-P?Ojd`#FTiuh{cQXFdLI&Dsgjq6~MxG$f_Ga^ZZaO7*d&?8c0e3E?mDNnIm- zBbld-EvQzNtlE7b9=u8(rG9m&iB-85q zJadt>rNj*0qqr!iy50}~xj@lzKn_*jW|_J4i({S;QAtF%)DynE|KL9tKN@AvpWmTM zY0lb~F&8Xd??`V~q%4TSlyFI5X;{ zXmVNN%~Dsz6%!VG6IyF;Bw&1s=Xvr`dN={!G5l>kHA;YVgsPiQe0XA4L6DU!H&wY{ zXl7T#dVlHR8O0h~C@ zt}N~zqL#F{#kv>bCCxL#V<8~w;CBk?G*jC3pGtVkyj2B?GTT+)X-%TEm5IvlG6LkM zA(1x1UeaM?D5@~!iEu_fl73kFTN>bQ0*^RK7y$F9kmn60fO~rF)asvPp(xJ=oVf`< zd{MrUf`Ntg*rVz)B&sMyMUZN)a-Fviuk(d2=9v^-zYRka$P(Ng0s zQ2wNpjAZnrX1S-(FDn09n>@kn%55JEG0l3d`_dg&=n?aO^bEU{(Q&@RtR)Wb z1q6fWx;yrdM;s;#+SKHpu#9}ac1L>iba){cj_FgswKGS$9h%9oQJzEsQ^+D0CnRMFNGRL5U)~4n2B3)MwfFajsZ@*XvI2T~0a_3i zWY?DGw|BJC$cFxid<6kF*V(nPQBBF8D;X$v=S1EbWSyD4U_|!Y(sF#h{cV@f zQO5vxGJOG(^6hQrX1=O;E9+r&K+$xlrwe4U8KG}8sZ?4!J5$1+TI7s;I!lAEerW+^XIm$12lKHqdX0{u$CZ&^5JPr7^hU5g zTdO_c6Q9pz9vO=pKKu_$XBiiD_q6e)LAoTBknV0+M7q1X6_f_a1?ff_>5>L1rCXXs zTDn`hVPT&i_y2jxhc{<;&zUo4=DNOfx3pY%1hl+Q?vEkx@=jP-m>uAD5`3|G!`LBi zZ~&4h8WfCpA!YCo#+vimlf+*5_}N8b#2lyjStMf$KP|pSh74Zw8Npp!2z?mazx8@ zu@(1uyRnxPVeF%q;Q@W9jB#S&%8JZvKOxAYlHBuV21K25MD4YlLU1UCsb?Kqa&k4th0l7>{t+{{9Gvs#Y;r6 z!_p9;)}T!I3qiHbs{2!h*kH-@0S#|8d$FucJlZpbv1n}RAA2(R_0oA41(r_9-wh98}Jug(@Y9xZ91~zlYHGakLVm5D>Dkp=MLdNtLqn z=X3|CWsf1Wh)0BC#MMPeRUVNn=jfva^W(lU@aM=d>2f4z+!2j>-8V2}bg!*$gT6wP z>cy?HRbCQqj1eY{9BTYVyV%c){MV=v=XtKN5XX`J1`jL8D@!dYbLAnVo#y>`s{Z$0 z^`;07#ge__LX&<$bANYLez#hi9g+i_&vHOg{KOvkHjWJ;X&#Ch2Mo&`T-^-U)&b7+&Pq^yE$m;AgNwZFV< z|5F~MtiW4@%Zn5osrZC0-;bu2eGa<%5JE9Cb=v%Ft3bGki|^~ukD8^*K$r*7kO=$Z zvv3THic}c&ZT5-1qLehsd6ll|OMlw4`HR--#}9_J{a?(ig_!(q%UV49jbBR3512p6 zctB_dzVl?{I|>Q01h;%I(O@{@Qv-b1KyaFS4Y_{T`A*)?A(7t#6&bO-oCfnQ$yDkO ze)}*yDyZd8o7w!mSvF!%#eHD$DWssq__fX%yoZ(=JO+zM z^U|Q;g!GN3dO<4>Wt4^ms3p-QcN9sM|C~-eDh82QPV zVTu{*9OPqiBs#)~)}VvsJA1){KG(#IAS5iD`!Sojg+HgW&n-SI40vtzV4o%oBTzDV zhJWPmIrA`i;Mb&=c#K0M>7?laIIM31B03HDF8uBU%FHpoC#$vRVvrO~weWW2keK}b z=l6X+aN`b-q1BqB0#%%a!QqFcyvhgkZpM!tbj$~@sAc_&mYFM0*vxEzXi5R>#Oogn ziG!9KvCePN(Iutdzg6c)n|60sQJK5M+O3UB(RDXaq2wSjMCJVG{RgQiJUT~GOIXpC zH~geNr0PlDWI{wZ10}r6hoxN+-A$2Viw&tlEy!L58T4<-^Ix3e8YdEgkZ>T|=v;X) zc>pR9M?8Mt1nhYY6Ujan@F%|~Qp|$XJ(cxhTE27oSt{ab8nIP)Yc~P<-FE!hQz?`; z*w>!1T!mE_iHKfRT|yb#t>&MrNrRSWafRwGrM|zTi!OT7a;chqhdg%7ZiDYj_T!@1 z138AIR1z$#EwaPbh+FRw zVH^}xz-XEQZxA*{XyUxhWkt;M5S4qS6r3I-%Y;KcHt4tx=^dvAzU0)~72HQv+DR=P zUH=jp^13TI`cw4}pYe3)Md@A`y~=Y>jI_s_^A6iW$8kqAd)a3hdBAxI;wHnU9+n<_ zj{!0^g6qfRB7a6?Rh5Ws917Q^i-y`(*vNc}w%SH@HJGtvG5Bgbk7Rz1uFuvghm|uV zUEG|7RO&V|hRp{X1&ApQ7_`Wsn02tzz3Jk{=#6z~ZsLHn@`cfhq@Lj{j$}L4&=GA# z;dUNjaUByMQXzjH;=T~t?V&&>T2J~r!yZ3wI1P(n2_%bC>&&cqq?zFd*P>7JTIG43 z8+DTI3M}yA^+p_0h`z?i-G}pmS6sm+Z=Kb3un|DL`>f8pcbAHU$RBUt1A%Cc3@h-5 zLDghYs=cfRS-?C_bbI!k3I~+iI&RaPDYIJm2x39LreaV-P%xmQs_~1dSU@I{X8$^z z+DhM=DB3SEAuwv31*+zqE${Pt2!Yb=wUp!;9ecrkjyuV;W()%Q1PQynOcgf>4HURe zIjB1lc2{0t01wLz_(S~j{b%3xho?Wkwg#HpmCoCx<-cB_Daux-2UK|)t8imY_=P!2 zY)U8fsrUAIl>S6iZ}=l)7Cq>viPBE0UcW&YH4%QmgXo4X@fqmJ0r!~JFjV);sEw=G z&I4Lf(qIz%)DZMGx3bFQ8U@|l`ZwMd-R2Pr4suX%_}at~(N(UlMflmXyQKmi%ExDn z57gviEvgJWUh32dDu@hVYa*q+$l_)G`u-vzrl{POtpNj{ixBxeeO&7d<)4)KgXmvV z-V0H9QAg-0pYFsKGapk)3Qyi;EpzY&xS>jC`k9dbPW;%< zZTN|CV!O)$4X-aoy2pZcA00Y|LMa}@FcN-k$avb&TeTHYqBYA#X4QpO+n?Y@v(OJ1 zgkzn!E50tEF6wN&#Iv~IH~9FG2~0kL8(jq0d`i6xDut@N?ot5lfHjk+4(Yp7tTW_-y@p?cE(chp!c7(GHd zM44f6XUksm^4ngV-CB7~hjyu>>u@ik@|ZvRJu%?MT`3+xu%#nNWfS?)mx?c2UTF=| zoVVNe)+Nqp$^)BBc4omWKe5Xjxcv;1-JKa4u#s%m26MU3u&{(wtc-_+3C04K485XW z+Rq0WgaV>^x(yn`AsADp#!0)I>gy;>Nhw-dTB{C zj~cPSnr_PUpop&YKzSPHM(tr^{|{X9D&w`Ag5E#z#({j9W!6q17jY^X1)8G7aziV; zra%?c5WKHVi3Ae6!&-u9oV2JQJKXLA7Th_E9{AFH<=e(Bf^Am_ag|+#?so&3MY8!M zdzoyv717;L4A&fC>H=<{AyH*Pk^Zr=hcT-BCQ5$J16D2}yUC)g-k3c2XvDN3J#oDg z6RI|nx>ws4@^n7#OVfZCG>DE#lzxbjmeqGMO(vhue9#CYGzU?K;DsMhmzTeOe9I(` z>ViUz$?afIYQg}qt$Ghy(fN><_}R(+mhAc3$MT9VaHOL*lJUuLTEQu-G|Fe5rln3* zx?juo`g92u-dz%kgNEXve~qb~^_&gsAgPd%9+e%WXXm6r3qCJY4iXCt^{TxeDq{h^ zNjXKI*=z?IhK$jmlVB3Loh&zwb1+EWI+pk?9WAYNbA!6(JhE^K{wvjUgyIvGkZh|D zuMqIarervp>y|{ZN-1MF3{cX&yN&HKf8`uKJo|e-;{Q2KeZE=4CdZLMAlU}$LaJAt2PI`Q^Ya$Rm!k|+Svo!|S?gBvsQ{p7E3UxJlnr?JH!u<(bZQ|f zNh+_D72*QlDo)T4><``g3dX(r;uiM~gnXhhsUfLBiaqOwF0&2txWsDIGwcbnzk4Ac zYrta1en@i2%{Ov_67u233b>BawEa-6MwvHbG9BEtsR9}dCbhpNs^sE6eg{IL@n-1q__yd}CRi z!+p&6{8$f<&(eX`M?n1++!O@c0siuAUH&Bc50LyeadQxdFee7AFI`28{LaQTZ4CSU zgg{%QFKalqG67R*HhO>O{Aab6lTBR##3mXti$8n%67bwO1Yaq}lF0E%H|B>ZqYyt- zDDrc&ehaGHnX@!HyIWB z?pX8D&L*I(cDGNI{T7^12f=^gdtGfwem@=So>#p5V|9eUP8k#B!~{L(9f4)9KDBhM zYD2`{TYs4b=cg7U{XEPx0{;;G@BE`*TfYP+?UKOg`KbP!>gH=m&b=`7$^ZFem@9l! z{~}aF*>V5H%TOiNJMD4GbH2;lPw0R|rEc~<#SsL$qAgSvBk_4bEjW?Nh)Loah|HoR z^VOz%0ps&Yc13!HK7Mj&elbNH8K@sWmOBycC2~9hvpV$#+cLQIk1 zQ8=JdkKbHt-^8oKtuS0Em=29^4%&OSAz;9uwh)YKs_m*RbOyPrPHee2o6~XUcdtP$ z$M}Rf$Uz&>Nyq*t6%pHs0JsT1CiLZ6)iqXv>S8izX}Or!`@VKbPf+)x1tIP16j|;u z#kc<^+EPqu$EBgRWGPclNePem*x!34=`pmajZx6{^4);yHs8=ep1~FkW+Z_mbzac}R@sKud-i>$>=SW6cI5i7#(<5p(S`wELex)Jc2q^-dbm>-8PB5PBR+?e@OK#aX zBU!2hWk!|svpPHj*tbq!^OQWO={0B2&zpecw$Pgd?b9v35ihB$ye!WSre#lLgR8aqmqoQ&?2~mEnElxeIFS?|14aKYMW=KsVhuqs2BDy5onS^hpxqW-C-q z!!tM#6gEh1RT){5)}cgEUpuQPYQ%_pj_E>Z1VecFk}-Q|sZx(oZ#P7tJx2p|+`5j{ zMB6F%x37emuJQmfeg7F9n}lCe(qYTCP$@>r3~nWZy4G<1@u@&p?jUTN#Q zF(M8o)#PL+myRu&_DRZ@P6n%2y+YqFdMO_Oa6~x#4;6BsEE>P@D?nh5Yn`dd%ua0| zoXHFNP0R(Nrm(i%>iVimBV`TRZRt2M8O?LeiDaGqLnof6(+O~A{8x>v=#KS#KFlzy z#(L`xHm|kVMk3BcW-5hF(gqLWUHN-82=hp9rZ04quL~%o%Za6PKkG&dN$O49K|@G* z8Kqd|@mQor9R+DN(sY4&5drj+iB6pC!)_6!tfIGHFThU|ZWAhhY|h zS)HIp9CeO-{^HS!$eyJ#au(9=dVCeL^2 z3OmVn?Kg^Ze{J+@?i2W=s~Vxr?-la^@`Yd&kSO6NWE|2;*{~~V*wB1QXsE>DMdan- z_K^lLm*R4#yx0`-*+nRzY-s#`Sik?TQpGwF8Z%pk0cz3s_aTkXpU*(fnW2r&Vq%`- zGDsM29Y_~+aHz*2Y*c`xqXlqD8rHN{rwxLOA=oVEB}_je!w7715e<0xShFwbo1|{;9r0%;A*tRB_^<_gPAYCtUsB zppjwEbiud`?|!~lJm0pM`e~wTr;nV4vH6IjGT9K%yFwZmv?SJWAG>iOyPaxWg!>pgj8!jT<0b5%a; z)*OtjWQ57s9{)LgE17yWH_F4XsnRIQSZDX9^6889S7MG^{o$;_#_}{>ncnMWSlN) zcEnjymXwR;PL|d@A8t);NOcxM`uwAUgpeH>DLNJb4OyMzMS{u6q;tBD^pDSte{cSH zgrP^ipilKK->+OvMzYP?d%p`6iNIv)XIszW=}?~qX&AGX38l3P?oU>7{pUe(?%$B9CgMJ%!Dwt=)Q)#pz46#MkIyU&)x`T;5{r78-sX z6n((E=qi38f8H)TS840Zi7G*~!ZL4NT~fA?D1LNle(V9A4E0ng_$DMKoDPd zh@LbbR#Ieh_j~Jc7<^Qf2HLDO4=HB#ZNFHH$jLv^d?V`fkg5{micP!HQ8K0_4At0DW~ z2iAk{DVu(`z#&nJn|FQyx_ei^4%~EJm!!&^YCo&m`q9*oxz6TEDa*2Req$e42?JJcouZmuaus z%4Fv0Tx;I#p}cNtOaENRRV2@H@%$j}Ra@ z6XxW6#C=-DCGPb1jv*6Hq~wv6_zfIZQ=VO{^{FKvo8J&;ZDGxPO8uW+HpY=&uE}Q^QU}^79@`k{ zQH43SSi&&2h+LXTn zhmWjPXkWFaHDHhH5^pHtW+CExQw^W^zvq?0YFy-%?%c?Q7Xz-TJE|ur?t!U*`oaW> znQQNVRzY09DJr{ppLD~J4ZqGbYm?aN= zu+`Zi8&@WGzxbPlq}O&@aBAD1z9))a47RFUw$6=cGN?gV3x{k$9zwwL19RRsc>$y8&;OoyQ+{h08}wC`IASla#kS$5cTH@KAGQIrEJn|LK`CJiter8s~`@~ z8b|v`J7(;-TrBUlFLiFOYR$-Cj%M`tdL9xPH+rnZSyV+opL{-?q^Q3ETe1=; z>0r-X=h`W{zNl5g#lwr?E^Tu%%}nn@WY0zFs(D(SbaRGd13+WwQx{r7qTOgKmQiaD|Xq6strPV)qx&a<2s(8 z?li;vMGi8Xv?jEopWADz$$M#(=y;t_~p%SG}ttwMA@(R(Sb&QhOU}01cgAP+q=deRD=TB z@(8kpdK5@*YgV-FwQ+r4(6aCx_Ej>QCu6I`!Wn>*u)i=5_~T=vMTNj}Xyziha8IHq z{gEV^kbG!8_w&RU97C8BE^(UD~<2CMZ*hx&d7K#BK2YP!v@?)5$}(?J^+a zrFxG$G&k#Gwqao5nduhfIADeI@3E&u$Xf3w@W1Qef18pR%$NQoH)6gch7bIm|5&pw z%uc;eFuUPhs}E?SuwL+aVJ^Jm=IJyZwrA9-?$8w0C5ve)3i}PYFzfJb=uUo-9AwaJ z{YM&mRG(CVB_mTsE#SLPK>k7=EWOjUQFjefu-tJ!W}V5&^y!yyX0!h5mxUca`X*F0 znn%=M6NS{tknIzv%LFW-WytrNN_ic*6ACgqIo<@fv3`e4Ah>frYMnnEv~Uz*+ze!@ ztl{|$#{QxCaZ@L~W%P5pEH)t_tmG!b3{9&>-KxaH9{|KrF%hrH&dr@^}@JK zUB4XHaOJ!CbwFy6y02EU9>#p*)wlz;iL}utGG%|9xen`YDq<929Y@hel|Bk4Ez7N& zpZ1RfXeZ*Ls{l_3s5OxeJaRqTQ#LnC)Gic@RdFSm&ogpJKo@{#!0hwW1Zz!)0M_RI z-JRzz_+K^uCE=TM)|xOpY5vl6f#=F$WXbNL!bgY&6JK*pL_%M-+$P3>U1Z=ze`A~B zju}hK6IVR12P#+n+6*@iWS`KW9P692;{LllPUPXC!&rAfB$zs8Vrh&Ize{i{?a|VG z&5(`5)wL(K(eTp#%OjVVxcKt1C*X`{R8tFcSvm;^R@-}f;7~In{n_?ff^y+^n+_?c zV+Xj`z;|bAZ70i|?bo+wPa`UoCjt&;HW(oyw_6%(5Gb8X7HtK0qj+(Ab*|D5PIzcSRjMR~&&>9TP^ zDa`K;Lo~9wDW@e*W|cLw>r$?ttA?qFc^iCJEju1?ynA*mAC3Q71lTRV7Eb`abpYUg zB3-sK^JnK#qltk=Y9>&={T4zp%QOo5ao|gLKRQwa>T-yu>IA4hfF}C{VGVLOLFF3^ z`GPHo2t3pE#DvA;mu0Z1M=PH)ty) zFfJWzzgr#8Sd-*rAT-NyJKoV8*~j5*6n z8*f2YuAn4YH`$F3fKMw7z|H`hArzB;I=st0`|T|LTSY4=xMs!S&= z^G5bbJ>g6MREZq5lnrba=BKDvw@I?m(>sl($wuVxruwKB2j;+`Y{tChTevK<5q`k0 zyp1`W@|`%aEoZw>&8zi}{~j;Qq79Hze*#BTQDrY40S>|#sP{>>pD!1;-0?nsQE*CR&4mr1_=CBmAY%7iSv5#p@tm{F}+&_lEw&7~SF*RV6ppT07N%iy+qGvT0UO8h$!W z8;}s?m1}!Fh@bk&U+@H6paqO&Z@w84h~^X)7Qc{)fu2=10jD2moF2xmUVjBg2ky4| z|McnIvl|uWV$*9<`#E6^Lv+)U`m`+>)z@Wr zxhPX+rXdHk!^E;D)Tw&&W`E*}vreroQhwMME%)A_Qu(1nMWvUpdjZ=Zz@@nQK-s?> z&4_lK0Bkj&e&{max* zwJ{h3657qap+P=r=X!!RR|5qs%Iw9Ha3Ffo(Q}D(uHv>XuZU6!HLYdeNC5obH}fgC zfzfa7YMyM&z6MaSg-T9}?zE-h3OsX}fan^J8*n==-+ICnM#lyE7p}2bM(|X^eLxtk z3(l@ak^%~2#FsOvgj;`FO<9?&Bx-l)KS^|*v8AEys~u{wgq!}D%g9ahntTkpM`Rcw z_wT)&q%nlEbm^_bd91=fpv~YLc^@x#V0UIw^t+G_05pu6w(FMvnoNN%?!8=tt&Dbd zz?D%BmQ0$^KVL4rdrwK&>J&qkjTETX2{SN%x@81k%|5vZw2c8J z8qseYo@jRz=T5RxR8i9 z8{Ie;7Nryi2;gK<$qHgUTQDCU0g5DStE?$@LijypQvKIksg!ELtZVNozDNvA4JEqf5JZ`6ew%?JDlW zGQOa1nea{yIvs<|AKW} z%SaB#rW`z~(_FvX8^sT`^$d}IU-W6$=@SYA#gS;a5Y?;c?Iw70hGv{sFhC=_J5_Za zE@~@cAH^5zgvu1EuulC3Kc;B5Gfi(FjGosg)0Rq;Yuh$m`lL*bYYHa6uON7+&~^#! z7WNk!8h0qR2j!7;nEyTTXW8}Iz_2Xtbc51taK*GdK0|rjRd4+Zge=bsfh5fI^OFX8 zPK4)k9u#oB+nyT~$uugnNzmWR0DQG^n`S^b2Jr7o>tPfjKM|YW4=M#sc-z`2@aLaw@($vKJx2g+63ORH z5{8ao;r?2&jonZ-dev)r4t{Tj#=A9v$d2u- z1QdbB`sM&@D~%-bU+hheN3(2N*g~9CToUG)MijSF7S))We7*yvZb6<(gP*yfiQ$bK zyGTe&oE=TW+??wfPhjJfosAExKqnM$IxeqB#9gi%(FQruA)k*jPalHNvnLJ`gUxrM zD_y=iDUp9m7nG}CMlrcW>aWu)B(f_*VXu5us7XSHm3Z^y^tiTav@>{)hzr7(6*Lve zs$1^|@>+`M8o1SJ(oxqZ((}W4-Gc&Z;uPC!+Ow4?l?mq6`IeykCrJ))tVu;Bonw@D z?q~FEG~4{gzHZ9fFP8VbYQ;-kOPk7xGsMJvM}E+17sK7nw!ufnTD2YjgJxn89MH-Y zW``J#r4|NpoP`7{RTrgx#FGLuGwmRVXEzH^j^=h5cuZe>SYOk2+2z3xX*zg#2N4Dv z;Kn&MbpvlwhE0m6zc+(OwttVpm${TbJG^IN-_BkbB|f&%-O~DKAahVYYdh4z$*0H8 zfo|ShO|;XR{ZG2YA?cMKs)Q)X?3Jp>C?`ptE`A2Orn}*%3q({ZL|V;neHfrt`$wgi z$0yGF;?Y2FzaQ3eh`W)AP+=;nu$EbLYX&Y}(cNwG4gJ6^E~&N00IEX>9{d;%!)3@G zLK&srl){G`q6Hk~@AJVQS7D1z_cb-^d$`73e@9T}kJY*I|PP@bCec+UCqx-W} zoPdn(9mB6TY;1^@tLNuzLk&Zjjx6>~ukCH}1gSQVo~9s&aOARYZr~QNzLyA*cnH59 z*6n9_sSC7m%gR-AkUJx~0Gf^=h4gH0x|PMX@4BFVhfBGM4Zp3Pe}#;v^S%0RfYPLq zoqv6J=3H8;wv!^oo_b3)M=TBl@k|s0W7dh{L`f+djMwa}v9?L#75}6d4auFnl4Oyy zB@9e+p0Q1$^^zGh8e@Zlg?I|gHP~5y=P>4)62D^BV&?se90&xQF!x?;KvnV!#VmIT z6r36(m%$1e>~d~d9$7mtyi&~6(%~GZ^hKY1n7~F)-v2hA+>4DcLCI~+TyOozx|k(c z_0dx}NJOC$TI30`!Ov3-&R>x*I6wh^KafN~AR`l={o8iEMd!pY3nzK1Guv)Q^AKX= zPbbydX5JI*7ux0iFg_Mn7Sre{Q|hTC@4tFwLJ~es9ag+=m7$e*$9v($=|+VuhztXbO*W#w6#aGe2d@P+teKQO3^ z1diD$d*H8mYIRxt16Ot?elKxjo{O%K$O4GwXKFuIJItID&P<_=dyjN7Oey`fA z3UiXXrx>PpOo??JNdkse=78Pjwhb&44D&sHU3C#{i6e#ocH10b8;x{}%Z~V1QWSDnj9KPl?`;KX`Zy}~~gXc@LU0gfp zr3CMf0yv%g_vc_OU0C`h+@u70NM+W^TrENRveHE8kzX^I+3~x4%_xEEcub*!%bN@z zJu)JOyFzCiD>*BK-?wNuD8Nt)Mx?lFG92NFJ0}HXvls zr&kKZWbR*9wQU8DQ)Pb{p$PJ~FqnRcDLhZVXp-%82G+rb7wmJIfM%k#G)5xRwQ#~DKK7Lln( z)PpZRwI9cFQ$7AwYd_ta=sC)Ms|qEZq^$!14Y2GAzqa)jAgqUNpt{E4Bg6?hfc9E% zNq7WnHz+i(1?~b3ua!}jW*FaUZHmZp{3OuP{G7VgF4ZK#{g_fVJ~}R@)t+3;(ePE| z-@a$A8bq+=V);}lU>j}g13veT?Hl~~`G33roKz>tX{GAI+8SxN(zia*GQED)oj?{w zgOmHVVYWYv6AHsSzS%t6DLr&lRbiX9q8DI+7cea1ZKwsbU42=8!KNL^zhn!Og9B-= zx6HrkM^U*0ejT8aomkiBXec=V6J9oq*(H`uFI-s!R1TLzn*RHmI_fjTm6XF9bernR=k$RG87v717}Ft1(3 zGEE|cL^^cpjs5uy%3t*_@q~Uf_CJ_cHO}f=0rj6|omF^2F8pHa=U~Ilgs}WPnt%8C zCMTQ(Ud-`QL_!2GeuD2O;!R0(71$yXqN!*T9 zBH1!7x<5#o;)LsT(oiw?KTQ$`g1XYl0@fcQnXa%#a4n9iY9|XEV$WvpwF!wuk)-yq zeUG~zr|RAC-Q89Yf*JwH8;!i@`p{I;P1_;8GDkATJM`?vLW*6H#V-nDCHy;omK~cc zfPsG&2pg!NHSq1vRr;MSw4V%V;<|s2&`MvtvCx!XnYULL;`hO~G=bv$mgM~$Gmce^ zU2j%4535_8@3E|V>#0B!;@zVU|9c!>a1RXd^Nzw~0iLH@K$-s<37#xM`rY%#KcvYj z5G{}O7YqhKjCi$T=#*XGuWwb!U(q4ggQu)}c_zUh8cO{nImkq7SfvZF1>KLE(Qtvv zMP=hGDc!;+uMNBV4pt1r%45lvIE)Qeuef-%acE9fR)a*K(U<3uRy`+AQ*es_m_KCC zNW`I@KbIDz{Vaor*J{{Qwcz-LP^`4510C;BS2p(rabVp@Exos)cum}%cA}fJRqcto zC@s|Pjb7IZ+Va;hP|4yliT_p2lq&swP7^Z7JG~+Vk8ZT2s~;6dK;I5n(gM_Q^{y*M z-kRz60~YOGf8w#t1iIo4K$-V*b)_9z10J_h4cYWXfScJfKvlk?T2BvwmM8yY@N(r1 zsLQrpfCP%~(B%C*Yq9)ReD!X3yF2$`t`h|Wx`o=Be8hhUajE_4K5;*XQXEN#_~+9* z)GU?WSm(|>oiQJF%B0$`grAG- z_yI7jT%T3n6%PbXfljmR#(q2olKT~}cK4Za0b@I0oY(2-Z4UhY#M<`J>=6uT7Cn#y z*P#>ScI4JJ8`j5yO8xbQLXzG}onp*aVC%(H{m0+GILpGQvcap^1Ysixom+Z-dCp7( z|8_%MMs=k9>0~1`6%=xUin@|Aunhzx{z|Y2BHS!5cXMXGm+5-jQ zRqSdjeFQM)s&pTH z8W$?(PJUrdCfc<=t_N%1Ss&5G>9+5bX`L0YD2tfi`=jr+=126=w}%VUA7{ZQ&F&~% z1E@WP=S`-y+9|5c*}X$+0DAKZ8K@qUw$q{>M<%upAl^J)&h0N<*C@!JMmT@3(+ z_$Ahw_Ummk0OkzoSl-Wus^NL1gJkiw;1q zx;Z%bt+QHl-*M6cKpd-|iK32nN8o!3`09g~Kz*?MyuEko?+Iz!U26Bu+7&(Tw~8K` zzg20!ciigvNjkI_e^KqXUlL|LCyA~R&2y8WW`HC~+wb4f%Rib)Tek^H_kTW&t+#}2 zLM(5mmd5a;e)NZT(X@Yv1zi5Ej}jrLNVxY>1d$0@)_Z7{&yi~#h-N1<@D z@vL{I+-(5bV^FF)Jr%oy`1#n*Q;$`7olyD$xGN^bYw4|bdo3%Z0Kmj@yQ2Mw0c|S*=Bs3hwPc;_;xxVP@0A?I`inbY61zNNl9)*k3F1$-I)O#Yn2}hF^SfYDA z#g4bl;0HVZ^>Mg2RQRNqE3Wl%MOaOmCZyvqCcop5Zu0r3=tFZiVazU8CGZPs95WlT_CBt)D%5tNDaz-5id2K_#KLU&jNkol$w8erxB*vleksmVc8l!PhE{Di&D&6+RpY
  • 5=xSXy3+g}fHT}fA>;$$O-YX>Hecn9i4Le@Wn2r-G! z>iUK!*kTDf!ggDb}rCz^IpxAcbIDGh zZrTukSo7o<5IFT#R6 zT!q5m^nzJFF;h|bA@JJCPRm5a_pv_Mb(SB#CngfoNK4|w8T-13Jisz2BD@lVSa^D< zqak#MU#jat04aaj2^MR1vkdPxgXwkb2XpMtE{jXN9ysfwr=mm!{%?{Ya=NzCE`YUm z{a+YCBJdR0p`jA@+Y3(f3%m(vHw7%^iI(bRu>JdLwOTM)EsR07)0=ue`H@pDxZ>~A z$X;&noTh%B0^(=OW8Ulnhww!1?4mDLs6yL8DGVW|7I@U=0{>#VD7P*sr6hL}bn-8V zKHHKdCbj+n$mt-zN3;NJL~>eH3eRkh@_57L(qNN6v!z#mDVo!OT2n0IhtC2LRrr>6 z?WPzc?(Pi$az$jpF;3uc%(mU2W>yOIikrrGkZPh+Zup}5R5GD+5N_<4S>5lMjN0f^ zqILDYuNZffublOTed>UEOoETbEp?Qc0&((WON1*}&o3rDM0x^GVR;1z1i7n$ zH4d`U;1|$*`+KKU!blz1=$RBMZYcV<>a7t1pBR$FiT{cB?)zEvEu<9u|tI1?Y}m?$755UO&EkUGH2fG{4^3%NIY(px?6UhOHvHZW_^{f&;d z#)Z-cdgjiEg3*9^HQ%e|umwqgN$+Avuff7Y{5x`)Dr7&H#1*q;n0oAhtRh_sR&pwU zBQ=Tw3_601^KMX6xNG2vku_BWSmMVvr4zA;MlF_!OtmnTb|_c6#|hcSV&uCblu0uR z%1TqbqJE-1zk9>ww(7OJ+j#oU(L=ujlayjBF9ozL0E7x>9lskdzpXv{H*Xw5qxT*- zq9yD1zRGYQD*0o6{rSXvEBCuC`JBvqJp&PLFFF&S0SNPtIA4MjVf4jD%nz(o>i(Y+ z(NvA5SFIFEU#3B$q!bTT6GsORQi}vH+AVR;k2_(Dok15_Y3>mp{u<;Q_*sZEeSfem zRS9mEv3F<>tN0Rhvf9dsx@@Qfpg}TqI45xcxGexaG~4F*p?O?`~2nH{yp~sE{=(?Y$dh zAD(|=@4sCzQF1JO*^2p#x90cpl5oeyD~JC*L@K1Pk}5&ci(PoN?8T?jsI^3Gay@8P zvw|$*Y(!6;WK6>)IqY1^2VF_7F1)I84&TdOa%MFUT2*1>OVepbo|L6*MGZlKelY=qbzf@{moYrUIYLSc1v$|i)r{c{_%P8 zx!&(qGatt2CT=99ZZI&Jo`%gwAT7Cx_?5qgxC%hnXQZJ5!kqD@kN zm?A|fPiUaYFgy}$+*3$QcM#&jrtmyUdiDX00@&V#4sP^irOAd@)i;I=xA%7yvK8^yO$On&}& zS&n?mmg{s%W#x{KmzNEyFV*nFsV-dCa&PXcNzBsQZvS)qnRxf|1KcZYy&f;ffE=Aa zQ|c;dusT6Y!J%&CGTW&{)J5}X_4(449&VGsQ(i21yn}%Pc4_GoRQgI;vBz=(U6Z3= z%ifONPM96RdxNJS=B&J$k7Q@3tCL03q~jt!HwWgv92@pa5SS7AS(chQkDGy{xGx2y z-5oH5Vz<6&@KzySuwXkZ$Sjltw^GkPuM1yUUeB{rz}>vPzpsO~$^VyUECZ^>n7c6z3yn^5hV})yaTLorYXd{$8(%M^; zJ2djgUM+Ji2D3@zusylKhkCS=XDkQ!99oxd$D%WeD|A_G{1b@VfM(at2?sL=m_AI< z?je;_YN7C3a@BGs?gU3$o9AXNLoX7my5!JZp&$NVCQ7=`IysVO$DIdl%g0F6Ev|vz zsXxSVz{q(2aX*JJDN?;41?UIm*XA&%@MNBWttZX?Ye3Oax*-Ws) z!@y4u)dq49P8no0U`)uj-$2B8{O2xzbol?K)$-ioTKvWu!<>BY_Si3A(x5R6aa*C0u1=a|Kk#R+CSmoT(udypl`UL*czxWm?@9kG3I!%^SCE z5|p+xuKp9Gl07+5!nU7VK16A%B5&f}htyn#rSAdcd(~mH&A+@e2JY&G)E(N33 z>$PVRAg3JI_8(imIBIS8#di4Orda$`pYi5aInePw|Fq1lAdn@obn}kWF`e@V> zEBAcvzJ$sK$wid@3B+ypC*DhO{*m;#gy8(2p;y@-PPKx-@Lg(3^oh1L;k@_s6o`NZy&? z*b#c_`r~nck8eVy-r4zpfi`LXKNl6-c0f{$mSx-bJe`wcwkCQf3AWR#cB zUy0r(?OA1P{PfzOo|6?>sWdq_gFL)*P{EzlFt_B>=Eq=f?zs%!7wyXA{+6uccvHOS z0$`VwSCq$f=wM?nB?A`nKs`o1UQ_8SNQQ3xp88K=*%F%zrE|e_XX2yn2RgCs;zJu; z81M3ZcqCTUVlZ|>pV8}H>%SI%3cc-Ge0Dz-!ZVeIMqavCw#>NuU$hb|_-NdeAvcAS zfHSnj?DELU)Am$!I7Tf-crYGFMeIhhg*ky&if;pMS?ILlX*;R)o?VF{Tl%Sh*`3hh zhf2x_nBE-6f~i6$Ic$(T9TFdJDP$-7CWkFFes#Y5+Bal;t>>j$I@Itc=R2R%1c$coZ1{XiU7%)dwYsK zQ{_9kJ_fa$Z9y#!0#vHiwjuNS%e(Y#*6@L+A%R`3NEi81j z=iX$DjX1n2amH}|1ozYKprh2(pNLkhO*+~JRV zKZokp;Q#x`ET?{{vzYiB(g0`IM2zdm#^ujC(JyG*!2e#eUo(F% z@>3mJi4&t*8pv8{8`wl?^vjdaKCf> z&UL^Fjw0#WfC`>2Yq6j))y=(ejt|OB6YCAJJXEZ}oDC1$-PwcFEO#&&EaNfw@#L!$ zS4$SDry)H_l(eRTAUnf4QKimo8(RJ2s*iAcF`u|psd~YwrpB}95NsErr{<9hMW=&| zKLf=#6$TGk?J=?uOS8Z3`qc%G&0X?6a832QjEI6x6-w(qm|)Vr6f$&Gj(Mu79Kan< zlFc=Qy^p=UOgJ16r_#BLP9GNv{^uND+Be=@lAtIj`>Xx@ag2%!F&H5dg}}0 z`6}4qf0w!vaI9f$*2&1nAoNxU-J55@nATa_dGI|}XF0RG0bT_xll28VXm2rgw#F!k zL${Y)+-YjOQnUCykkk{txsmEIl7(5>!iA~87LfeR0tDUNijG` zak_u>_vxRkU(5X)^F`tS7E{?KD(ff2ey zBND?lpSdH-ZN|Jh2QtBTqTL(9=UvR}-?{%PrQEGWN2X?kzpBuEd0XO#%)V@XDyHj_ zVb_H9d?}hM7wP0$eD#rh=f)kyd*V74R65;d#*Dt}EY^I!r}@ISxI8Rvje}c>Ji)BP zL_BqzLJExh?P0K5hk5t+DrE&P{I5SGS-(8px_(y0}i~#kNO-Cz8}*TOX#k2I6#QGK8jIb)kKl zQYZ;(SbfpJ=6;`5XroT|@T$}{=Tem+tHaXyiHhJyiFBU5iWVC-k^+C(1`ubt)=G?} zzBt(|V+ovy8-WrjS!-o5 zQIyoERG?2(uwjM~W|s0GiMEZqCa1@QSJ_$yAzVb>$bAyU6n%T0Q04LT&tJ4pYDwCE z0(L`Q%bNvM?bW5u>E5293V6Q2%JUKR8#53Q?jIR@5f67cmg^r>C#saUS^}!L!j)FC zQsdO)e%U+hb2~>Gm&Ifiu?%~sTS8Q70BUT>*9C6%>8*@O<6_97#ldS|K4S6!GAH>) z2Evt7MVzg2Y_?qWPoI+2{KQ}V4N5^`pGU43DKi{JE&^ZNqI*i>y39z#f*vg(UJx+GWA!)`ya9`upJo< z5Z%r!T%1(8_^{0VwCS+xE^|?S4z?me>@#ooH?gTiQ0R8DSZv&nL^vS$RY77@%|?QA z->AfuWyq&l-)Y-ld1`s}PK`9kiwTxREK;1rKO+NJzX^YK@mv>dDP^#wI^q1Rrl!Vm z=V8|*_1uh7Qg;#o;1DPVpfW}fTs&LEVkdarGRg8_KQibGaEdDm8cc~GIpV1ABa8BS zz4roRBoa5R@t%y?l>hTs#7QCuXiWaH=o&i-9=~*ETTtl?vCmwdeO1XL)GmqodA@}Q za48LGHvkqn-nn8qclO7xWksagV~IHOy_Y7_#bwbM?c8ESSyzGP$%E{fD#Jk~xKLbg zmNx{|PNY$}3AOBtNYR_wcLogG%#PvD1F}XI{mh!%x7#@$;vk-3jF<6Mv2%L5HJ~#3ftr7lsbHokP3KS^x zh`Y-MHSE;8UqBuCGHjK?o~FwMG%K-d!v5sqF~!kpras5>E?C)fKh1b{^p_40T=cOJ z)8cBjDDC|}rSXk!xODBeU(QF~7Cb;a;hIW<=(~bgUwz2Z7}J6zTx~O1Zfg(H{xSYc*XW7B1$2V^TgeKm}c$^iAzY?3uFE1XXm=EFV?j6zS4d= zcu`S|nk3g${PG2%98%}sMe*crI7%orAb5?q((O&x=)JWuW4YkC&z`*nqF`ikbEj@?uN%&|u(Kd2+|49cp+(62zQw&wMM= z%x=^^kfCpI2BfNT>Ct-=7AjAKz*7S~?`Kl4#RVj8plkIc%Ers9-CFid8qY_Fg?3*Q za ziM|p@rf6D$bL@Cy?J`VXZs*lNtM^V%CkvDt$ zU%vk=4R-`Bafx#veS`yWXU#x0#wq6f2NUb8p_v-C8SfHwuB|XEI0Q+667Cc-5{!FE zVBvV+NtPE8XzyY~_)7?VLH64xU&q)XleQ{Lkgz$J4mWc`cNG;Te+VPj@_ZKJ-X=G*ESFvsT=xpYddS9jL_IqGbQV98PR4H2Y}x|EuqgLTJz zm zQk7IW@P4+OUFlo#RooHTUi*ti+D*t(elndK@KMGR_LeMz(5UCtP5AO)6XkP@pKklnqeQL7oiQcZvNQx_v2m_tfAJ3iQLpdFr%fV|6 zV&h0i>wMztq@rw`F_0mFBh6$9&jrXWZ;)T?fv&6Ou7kXyyN|hlD*(HiJ zUV|tH;USv$Mkf#~`EBn0S#03HH^~S+{nQ)-5|;*#WcI6QoM;{qh?|3GQN$Hd{>m-0 zDOT82N3M~%am7p2g#E!OynxxD@pV*ZecXCjk~=QaJg08$!}(S? z{PksxO4ARR5F}5$s!H5(KWJs*-kejm)OH(yrpXirxIEX|nh;c9MW#DojC97ZYKltH zr+YLUD6DdGkk!!#?jFw;B7r7@z13RN<2;OtWhQh*AxGSXOZT0faTVkcAii%>-w`4h z7WXb+Po4KA%*~{t!|IeF4ga?6a-_Ss@=dkh+ZIGdI+ITr4lP0oy!QaC2QAXYQ{;Ga zKBJK&1L49M9Pt1XE)pF3CcZwaDHdghz7DzB$=pYR(RVJ(UpiZ#Wh2k5c1W~Cr zG3r6UT9=D4ro zr@9EUMyE-?P|;Q95#F7k9C@eYQQ~t%M-1}-De#g#c}x z@zIc#dTl;%VsUdj93j@cn!+JZ4-uNP{j^J0wQhrmh%aK&jOx6YHl0}`fdP{CHnmzG zSK6+r2rVN)RCV=Sk;(U|?MjXS{1r8ovtBqKU9H9@!V1SVRB+BMkW~Jg7R&i~>=Xv# z#7*w=j3P0+)K&)97=^;!aBZKi@@6vOJ{Fw}RAvg~Lr=m(35>I|e#Ls-h0uwat=*)$ z$3Z`w=Ll^$P=Ssz`c7@^kqY-{;s_vhE(_C6hLXGbx2b)~P-m!TzAoca(A6GP7PpkQRoP90) zF|U*tqen32bMjfp$rA`H{dpcphhY7<>5%(;7_!4E%=x*Jo|>oHRbf)3nT}CGEA~ZzU`{;z zLApDv4lDn4!>g9(fi%)+Rk2_c<)jK(EKvQ<`B|t%z$#|7YL2JHt#R3(LJ>L+1=pS~ z;39#x6tNf>Ih$XW4qX(B%gSnMzGy%*mzwDl)4x#|wL3&w>Wm$W1zQlhfAJLK)oki2 z#Avu{m2g(Qi0pd4{W9c#Q_u;k4m=rx^KmZQmWX!p-_4QNILTmMbJg*B#uFMt@6UoeSDXLwY%^&EY zg2O^U?hJ-Z{T^K7`B4MwMQ2oGGCY6}00{CLu0<;!HiXoY0yWa7Cbw6!&Z%mg*D$cS z3}GPY3N#g)a78IvDiOs%!zIF}G8ct|=j4LNawk{bf#->s%;)R#Jr3v{zk#Av^yk6Qy^^rxB{Y8afakc$-k#S3z=Z+ZyitF4+<+ni$Jrfq(lL%~j z8AcGeD-qqD3+mh+v`$8_fj$p)tWc=O8$FE3NKYnCZ zP*Cz&_h0xQxy}g%+dP66q~IAo<9qSjP@5mhuph-ENS*|efww9?=8x~D7d-#0zz;XE*$j@Bp zoF~LdjNL9$W3}O)ux&{OHkDONf z-#xJkByEtFZrx($f+S*E%?G-r)@F;8fHNV;fE8emlgMIe0sNTbq#+7z>181a`YdL# z@FMBGR5){n1=->}V|t3!<22H2U+q%MKO(F{5dc6q>2~ow9?zB#|ASC`v)l~}2luUS z2wYtWMP751lp39q_C|^gPv!b4`pn|RcjBCuMBiSJ(s)w-s&K{mgv0q7_(2h&Ef~iA z?PDY>(V+R~b-k&f;%q_UU(ko0pCsU$EDeX=!+4KrZxf8U`o6Vx&gKw_WFrshVoVyxsT@@APS^t09zk;5p_hPoZXf6C6y-x042B2KQ8JYDB0-!B27_x{Xn0b#O+3>Vz29zTf6Ld9YF#}A!DBVw$V;rhvrbdPp+E3jF zKZgx4t0^Tg9BF03Y5mtzpQ%2uPsqxg(T*;khiDUQHeAHTuOv2^R$=6H1wsLM- zv8*+`;vbpN?*YdA*&rZmMCQM}XB`w@SlcSEPBekJtD&?R19;s%NOx<$+IGW3VEm}z!NU+^l30tLc;dAWD7mQ z7`;y2_=Rz(H7U8BH2c!9`#(g2GM!`LUhAkwH6~=dTpDsstATaW>GPJ~^0SP^xWV4+ z3Y^+)N#kt9eu#kg7$yzr*ZBlAq+p7NE{gXrVa&uXz<;+i=s=|XTm<%2qJ{#ItF=`u zdF||HyfQQS*%YOWi{Yr zNoYIUM?B2J8U?QgxyqipEOQha22+H&-k#Gm{l8R7$J8`my*SD^00 z+z{X62sh=><-(W9BVn)m2&wta63dHc>SDk%KWYN%=Bzo*m%OPgcFXOIINL4HfpEd! z3(r~{vEPKAp{i4ElW|S9a$){MxZrvopcm>dL4gMPvODu7T}r!;w}} z*FBL@9n#X$b#uhjIMST-9><6Uf=oO(MB@-|f6I}I<9eh>TXhd}hr2Y3i`vWiTy4W- zE!OmqJ4CgFeq`hy`rmUz#nKXNDW`BB#N>#*| zmj&7kujWPul^7z&%Q(U>lVCh+cUV?GZc_y8ZH{wG63Sy45Xz^j1g%6_M^|BBQRQ>q z>$%r4S>Z7fBpe!Q`b&~Mrt(R;#AjTNlt@o#Z&6@-eg`C1CcLsezRlIi*??c_23kdJ z+ev5!CyjOh<$4Dmn&CU3)}v-t+9-ZNFZGoqU1BN$tCk{5N473XbL~q-bI|EDO#qAW zm5Xs8v0xZo?&QROZOGI3)SdlccX&83R;ea02%YyVjZe0hpq@~+ z*HBaQbQw@(99+RaDxk8+*6CZtF-4X-)?nY1{T$8hMHWtT&DL<%e!0A}#}(- z@HZ-M|9ShH{ZZiGME;198@D-2OzB3yoA@Vk&%5@|L3k(Vrt&2Kge)h%#zRDO48RSQ z`^6>9!6d|2{zf}@2aGNsCvBW#PM2`vuS;wZcdF`zUXeO#CekunnO3CbIbDSc``(77 zw3cy`H^bppv=$ukzfSFeop;<+zP(CacEx$lD|jH$kAqSURee?6vray4)QM=b8@=%A zs{wb))+_XLrrSY|Na|r#wjzHq??GX+BA%~q&+fbJow6z$D<+)Z{xTVVIj%?=?jLWx zTlf4y=BTjat2@s&3rA;i5h5KfO~?T)L5?4dG%qX!9w|K^|GmPx!>W*K`sa6r>6?t{ zXxQZ~&?Y{BmsA>jV7k4zu#U^o%Ir@7HLv%I2tq>N^;BgS9oY`-5m1E*(rDeHi8Z#D#0<4``EyjVN+bOG5_phH zXxW#a9Kz4~oNC+3HHb?EZ2w#8t!jR7a@AL|N@0AbsLC<9{t|PKKFrrauq|-)Bij1S zv&M3=CxXrc;6$T|7YLVC`2%__Qpto3Ldudt_1rCE>SdNeUc6%rS6%)6p^MM^I{nKw zyJA0%KV?!Iswx={M8>)kqz130=c>b!)4AO3Kgtg4yW~1#Kb`hJDJ0ZMS|UTTFe8f? zO!;~3o=YrfAV}quBVbqn=8bG#zc=2{TJnZ#{OP1xj)A`%a$ZS|+DS87oyD*!7!MD6sLp+OxNa9YpQ$?U$cPg%3&n$PH%j9cxK8ZX{qdBVpcJRM zz^dsZpBej2hAfJan?m z4HfMmfUuv71Md9^sU9|(uZ)|e{t+A?2f}Rwh;4!^b6~4q9^QgDDXMPRhQvL`H+Pw5 zA2YC9#s6^G@iZUQ|Mp(aruY{qpu%T8vA7AUPH*qLpeP9a*3|eDSx`hmLRdIKx1kk` zkd+Cu-!=v;%Ny@t-^g(6**#q1J*S=sdZ^($UqKFxBJE7g*RutCSsH9y-#0OL1M|?hEWpAn@_+ zVa&oU6)OU)cg?0XC=xDXW$FYrfx7bHRi_RVV2v_BwL%~P!E@9gBA=B7d|E}XKnFmN|=(< zfBU+p>F@FSASK+%=f9y5jJUQoB&EbWG$~n#$l`5^j4O z^p4IH#=%`>X=HtqhgL_$9#rpuNLei-CE2%KA-P#u&`Hji8aW{$p;&D^lVqsl|8AVI zEB76+`=DD6*rw0t&)u}4jH!MbrU$2IuMtCl=IaZmYC zI4z>1-|c!@2W;0^>2m?$MVtoqXSC)~15-3x=U;U4HKjA=A>3E)5@>6XqA`Xyq}C}Y zSio3gE3~^*{JbVhgVC5^B<9$6_an=r9El6v^Zf`l+nPT2kj1s%skAl4Jaj%i-AHC3 z4tVqobctRKx`NEIPytw_Q}_QS@k9OW?JfK2I#&xkxQbgmmjR7ge5AqsxO{f{#MnBF z7`MD?4-;(`^&&wFQ@B+@;m>Sl`Z;nB~p8;Lx4^uu8)qL!J7ylsWh#--3`jA!Kzn2w+NAh_sY)Y6N**NCNWsrp zhVT{bjsT$d-QZleXY*aHhG7p|ziz;+E|KM#R#wUD;m|=t0s#agIxbhPzlph9d3xkA z>uElhG7=XurnkwaQIY#z_Q}|4;ar{;&W)1E~(KdhE?IDB3{bR+Kk_UxwRC;39v(cbt zJ*-@KxiNvvFhz6l_MW<4H&s{nah*zVSoV~XLj3&z1%7gie*C_--`6Dt56sc=A)hHD zG8^Vj6Ce{XV|@*HK} z!39V7T^|`g{b1}G^#%p%K3jp;M+msp8oaRS)|%vZ^q7S~VWa;X4$hX?15G89+h{9~ zJ31!a3By-yZ)bsZbC^lU`rTdb-EVP+Cmwmg6}TPQoB5cgqjFGRweM%(G+*DY`mtQC zeDUAE8TqM>e`SJOeysuC*WT?d^Qzy+>>lWYFX-Sl|MgyQoPV#?A_G#O0BPsuruCnr z)B(Ejc@CC_%=-40mcWXN2R-!F$~yTu(8nA zOpNf^{>l#9plV7$Fvfz?uFmQ-H=W+Y=d-}~uU^GD3A(Vg+PnJ$fh*f{rz>Od%%n&~{2&aka(N{^%EL@yvwpL#F#C4gmgC=xzHbNs($&-B zz06?+SZB;FS$(p3q3oab^=Y(xW#MiqxAiFLDNv&1mZ{ezy7~HJDM}O-CHZh~@2i&Y z%nM#+Mbq#D(NEgzK7`?JM?MPr%7v1*ZYAr-|JLUu>K6vTJT<+^*kU`gJjmNPpQV9x zy2Ay_d%C&Bk0Y>_KZ(S`oErZI`aXi6fA*ot1RS5z;li0nAJG#>nvfV+02rs(z5Mpb zgZIW&5r8j_2S_o?S%UX#Tgjg0_d<#ET9X zsCXTa*<523Ag7xU{T{ zzXYM&<++D#N(4Y{`{Ci?_I!JkwrqIW+wj|44-XHXS+sFj%r*PtpkJG~kPs7wJ3nR# zFRb^DhjbKW^`}7I?YI;&kn`vE*0&;)!?c}LU@GqM#+gm2&v0)Kb~2UAIv#9X8k9+v zO_vC{Z5x#ceBeAC?m+>b(saL@s65(vPx1M5w&ebG;XfSDW*e2L$G{|&^bJpH@?-I$-h z`PxY2m?o6d(#V~`l+og9Yw2Tac7RRRfzG2ZtslVowss+g#Mah!*OcCvX2z(i4mOey z*m-#9?Dn;EYTO(J&~u0hF(;ggoH#T5i;w?U7Z4C&{Bu510tx&`q|8xzSgMein3%Y@ zOS}7CP42=T?C@`+g)Z@vC^iHTAxT9eM(=U;^Ye3ePv|h~*fUKHoXwS|1I*Nz5U1Qn zIZS@)X+EA$Nnl*^!qAfg$uj|PdpX}yskSrDMlkIN95q!xrVY~E-L^*tS0^sK(o<;( zJ5=+|R}l`tIJn%OeDZsSiO4wJvIPHp#Z^o i2*CgU`Tx#sc|>~S*h99f;?o6yKVa}F`%T*H!~X#4j=7uw literal 0 HcmV?d00001 diff --git a/tests/data/test_simple_sphere_light_flat.png b/tests/data/test_simple_sphere_light_flat.png new file mode 100644 index 0000000000000000000000000000000000000000..f8573d6dff086141184054720485de83b32e25f8 GIT binary patch literal 26694 zcmeFY5CMWEgS!)aaEF-?7$it=fN=RrIarN z0KVfXevs4fN;@`OjACjkOWQv|Ju)NXHIaER#|tc_bAZFi1t5;K>ifN$>}R)`6Ix}G<5JUcWUuf8B|<#!peZqX?K z>*nU4J5J?E-{1X>0TN2%dbmB$R&>Z5I5{~D^W$2O{ES9HzNFe&@>pTqA@1*>dH*$3 z{_N9x7P&^X0RYzRkL*5A@mbq@7z`GGP8H-Kq8EDQv!K~>_wev1oG(P~Mz)a{`QGWP ziP2M^d3K)???;kbAbHSGA|Nj_035d;0yUr;AjdkzLVjijuC!`0L6uXxhq#ZAD$0qD~^X^7OVJYTtrhshYm4IfFcH;XE&Xd zNI+K)4+oB>=Ja3a)3A{}ztQknC6c)5a6|`$B7zHf=c-CXi{hFGO=7)T?owa=Rlj?sB z07WYedKc&C{|p;gvNw;9M-9WSREsxC1Zihm50~Z~+=q*8T0xwe^JC$y?d{Ej`tDA? zl?;2e*19OBbo_uWarQPkv1f!a2L&qNSqtP%KtSLuWP(27c-%Xs0ZE4^=&7aVx`rAB zg3jrk2Pj^BCe=07MOoKDnt>y!;etaOr@*Z2v7oF}U1LHUD-^e_w~qTwsZr z@UydS0UI0IgrX(JPjqAxVukB~UmNXhZ6C)9R?B|C&mG;Fx1%87p#^gdm15|GY*!3R zxSYvis`*c|*S5_pG0cQs$YyUK2cBsN-sV&n>2kH%-0Y#pqvd>DMO^xpkTtF1nMYYo zbuv!MZwZkNu}vT4e^CtJVhtl55CdE+^YECs7tPZ7O8c32E1-s3*7)=BAkQ0x*>=jp z)njq-8%a%(|AvUFOVAq3LZ6ZC$oT>9Uu8W!?tZ$yN-iUgdh{Cf*YDoSWY`7eCC19F z?S)+aB8?dMRH6)2s+b10u&Fm2U1x_?2Pkc1QPQeL-r1JVq~oRFU{iS+*9F(9OZYtX4dKb zCtDL+D&DzQ%8nRDXAdQj=$antpzR`Wg-Q9cgXrEP3k)<&q+^NbOs}?GC@mKh%jAWn zv6L;{s4n2f{Qv~Xw5XG#G~Rd#fZeO?&e%By*=`Y3eL_H9)FQ=7f!`E_loo)WKMyNj z-y6oN{TWlcK%6+{>avrhun^&a{5K5OItbpjC5#*W`*7~tt5bzaK>()0nh?xxjLed= zl8pg23No%w3e|0%r|Wqp0xl73?6w2#-dp zgI#ZLdBEUN+^hG)4Ix(Mtj~d7A|W2<_3MqEBV`Pf78u;<6p#{n&1cPU>!HWbr|lD` z`!-y^XDV*cK2Urc)_#?)DM?{oPpvMlSgA|q^~{HZWq%t<%IHWGT9kL}Nio#SOBEOWIX}}QRlXIYb4)VYLsEX*Q53ci^f!L<9k&}CuE&7HH$J@{oS7&jo&WAc%G88}2Q4@eUf^Ou5*NBmJb7d4qkgS79PWtWDK3?N) zHrpibR>ppHR|Uv|gRG3aYopMG)Meq?kSgisO~uO6UtQ1yfp}O4A{y zO=9*HGilF8$agk-tiOW#DgGGv86m#$U8Rpb%dK=9je*c9S>;-BpYq9&aw%j5J!ODK zvo$8J*6dim_`$qX-)c=j4rA>SZu3&E16o;K0F)S;d1r{tMk>&RNogLH>;hbp`;ps~4wp#gy_p$-V;^xLyqT`TVU-v7-hI_=rQ^E}hP<_NZ~b zQaB@m+NH%gA(ze)E|7%xzpeSMP~nS2E{PPVd6gn)sI$Vr7rtcz{KUepRdAI8KYg!Y z5_je4Y2G`^a2yoW_BDSaiu2p!y59@Csh}0+^=V!nd>4WXhkbB+>hQb2xVq~3`~0sz zx(}d!plJg4sEgUjJdlQC0bw-p!QVYg;#I&Ucxdc`fd17TgT4^bM=JCQbNUjRF6V2) zi2gh%ba!`G)N~#ajYR>8hUv`p+j>_Ha`-I$BZt@`qNAhb>#ey)W$GAlY2wQ@e6A)2 z^;7CvY|3|ehu`JdkRJ1=j3kZK;a8Eey^ZS@I)M=qQmZj9cCuG>zcyUuE`O zP)$}2jF^#BS?DC8ndt4&_(50+ovue1rBtSyEt%lA5Iuu_MyP39t@9=6{HqzQ399KT#$ZH2nT zo;17>TiLaIkx*Vp$@7uLU^Qo=E{5|IG7KG|0>mjNrCVL1yKtrmRXLPE{ywmJj?WJS zsHjq{|NcX%2GZdprwXX5S5rPO-9reH4H#9GJzN(qhiKCA()0{)=HB-sIP9ug3E=-J z&4#UsVCTl;8;RpuPSjw-Rrw^klV9R)Y=VHliELx=8UqIaOA=Offe8x);V&N)1BBop zDfyvLk=I1iUoXmCB?TJ~cTiQGGBr)@#>8quz9tpDZV@G z3$(Ds*+`u|&dU1J=tnI#h2sxOVUU&+Drh0k+3$106K;hqwMa}5J(l(3RlrgLIaR2D zVkGR4eE~p5zSq)?3`vB@sS!@@Cq>^G^<%u#*3l>)YL)%+uKeO{AGM%*wSk<+6q8;C9xscF`UC`m z8g4J(60)(Q9(U+bgU-rsu%fHPz3D)<4PDG#eDkj8!J7D%U7%KwR#Qf!g24;NhwEZ{Do z6+dK&Iu}mtvwgFZi~2T=KpRS5#ZV?CatWpwNc%SwfJ+c*T8}-9vBYHoSfQ(`clEq) zk#I76^s43A+}s48fODLhE5UG$kvG$EeUUH&^`VlrZ;otfaS@MdsRBgH=-U}%>?Aei zKH^%GVhW?EI&_T$!fy+sS4uwgs9^A;o<_00|I|p-?DC6EN>rP*>>f3lbOgOuOFp6B z9s}`kZKv5$g9FLpCMGVYTxhW%sLGt3)n*AN-2I#MW-j1L#;o++?#@!cCu-<-(`=>u zNJ~uHoZ7?k1gi2c;#Ek9vyx5IrYKC1GkadFE_*=Q-t6dG!&7m(`~(nzM{&7f26;W@?@D6KunR8bYsOpfrNrl*PCWd z>Qsat1z6vS&e~&uUTT8Jr+e2KcXXR={kAXRI9(~}zj>6{-YyYfy~+&`HxVdKIkPeu zs__~FsoX3<}0%53Bs=f3&~FuVbbK!I)6ygK4Cf~z=rdoi1I>+Vc;^XN;Mtf zHGc-wGwI3^HP7Q&cBx$IvG<1K?K>vlsh0*BG}I2V)$*)9@k!%^DxxX%{^!|Iu|#1z zUJv=GO{e6&pCz^?!px1!?JB;FLWg?4D z9*?gC#&DXsz#x*TL|*qRftMj=PzekZqz6Ohs|?a|EdR} zx@RIpxQLH;0@sQlio@W^U}z& zN7flw6^_~SbrtxtMtD(``!VlZ%6w!yJrkhT)si<2msH_a>LV`KhvL5_V!T9#m5|U+ z9K|M_lTzmzaIri2%eA?TJL7ispZ4(M>~wfT2U#aSGKnAGuRgE4P~`L?i-No55uy~V}3eLw< z#E6mL7xHu5>V`FLx;r_1eh%`Mk6mNPr%`Fo0B^D&^?&w7(rQBsKD$N0M!qA8yyn}%?mYc+wNn-8*bwx21}PEvQDX7{%H(uq-Kf?#B zGAF0TjhLt0b?%!$!7}DBWl4aagajUce!u-yU5-XmnQZAnq#A13Mnfa5L3k~LtlpE@ zvx8xpFy7!=PeJWP&9S6adp=FGsOvBHMRhR2+gpNgj8$AcU>f=%;+D*5J*^_%Ua>|7 zgeRRF1o$YxA&?4aR=$qGKurb4JPD-9Vpkh2q6i6oc@Ss~WDY__C7CtmBAPh*#=<;B z^RD6}TrmbHiq3r~<08#>&&qX8oKUeyV!)?RrM8%zo|=09UNY-q12kg>JLm*j19rpB z*37^}^t`@rgf!0IJM-7HzQl-jF~46v=&ySDy4){qE*@j~vlf|24gjd58+V6kE`NYz z(uJcJI{f>ppmWv-{y__wMY6@DzND?)@t62!sF2;NSb)urgsgQr1s^Z0ICtEBM*aFS ze?M83CB$Ocm>LHc6_uM8-(EHWUt=C}qbt0o@eTvbSbF26T^!k>1N!W}rc}_iZx|E5 z|4S*aECg-rM=p)0n6Nq*fBuB5h)C$UM5IHe@=y3^Wh{00@3%j`{uj7b@6(jxcEx3) z+8%iFSwT;69sC^~%v&5w;D`w@xtF_P9k~9*^AvO9nIa$$RT2}|#9<=H^(i(`?z7@7 zfc8CCWsS=hGPGuMa*Y!*gOdE-5q20RoPqpR(g^;1!a8yUB!o5sEdWty@9rk&Gu)CY zG|%ZDiF+{qW9W1vjMI7}^r#^{5!1ZC%&jg(CxYh&r7ZTl^pYB-Q0YBO2=@CcSc-Qc z2l6P4K2(kVjVJo*KZq*y@*H%WxpP5KBg3KDYjJDNf+@iC>lc+{=7Ay^9~}o?{NOA# ze!A6<{5@t9z97M^DC)@WtsoReZj5zfV`JZEy5T3cuHF6pqIrl1$AnTB!(Z$*-k#QR z1i1gK0_Q&eUPiU0SAM-j|2Nb*+dM|{56ego=Y$51_F=STF{Vxny2$GKoS?e40L6%D zCZ$ph?`9?ob7-2*(NKu1+zeYG!5>rZnQ&TEzk2=P95+70I3IOM>sOmaK%Zq9Q z-JC%9m`r!5!mD*E?s$nnL*V*660q8=l~2f~Od+-f)#)?)nEtB^{mA)^f{+}(h|U{? zb|LrJv*c)MhpH#6!A0+C-6ynC`#lX^An0G$2e9`nmI*_bV4ae|zSMovUYH$_s zDZG5ZYX{VSZ9TU2^1+)QJy93e6ljfVf4|pA=MhPXceR%rDg3O431@8F(?xbD(S5jc-{S5yiCZd}}P`yZ&(p4CV#@T$J;m z9=VjTv~GT(pRSk7?w`DLu-7hZOrRz#jaZLfjO|s99diG-oAqm)K zySox~lhZOqfgp{-^hVSVEr9-aX$96%!wWUC8owxa+~bdnTsKCWhFRjUc2geBpXX<1 zypln)Q3{{begC|ra48Jw7Un-LjyR{4B6qXBx-nOM{_yl;tx)D2ZIusJC;JdWmU4X1S}ebNu_G_U1HWRVhUtRopfLcG#tb@l6~`+`vJOdvy;~*s?pA7 z81%N;;vo>Nh-Tt&E2LD3{8srw68A2<1ebb6k=4};;yFSwX6kv&${m zUo!f{KW44;*n&Su+f2Sc7DPK?gG+V6@rr}aDKA#p2?ym)%e~6wEM4E+|Ib~1DAvMH zTsZecE!}M;J{uzZVM~A`q(Q;sZhd`ye}5m$*s_nsAI@7+2isKbW#ZfonELy)ckJ{b zN`gch>9}Tyhi*80ip$T7==+HgUBiEMigI1}762`iXdg8k$q!_G(hp<7>;8ea_!4xs z_tRx8!nl%xKz?4{BjI}S89ZyB2>#Dp(h8;`C+c2}<$n(9haoLvTa9DKL}ZQKItxr8 z!JF@HNNS|M6IEhRKx}fuUiJmVw;vZdKyH6yq!+w*K@K{KCs>OtIN>c&=gg87h42uA z9%YKxeKXuR+Ob}UuQZoZ8W&sLPA_)5bZs0CICHQcaDDL4Xx|^s=sW8xbTR0Btz%2* zhtZ;H8S%*^^?9($voixqmws8w5hm;59$#1p6W5r4{d){a#dYaO)3DIGATu*{vY!>3 zED*%%)m}2KpO$5)%6D{L*wnLH()x*Ob1Myo9+WP!0X-r)pwb)%E$q5f~FYj^QUeV)gU~o%NjlgE((Rtnc-LAs31jG6aY({pOi| z`{e<#Pq-$uX5+cUb@rz~76A#5(! zpS(HKQ<%&wS5#P5IBnToDDAjfLZsHo=LSt~7n+K9g?F~WPwv^;LDJhr11Tk><>(2DLqYQc_J`eDQz z%a*C>@o6wK3<~8Y40`9-pJ(;Gj?xU zP4Cs5vOly!Gx3Z=yRD(R`(}HTRh?G9pi^mMHEq9Ah6!{Kwu%jRs&EB$`QxLa-if7Q zkKh`Ihf5y^&WNUw0wpZ~`Tua98(y>`X4sC6QwaYqt!#J#IC)-Lbedh%K=t?P7UDy1 zm|uG**RAoI`&*Ov2xZQnwLJcY=~^`IUWvzf%|If1Dgz05zc>IK=cGq4rjZ&M>3&#r zR15_M6y~IT9lsb9M_s99b|oZFC@^-{6_OKLr`2|>YAG_2!|J)BHF_IJ?9%c6nx>$c zMp|Zi*2azSbfbrr@J?<2%kC0C7mC6=3|<)G4ymZjt3*Tg)M(OdO{tiZo!X7zm4BJTh+;VcS^+T) z<_~ws|GgZ%1Eo57pE(pXA_P!SyJ5v5yK!L@{qO2dM(hed0$w2}e{$|(Hy+ZM$$UfD z!W_5tKKC(nfU54vF_a1B!s9bf$-h{(hoIgatNQfcbfxl;4S>A#Yx6;4NZW*8;fGH6?KD(>u#K5$ip~_A|BxXSS_06@$+Axj!*-N&WqwE3;>>2dOEk zD;8T3jbCyW$f@UF+MVdj?&D@Ct2=S0u^6*S06(tj?Ey0kr1f>4=5fXRkEORi^?#4% zjM)2Z5H8E{EknMzmb%>-Z@sihoJ`$zAo@U=MqiUwym05(E(_`B+uxt4xV(Bn-rsNwM1XE}&>uSW>@ z0fG7)q62`fN|^LQp5-;X*E;^6BhG{Wy=Gd1*vw%?rTT3jUjL=G1Tn_GvfYK70*KiHm>6o1T<$ z-RtkO|HTC560yulq+V{`*pHuJOK)NUKiR)Sop37F-#0E+!mEC95R=)De{lUT2Zh%H zK=|zfle#xST&~DF8Lz6D(Cz4{Fa{nkZ2b3ej@nFsSHj>Zj<38%D9xE*x^$u*$I9dY zOumEraROHms-Y63tC?Ql{=X4UyE1X(4opi+WyOSv zI$3Y{0n88k?~ZlpnhQtYn*YlXDmr22S1_pkO!y1nG`JS9U1QP#D zGHxh*wn0ICV@Y^@)d-bTfpuN!saAowI7-zN>gr2Ng6`#I&euubynn|=I@Gb@%45)f zZmyY?KD;{)pQm}+9J}hc*SmNg@o3~+t!jcrR9V}+u$~L}y?l=Jy)$!MolYM044rju z*7%drQvcM9lqHpFF$NL^aj-#LW3u3QMUJ zq6bjmIO}))WBxFevp`UESzD`w)#IeR%qX4uxy3|#tJJzmXGi%-^?xy9PdhnE1_MqApTr(~z0$V75*o>$wZ34P|746T%9q^+9t zK)b-%SA_KYQ`jJJW|EaEd1c)Cm!q1YKUY_?q-brgYUt zdcQ3KjLHXEwC>5;^s8h^k;Xs%Z5#cY`5h!9eft?UyL?GQM(i(DRYU7(X1zx=YcX0^ zN~ucnl(LKo5Bw}0mSJ32_S_1?xeTnZ#TK2ba(Un0+}+(h zJsB*c%lK+0F}5M@?)I=B*Q!a6%j2a(v;ZC1T)3&9gHrJNXn9+_%ifY+Qk-3Q7g7++ zB2pdq?Z2iMtsjjCFK*D0F1atO{`-KV>i=g>T8Ns*FO-~C5MgjLXl)RnF3;!?TmJS@ zum{C`01%(d7&uW0jtI%qqb*_JEg98z);huv--`(WIpj!hkUTh8HLjkWb)2>@%1K|| z_=hzSoP2y6MlX4cRI%Z2^NAbvX}lzK*e7|oG2eF@PGqNE7;y}yE9_L^C0&YozUgVP zr-yT-79AEbHHBGx21p7|r}9((2*gX`x@8Rs_|22}V)IgG^l2*xgab^h^0P&oQxcw{ z+^<^bZa_ED>MYx@t4p@-w%W+&xNKM&jJ)OJtW+Gnk*3+(-?zDYmmh~#t|=~_+t^?GY~RNS~!qtieiu(R2=dwo+8k|LKx=qEwc3E6a3SBGh@d_WK~apJ4khDr4P?`VK<#CBIvG%fTZw zeSYgJOo9(pE0#j)`4^9eM)Y&cjWb7(;6M8I`uEt+vWUCYJ;Y@brit&v)aHt>)y#sd ziE!&+=(#-Iw5@9s;N1>3y@EWx=eU)Lm_|FVAZfBoSn{QCS}Rp=+nO2F0ub>)5zi!lsVERi0Az|xHzqsiK~k=8I}=EOj))74_pe@ z0_C@j%0?h~U?=An`2m2N!7^aZ(nKS;NMiK`$HRABya8krKYfWOTt_>0pOzc;?Fdt2 zc7!#j%Y(C%Yx0w|HtH|(WWVdE9#-#lcd&eqGZ{G6W+5Y{ub-U8)EnOszi7Fz@eB@wav{j^x(h0^r&Z4)MDRle5k({)Aroep6i>0 ztf<&a_ud8Npr(hEsb?AvV^ znf3yxxg7xFsEwa=bu_3g11s_~ycW~KMx(6KBQ&kngo{^Mc9(xa=fi;15+&)zCJsFs zqW&0(neSk1(v!;v8A{?y(ruHzEnehn6Bt=t0go#8K^tQ_SRYY`ZDEqQF^7$0*+#x#*!)_vHU#|D!~Y{WTJm+GTB| zanYR=zbfG&x3_zD{QH22x0=_~{ldWAIUGZ(=X`hi%4G$A)uQ-C#j&S;+3gBnm#}AC zzO)x(Y=Di?+NBf^t~+(kDB81VbcE+FtJCj}2akC~-6EiW=iKFTTBgU*sI`ToUiL8x z?`Ev{eSy^dPHGGM<{DAuUw_&a=y@B#^|S*bMn$~UdIuvtw&{jXq@;Iw_*IE0~&Ra2(4y8Q+o&_5MNsaaFSvY`Zc16Y1?$CTfJt{ z8Dt7C+%Vx0<9<6bP*4;w41`o*JZa3rqlK;`2k0DboaB`H*3LAhh6V*Z%sj@X_3O{N zJ~^EZ{Z(7b6f2mlO5&4zMQ>;X5yKVoP3Z0X4huFpRiy0qP-O?Xgn(RE-tIJ z=vk>bFdc8L)}i=|Fe^K0n`lGJDAZp^N7^cd6XNAlQ|3P@<8cHghjofBeY`i zn-@L%h$s7*1g(7P^T%9`XwemA-e1Y^X=s~cjY$KVTxx7 z<*y5VP*HI|d`Z%99L?w4E{o%F0o(VyAKtCtkgiHd_pOlv$;Q!}uaZvp60`I^>bWrv z7&3T5&(|y1uOoMjD5^l|qcqbUf5CIqx2;QHYbEj;;k92KEOZf!w?8!Fs@`fz<$8Yr zRg_(|2L}YNs-q{RHPd0dg0;ZmKg{pcwMeQr2=`YgBrh2p*X~yPmW{@6#Vq7_7Y_>a z;D1`VtPhzcTSwOS)~YZ&)E*>*lK%bqXJ%_yh2#?WKR=_F!`^JEPLZerKZAL`2MmT3yR@Czz0qNlIh+% zen3)dN5B)&m;JAPBg5Vvm7Zms2;W{99Ua~M<-aHz|MiGdYBAZN+i|}iFGrJe{+AkX zt}IYAt#T@i6o{9EAYs(3zSrR$!|~8LV<^!yfLgea)ynXUVA3W&HsFb9c0?$iJT2ck zkn@ZYcI5Z256MV>;~k_@fvs00ggKv~eUaO;oTW9k|6J3>?3in<105uS@&yIT z9>y(=ns}&YHzCGj+ez_I>g99#YDs9LTkv|XwXIDzr1QGwc{bOp3ncp2FM?9u&ic>HI}NOiX?3IA30#S6@_|pW9wd7XMjX zU}wea8;g?W=m0&s%JAcQ46EHM`n|^l@$4}+mIJ>hKQ9)9OI!<)&eZS6Cr&@81Dg(h zS);*@!x*hC762)Q?TS={`kHtm)}r?OlH=A6SpBnjc(}yWR!7*0>nRhq{$i5d328TZ zU@_Df7T@!eKvXP@9@RTMvtQ#Mi?7)zZ>JhJP&_8IXiS?6&0&a=VA1FE<|7m(wz22d{K5M~xnC1I*!t5pQyndL@HYTl#}wWwIF zJCuIG!M%`{KHcfBm*`y4y$1yz7j;M*tHsn37m<%8TH(sqwf5z0aJRf^EChlTCbyh{ zjY>2@1$3?NE!<(*BNP=^Q)t55A~oOhefB6b2?$E@;%|7Tw>0j%m(efe3>!5mFe+me zJW1Y(E+Mt<7HeQew5j8#&m_y}8UJ_%*BCZR_&Y=80y zO6?M%2akRK;c!sVf%t>CU2Jf&#>Ld*I!cvbKNZ+mUxz9OFbh~7&!n(Wze-aAwdf~? z)NS5q_q}T2Kvq}r@bHv#@J1f49dJP6&zzm`_swH{$zYlkAt+moMO!W31M>xGWzbfc zln@Vx6{?Q}r>c0KYqMgu97ioY=nuQ5I>{s4z3BgKChVQ6d244<%^CtZ>6a~Z(rYuY zZ}mlE2Ks$BC(yoJ91TsCH%K@?$q{GW()`(Ka9T~&MLg}gEmZXPf9P7MzG^SYBg-K&Y|C8_OfWb zqS-0L)j!Qf|LZ$Lae|c+=MTu)QO2L_ksYfsZog2!;2JTyh|EAllg&$1X23(3&0de0 z43p!TW%PsZvD4dCU6zfleYdh-uHqjr&Kmcy;n@;RzrR+!IGn0oh2~HhT|=B`9vSgk zidjYmR?9Y{6Wl>3)87C(p)1EnKh}hQsZ;WxUugYe6Y9i@L#k?-K~7D*u!R1ay0mGb z=jYJFazIG4Q*t?|4*f)|z1Jd@*EW5%sJ7C9yhj+|nDvZ|b&A?;e3Exc!N@^+-(W$h z{BBSFP`5PAVhiM1SZA76#?q5f6iM!DsCdg{lFSMLh;B?Nt$1>Mh~G`t>F-~W^C5ay zu0*2KqkA>tf?W-p`&cz!F2KX9{z`nKiC4+|TcFDnCB?Z8%|UIPP>zzvP&*2U&LOH7 zHF<(e^_kA)zMHaCBEp;9XjC+4%ro^GlZ;<5nMI<&=lQxK@cGdArFkZi!t;eA&aS=4 z-9k^8VmPj={)X;^m~7#XdxDj_I4_1a2d3LUa#^<^YW;Y0nIwBvJK1;4Y`-c z*x?IN(tZMuFQ+rqGqc6s>M{*C%#REfK~~P=(ux5@4ihI?EtoD zbP#>k zokXZ!7kJK&Age4`&kd0!n{it?F2*`{!a7tCp}R95af+1~1L7oMKv9&(ijT(O*X*D@ zZXHHWyTHX44U}<>N>8nT`hgzG(dX^qq=JG37-G^OHPsh{6aOTU?Zc$TPTif&fcpzds$v7Xw$KT)Z@lyeYueGxRue}y7 z-S-h=y7o%Ym;X3qF#!K)OjN>oHstB5%9+}DS9H6>JasZU?ZIB7$JIH=GXeZoYe~id zRZ3}o>t)G@*+0_9t4+;Wgh`IvpX21uF19G+iMpYy#aQo%uTZ#`d;MQR4k;n*rhQTs z7F(l+XfSCrLtNU`1N*n79+OFESmW&^B>m`HWFZQvCIT_+j`o?wEpQACS5ZGI-pkq} z+^>9Gxs?fNQgb*up#c#^1K-+;_o+k@Ft5X$4N+yK)9)N@Kf=3*T3EG0MsT)_*L2t4 z$OqTSpiMFztY$CL?g|?;=KAwr$5a6((Z-JnQ%6fb!@4YHcwVr1Rj1!D4=bJ@j&ko-f;|=AB%hs87&_&MzF?R^@I89d;V}pd*hLfh`T3(a&yfW@-jN znNp5u=256rR-78<&=g)Njka^U^}1imp8{O5sufu=5to@SCPUKcRHqMhi2Juy!nbH= zC9wIg`5`)leKIzGk`!;$x34$lJBxF0I31`wFAaf7MnL5%Xm?|rp{iDn$wY_G28Hdl zq+>%Dibz`y9!Jlc-$)mbF$BfAQpIZ6WOn-#!2DK&b%whLAG<6n!voQmxKgdr^FY*5vuY~>3Q!X)wtN~k<s1vEoN2IQ_1Sw2K}n!z_rXPY@S)mfoDoGK zST_#csFnUGKA&Y}$#D$c32ve(jGn$zRCDb={p#F```%5sb}odvc&qa2yeul<*D^8z z(el;e3+#dgQ}9_2=KtuHAFM-oV-!&aFS#6%Gq`*hCRt1P0JHo#Q#+`^>jAr@!0)_O z!xk`W3^!96(@1rxECTR63OO&6$bQ}?%Z(nGiY1mbO_pY%8=Yspo0>j(FSa2RVo|*I z0krY*Fx!#8b8~$-`@ayAdvj{SzVmqmfp5)PQcAI-!gnUF+-A^CM%sEb+moja5GLZq zFyAZ_d>w(>MSu4A@DT1!;XT2V<6U++fax?Ihpa5ni`yX>BMq{v^G^4+h_ zU3}?uzvQw#TA%!J?va7TOgp6v`wh6Dx3{aig~+8Q-W418%tVknt`y(A)bS@htK7ezTvT;A%4gd$l7lQ%G3k?qoSejrsKA z0(o?6#0Ilw99%pKYY4euvd$Adb2kUzYvD+Il9tx_UF{?(O-x9Okx234Bg1dtUuPcF zwpAPLFyR0nmwBN+HRD=6K_JtGtVP|WPmhZ)+-NYha9ja?dXos$4x1-Mew@2z-V#{tl$E-u)Jv0Mw0I#r8~4~vLadc9jCeh=!nc&}k1 z7W3EbHkwW*;Q3m+eH!`D?*CX{D&e?x9oM;7RGYSLdiXV&m5ShN4~}sY`>S2f9?RpN zk9cs(=2fC=^N}6p;76{~e<3I|trzLvCQ0Ulw@+h;4u~Ibju(@7jL8N5G9{Ib9gERH z5Mb58OCZw<2-qbB0Vd9#Xf=DM$jBC7#T$CJ#W1-Z5u5 z+<*?`eFeky)F`Rhj@EoG^4xag3w@;xT-Q0|qWss0RT-vl2mCc2j@*LqTU<&*T>?OM znrLsVMFbv#`A}Q&L=)EMlUQrA7ceJWf88N}fzeau@-#s8v0nG$fLa;OZFFa6N{>mu zlP0O#oFe6`#_E;RMsf71m}>45IW}1xI-h5TjEt_bNO?}`9I}*e-rWvSWAQ1Q~p1Z;oVnE z);TDkhizf~KD`oTiSW0Ces+96sj#FZGU;n(RTnM9l<_03r)i_xVO%NcUr{|?gP)oy z2vOIo>JsQM1*CT7v-=$BJMA9?03g&D0mtU(b|N>LTk@r%k(z!f-8jrNMj){f$AhW7 z@2$POXhg8gM%7}-GO4kZIq+;L<+ke+dU2CV+k!R+n%hUzwaQAP@*fJ*Z>hG=Jw80e zuGvU%NAQv4g3|zRol1hz*I!Ulrbey%w#hJ41*nA?z6c#X9p_IybOp#8u`JK03YYrb zdr{93Cb>UJl-mK8Rx@tjE4ezY?7CYO9$YvmVKtz&;v_P_I5RdndS zd`SK{P@Ua!na)l2xgWmPwU9^_8Y>@zPSe6mE^bJ|S3+7HxCGXc&~HT$fEw<`)Ku`g z{Ai@LJcle(u<1>q>!2!le4jNO9d+YYC2}Z%<|XP26$HTqw3J#j^+}I*T)%cFk=>c@ zcL68FXD*`%Pj@&p?)&N)b`;%;unJ6{9feAr3TcN^d3}{X!P-*NK*G*nL)D`Q!M~p> zd1_M4B>;R5ny!IU#l;S}NfxKE&^?WS?01Ab9reBT-bL+ABMq!gJkbF)^$kl6&@Z=* z9RC{N2E;5ar9n<~(dp{yJ^3ZKHsD?oh|0|d>gS^FSZP1IF`PKHiT-T~#r-?;{msze z2D`~Hqp{j~C+Uc~`i9l=uTo-~tWQ2}#~%xo0PC9QnSY%-ssAGV_RVi_t$~bJl_Dp-;%WiVM6uJYx$#k{)*0YUq~t37D-iQy=BeW=0a#iSnyW z+8diVC?=tynoR%T98n4FH*NcaI(M!0jsG!V_mYzCaXVX@&v6vzAU?&N^KGeqBpgI- zD&4l}`5I+N7-QJP(x4mvfqMH+d*oFU-;Xqm%bo>R$+3RJ&J`cOQ?L3VbaHOTS4e6a zKiO>@j!B&3dq-yVx0jm{XI;XR}5#7`H#O9IvExNtAI_=z@XaI0m0G#~^aHnBU zF^!q;tv<7=vIX&7#ujXmz`eIFI48SUchd?23;N93|55&P%P_6~fW2zn%%X~^!a7{E8_%KS#j-WO#gn z?I|l2I2ZRI^Rr$&`Vtr|$7Do4Dm`Kcc;&A1SO7IX`#3?SB4+|a27c2y>+h&wS#tH1 zY2S&hls1!YjyEnRhdru4w8 zosyo2Ugoog+99gC2A;<{%R@4~zFH6M$p?n+WSi2IPbf!doup_Jd9U`4TdA{)Qbt0o zVbAt@KSos(JBAmyf1+x>94UKb=)|DOiOZf?maeR>R1OY5CjrmE8R!9}gK1-RZ6=tG zDC9$&XfU!R*H{r_yGF5-Q4GTiN)(vqqAK-sDAC3Imw#T53*B;-kfB*E7) z5A=8IB~yhFK72N|h8Hn;fD=qim|VB{a2p=UD#Ms4otm^${I)SYI)Y=3p12ZdtpN!I zFa>Dd{x$4ZO*RaScvf?^LH*m*;g6E*^Df1e6<$_%5uU3zu|V04MM{ zSNB-Ju;y)k*S6Hc2y3kl&TQ-HFf*Ny7;*+<;?TC{~Vs`jIp_E?kY7})Ws`pi>r9)e!r-4960dT zE@(`cfq4)!M`^YSg}xkTCBJxm>dS6OAvZQ2GQtE!fBtyN!mLy5Nl@T>Sz>jlpEtNL zjg}w#|Fn13e@(v8|KI2s43ut^hBu59$svt&D<}=pAtkj@Bcvo0q?JYx1tf*h5&{B4 zIz+lra*Ta$KcBzh``h>D-P<1bIoG-F`<&~%o(HZkO&Ip}Ix@Q7=ERvT6Px1ldtXEf zoIXnBM}2yfZ9~~^1o9xL&=P9bSt?JcMOL$dhCvzzT}m28P9uCGZ5!LCc1(d=h~re* zkA|Z%gbY|`P-+_+;f(KPP{+?%WhKD`=}@6))p6ZTSBA!2p8ml-`SgydCnvwaAZFz9 zkHN09u)pO*j^MuX`=$OSpYquZm2!!Pn}gC5WInpSTR%d7qK#T#_C)XvdyFwxEmw6r z^&`xkvnfr#^6C2tyskxJAJ#A-qakQ-uKVDA61BKz)5qMm_v+A7#w$}8bPDyNMThEh z)*ogQs~TJFD1IER_C5PYwp-ZpgD{Nb<=ha+{~PGK1LD8bWf(rtmYk`(UQISXssx%? zl+UBRvUIlPo^|F|+#WT<({0vw3iJ|~;kSn;tc4>iETp&(h#B#B31KLQ7~`ICRpZ*M zo~gBzw$F@EYl;OvUWAFVXq$J>Lhkp>arU_#nNA-3zB4;Q$*A06h&K`ZMQcLi#AP7S zhDZX0KnmY7cA8odX%bK#nd;4P};xgZkR@ zbW|uj++tQw=j*SB+x1#ULvufj6fAsZLJ9x@u6;z?BNf-}XtHxAGO?ZL zvyrl$fz{wu)!w6`t<)BLUda(Y>P5Bl=KVmQ_j=i6-{66Htw(b99_Q;hHwLE;G?=(g4b^y z1@nkpyAH^6;Kg;Bi+&xq|m4q7LHOx>- zS_-)nH$5Blmy@--WOWok6uuXORG*fes^lTM6qdbGnY4vVD`~UPY3|)KKkv$0#W8t@ z%og+Yd~pZB6poLmV`ogaLM-k-ZXb7 z_g#c((@N7$3)B*8!B2a7N|YTKGUUJ$wm{TC$Q28LQ|Ecu_Sx$>bEPFdng=a?B|eAA zC^!h#Rd##jX)Yfz)J>0#O1%!EEeU0y4{w;Ji}w>ta#iET`7tD)!g*y zUK~C=&VJ>R$W!=%m`oR5@2=hcdKMPSMVR(}F6Y)HtMQSb`!KDMPtZ;c$B=Hozt*a; zKBz;TyEu~lE26NfFbtGU{bp_HcYws?ls^4Hbu0-pC^eDTtuB?3jO>p!jU@||aVqyC z@9daKe zNz(m#BDCQ;sOQWm{(`Z+u<~usljG=6@;jhh6l`0O_ zjNY|u92c>Ko#KV>I=e5d0I@xY+}NO)zwo#VbW~K1li!ERFgK%A)t}2QcCKLvB8Rk> zfAj%4e)6+|D>opEj2kF5u+JI2sHQ5Mx3ds@L?I-XHnyL3^-&~d0Z(Ues2JI%5!(LY zPDF6prB!L5G7%#?wFJ4fyMy7JN_1ejMUGBm%od2_Z2UJpx+nmQjT;0pkW5 zw=Zh9C9k?;T-MJybBhlPpSmz9Uq`f9O!URg2OKNYJA(S-)eJk$!Z+5y!NTaM*p;53 z61VLCzSfnE`wADxDF+M|ma*Co##bz+EOz`{*WUH}Hc9YT{@KIn@5J8b>uCws+}wnQ z??S1JT0k|P_?Jz3b1&&DTUU&a36kd4%{Z1N4POjX2i!AtVFf+?Wd4xBINMb^lXLoa zi}-%@i`<54y8?Ap{ChP`y#r;WJ%}!3rK$cXnK8xj@1H+^CMPpInW9O}vV(TyRAc`n zDKn=prcy0%&CKH;P8`?c41VBNF2+om%4r>J-#T&cwQoC`23_tRya_6Q8YKKHv%^^V z#WP|g&AnK&OsCf8GcpYoDHsy5{taVW9aQC6<4T68N30VGLAaNf5XYI%1r(#H5WQRY zc;uA?<@aj-Wx=35tcA05M~NRE$O;8ORpJzu1|<*Exxp($ElN+SaTRDfUxiRdoi{dG zBNZ&^SK*He)AKXB>>?Hf9R>}upqnDf`p2IPh*69VNsNZ?dsgy5Kz|hSlJ1`m zZ){1)3aW8{2Y4JJdHfT~Go+(VoPZvAs{?G%sgqD&7dSRwzh%}-*w1$T@XP{~&!`m*)&S5E5P`s1;r^9}Z@kn!zVH`UO1~qTbl{0$0^i=f z2n`L5hyVi5R>BitD_z9ol?gHg5|4e+nLjB>#wzG;f3lZX+BkU-Ih(N{-bN0xxbMN~ zBlAkHPm{x`Dn{I@tya?|M7s}5x``*c*#ez0XGmg{Wp@g=+%6_|di z-GZ-A181W`r>e@sm_REw&^;Mylb~$FRintrN*Z73QevHa$QiX+SZRSrT>We*Cd|2B zYcVy9WKsFAkz5@&+ZnW-m+grbt6Y=r3O6CT`;u;YtZM2!lWNP>-%tp|H#T)oNf0)^ zUcO|Y#2%qFK8E6Di|e@UyagT1?6Ry#Z8&No45zlZ4+soCcL1$v^4tIP1q%oc36OC7 zjTXO)@6^2zOH6f?s>}A0#kE(46t#h`2n~xnMfkvbZ%cV7l6Ir!)RZoNVCwu%K8WwU zA%vjX9FC1<qvr%T^< z;PSI-^?~-Xq-A54yqo;wDpmEcBo~ms3nsjHRrjk@Wy>mww!{>JC4I%_)boXIl(@ z4cV_4duQ$(eM%VO%zNS zvty3H11G19XW+EU(?xe@I^hiSF>jSRp+kYy%pp>f{MO;gE7gI%w}V;@q9Vmalr^9- zpbD@TXUX=?A^Q)yO~1z~OS$hnZIn~%e^o@QlO%jIZSBya)Akok@W!XB4utjHSv&;@ z>-rPLBg)xTdU?F%vCRQq>7U36n$+-cZq*t`Wg*?M|QSe$UUd-ep$pP?fE2F}%~A zEFT~{@_AJ|j;}HoHM@9g-&;qAx-U?OT;i2bU6oci-iw@Vb#nKFy7Lth&=lQyJ3#x$ z1KTUlSeED>$+oi-P8hK;-gh?~ZHEkTqrBDWM84;{7}PT;7IE<_x;)#NQ5%xnamKhw zd{nb@M~h7-9N&ESc^L?9NDV0PM_4sOMAA*$*KDO@Q7fsZNz9?P2{h;LYij8C!rslm z!jfalL-bkG-3eg+r^DrnFi{!Tyon2+b$JU_i~j77y@M~+Nyd_2F0zbgd8W(gN_W&? zKhmR>8VD@oPUVrEd99LWJyq3>RBBmewsU3nJc|Y)pQj%cEC$3XhWO2U4)?L$koEX{ zW47U@_!Ak~sC?idA%259erWWh@+rf3)&idW8_1*hwu&N;S<8R8=+IJ=U-X-aFtPNn z()5#rI9)TQjjxg&H$;U=WByJV+RjY`phdmU_e`kwrYc`+;#6G1DLshq_Hn9q9KX9= z_c8lhp(+=}7SmrAoTn9>GFeGkxoN*Sg+0*}`La#!Z8y8jiuu(+Fj8s-BETa5OG408 zTSIP`*1U(W(ox#}cJ3U)@cV!2li4Qxq|MCKyv!Cw-Ab!#L_k9=ju|+8Ojj@)JM-$C zm`q&wm45NAo9Qul@6I9v4Yc~DcT{M+Ns!0%09*et>N{g_G_f5@=kb~u>frtsRq#gz zC_^Z#u_D)jBAb&=uq@RHEx~u@7v~l8%GQnkUtE=O?-uaky1)C$K?v;pa_ix2*kysw ztihsNN`ZXLu-uLQ$kWt-&^`4{kA7=|^X2Te`+LPFVbxOOrp+hcO;!)m*K%O3&Hp8# z?7he|N`K|}vzmu@HH-;X7hX|q>)Ky41ks0|Rgt<|czk)Fv({X&sA0oDtrvA#J*V!9 z|7S8SsSaK~;+84)>U4V<{c?Yg45@p|8;tk$h|zp4{7J~)-7AKYZ8$`Rf*Uwgg|x3I(x>{k8bs=?k%N~N; z_JvW1+@uBkdjPY*Bg8r`EE)9G+w2SGhd%#_47vuIH|Hk9hnbruX?~OUN_%lfX-puD zOgi(HQzaA0p;v|(X0#~eN^H;(ajI-sn$MiGmAaNj{}Gq}F>|M_nEiyov#yzz38}MEY(A`2BjAau^+DtF%{@3ns3r~2J|$+-3R+)RHwNVIricR{CT7yWSMvqh#!SL-Oz>s*xx&NKUULkT_3PTjn?w=B^>gTUgi`oy zfEn*~v+1_fWx1GIYe@rsm4?{=4$JCV%1hs@un&&jmMx(_cn3)POz*#aeIL9%5AZ(h zRr@F;UI>BQ1azWkT@3P*PBW`$bLtU(Td>IVqearT>6o+k#J4}Oo$0$>A5@w?xItaGU|t3uPD zJGG|d`JW8<0Ap~7nIxF$96{W^tvb6F#ThDMyt5v^nUO5p&}P0Tu1k790Gbph&2(lD z!Jma8oNg9C5+Hv2X%b86#6K#+tjIqqlWR^sGhC-#pOV;j6>9PCOCmXLv1vyz7WaKlRY!3S! zZtZ2O+QQl3ywS^oN%&n9mV^x;8rsc78y#$0nnCHeGXu#!D+3i(+t)D}0H~v7>m{h| z?g#h&>|1UD!rw@VD`g@3(XQdB^8pF2+nK9ZZz_wYjLq$7bqyJz?4h~uLSXaJ_3GmD zJ-vh|K+nD~(MT|Q;~#MAzqMthzc?E`-2;3jU@&$ddKN+AoKhP|^U553s#9A(PT{ka z19OHbMkW$xlg3dK)oZ?0)ADlp;fbXqU1REWMvjV!Kko}ZE86Z9LLo=(2%W3)yx<`fyV}BK zL-WYg_1aVBudbo_NmM2L?E{5$hf|FfsACWp=s<4&Fh=+bz~ITgo`imVb=JuX&}L$W zNGv^Nb;_I`P^}HJR8?Qvh8ntmLQkN7%vdG_m z_XW+UUb*ABLiwPAnT6SQQ}&fBCGH`%lah|8L%Ox>S?h!sxPCp~A?Y5&3;m$T7L>mD z=}rFYgql1wkd*!WXtU_mSW#6q?$6?{*SU5?x<~Cm|GocfAC-MBYljC-ZjJ!}UL+~4 z$p>k?g>Fk_-Fn-)Q^H0p`RGRwREV{PD6%$|MPg`o7`lrY=`1P@;`noCK_vJ%_2XAvG zi6-{OOx?>xI`iRKSCxu3hd>=L_uXf(C`;ZP@}{g5V6rd@FdX`)Tpo4m`^x8IT$D@? zOmbrY{RaJ9CPQ$s;FmZ-{WN`6QWAK&JMSbSEUhD_A98^|6RJD#f>5_T@2Z+Iu(;Hm z!vpjoANOWk$CuRWt}8lOf=yxbSVw=ixu*lhTdH+SqPPs-233n<=VIsUW%(uCU|9D z)zB%wrap(~0Q4D3TApx=b5}V-dU~`B*ShD2l70TyXL;XAYcfNtL|rP2XJQ=vL_!%3 z1Ej>^D5{Ng-}Rm}n?J;t-g2)bFqWCsPS z;s)oe<4H?G7QT}{yG(qP*vKi;vnI>yiVp<}-SRztD*4o(=a$W)Wv5A2!n=#!8uFsY zHFKwE0(S!obbk`K47UEEl=byMNPwt zw=%8FOiY3Fpf=iCqet)_ zx+0ygG6BF7Un3DUa?`8PJ+XUyD64*6UUuhzA_Z83`yIH}FuZ&u1kGOj8T^_q@s-`? z@3Mg6$>2j643v2SFwoM*H+SJBArp>m8z&_(?WWac+XAtb2K5-KS z5LDZuYZ#|8qVg1N@AOozjg`d6oqg@Rba@^(A(nDGG#g?RN<5WV@m9VSz$J}Vm<}km_99&ISb%Af ze`gN4{+RY!Y1J~=CGw1UBc$xwcpX)~gc}boc^dZSgE{>HgQX&bKh4~@dd~iJKec$< zk9%gvzI^VX;l8D$u6QVz8HpwS0zUYStTK-&tyi(-lFJQdWGV}nNlh9Ul&O7egh3GJ zk;S;;>?oAW_JY2Um6War%LyzIzp}HKd11K?F2P9w4;+AJcfc}DQqqO2I0vuU?6VM3 zupbgs`)Rm1?OBfOfOSKKoM8b?EU_O+UR6Eu;4$Wl8`&eGT@>P@Jr4vJhtj@N2=6=l zh%tZ`76kEx1A9m$nC)==wnoqwaQWcL$PRauAVDM}!$oZ_RDeu;x3An|f;%^|OJi9| zLZJi;UK88Ls{k<$hPe|~p^?vXgcYq`;N%;|2&7|z64g!D)jcvSdmOJ=?eWvGsp__Ba0&W(EPOk9VAD}*? zW3YLbWR^QYiKQlq1fTGha)V0*)4ZPX1AZZaOt-j_iONEJ0uB54sKT#JN7SP|Ys(w( znt*>vqyJ-PVJ>#7hE>o?K`vV}c3E2)lLYfKGb*1I&U&IQaq_>|if}T2w*c%{3~2ik==GvZEDqAAv)NUm3VJaD3Wwxhwi>|a!P~JGr1majAZWR zT1em=BJrT94Rci>mi6Ld>mSy&P6k9rQ!{~oKViU}omliyA_^6!A|YXPU+y82Y6Lcm z*q=5`Xln}#9~WwQCSdYIhhh6Mw>Y=IlokTd7ac=e)VrxH_MsnNW znU}-Y2}GHI%iEa%Z}(-{ozx zTOSi2f0p+7vAIx3*ITwU!nYH+iJO0?Hie|4klrib6zVDflbHzFqsuMzH_lnrh+NHb zn&5Ph zH!YN$fS2Ej2_)#mxY;(-#fu`|%KlDlo7FGD{W~&t=awdv7G z)t->a^L?pp$s&jCS@8+DHzEB>BC&7rmrLc%Km2$iNIp)ucG(L{;9+3R`!AXttZF}xv@bPw#k&I$KN0WD zrD)05Psk=_x}ISFkY@e8CR^V-uSN-U!wN6)^WnwhA-PeXDk8(g%&VaJ=*djH)Uq-) z4+h}`pzyW93Ev1xkC6>*G#UImNOv9il4FuIo*^~LvTR)T)}tEMi5rI{eF}+T$=Bzf zBxKxFoum{g|j-h|l*q4@;jE9ZZ!}~YtiwQ`O3j>s~JN7@nWxZId zSI94Uc$1KnW~>nVZ;S#ozJz8*Iqxl)2C7w`@WEau>-0zj%;!qV<1e}y^2-rF6j4(8 ztBT47c|NeyA;)1{Wm5;R_iw35rcK}0`$;6rVYq(dFb~_Ev-lC!D8iLT(CY$L|IcQh zoay8SbF(A`8t1}@0XSqbUVD$C@OZ2E+}4CqC1xnsB>oceLtmipi~A!uy{qKNX@jDv zr*KZXT+$X&Mt> z`&&jB&#d1WXMz|p7;ry)pb_NSLA3)(UPLcTXRE~eT(qY(s6G~nDkmzp_?G`w6l9AF zuZ*xG1%@h67(X0WHk~Jv`}QX$vU^{3MX^UWYFxj{K$!AzX(+)cY55=M?aR{hSXIC? zAMSo<4%jYU8DGiPneslZ@M~0J@?w5S1?40Plw{txu=o_IKUad{c^oT6)nH>#&W#Zh zQT;dhA&Ymb@b2Br1Im}P@;&}V!NeXSTU9cFR*Sz3f?d)P_A*Al|FSOJ0EYV8Q3nNF zsa5it>pKN|f8svEf$N^$T4nnUAhpVT@4aZ=RMq=kN`YpRA>-KBpYeD2W359<(g(o5 z5TkFIs@Q5Wtis>{X}_nOSKRnv^;9#-)YL&qdwTSJHCi0o;j27=`!?TwCZZCDtmNUC zxZ_cgvyV2s{_{8|uDtTqIaal1C;@{=fh$dcEAJ6}Au@~HYJM}*y5EfK?RG8bNl>!XW20_E%-UR5@;m$j-=2!^i$X$sc@C8{5k$9tdkHXmVky^&4(SdprZ z&Hs)O38L`xP1{+R?S?{r(Uhc$C5!zd3CBPQmw41x6s_AKgq5MI?|~s37c(=TD6na@ wj7n0#=LRY4l-xMc|378s|FH#*MlTQ^ksJ&*f<3vwd?1jvy1rWN1KX(o0ncFEg8%>k literal 0 HcmV?d00001 diff --git a/tests/data/test_simple_sphere_light_flat_elevated_camera.png b/tests/data/test_simple_sphere_light_flat_elevated_camera.png new file mode 100644 index 0000000000000000000000000000000000000000..e2a19c837f4d08d25161f1e027803f82681ed6f4 GIT binary patch literal 18420 zcmd?Q^;=ud^FEy5F2&ti+`Ui;F2$j^OL2$d5}*`^Qrz9$Deh1xUL1-Q3GRVl-@LBp z{rM}N{Bm;bIXioH&&<8E_sm48tIA_zkYfM<0Bl7C8BG8H0sa>OfcD?TpIo{I0Fdug zl#%@8lYN?FSm!vZ0iIsr8;zIo0R(-bMF1iIzkSb2#2<q3EdtrS##hvOtH+a!Nol4W@eS>Tg|sgj8R!1Xnca&=Cf&~3|1~901L33n|KxX) zJ8CEdFw-~jKW#Z{tm}P&{b>c9U=jk11R1UZDfy9%_?ZFiq^&8pK{FVxEcpPlGNb5r zDEK2dFcbmY?W*!-6D1HJ3Pc7Nx{FGgB7O!?z}s6S4;}Kc%s#6F>fek7f@~4=PzeF- znC`(y{6KaL48VTaj447gP#F~&z$babg>4UDg!e+Hg)X700Cw{M`85Ats3iy_gqQ(` zz}a3Qb;KDU27p}R?3{)lcrOVAW6HdK4j~8NG6TxNhrLGXh-P6303w;!FJN*2HX#6w z7FLJ*2auc(*doJwUV_WlR0mv$;XSt@+5-@n0TMTx%;i7{;5vU~kOcNRFEf||c@~$Y zJY?;k^D#;v3hgspnC_0CIg+d4MNSyDbVwEgvMU=ctRL5YORf*elYsWQ0No^=A_TuL zWiSnkft#@1nG|Tmrw~l~Rrjz1-Y@53W59?M!{}Wspcbzfq#TiU#0qcLzDOCEgFnp( ze;QXISPy8@ve+F+=G1bbn_C zGQ;UmhxdG|5WGy(^_>I*W9w?mGK^}#!8a-&z=lfkA>TgGa+FgXZ|Cu5Hi=3sszvxmj-d2w(HIxfLzZLfa2P8api)8VWxBs z*c#a;SSBXyv%WY&yr6T$#4w}qfdJ_Pa128bmH{T-U@Y-y6$xN9HkSw(EpG&t56}tM zrR)RaTEVsJxPJm$J$%|qMIFG_?|{AaHQ}>>I(&2?3=A{u!&v}-KT{J9GkADg8|4s; zYlEx~m>5vPn!(PSp34Uie_bLm!xJx*m4w^-Ix{o4B>w<`8(F+$PXfq}P{-aUbDvL) zfXq!}q63Cp=ySj|5U?dD2~_?TgwZE;;iiI$%uQgT4wl^}$in=;dtp83eHa%7DoDuO zs3sr53tJvp$p3c_{DRguele>8MCQieqVHq8fW*PI7SOJt4v?Nf1uo->_erV)o-eL| zjyQSSXv_crTrH7j(4TR@Ienwc*#DDrBawh;rhB1y7lL5>O&8lN;=~;u~53C%L#e#s3~qFE|IdB45x^{7(WHs~ZJe z6*8@l_`f0G+K>!K8M6d#vvs}%8c>jr5BJ^ba9m^9@aeby{jL|-v7wN;H#Y)!LnCEi z0TDLN`RUYW4}ckhFjnx`ivaJ^id@**R+oHGpV+w2y`WhhL9{i`DGBVeUJmt%t~N!r z(UK>`?Rx`2ID~+0%)YntHhqPR&oy) zhxsu(`arYODNry9E<}a_*u70JywsZepA|C z6q=nQ9N?fx5F7lva2flHJ#m!%Vvy51x5XqhyN&D2V`e?r+q>+v2IbCHuuXl&hkt|H zwjJN+_Y#PdaL=ISk%q@F9~T$51P?lfBTc>;;W3BmeZgzPlST<~#V6t9?ezSX- zo*I3t=DV&JQcIoP*B06PNP{pPm!J4^n`eqU-R-GgAxr?*{rx-YP_szC--BF zt&USrphYbq)N9@5wrfh$ON=4(3-_4^tU`eC@7&E%CcSu?R-EX+4p1;|TP?sDTLo`@ zV`Dfs$g8VN^`N(PFrR_eyrV>@hJ^JrsA#epaM4 zkjB>$gVZ(H<9O=?BFTRpQ&{m z&nNQ=KTfzAE_&bcz1;lLHxX%>9O&)I&Pt8qXrhSm@o*%qT@%Xm;9%H>I+nKVuM)UBj1w@e9U@ok*SeZX@>kPIZCxy^p5OWU-Hwj@?- zN_crWP2n$m10^1v7+Z~$0IH3H6H=ATNU`VL%jLA)f+(H$4n{HZpWd&TayX~Uh;%%$ ztuNe5h%q3(wvdh*rM9-SYjCHnX#F0a>YtYccGSLM!c7BsFTWBt*Hk@;hJISE@?>nGkZ12$Qbra-oZzqI zAA>KFk>mFv!{a!NpcBo)Gs&Zrbbgc*fv`_|b_sG;L|)BHy6TeGVW|^hO6W!CEwYf4 zgwnVgKWYgzKnNtYY@dfT50RM@IR7d4mUPZ5Rgl~@6aBNXoIZ(j>-jI|*;BN6^Sz;p zLu^ii@9fT@3~(qGF%CkOIRz=03WqkPPrNtE-@W0cAlzqQ8`W`;;$avkSN&nd(sxGB+yYhLa`gd zJ>6fUZ(5lc9yo{~>K`kQ1pT$*gL;Bk*^pJ^LH@!2)ObM>%=eb%rL}IBt0)K6K^5~r z-;BPE6uCJ9k2V2;wCy!gEU;5C_mHFpd*h80x21`K7a8ZnvVVYRZ!*SDy!8|s8o#KPN3Z;OE#s4 z%kme$B{Iv2k~WZ69wyR`EckXjEHrrw{jF^k@w|{(QN?|(?)hF3!$T3u`+pinEWo7`U)RV|9S#_)AJ+>*cmZ8IN6i-dHl#i{+{mVY}c)Cr-TbX9+ z{q*UUlcrhVj5J<01d^w#;&Br1}M&{1L!HgX}+y0;wu?XQl z9|2Ad&5+peC+l?yv^EM@dL$vXW=a%72Amcs92`=9P=&l^E|^eY`sqlp;qsl~@<`3Y z9nqp>e5w@T$-R-U?7^>perhyRk(FIX1g$p3&v%CzIK90qt@B}BI#If=;(XNc-w`)# z@=2u-s7h{{pXbjhbi~Ru6JZ8-Z$X+%U&;UOM%mQu<6P;AGKa@?2Z|lqU&>B9lh}3G zeHN{{OFC0v)d;1*iEy->&%&Y@Fu-P;C3k+kf22$p<-KzYFsWVUP!WBEXj0EJIeqre ziUk>D8LAm`nA8k$p83itZR%ZWkcV}h&8ILd-7Gi-_Si-)z@Y$hfFf9!GPA2^>VmEp zb&eb3f$zDswO{YDYZ{bTxS`3UpKZampfr!LjT&}2ET1aklapc_CMB9KzfS!sMxEe~ z2K@@aFvFq30n2v254-AUFalws}qm8JYqcq_qFrt`7mbAGi&7tzR0)xs~U6k46K|g zB8Dthi|2xuc)u!A9{aev6rZv`|DD>Fq9iU(i;6S_681H+^Sn+7Ctick>Z(Kozc=kW z#x%5zZ{6A#))d_jVX1!e_U3L-Bi^2GkO-x*{(2g*=8nLXj7skxh8rQh?@LipzjyXtVNW}&k(XQedyj4co#MC~-hGbf z_Loc73x1j$GlQ=I-N)S?uh_}0C$6=R$9Z~IpLX6a3FpR+FMP}-W+%X%ue?R0>bz9i z?7rEIk2v0Z_f2;)2Vduu8efGU2O~csD(_N8|NOn|!aM~TR(>-DAFui3-%cZ|?(Gh5 z;x-3r(^grB>5UCzG8cneTxguET(cBS!yA4{RAw+Wme$CYsMB<}eO&f6CEqOmDV>UG z5tSLCi^^3~)lTboj3gA|Mn}>rS)BoeP$;_lCSL`@%?p_SN4ST9j`3{)glmiMzKxOH;VKllZEHcoDa2~ zmIBL(OCQOGpWU9b2-_-DmkyI+7X;JhCv}cL`d6DMP1_(5O7Kx#`9(+4G{2JW#_H$k zf2klde2Yzf*c2u|!Z{x!$|m$F#+2+vOr`ThlI?pQe$FVM5DwvZnTQg6duwduskfi( zh{gkR=ADxQ=k*N|xXPMLY%g`PH5=-ej4&T=4tdI0CEof-u)z`+qABTEGA+L)O&{|kFFO1i} z#)1(@EeNgyTY&Rr$ji=0iR?JJhxh%yHfzK*73)25fqyrSjtfrb211`5zN)B*PM><<~Xe5ZQutc4xaOlGdalzD}Nhq4O9$OnpI zN+sU0}dToIkTjP&1TB$JI)YhZ@#E{Po-$8HOJn3z39v31~%P zB7waslNj06Hlv7=bupvg9EUy2X9N(L?xf&j?*jm;j4lxp`(fP~{Hh}iRw zN)tDeaJ_Yx+@r{KC<%#wHN(GHH+(&yjU+~rtfqkuW2(3t@ zfM2woi@v;BU6jD+5xPozvFQ1l>)iNyLhmJ0)SKjdz95IPkWT-~DcWa^UZcLK&xt1p zjm47${w_Fqn)-3Ral0%CO(a+m3k;s5_ICTZBpidEeokVtuoL)ZeRtK#-sxB~(QC3b zu|^bsn9wQl0bP&Vot_zxQ5pF=pj9i>%Rtsu@mRR1v&^O0-U@H4U@*f-&^>Brw@DLz zeoFmAAdiyyD~S6dI6%bDR=I;zLV$xALmClzW^k-huZf)2Ir}knO`||dS&Zcl;W9U~ zN~pju1|=pUd203V^8Jz2MWDaz`PmJLK-090V8-{)ojV2NlaqyBwGSEJBEQmP%(Qim zczw_FiWNLzo{0XTfE5kRE~zVX$(s=LEwdqX$q~~ka5--s9`4q#wtM!x6tr*K&{53@ zAe`1)q|H@|^tPqi@)?{2toH zmaLm4uXsN<&xK2zrD4G?we?ZH#q8TOzlhqu>8$Cmbe_^NYv;(MK-R=H_N}V^}4LsIoz*#w{<*h6_qQ z?(R2m5X=8(G^2JjX?XsDZfJ7>VkYdK24}g1y}iA|5ZXjt=^f;f#x9^YN{}tC<=OPc zD|@3}oie6skAf!A**D@sV)!pfkE%8&gDg#4FakUpJH^wK6V;+YPaMnA7$Mt8H_g>X|-!)F)t2LaN~Jua5Oh_&ZJ@x8W}6> zd8z$2YFN*_qD^oy55(x3@T6OF>!_b4>)?}}`TcR-bYm^CX)MFOnX7u<)aV$K0wvp? zaDVIX&kXUFEZWSImcBif?k~OWQe?uCzz@CyLokHaqpZRCUA#boPP zF%2sjcyvs%a zT+y2Byt;zBIa=FKgY;QVT!a#*=KaT(N`n~Pwey6lH&9iLWrnsf=is03#Qg@Z_CI_f z!0f&rfjpsZ8c?Ij0eCU>4Vo1K4#vFCs%fSnuqV{Le9^U*?+gWweYMe#Q0u19QI8IR z44yJ&EU&-i?pXZd(mwUtZ73M|8KulYy~oS-E?9z$${z#=_zl_ks~2)(a|3N+Q4^HK zs1m0Z5Dl6dVCR}Jq=}Y*=a5%9`sf=YH_LhZCf^Q<>e(UPtL0|N)=%>7pD z%Nw3Akr14={z$MFInl1mtmYpW=TQmZp*ZN%nev>7QJ;^*bx-=GSpqKI5ze#TXXzZ!wo_TkC6;w+s5y zlQyd6ZJ@4)L)hn+{38&?d!{;M{%BVGlcR#2LK=))}MxqSahca2y^$N?{HY!3G8Qh9;@;V6xZccQrTcdAVNpfB@w zM;g|s?8!$q+^o%&L2}V$5i9<2Kg}2fDC%b5An5bg)9yEKr_Fn^N&T8$Y^qHJ9C#Fl zkJLh}`b*o4@2+B}6EXu$ir_%`A|Gi1=(B!$Y4<`C7O_h=k{aPYVa=`@U0e1N>zA_O z7-Ll=3Jlzzd7%i0#b5C9W$R*728 zRWQ{FW2&F;GVfYV;@#4FQ7&V{NIq}!*tAe_Y2TvMz80d$=Draxj~|~`(^LBB)q7K~!g-aPB zuPP+Ve?84k&6Y4v;*c{kdfNxyXAz0m{>fNGmEnnz;@1%rRFy3u?1%GzQVfXryF{m= zEllGc(I5MS)c3ps0Trh%Yk%dFLRo&CjORRyf7wAdcWn9TKa?lf3-S3to~RMF@8Xu( zhWhG@8!>q4!*BgNZj>;kZYWt~LM!q%(xoWL$?|EJDR=@0AP3+q#X=CVHMnrFQo-$1 zAne~rDu9({`Ht)8nvz9r-V#EU*5ZZpZrXz17-cSU@X6i&tjMN7+&eY!BMwE`2*zaw z%k!XER>}M`NdJDn=k2L`x}vr>vfHp-Q;$8qFuyXw6U-+gmG{S1`^%1OJJD{)j}d7? zlBjB(!w4p%fx0+B5rI zk7%Ila4AG9qPt7%Z#Cmo8#O$*4$;RKx0|dP<#I}*n`=fwzF>w`EGm|A?KJdxD&cja zycWM@VpvJYk?I=EUF1Knu8^9Qup+6}$>kk(CQv*5sXXht$$y>i_pdeNwJDXq9l==y ziy!Y>$V-}k#B-!Q|KEJQwu9sp^T;5AR!}iK;B)3G9Baj{5Vut0FUEx7gc;|VdTR>j zgz@qkJ?8RgIJn!mc08z_Xsp{sPC&7MT*)WvRw}Y$nf}$B;DPsO<&8H_btf_}rDseO z<$RB&Iv!eEj+?REMUUszl1AwHF7N(?*Qc2P2qjcWAEN6N+87Tx7lgX2SPQVIW(YSo z2=8_`Ma4NB-kvqt4R6FhV9%!|xa8QIfew4pm>j_A)oQ9JY^n=0g z57VpqPBezy|5kMa#Pb?nkw}40;N?+V&`xj37d=YIW7Ag{;!|0ycFyyn_`&JeZ-82K z@mY6G%348{Hi_-ANtuIpD(8*N3N+58ayhQ~JnTS;CGcE7*@+B&)dzVA3uDXw#N*qT z)ZO@^0$mM?4a=mmIp)sdgQ@5m=>5|HDA=t|Rx2#0=={vp-AUIzk{MM_>+GkGtQT`KevlSbg9^WCvJhRj?u(3% zN*Nnj!*8P<-RpuKf2HFa>*M70L`$R zs2up zXbALajr`_6eM1x6E54>TNq6K)XGQ(K21up&wXxA)&%otRYS~Wtgt$f&s3gHFIQtpu zO3$#6aPJz2LP~MUA*b|#ECdUgb{~OPuv9uVKSi28e>(8YVSLwVV}~s# z{gAeZ&sXoo&34wLxF-Wk1?q*A>WVpd!E9sw=I9h)LY}W&Q=hx}*J5&w@inD|h@xmT z?P&CDSoPWNJWXxVXChs+M0eLE9!tk3R!<(emp|29^>-d&T|pe~V78{;?}mL@1K^-m z+IqQTjzo)R(te0{=iR?(^kL5KOz3cmg)rtu-yc517h7*0gz-(*%UJ#JT+ZT+t#{U6 zZ^fWEW)yP4U(ey5eT008OtVr18<>b#Q?IOT zqSy1`O{9Jk%vl+$Vo7!_-Eeh2Src`jm%X1&{G?jtitga=H=mtJ%fIg_GF)Oti%g)d zDxwfu8ullhdFR}y`kMG#|I4Kh`{;|7l&l!G!nAnXl@EqfP=*RL7S(A)N2#h?B|0Ty zO}NedvHoUmN1DK!@%*O9OW82~x$G-GSmL)mMbMB${`EtENY20*Eeb-4@rwYq zBz&6~r?%a6tZS?+t+1!Y{IOi`DZsf)_+r`dMec1pDE1;C0edv#?Sh_fsQ{DA3Q8Qi z2&cm@4ASkg2C{ya?97v$RW8q0*}J{F72M#b(cpj&)J46jn2*1LeB1}MUx|80cxgv> z5t(bcE%KG1M!@N}EQ4syLbMA-J5gAoKfENq;WxoQ(ejI4%_1EhmuZ*rU3l`u1NlKC zg}Q34d+ny7^UvPiI%>;NTl*M5nULV4hAJ`|#@5Tr3-0R$E{!CR0g%zIZcmAt-0)iB z!y4YhXmsNIxQ_4bz_iha7B!3X{HDi3%}v_R9z*Y>cg?BZ(q-pd^1OCmnb@W(u2@9D z3A!uT=cwdM@Ii@9mqEOR#ZvMn;@sN>s0t4)(GpxCgm0?c1!}aD8V50y?8%^b40z_z zuPyvzqxL>3P4VG;OR%B>1idsR_}Og1pIs=+W?41s)|DZ6pb@?b?+orzO!Rxqn=wB?^ ztxRZY{G$Df4UbACl#cN+@v^z-q#f8Kn3zCxfIK+7L^2nf3RrwFSfPNGJ| zko#oYmwI2gWl{hw`|^ynrrV)et7EunA}TXP^JGbhXQTT`S|?>bbV_m|5Zt5T-%i6<#*jU@WJu^|=+`!joywNK9x=y0_%9VF=aZm8R?wsbzjx z;OB2Uzs}Z@|KO{{{>e~$jjtS#J#EcuVe?w0TBZnwl?9XcRN4MxA;F9h+xheE^Q#J` z{c9;|8fYl^l%;c(xbvDh+F(|jcF~MkL#?{|l#lzZYDctw5+a|@q3LP&uTmOPQ$9`1 zo#wd5RuV~e!!L--3}y{K8p#%V_1En^0Ye8ABRdDz573}<=y(+v3+^@#P^4z+G|l8I zeCXS>fliaA3*a1%(U2sZpPf*+!L;^)Lro51fi^}T7rS`3Y@+A&vhz+dY`AUpFxk^q zYV1}69C8eF5>CK>)s3iV3r&B))!UFSaxdqZ}sf<$~GUV!3hp z>2Vs|d|jlH;ubRMtfNXTghC@_l)AyOcELO6o6j4Y99$S5Nr67^b5zxGKj4F|EX~Z&D(( zEj2||S;cpJQ0KIRo5%;0Ouos6pfG?@dM?7biW$I;65rng8fjL?`mXE5jK&-nS9!E) z`nKc}MtyYDHcd#_&B}xn5@tturZO#`GxyBT#_}Bpho*Ym^n5|`15a_m!v)|whU1rr zV93j(Ins>8iA7{*GpF8h;g1FS!fCVTcWS38i-x9MVv@^byp-q);tB9rlG(ka+@Q>5V41zH{b?zbn$ z<6}3DMl1O1Y?cC;V`rwOqZ9iuivq0RV6KHJL3`2V0Fy)g){HF^_cT28V|Fp;d@IxB zsLGV)_i--V?SDJQ&zFq=0O^E;a)r+gNRB+`FMW8c6s{{?u+2!Ssd#9{Y} zUd^W?@b~^r(tAIA2w+AGY0ND;L}cD5hVYE^;)|MoU-aiV#(%smk(Z>jvi)LDK@x}a zX3ueT1w@_jG6aBsX+^52j5vksQBJDUT}ZFPhR38mS3J=;LCX8Ll{nzWGOm}b)Jc(h z3t@HgbDBxhS?r6o-siPFUhgf-r16i^36qnS&FizYE~dQ#iKpuG_V3 zvwtfx#z$Ds@xdI=tP1Fs|9$b{um3srHyD=O<4w-e%)CWk`h1MZunTPIkKt4N9Hnca zM&##S5i_Kj!Mx{{yzmpAZDbYrqdJAb4lI|TuXj5J+n9Y5x3Vif`5oQH>H>-J>%#MHZ&ZZ8^^*c3keCPMc=*orPw_O>22^`Tm(dWhZJq5ew=_%T z^|RvuLI31~-ilrYzZ?Y51@EHeJAxNCy*u<6sFO4n>awV^hqY2U99szOS~z$RxU>6K zzJz0)&fdHpRn?`EU{aF;ZwPc_lsg|X07RfGF)TDJsEcR2wif!Z@wd@sujdt5(Jd!@ z&)}a2uBhiHKPI-b*&}E`n`?RAP|9w%G55zcQu)g+WR?DQ6q~S*oX6)q)=x#E{|2b} zyL@a;GX3KIFqhUt*pX^s9jW=f#IThLSR4T-B!ujp|9RFiI>gT8m+lTr_WkEL36TXL zNY1>TzrNJ98xQGa7w`H^R@2l|p-jcC{`i>{3-~CI;l$zr`WiyM@rthwD0+Ixd-e51 zE+Z~r$5Ze`8~v!G4p^NtplbZPZ0mzO=a!IuqCgp+eCtCffd^wb4}N{55%aq_>pZtd zW|#)OJ{8F)yue4}TBAV0n)baxB; zA>%Tugrw8QCM9)SB8zjvFgt@P!84~HZquNc-eiy6p(MXrPAWP9j@}Q6J!w|sS?oa! zkfG+A-2td5GGQN*N#N~*=OKBG#2%^OQ4Q=>S9tM zMy=xL=Rs~$k?iPsRJppY8AOQVk+QKqubtF1^#K(ELROaW7?Wae{N!fwb?}}rq^2;6 z7y4xrdj1Q}qJ2KTZ+SgTc(HwY-Z&3}r?2|Z?ZFarbNx5Nk>COJVdD^dNZpqQ#GPu$ zA?lepOYs&`?uhBru<6(HO+7*0z#6`YdJwY$){Li!k0D}EWur;9@%lfx)}*17Q5!n3 z6`o97THz4BpQdjFBq2yRUC(!G`n$U>lc4M`cJrsJ*q4r;`|kBNC_Y`Ca>eRoj`p8Q z!k?2f3gh;u!>pfg%0R|~=2RXtaDSA4us!ZrMk6% zk+-6spM6r*JDpbh_4RB%Xnq3xrzuW9O|Rh7()|O}09q|M$e{Xutq$x8_Da`c8){_0 zkchFEYlV}0+Zr(RxZXtwz}lg5I7#FFt(S4^5b$jvkL9RVPSQ&mu~{pC!Ky{54* z>o_d)jAQ+$tiOCl3@x_~yA#{}vMzEvsRAF727^;i)l!+IPh2Wl)|Bjpm!HShdv`FfxzQNF z4*~mpdHAN6^*lX073mbsw`!xRhB{sv&x(xMR^6+EO-O^y9IjQ;A)W&LiGCJJ<8skU zbJ5nyeAFR1^0pcVoBwVMBV&S7ZsBK}X}GKj{)bXYAo~DYUDumoEs;NyZ-cOiTA0+d z)4~nVG{*U{Y2~a;Bg@9W+R1$(>eCE}(MKAJqjvZ)3tqlIQ#7pcc6su}wvZ2(oAVZ@ z_!>ns0r4ETZ^DBo@!$U`6DL~N(RYt@AF+h3e9%#Fpk4xI7S|)o=Rwv%gogx%3k4E- zV@nxgjMS+dNSoe&6-G-0kl9x#9;v2BDNl%;nD<_1!wwD>uJqP z<~?;>Bk^vzFlc^>=U$~pL_{=}onT=b0B2Xvj9mJp`3G!?zTF3kY8&|#3lc2wos|Dj zX-aQ=f(fg9T*|d5gcwKzBa937`H3S-{*r{__k{+)=JVp?t1lQlkSQ2UMQx`e3D^1h zsT;HH=qD){s4wo;Tx_&_bNrQPUS#I{HI0MZJu3T)R{!x15@)`9QtTyXeCC7~JLro) z{8EcnTyMM9rDCXgWNHMP-k$A&j*8 z>16hs%I2OCNDh0;02AopSG-&(Jfm-?>*;F3+>F8q0pghEDsl2MM8A;O*bmzU zi!`MyeayC}eJn8lL!8CKTvR1wMO96TF?9b$74S5bw&Kd8}C@SAxcyK^ixC(+8VRlbh^@2=o2uU2iV; z#3^yi1$A1z^kPZIbuw^=oBr00zS&5w`TW*l25mK)>i~h*%|6`iH+(?|Ji#0DYEYEV z>QvIpYIGAhOw?W3y=}1jN&;GX)hg9}B+jxczk#3uX0p+9du{~1gx;r#UaXL9yUfE2 zp^le1Wyy^7>3xhy`p!&VTi0C$n@#cdFUCHnA1Cl7zB7*Q&h{QnzzvRKKPQ!QMHr#H zHq*hy#XvI1Rc}jqe|MePnr%G0&n^v=bYM1w{~mhn4TH- zj*VdLvk(d1S#2p|7TJr;%y`paa6SywX)oh`nnJcsX2Bdc*0yl5Et&5Wlj zE5~{CwLf`2T;}X~9x6-PQ*Zg`8STHrK``-ylr-}EmtxEDu)5SkjAT@gmNJsZ{5*da zASgU06;>G{Q8=pItm0A#XzZ?eI5Q-Ql1tjUb0G-6B?Amr-4(fsX~{yxIeB>AFl+ob80XclR_6`mf$q^odgx={$F zqq=qy$ZAuBl8%zTmYzGo^t5X5FPNe=Biy6Y6$y637mP^ENs^5>m}@pux&f zI@l`#2(7NBlE|q`?bc7zkvq^{ca6c4B+7}UM^6#DXCYLA8P5CT-k0P0R< zG$l0cPfNQ)Q$!A)~Ge^kw`t^TADFieIHYGp5 zAE^q;ROb2?K@s~R30tf6(ZiMl_;1&}9MuhIUfFGeYN?ZOFy}Ns*M{=)D`3X{!ePrQ zUZKy%FpKtR_Z#7rdnc-iIn?a-&BcDQo)0-r)IV$L5bGXaIyV+fzL`)G=spBn9xHh`MJr|Oq2aScJIW-PkV z$X!2P{uvf%Cw)W*paF|<2T?xl#BfE&eO2j0@8U8ExEvOa;T6ao*$d~BFXVG6ck@Kx?%AY*{^eRCAy#&(4gH{3W!Pwm@jx%l9)5*m>%y%J=-shJ?VAmlyH7;<<&e;eW|5D+3{Oa~tM7rkOvl zty`+Hd=^&L_xo2%XD;)_ngZrG&HFo*VW6O=`_bA+UX*F5F`Z)W{PLj9&6SI4O)k z5bS57oRg2y3fC%{NhTZclCjqMZF@|)iw|{d;n99)lU-#=#C>XDaBwU_Vd>KB52)b_ zvg|lrh_Xy!NGo6)fwF&VN)TU)VclRQn>x}NC-duY~awD{p_PpHI!Q)ttDi5BadL7Z= zQ)13A`Caps)Hy!((t^e4wVY=OM>+lV$hnpE)E^c$i)?FTqFWHoV8(A(W4~>aG^Yqz zaFx10Hb@QbQ9RP`eXW~^cm@Y2>>X$>PcR{=zTCOI3Zm^P6BG{^%mYVy-F8LzzRK!@ zdhTKfYRsVRWh`JOk?V=LDe}_k`vrHH0}y0~C6)YW{wObgh2P**uf`2DCZ%~8lMj6m zFjf<|n;789b*1GfD&}hr5V7zaXGe@|s@$7O*X~?hF3BXCx`2`cK|9CZ@OQZBh1Vja zjmm>yvyYFD+*s40p^d?OMNauoV^xN{C@oS(k0pv*tr&Z58h>@eUJZCr!s1UNc6`TI zlzwze?X3}+_L|*29Lj@4=7N@ylqJ%};R>19fz)U;IV=4gqsWXQdehf>6FQCOyvpo{ zkHQog)q49|gJlsU#fVXLn!&rR^!xS?Il`x$%dFBHmZkeR~Z|E}@yqfj% zH8l7YI{WT>4Li-oK&Ebf@)W#SEBB_5YK2zDyoPGM+pXZCJ6Ghn@vnYZqWM4`s+98I zjExTE^%&yU-|nO~54oykrq2teL|mx*_0MI`ohW??=1b<9Q6hzg7lfo?HD}_B3bj)6_!F#$v z!)~{(zox?Rev&`B%Mq#Y0?H@MSx#rMt^MK*7>~KrTpxNJF$l+^4wY+4Z%`a;Zp5Q6O_$~I^j z1nLuRRBW?DeJr|{ov!EGD{FmM!zy=gB`IEUc>ss3+2CQgkz2#xBb=0$FfCQ)r3fyIC>Yuzcn#L)>+cjspxH62!J zN>NleiBA%INiuvlC=i0sPV@rFdtv!a%y|$KH9yqoJ6!0Yhj~6b)UiKN1M}%*LPVWQ zQO=H(-B(4dec{v&XVPO9Wxu`|id#WSB;6siet5a!dTsP6as8YRaE@^ay3VruZo7J^ zR)%KP{mA?4>z(um<)>3hg5=@TII5gLP!l0|c@tXxl{Z%YDbd%v>9(uUb$WRZ7ZkKj z64dX;wDl2Qu1~-Dv@c|NmOqD=yb@Wcrx-YF0X2nh?~(q&sgZ1200TW$Un{b)^QhVA zV}Wmmi3b0#oOAI?`U>FqPsnEvtkfF$ppk;+IvO|<3DA?r2N}8?SyDKxo_oIc`@P@K z1<JD|gHAfhay%6$vG0b|_E8#`QJAj&R6h(PO_$(;i(g zv8&sPOC4rn4$R$aZ#UJ%=v_5O#)F`GLpwuxR7 zM9J)mEqAcG3G1l^=hoEot_^Hj(uUC5?C#ktCPO_1EXDK&=UvWr&_(+C8^DmAIsVqF zUm3R_IUkkIcA@%En-_iyl2qe9$jKFTjjwaBjbcWNPjY%J%p&w?Js>Pll+g-3V|r^a zKjcI-3@y#A`}tCm>jd~7gODMdFT(-IXF@Q#WNh4IlQfrYBwY?9+Tcw+QevJv>eGoKW;6F43+j13r1| zA$$hJb8LCi8m~Rp9YQ(9?{TBH) z{Z^fjJq^@g@KmU#W)j^gGKf$hiAg`@liq#if8Q8Iq{ zVZrDmv*k9))5lmkzIQ55xLhv0b<}~n!3NSV>ss5L+PZ7b?$OwsaM&IvE6XZQJxCPY zjo`o2nXQ%}zCBpv-Ab!e2!=lm@?QS88|QKX`QqY~pTe5o!w!;UsmidOe`>pqyDxA` z>cnVBQv8u~<<9*>@4I(B>`W(CWr9d=B+VB>Vxh99AbV&JtPXdec)UQZrGyT)<7!*D zgT8;bb8l{>O5c1j*H8!_K*!}_3Eiv|v{3#_|YmvE%Wv#`>+vCAm33(x-k1R&!@=vW`nPD)PkldJ zWeh5E&X~`2nMv@X?!-b-0wrc1u@=8kAx6Lwjo1CD{BFW-t>s@Y$hr*BGKAP6L{FZk zXMpAJ`;;d5GM1x>bB~ri3B?~4GF0A+wcP5cc;fk=UXX#Xg>+rfDQzb`gV1h8wA?V- znHfi$NcUg;N!8b$`r7L!i#@F}hODvBkb6p-RyLm6b8kg8(NyqVDOK1XBWSUIWSl-1 zpn6CKe1S*vE~W*SVJ0rw7b#f^_TXhpcF>5l_n1r-rV;}K;0gdrF==B=_L(>8-+i-l zP5G?1jd}So2YRw!#B7!l?o#%n4|E&AO?FOEOd1M~6X3HY94|L56`7W 0.8] *= 2.0 + negative_mask = bary < 0.0 + positive_mask = bary > 1.0 + clipped = _clip_barycentric_coordinates(bary) + self.assertTrue(clipped[negative_mask].sum() == 0) + self.assertTrue(clipped[positive_mask].gt(1.0).sum() == 0) + self.assertTrue(torch.allclose(clipped.sum(dim=-1), torch.ones(N))) diff --git a/tests/test_rendering_meshes.py b/tests/test_rendering_meshes.py index e6c782d2b..93e7ebf75 100644 --- a/tests/test_rendering_meshes.py +++ b/tests/test_rendering_meshes.py @@ -25,6 +25,7 @@ from pytorch3d.renderer.mesh.renderer import MeshRenderer from pytorch3d.renderer.mesh.shader import ( BlendParams, + HardFlatShader, HardGouraudShader, HardPhongShader, SoftSilhouetteShader, @@ -99,8 +100,9 @@ def test_simple_sphere(self, elevated_camera=False): images = renderer(sphere_mesh) rgb = images[0, ..., :3].squeeze().cpu() if DEBUG: + filename = "DEBUG_simple_sphere_light%s.png" % postfix Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / "DEBUG_simple_sphere_light%s.png" % postfix + DATA_DIR / filename ) # Load reference image @@ -117,8 +119,9 @@ def test_simple_sphere(self, elevated_camera=False): images = renderer(sphere_mesh, lights=lights) rgb = images[0, ..., :3].squeeze().cpu() if DEBUG: + filename = "DEBUG_simple_sphere_dark%s.png" % postfix Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / "DEBUG_simple_sphere_dark%s.png" % postfix + DATA_DIR / filename ) # Load reference image @@ -140,8 +143,9 @@ def test_simple_sphere(self, elevated_camera=False): images = renderer(sphere_mesh) rgb = images[0, ..., :3].squeeze().cpu() if DEBUG: + filename = "DEBUG_simple_sphere_light_gourad%s.png" % postfix Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / "DEBUG_simple_sphere_light_gouraud%s.png" % postfix + DATA_DIR / filename ) # Load reference image @@ -149,7 +153,30 @@ def test_simple_sphere(self, elevated_camera=False): "test_simple_sphere_light_gouraud%s.png" % postfix ) self.assertTrue(torch.allclose(rgb, image_ref_gouraud, atol=0.005)) - self.assertFalse(torch.allclose(rgb, image_ref_phong, atol=0.005)) + + ###################################### + # Change the shader to a HardFlatShader + ###################################### + lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None] + renderer = MeshRenderer( + rasterizer=rasterizer, + shader=HardFlatShader( + lights=lights, cameras=cameras, materials=materials + ), + ) + images = renderer(sphere_mesh) + rgb = images[0, ..., :3].squeeze().cpu() + if DEBUG: + filename = "DEBUG_simple_sphere_light_flat%s.png" % postfix + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / filename + ) + + # Load reference image + image_ref_flat = load_rgb_image( + "test_simple_sphere_light_flat%s.png" % postfix + ) + self.assertTrue(torch.allclose(rgb, image_ref_flat, atol=0.005)) def test_simple_sphere_elevated_camera(self): """ @@ -287,9 +314,6 @@ def test_texture_map(self): materials = Materials(device=device) lights = PointLights(device=device) lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None] - raster_settings = RasterizationSettings( - image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0 - ) # Init renderer renderer = MeshRenderer( @@ -327,3 +351,32 @@ def test_texture_map(self): images = renderer(mesh2) images[0, ...].sum().backward() self.assertIsNotNone(verts.grad) + + ################################# + # Add blurring to rasterization + ################################# + + blend_params = BlendParams(sigma=5e-4, gamma=1e-4) + raster_settings = RasterizationSettings( + image_size=512, + blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma, + faces_per_pixel=100, + bin_size=0, + ) + + images = renderer( + mesh.clone(), + raster_settings=raster_settings, + blend_params=blend_params, + ) + rgb = images[0, ..., :3].squeeze().cpu() + + # Load reference image + image_ref = load_rgb_image("test_blurry_textured_rendering.png") + + if DEBUG: + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / "DEBUG_blurry_textured_rendering.png" + ) + + self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05)) diff --git a/tests/test_utils.py b/tests/test_rendering_utils.py similarity index 100% rename from tests/test_utils.py rename to tests/test_rendering_utils.py diff --git a/tests/test_texturing.py b/tests/test_texturing.py index cd3399db9..aea04553b 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -8,7 +8,6 @@ from pytorch3d.renderer.mesh.rasterizer import Fragments from pytorch3d.renderer.mesh.texturing import ( - _clip_barycentric_coordinates, interpolate_face_attributes, interpolate_texture_map, interpolate_vertex_colors, @@ -94,7 +93,9 @@ def test_interpolate_face_attributes_fail(self): dists=pix_to_face, ) with self.assertRaises(ValueError): - interpolate_face_attributes(fragments, face_attributes) + interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, face_attributes + ) # 2. pix_to_face must have shape (N, H, W, K) pix_to_face = torch.ones((1, 1, 1, 1, 3)) @@ -105,7 +106,9 @@ def test_interpolate_face_attributes_fail(self): dists=pix_to_face, ) with self.assertRaises(ValueError): - interpolate_face_attributes(fragments, face_attributes) + interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, face_attributes + ) def test_interpolate_texture_map(self): barycentric_coords = torch.tensor( @@ -220,13 +223,3 @@ def test_extend(self): ) with self.assertRaises(ValueError): tex_mesh.extend(N=-1) - - def test_clip_barycentric_coords(self): - barycentric_coords = torch.tensor( - [[1.5, -0.3, -0.2], [1.2, 0.3, -0.5]], dtype=torch.float32 - ) - expected_out = torch.tensor( - [[1.0, 0.0, 0.0], [1.0 / 1.3, 0.3 / 1.3, 0.0]], dtype=torch.float32 - ) - clipped = _clip_barycentric_coordinates(barycentric_coords) - self.assertTrue(torch.allclose(clipped, expected_out))