From f8a1e6c7e7171a1ba345a99ca3f1de8435568f45 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Sat, 11 Jun 2022 01:14:52 +0200 Subject: [PATCH 1/6] Add CVAE in Flax --- examples/cvae-flax/__init__.py | 0 examples/cvae-flax/data.py | 42 ++++++++++ examples/cvae-flax/main.py | 97 ++++++++++++++++++++++ examples/cvae-flax/models.py | 94 +++++++++++++++++++++ examples/cvae-flax/train_baseline.py | 81 ++++++++++++++++++ examples/cvae-flax/train_cvae.py | 119 +++++++++++++++++++++++++++ 6 files changed, 433 insertions(+) create mode 100644 examples/cvae-flax/__init__.py create mode 100644 examples/cvae-flax/data.py create mode 100644 examples/cvae-flax/main.py create mode 100644 examples/cvae-flax/models.py create mode 100644 examples/cvae-flax/train_baseline.py create mode 100644 examples/cvae-flax/train_cvae.py diff --git a/examples/cvae-flax/__init__.py b/examples/cvae-flax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/cvae-flax/data.py b/examples/cvae-flax/data.py new file mode 100644 index 000000000..7a8b69d1f --- /dev/null +++ b/examples/cvae-flax/data.py @@ -0,0 +1,42 @@ +import numpy as np +from jax import lax + +from numpyro.examples.datasets import _load + + +def load_dataset( + dset, batch_size=None, split="train", shuffle=True, num_datapoints=None, seed=23 +): + data = _load(dset, num_datapoints) + if isinstance(data, dict): + arrays = data[split] + num_records = len(arrays[0]) + idxs = np.arange(num_records) + if not batch_size: + batch_size = num_records + + Y = arrays[0] + X = Y.copy() + _, m, n = X.shape + + X = X[:, (m // 2) :, : (n // 2)] + arrays = (X, Y) + + def init(): + np.random.seed(seed) + return ( + num_records // batch_size, + np.random.permutation(idxs) if shuffle else idxs, + ) + + def get_batch(i=0, idxs=idxs): + ret_idx = lax.dynamic_slice_in_dim(idxs, i * batch_size, batch_size) + d = tuple( + np.take(a, ret_idx, axis=0) + if isinstance(a, list) + else lax.index_take(a, (ret_idx,), axes=(0,)) + for a in arrays + ) + return d + + return init, get_batch diff --git a/examples/cvae-flax/main.py b/examples/cvae-flax/main.py new file mode 100644 index 000000000..9d78b2b02 --- /dev/null +++ b/examples/cvae-flax/main.py @@ -0,0 +1,97 @@ +import argparse +from flax import linen as nn +from flax.core import freeze +import jax +from jax import jit, random +import jax.numpy as jnp +import matplotlib.pyplot as plt +import optax + +import numpyro +from numpyro.contrib.module import flax_module +import numpyro.distributions as dist +from numpyro.examples.datasets import MNIST + +from data import load_dataset +from models import BaselineNet, Decoder, Encoder, cvae_guide, cvae_model +from train_baseline import train_baseline +from train_cvae import train_cvae + + +def main(args): + train_init, train_fetch = load_dataset( + MNIST, batch_size=args.batch_size, split="train", seed=args.rng_seed + ) + test_init, test_fetch = load_dataset( + MNIST, batch_size=args.batch_size, split="test", seed=args.rng_seed + ) + + num_train, train_idx = train_init() + num_test, test_idx = test_init() + + baseline = BaselineNet() + baseline_params = train_baseline( + baseline, + num_train, + train_idx, + train_fetch, + num_test, + test_idx, + test_fetch, + n_epochs=args.num_epochs, + ) + + cvae_params = train_cvae( + cvae_model, + cvae_guide, + baseline_params, + num_train, + train_idx, + train_fetch, + num_test, + test_idx, + test_fetch, + n_epochs=args.num_epochs, + ) + + x_test, y_test = test_fetch(0, test_idx) + + baseline = BaselineNet() + _ = baseline.init(random.PRNGKey(0), x_test) + recognition_net = Encoder() + _ = recognition_net.init(random.PRNGKey(0), x_test, y_test) + generation_net = Decoder() + _ = generation_net.init(random.PRNGKey(0), jnp.ones((1, 256), dtype=jnp.float32)) + + y_hat_base = baseline.apply({"params": cvae_params["baseline$params"]}, x_test) + z_loc, z_scale = recognition_net.apply( + {"params": cvae_params["recognition_net$params"]}, x_test, y_hat_base + ) + y_hat_vae = generation_net.apply( + {"params": cvae_params["generation_net$params"]}, z_loc + ) + + fig, axs = plt.subplots(4, 10, figsize=(15, 5)) + for i in range(10): + axs[0][i].imshow(x_test[i]) + axs[1][i].imshow(y_test[i]) + axs[2][i].imshow(y_hat_base[i]) + axs[3][i].imshow(y_hat_vae[i]) + for j, label in enumerate(["Input", "Truth", "Baseline", "CVAE"]): + axs[j][i].set_xticks([]) + axs[j][i].set_yticks([]) + if i == 0: + axs[j][i].set_ylabel(label, rotation="horizontal", labelpad=40) + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Conditional Variational Autoencoder on MNIST using Flax" + ) + parser.add_argument("--batch-size", type=int, default=128) + parser.add_argument("--rng_seed", type=int, default=23) + parser.add_argument("--num-epochs", type=int, default=10) + + args = parser.parse_args() + main(args) diff --git a/examples/cvae-flax/models.py b/examples/cvae-flax/models.py new file mode 100644 index 000000000..c19f79512 --- /dev/null +++ b/examples/cvae-flax/models.py @@ -0,0 +1,94 @@ +import functools +from typing import Any, Callable, Optional + +from flax import linen as nn, struct +import jax +from jax import lax +from jax import numpy as jnp + +import numpyro +from numpyro.contrib.module import flax_module +import numpyro.distributions as dist + + +def cross_entropy_loss(y_pred, y): + log_p = jnp.log(y_pred) + log_not_p = jnp.log1p(-y_pred) + return -y * log_p - (1.0 - y) * log_not_p + + +class BaselineNet(nn.Module): + hidden_size: int = 512 + + @nn.compact + def __call__(self, x): + batch_size, _, _ = x.shape + y_hat = nn.relu(nn.Dense(self.hidden_size)(x.reshape(-1, 196))) + y_hat = nn.relu(nn.Dense(self.hidden_size)(y_hat)) + y_hat = nn.Dense(784)(y_hat) + y_hat = nn.sigmoid(y_hat).reshape((-1, 28, 28)) + return y_hat + + +class Encoder(nn.Module): + hidden_size: int = 512 + latent_dim: int = 256 + + @nn.compact + def __call__(self, x, y): + z = jnp.concatenate([x.reshape(-1, 196), y.reshape(-1, 784)], axis=-1) + hidden = nn.relu(nn.Dense(self.hidden_size)(z)) + hidden = nn.relu(nn.Dense(self.hidden_size)(hidden)) + z_loc = nn.Dense(self.latent_dim)(hidden) + z_scale = jnp.exp(nn.Dense(self.latent_dim)(hidden)) + return z_loc, z_scale + + +class Decoder(nn.Module): + hidden_size: int = 512 + latent_dim: int = 256 + + @nn.compact + def __call__(self, z): + y_hat = nn.relu(nn.Dense(self.hidden_size)(z)) + y_hat = nn.relu(nn.Dense(self.hidden_size)(y_hat)) + y_hat = nn.Dense(784)(y_hat) + y_hat = nn.sigmoid(y_hat).reshape((-1, 28, 28)) + return y_hat + + +def cvae_model(x, y=None): + baseline_net = flax_module( + "baseline", BaselineNet(), x=jnp.ones((1, 14, 14), dtype=jnp.float32) + ) + prior_net = flax_module( + "prior_net", + Encoder(), + x=jnp.ones((1, 14, 14), dtype=jnp.float32), + y=jnp.ones((1, 28, 28), dtype=jnp.float32), + ) + generation_net = flax_module( + "generation_net", Decoder(), z=jnp.ones((1, 256), dtype=jnp.float32) + ) + + y_hat = baseline_net(x) + z_loc, z_scale = prior_net(x, y_hat) + z = numpyro.sample("z", dist.Normal(z_loc, z_scale)) + loc = generation_net(z) + + if y is None: + numpyro.deterministic("y", loc) + else: + numpyro.sample("y", dist.Bernoulli(loc), obs=y) + return loc + + +def cvae_guide(x, y=None): + recognition_net = flax_module( + "recognition_net", + Encoder(), + x=jnp.ones((1, 14, 14), dtype=jnp.float32), + y=jnp.ones((1, 28, 28), dtype=jnp.float32), + ) + loc, scale = recognition_net(x, y) + numpyro.sample("z", dist.Normal(loc, scale)) diff --git a/examples/cvae-flax/train_baseline.py b/examples/cvae-flax/train_baseline.py new file mode 100644 index 000000000..3565a693c --- /dev/null +++ b/examples/cvae-flax/train_baseline.py @@ -0,0 +1,81 @@ +from flax.training.train_state import TrainState +import jax +from jax import lax, numpy as jnp, random +import optax + +from models import cross_entropy_loss + + +def create_train_state(model, x, learning_rate_fn): + params = model.init(random.PRNGKey(0), x) + tx = optax.adam(learning_rate_fn) + state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) + return state + + +@jax.jit +def train_step(state, x_batched, y_batched): + def loss_fn(params): + y_pred = state.apply_fn(params, x_batched) + loss = cross_entropy_loss(y_pred, y_batched) + return jnp.sum(loss) + + grad_fn = jax.value_and_grad(loss_fn) + loss, grads = grad_fn(state.params) + + new_state = state.apply_gradients(grads=grads) + return new_state, loss + + +def train_epoch(state, train_fetch, num_train, train_idx, epoch_rng): + def _fn(i, val): + state, loss_sum = val + x_batched, y_batched = train_fetch(i, train_idx) + state, loss = train_step(state, x_batched, y_batched) + loss_sum += loss + return state, loss_sum + + return lax.fori_loop(0, num_train, _fn, (state, 0.0)) + + +def eval_epoch(state, test_fetch, num_test, test_idx, epoch_rng): + def _fn(i, loss_sum): + x_batched, y_batched = test_fetch(i, test_idx) + y_pred = state.apply_fn(state.params, x_batched) + loss = cross_entropy_loss(y_pred, y_batched) + loss_sum += jnp.sum(loss) + return loss_sum + + loss = lax.fori_loop(0, num_test, _fn, 0.0) + loss = loss / num_test + return loss + + +def train_baseline( + model, + num_train, + train_idx, + train_fetch, + num_test, + test_idx, + test_fetch, + n_epochs=100, +): + + state = create_train_state(model, train_fetch(0, train_idx)[0], 0.003) + + rng = random.PRNGKey(0) + best_val_loss = jnp.inf + best_state = state + for i in range(n_epochs): + epoch_rng = jax.random.fold_in(rng, i) + state, train_loss = train_epoch( + state, train_fetch, num_train, train_idx, epoch_rng + ) + val_loss = eval_epoch(state, test_fetch, num_test, test_idx, epoch_rng) + print(f"Epoch loss - train loss: {train_loss}, validation loss: {val_loss}") + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = state + + return best_state.params diff --git a/examples/cvae-flax/train_cvae.py b/examples/cvae-flax/train_cvae.py new file mode 100644 index 000000000..7e4a6aca6 --- /dev/null +++ b/examples/cvae-flax/train_cvae.py @@ -0,0 +1,119 @@ +from flax import traverse_util +from flax.training.train_state import TrainState +import jax +from jax import lax, numpy as jnp, random +import optax + +from numpyro.infer import SVI, Trace_ELBO +from numpyro.infer.svi import SVIState + +from models import cross_entropy_loss + + +def flattened_traversal(fn): + def mask(tree): + flat = traverse_util.flatten_dict(tree) + return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()}) + + return mask + + +def create_train_state( + rng, model, guide, train_fetch, baseline_params, learning_rate_fn +): + label_fn = flattened_traversal( + lambda path, _: "adam" if not path[0].startswith("baseline") else "none" + ) + tx = optax.multi_transform( + {"adam": optax.adam(learning_rate_fn), "none": optax.set_to_zero()}, label_fn + ) + + svi = SVI(model, guide, tx, loss=Trace_ELBO()) + x_batched, y_batched = train_fetch(0) + state = svi.init(rng, x=x_batched, y=y_batched) + + svi_params = state.optim_state[1][0] + svi_params["baseline$params"] = baseline_params.unfreeze()["params"] + state = SVIState( + optim_state=(state.optim_state[0], (svi_params, state.optim_state[1][1])), + mutable_state=state.mutable_state, + rng_key=state.rng_key, + ) + return svi, state + + +@jax.jit +def train_step(state, x_batched, y_batched): + def loss_fn(params): + y_pred = state.apply_fn(params, x_batched) + loss = cross_entropy_loss(y_pred, y_batched) + return jnp.sum(loss) + + grad_fn = jax.value_and_grad(loss_fn) + loss, grads = grad_fn(state.params) + + new_state = state.apply_gradients(grads=grads) + return new_state, loss + + +def train_epoch(svi, state, train_fetch, num_train, train_idx, epoch_rng): + def _fn(i, val): + state, loss_sum = val + x_batched, y_batched = train_fetch(i, train_idx) + state, loss = svi.update(state, x=x_batched, y=y_batched) + loss_sum += loss + return state, loss_sum + + return lax.fori_loop(0, num_train, _fn, (state, 0.0)) + + +def eval_epoch(svi, state, test_fetch, num_test, test_idx, epoch_rng): + def _fn(i, loss_sum): + x_batched, y_batched = test_fetch(i, test_idx) + loss = svi.evaluate(state, x=x_batched, y=y_batched) + loss_sum += loss + return loss_sum + + loss = lax.fori_loop(0, num_test, _fn, 0.0) + loss = loss / num_test + return loss + + +def train_cvae( + model, + guide, + baseline_params, + num_train, + train_idx, + train_fetch, + num_test, + test_idx, + test_fetch, + n_epochs=100, +): + + svi, state = create_train_state( + random.PRNGKey(23), model, guide, train_fetch, baseline_params, 0.003 + ) + + p1 = baseline_params.unfreeze()["params"]["Dense_0"]["kernel"] + p2 = state.optim_state[1][0]["baseline$params"]["Dense_0"]["kernel"] + assert jnp.all(p1 == p2) + + rng = random.PRNGKey(0) + best_val_loss = jnp.inf + best_state = state + for i in range(n_epochs): + epoch_rng = jax.random.fold_in(rng, i) + state, train_loss = train_epoch( + svi, state, train_fetch, num_train, train_idx, epoch_rng + ) + val_loss = eval_epoch(svi, state, test_fetch, num_test, test_idx, epoch_rng) + print(f"Epoch loss - train loss: {train_loss}, validation loss: {val_loss}") + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = state + + p2 = best_state.optim_state[1][0]["baseline$params"]["Dense_0"]["kernel"] + assert jnp.all(p1 == p2) + return svi.get_params(best_state) From 04b9ad136028cb01725377955dde79826ab8320a Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Sat, 11 Jun 2022 11:20:59 +0200 Subject: [PATCH 2/6] Lint files --- examples/cvae-flax/data.py | 4 ++++ examples/cvae-flax/main.py | 21 +++++++++------------ examples/cvae-flax/models.py | 8 +++----- examples/cvae-flax/train_baseline.py | 7 +++++-- examples/cvae-flax/train_cvae.py | 8 +++++--- 5 files changed, 26 insertions(+), 22 deletions(-) diff --git a/examples/cvae-flax/data.py b/examples/cvae-flax/data.py index 7a8b69d1f..657097a27 100644 --- a/examples/cvae-flax/data.py +++ b/examples/cvae-flax/data.py @@ -1,4 +1,8 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + import numpy as np + from jax import lax from numpyro.examples.datasets import _load diff --git a/examples/cvae-flax/main.py b/examples/cvae-flax/main.py index 9d78b2b02..af2d75f7f 100644 --- a/examples/cvae-flax/main.py +++ b/examples/cvae-flax/main.py @@ -1,22 +1,19 @@ -import argparse -from flax import linen as nn -from flax.core import freeze -import jax -from jax import jit, random -import jax.numpy as jnp -import matplotlib.pyplot as plt -import optax +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 -import numpyro -from numpyro.contrib.module import flax_module -import numpyro.distributions as dist -from numpyro.examples.datasets import MNIST +import argparse from data import load_dataset +import matplotlib.pyplot as plt from models import BaselineNet, Decoder, Encoder, cvae_guide, cvae_model from train_baseline import train_baseline from train_cvae import train_cvae +from jax import random +import jax.numpy as jnp + +from numpyro.examples.datasets import MNIST + def main(args): train_init, train_fetch = load_dataset( diff --git a/examples/cvae-flax/models.py b/examples/cvae-flax/models.py index c19f79512..cdd9933f3 100644 --- a/examples/cvae-flax/models.py +++ b/examples/cvae-flax/models.py @@ -1,9 +1,7 @@ -import functools -from typing import Any, Callable, Optional +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 -from flax import linen as nn, struct -import jax -from jax import lax +from flax import linen as nn from jax import numpy as jnp import numpyro diff --git a/examples/cvae-flax/train_baseline.py b/examples/cvae-flax/train_baseline.py index 3565a693c..3478806e7 100644 --- a/examples/cvae-flax/train_baseline.py +++ b/examples/cvae-flax/train_baseline.py @@ -1,10 +1,13 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from models import cross_entropy_loss + from flax.training.train_state import TrainState import jax from jax import lax, numpy as jnp, random import optax -from models import cross_entropy_loss - def create_train_state(model, x, learning_rate_fn): params = model.init(random.PRNGKey(0), x) diff --git a/examples/cvae-flax/train_cvae.py b/examples/cvae-flax/train_cvae.py index 7e4a6aca6..9fb1c3ccb 100644 --- a/examples/cvae-flax/train_cvae.py +++ b/examples/cvae-flax/train_cvae.py @@ -1,5 +1,9 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from models import cross_entropy_loss + from flax import traverse_util -from flax.training.train_state import TrainState import jax from jax import lax, numpy as jnp, random import optax @@ -7,8 +11,6 @@ from numpyro.infer import SVI, Trace_ELBO from numpyro.infer.svi import SVIState -from models import cross_entropy_loss - def flattened_traversal(fn): def mask(tree): From 32a1ec111b7be8e7bd1b2021aaf2d59a1ad1ef1e Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Mon, 13 Jun 2022 08:58:52 +0200 Subject: [PATCH 3/6] Remove init calls --- examples/cvae-flax/data.py | 2 +- examples/cvae-flax/main.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/cvae-flax/data.py b/examples/cvae-flax/data.py index 657097a27..09afc2bdc 100644 --- a/examples/cvae-flax/data.py +++ b/examples/cvae-flax/data.py @@ -23,7 +23,7 @@ def load_dataset( X = Y.copy() _, m, n = X.shape - X = X[:, (m // 2) :, : (n // 2)] + X = X[:, (m // 2):, : (n // 2)] arrays = (X, Y) def init(): diff --git a/examples/cvae-flax/main.py b/examples/cvae-flax/main.py index af2d75f7f..24cd13eb2 100644 --- a/examples/cvae-flax/main.py +++ b/examples/cvae-flax/main.py @@ -54,11 +54,8 @@ def main(args): x_test, y_test = test_fetch(0, test_idx) baseline = BaselineNet() - _ = baseline.init(random.PRNGKey(0), x_test) recognition_net = Encoder() - _ = recognition_net.init(random.PRNGKey(0), x_test, y_test) generation_net = Decoder() - _ = generation_net.init(random.PRNGKey(0), jnp.ones((1, 256), dtype=jnp.float32)) y_hat_base = baseline.apply({"params": cvae_params["baseline$params"]}, x_test) z_loc, z_scale = recognition_net.apply( From 33180ebeb34a192b534c28eedfa64573492f4840 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Mon, 13 Jun 2022 11:44:10 +0200 Subject: [PATCH 4/6] Add README --- examples/cvae-flax/README.md | 15 +++++++++++++++ examples/cvae-flax/assets/cvae_predictions.png | Bin 0 -> 85224 bytes examples/cvae-flax/data.py | 2 +- examples/cvae-flax/main.py | 3 --- 4 files changed, 16 insertions(+), 4 deletions(-) create mode 100644 examples/cvae-flax/README.md create mode 100644 examples/cvae-flax/assets/cvae_predictions.png diff --git a/examples/cvae-flax/README.md b/examples/cvae-flax/README.md new file mode 100644 index 000000000..ac022fd19 --- /dev/null +++ b/examples/cvae-flax/README.md @@ -0,0 +1,15 @@ +## Conditional Variational Autoencoder in Flax + +Trains a *Conditional Variational Autoencoder* (CVAE) on the MNIST data using Flax' neural network API. +The model is a port of [Pyro's CVAE example](https://pyro.ai/examples/cvae.html) which describes the model as well as the data. + +The model first trains a baseline to predict an entire MNIST image from a single quadrant of it (i.e., input is one quadrant of an image, output is the entire image (not the other three quadrants)). +Then, in a second model, the generation/prior/recognition nets of the CVAE are trained while keeping the model parameters of the baseline fixed/frozen. +We use Optax' `multi_transform` to apply different gradient transformations to the trainable parameters and the frozen parameters. + +Running `main.py` trains the model(s) and plots a figure in the end comparing the baseline prediction with the CVAE prediction like this one: + +![CVAE prediction](./assets/cvae_predictions.png) + + + diff --git a/examples/cvae-flax/assets/cvae_predictions.png b/examples/cvae-flax/assets/cvae_predictions.png new file mode 100644 index 0000000000000000000000000000000000000000..cda1718f5ae9fe0bd4b04c27245ba5cea07fc9a7 GIT binary patch literal 85224 zcmeFa2Ut_-);1i68PO4892KNHSg=q;Fm!^Wh|(1mk!ENS2q1(aHH-xW6h;vcLQ@1p z1%mV%B@_`7ga8RON=fJ;1_(8LPt-YQhVlEad9U-m@0s)d?(0&*-aC6gdp+w}t|vw)8_g#k z)Rs1TC50=5r<}6Yjk&D$IJOq)d^VE+J!Yw{8l``2uJlsUX<_xm6RulND~4>i_3qlv zOY2^o?OWI--4SQ|yu5O#-+zpy8irc73TCjihE-=xyvILeJUo+q_8Ty=KWrfqdHJ0@ zf7r$S#U%7+;k7$Mc5i#Lb{}N-Llds#{;=)bc=PZdc5r_=weLQ7{y%Ofe}2gK$2|~; z_K)8?tlbB>XR_~4WBm2T*CBos)K@8f)|aok_*oENm&<2Cd^I|s1o72$e^!XE-sQ6( zzWV6Tg7_Lnd=|vlFygZyzJ?K>1@SeE_$-M3UxpFKQ8~O2$P;N%DDu{Kn*_Zqqs7&s zi!-C9`0x--cu%owzH`UZsPR!!R>Qkj*Ikr<5TAgLh)=v-$5`=IUAMBNPS}F4+CnHOZi-o1bM3oL z1G^*_PUbW~IL&V5qnRpm6gwBq^>Be>?hXQEA#iEwiqdLI-COmpx>>wLXn&(bt~|K3 zk&j1Ee&j>n7m?#(ntW9;nwB=;(#aD8i?(1?_Ch?`s&|ScK2PgEe}o+VQfk|3ISx|R zu|`V1ow}8z2qS`{fIUHi1iXp6d`EYH6E$vzo-$RRyTT~uJT4JxTzIxsV&6T~M1S^)Z3s)t?&qQo`ur}NFY7lYk$;Nd)43SH}aB8DtAmqMjfGm_?pW~5YO z%>q{%B(DZN0%C5Dspr|dtx~*N4W>iTE4QvIhSg++DIk@TKQaiQ=l0#ErbLv$I{^#;#GR$xH^i^V5u^q&r=jO~qJrL+kJs ztHupQAE$!CT5qZN_#~0t&~G+t&Tx1J^r5o417Re4Pm$HPdVV*;H#lmf=pUN!bE;O5 zniV_x7X|El!y}37DR$N!zf6y{1UWZIO!XD>1)%OGZCMgHa1ZtFU7VDvvAeLzZp4SQ zJ35|;A|jDPh3C2>hmHg7tCySVOI*CY@1CObDuOznDfd7_gC2bKp)2BUQuVQ>kstS(5Y#^`pE zb4Buto5!`c%w4IGTrAYPWQaT(0!op&2nlL+jMnZ{$KX-5DE?MGhIH zV9o9(ZKx>ir_@a^&iSK)su<-%En6mSA^WnLsXPCK>!i`Xq|yJ|!J0>9pRlas$8RsU z2I+VTTF+)msj`h1Ebk`8WpxTRE_C&9GEWlcjvDHC*01ypJZ*8atkRN7kvq&g%}a3a z$~sBBt1NTi9;@hm2jf=%rK_Y6DOCn+k_;BnQf=miK&!h7A+V?S!=0WL#&10E_>9w& z*$oiQGsjQ~C3Xf&3%hhYiK5J_casF z%P1u>Ou#CoYS{s%d4=(j8Q%x?M%p@!{^PAVDAJq9OoOY=?|IrtLKS8A>6nhf;!2AUZyX}TLx z-=k%^JF9=jX5JnssP7bUFYmkP?#`S%4${Hy)v1s(iMqJ5lf;37xr6sut?6jf-SmL; zVC&t|F%9e8qFFAQ-$Ec&7qiK2@j*kf=Td{5bWQIHLqf>T&u7(nC#|#*8(4ENI9@|27g!q zqFFa0uW@>Oi(1)KC0nRb7&~ASIRu&6t!}0XT#@m^B!5z|I^v*cBw-!IL;#9jmX8<` zQ+?LVrmkvyDCR~Jm+`?mo}3<6Ntt_!PQ*7-suhD7;^R^S^dbp>&-?Tcv&_tbI$i1@ zSh`ZD|DRRh%~u@%lTiM(jN)?c=jDYxC?EX<1bdS#Wf0HxE2W*}*EaSX`L#pGQ$kX? z_#}}S`z&GzGU0^QoM0D8aE<)PqfBm;ShhG2898*@eadDxJzu;ha!72oKoRVK`NHoe zg--(Y=rZ5!z%D!o?eqtW90K{#sU&w#dxoeH%)%N&O`Rlg{-?By2DA+7Vcf@`g)@L)m1>#?|${v*@ z`+pl^%?L=B0CDhqSW{pXZ0yB39Z$$y;U=MK+)3gnqoVR8u@mmVsd~6h=gb=?5oqCJ zf%-1nWBI4KD)H*P_!3P@br7>8)TruL*Q(U zDp!KK_lxa1ar@TBt$S-v>xyrem6fea0V(Mpwh;*+Qq9-%%t5>1H?do`D7f{OUdXrA zIeuKm{Kf9-#g4psf-+TXa&j_HRT)Xa&Ucfm1#RYQgv!s(9}bwCn6Fo^ZuJOa_xU#5 z7FM*K>@HGp>3m*FOLVaXDK7Yq*T`Fy$wJaVFRJBEbC6vJ@@BIdk%F=j`%N=zd&^M1=6knp$q!yx zoS*W=iJE2_oPNDl!;YkxtWV>unC-L}d>msGkz-ME;`niWeSNX3?|$*9-J)9y1m)|u z+J#8%o22`Jkq)dd&}a=-v!ZzP9bv`2{mdxe+8Z0UKFZH$NlNR-$%L~j>6-bDghV%( z(3;0EsiIvJfM|O|-FkOOrv`c!78TV!P_P@oP?sG8mj@-I!Aw^Srex^H*C!xSgiHGU zT^j6Q%kKpS3rR+ijWN=`3uQ!AdQz0~C54gy)sCQ_-0)S!WyP4c&$NSSpTrIX&CKA?!^ zESJ6y6&;u#jAl%QnbnT8rrXVov~FRBeUo6O_Tlv|yXAT2kq2Dv!SA=ZqbPm#l<~1?4TFVkW}@w&l&s!q>SZ*s`9-16ExPim(fH8KeQIm<;;b&nbljhyS2=y<=$9eul5;IAEcBz+ zj;E3zn3~FmvKGZ&;~WA;*|%Im<;tZ-k2QA>A%Bd}o}lF`r++S)vH@#4Yy8!9UdD%N z;2NBD$9x($Lm;yZQM%RX-H4x?7Nt}X8k{bG<%CHbyvH-0rKEJ$S@Y+->BvrQCJ;zI zrQ^An_nfnbENn zv%R!L91Z|22pgG#FwyQefdYyg{m5rpJu~)u2HYcWuY>$bB(lYr0pk*K_wcE6?PX~5*)G=>m?+aOZ!6D4Zf0nGpNho_)dl#3 zC%zr%l5e&1m?P=YKI0snNI~!Nes|o8kDm>-I!GPNSof@QXRcL6v4!Fq7ozPB%vFeT z;iXsC`pQ>t?TVrd2VGz9%`vM=4VWsEgEjC7XuQoy5D`^hy=#^|AK$a)`>#Jv+{nW) ziHHVCbfrGKxVX4pKA0J46m1BAqj+|dsB$i87B2wn9Ztxtm@1c}if&Q3^fJVVY*CCH z?c|bue!q=u1d8)^4?PKG6hE~zSArAAKwZa5Q=qOFD`#`u`-ZDx_8LZTR+fA%sebh2 zd^kQ-IWUS>;_6Ns0YlF4<6Itj`!G58i%R)+DARH;$ITOct69>F1Z8(?ORrX) zWQnB#A>4d9TlvV5y5|=H&-hS66q9nS%x?8fB!}JIei( zKWY8jtx9ANa*t6-gOw~|}7F&=9qOhi}Ka31SHu54anqZN-w?8@U(IDcZ|`Cq}D9s zRI7eP``5p!0IeLjt1A%>!XYo!0HA0d+@OVwQ zGr*@gMYCdzwPg!83@by?>%n}8{XRUyhY3V;!!w`NH5AXcv=HS6KSpB#xv!iosub7!M9<9LDq4r8SF==Jq=VOlc} zHSbjcwWw>=l@!;N44RQ>ToVC?G=H3`H5Y+6{Lj$M61i#uNNdMy=|Y?aDOEFril--u zW_dV}d(nga{qH8_Q_)xezBJ59scHuaFAxCix7A3lpg=;^%?JQG!If};FWV$@Btfn{ zFTW{rXd)geugIGK-%hCqmgW-ggh#?J+3FC;N_UjFT?$ITw zHskw4G#wiwq!durVj!K1Ja&Ft_PM)ZX$5q3`F=>UZMBj zD_ET%=y=-VGo=(e3r;xCnWra1W8Hl8_O>BQo+p6X#z07Z-@okcaQau9k)<=zE3mv zB=I$`@7W-qCHMEdZYaFGP0>wCEeYwYSwo9nz zDcY3>vavl@rEEhlv^$9khfHpfWAkUzOba2-$n%S&E&lQMKy3r`wVm7!d4Pg`;W%O1Yl zt$jXSem|>d_Wo%AjMHBX=7JJxPbesP$*4}g1Vs+FAT6n2OM8Au2=ah{Ca_7@VqhR{ zp3-p$=CF+i88Y`!Z_*C4lGKJlAU%0HdvRhgs6;ox>~6@dTP2nzw553->10q>Y+VWl zrNL1ThiY^%ZI?yF&}}D)3%lu~$vI$f2?0K!+^V^XJZ^cxo@*G|2yEIwF35=bXIopgH!i$1UoAOFEW-@cbcBf(2Zrl4?kd2{sv`HE14eTUv1_1B3I!P!o6qKr9I6~|LTqpJTk?noM%R-X z0B+1NNeI$f?g;{BZ3(mVRIBG=*N%gc2J7A`-YI{+D88ttD37@~+XY*0b~|@h72^LW zOTMS#B#{wST0B+|E!!^`M%?#NpT63&&&%k)fsgjdo@pORDfPMt);o-ato_bcFlqg} z*&^5Ov%3+mJAj@n5ISUaJhgU&odZR=7YZM1!=Hv!R^;AwX7A<37g7pK=zIerO6QCe z|83q+liebi$nylY6u4R~rP>mNJK5}3)$%Sjv;nX%xZnQO%8pB?S+k|_eeIO^D3ZS( z^nA&ShjW2Q5PjZpkmfx8yqv5}uI%`ov_m9@2$BmycastnipC0ZV()CpfBISsqK){v z-abp|zj3`ePAUFF=oFdX%?kK9`{=*F|7aOKH>J@{T(DlAXY)0lGCFbi`-BaTGL@Wj zZ!~3)#ax}nMJ$dzJ8O_prBl}7lFK|_sMB#SccA>T{s zh>Yl5n_RDHzt^a_etK6P`C{h_Ke}Hvd>A#?6+qt#9w}|7%%6Lm_HEeo=f0-zYrfFK zbPfCG;NjnU$w$?_Tq@vhP=R@-;bx?N##x>O58)>z#XbKppNF0_)uJ6b%Zn;~kx*wM z-aPQeQSg}ep~!>rs(c-<#B!023TJr>{&CQO9~=xjh%pBTWj-I=0a1(MT0+Uz_V&Uz zgIR5bhm*hEmywy7sRXbkPs!`{{Ktw8haHRrUuu-el>t|jE-3YUfHz5Qe*eZ6-0pDX zZwJM|Sfs&N9@K$@1D_AB+X*Wg*7~F7aerI*qO6?(-qh0+v;h*Lmw%Nxpycr8er$9Y zzc}EOiM+Wgm9`5&TaUruVxt?PB~usYCM~LY^6&lSr!DnNzymqG|4mMxM6o_3%I*t{ zo1^V@Ym|@K{(T zDDE6i*)XyJ7|HDHY^UKt=3qk>z(3;{+RuUc`jgH3vK7d+{E{hMmo|6-biht;0Y%?a zOH&o@Wzp;IQ$T$=g$vRDkK4zwrXWym)H@7n3tLG>a6jwir!Go#*N~MC+amq{?+@JDKYX{Ljk}b<#Q&j|qUx$4m zFZ&MuCHKr-muohY9|Zs0D8KE-zjn#l-@+I1JfCJn4ejwbdOLKmJZJ~hRhY=aoc6s* zNJ7QkB*d>z7K6H2DILLwL(aEciz(TJ30|$-i4^fk#^9YkkjhCE?C4@`(!5%v;IM~S zT1>Vpk`H4ZU*BTt+c(jLE%TfVtI5pW%Te@AS;eX%_bz{^#R>;67Ggh&DeZ98)s8qg ziYg?>kEzPl*nc(*iQ?nQrlYL7~zO<{3E zi=*<$k1*b#599Hn1VN0L2-JnaMgmoSt1za^w+UlT75))w(>MAMt7`Xht9jJ@$Xys8 zK1`!=G(V=^to#OH+NPNw(|ObV#%4&W(TW;#{6#@@6s~`z3D>*<)3>A{43*~0+`tD1jMt3|tX44wQuz z){rOsd|^|bFa#2KMxWhx04Y~GAA#+vbP3EmQ*9hq_xK^WwChMvS?_q_^%LN#sFKZ? z44mW+sEg*j8>O%%5Q>DTMG91xdQ*@~E4nqbsH3Zx29LuDKFd1rf>nv7I-$hrqYnz@ zqHuDMz@$3fh8EK>GVH;7y*tm)JJL#>2U_4A$b1h<<%bMwI0=e$qV&F)JO?kJ@=*CO zXGFZuEhp-&JU7E<&QAJ%P~nGYXKlm9e78ntz}`#$PNVuIO*$S_2EZhf$wXMFWOZttN zVynr5$M@TmZ4$+y0ZP@PP?B*pPR@5_O@m!mhdZqBG>)PBMoeE+z#*hWMGl%$D0gui6p}IBPcGhw2}ktrfZ~D>1z5is zW#)L>F=7jSfvq0aU?UpG7iq&Q;p3=9UuWACbt_fURRWgLWM9#MOKiERciEm=f~PIs z#cjgi7n?U@^4vMjltR~jL8!5WIYm71?QRg_3Q&;GgUb4}I!K zsBxdB8nS+Z6G*Xy4Mf9PIMH5V>Ubi0Gv=qQ1LO#Z77=4n-*S|BTpXao?Nw=_r{u#&D?k^AwQ3GnYrV_z@X}5a({Q(> z$n?%t@kxI(za(OF2l0AD9k(4Clsfh!)Mds0JOw#Dmku}8+l0db{bckfHC1(1DRCI2 z#HQpBJ<|#ex2gF{cd)AHxm@N=w~^HM#}_WJeC0z8S38w!XuoEYSStK48~kT)&-+n> z##7_p>-FJnU!$vjfa(PG`_eLa!w3>t)^SZ$ag1t}sJ+9U2Y2!wA@-Eit`J@QcWov6EwFuV!k`ns3rg^j_|UD>~bn@@K8PpgHNyq<>d2 zG;_D6fzfjad1CppGexG_;VfkW)gb|A!N#fEq0z41L4kRj=L9gFLN@%OX%EJ9`2{ek zfj^}j@hcmv2G)LMK?}A{=PaAtm+-WAPJ=a$?~57S zh%)42YS2$esLFw?fMHaE!PsJsFq9P~CJr?Y9)E!~r*!`aRqqcxLwU}go58M&$R-;y z^HxtO=ZHds+9vpXo7#qUHTC9+9@>eFno#sDYWe1DNl8>U6gjoD=m#g^aswzDxIn+a zw9ewv{ZEP*m%qm_#OKPes))Dgeqy|7_#zkwNf8Solkv0+c`ISA_iZ7;j6x2~4IHt! zLAY%tw;A~IK{}xX{e$u2YH!IKgs8p&dS_AgUZl_1tR=+@vjawk<=6x!!{l5yh^Bd+ zNP)`6j~CKCW8aWEd&Qtl&=DbMkNeh^2Du~N@B)pj)we4tw{TF%*=utN%<=E(Rr#hA z%gRo?pO|z^r_cafDt;wkxk|TO&gg{(IhECM6s|ZCCf}gU^L=(zz6E8XwRCrUKvlU^ zu$$&Y%9-GD{wKQ+(m==|?_TChkr7drLcRzzl9d;IEDSH$hxe`ya$Nq2MvuNt)o*&# z!&Bv12K+1Hti`8j=MSdrSAqWu>>Og`GSyp_otf!4Y{MlEmqmZ<{rKG#9UT5I{@P?c zicT~WO=76SeTyu`vkuMGRBmVs4 z#C2W)BQEJ(@u;Jj(TxE7n2PuFZxT%mZAt=Oi?0{e1+YYrp0ZEd(gn20QC ze6O^>D)6N-s?11`w?$M4Dqy&3u+iv)ka_TwIdQq^J)7>4WI*nM*PMovuJ(W_Zcv=PQHg%W!M7w$otuVOd zCZTsb0gmd-2Jx0eV}Q>~Sln>X3xNRN#<%S^2@%(+{Fo*fN{52%KVs1KVQ${fBaS^; z4_oNKj{1%(3Qv9aR<(CwA`^~cUSzVG?$lTP2-R0i-q-2vp>JdsAOEJ*`01Pc?dV`lq<`Nw<=D*NoL6vEx0xGp*xI?f+C94}%o z9rXtd7N&q=OalE1ZNL6>>`hQ1KNC>$9|Ll_itGjK&W?_bj9++oGjoc&54aY#o*32= zSu=zo;?Ke%*8p5+zQ6{(JH|26#zy-yyvSC*8If}3`wUYwQ~_%~=Yw4*P})MkO3l1~ zCQ;qjO+?jm=ndNU^-p1sUGl1+tQB9tfd5n`0F?u?1vWZpC0cC)i*JL6iO&h-X@YJB zCEMCt2J}*TNKt6SD@@&~A7t50H+L z3sgY+Z5xiass7->p6Vk_zsH+KOK7$&FR(~}Eo)s*8fLVc8XJCEURD0HdVn*p-&f&l z%%TzW-K!U92UbS2Qd6M)R|gBJ!Ar4=Pu2+aS9P$4zn=hIF#24wtls{NS+Wf*8AVie zYSCmH^p$&Cba6XQvh(xv?GYjxSFejGj}NxX(#nYLM?(R{@=>gh9SIb*iNuw96DW$yGeQf5xAyw+p_h z3(bi0oC@Q?RrwgznYYJCEx~gIXU)t?-r(}Gvrnl{mD-iBa+XM-?WGO8v*L{2enKh? z)(0$+6QFM|w6jC?4ml6ZSmqY=)GCx&N@LhHwrAha>9IP+sR4mB@x5Z(Kc*O=0d+4d z5hXqKLm;IKk19E$91kc}DP!yTHeh%KDtF>z%Wu}bM;P;Af^nF*_tubUw1|zn zJbO!s(bVi9z=mujC1KLlqbqPYifFhLM!1D?1Q8#sY73NSY>?eeR2#!$vo#f}$ypYL!*1NO5P2w?rgY!MkGPdQROtUD{Tr$CNi{SFwa+e)>(L_1c&%0 zWVu~HVjC3CxF81QWd|Qbaw=_4QN&hr<`pG3w={1Mg@io%peLgJDkr~QWR!++=TvYP z1S)mOGEvE?jVcqf@MIm&`p1&X?|7dr7f?BKCE`I!YZP_l3NTjNdO^=m%zA;S)f>M& zQW<|W?~#cXA3r$oR+ew5meYrc2?Hh}WJV|U~Q7xn7yOVA2heB9^LBWqd# zy2dg=Pe4v!Z+m4r6?E#zK4=+AlSKRXr+?Y^AuDVc6xA0>`)b}2wRi2@4u8yp%i zy443xuDPAb8SXGC+=U9dg36*#AC|QGAw~x*+Vh+h(AdM$;$QQP`$^mZc2@W@K&J4x z+^K6en{`h7pWCO-b?@-V84|#)ctOPP6$8Cz{9k z`7jTy144}Gcidftpkb0IGm2QBGAFo1w#Y%{^s<&YqJ8Pn{DSvNP=5)}yi3rW^{-9d zV`Bw@nyOFv$2zjQ0I4Z;~a;reOJzfSfEhjYfz^pif~S%B|(wCb>-@s0be=?RC!w2hZ@8_CxK>HPz?zL!YX$4-F1Psb-ZG~ zf<`Z`)*!BHIVV8{jTcS}ABIy$)_U!it6(98;5^>_!kbxNUmqKt8u+2{qu~Z!K!LGt zT`wF)J1yJ|E!@f+7C~=YYPrv3Vy8JGxtT2tI<~q4Sp+k!;L(iuw&ET)j9l5!JulEp zn8|HTe70PK?0V6$3OI#-*tCD?;HKYaQ$YVq2De=$*{ki;!PTxpCsG@DJrJqeB`-OY zPxzkJ{JZN)z}Xjijkd)SGUBZh-71;&GvKg8H91^5JNXoA-GpKHxmqbqvSDA>jhdjmkF>`ZuFyIcfubuZl0Hm$DEx?NQzgQzO+q9qsMU=MvUDL0b2xC&&ad4J_p2LeN^6MKXqIRxOS_Lznb`K&5Sg zo|?+_vc^Gq+>k^ZG>wifj^`PZvW$)8tLJU~fn#lZ;n1Wtdt}{#z9bejBK_bLWBxj1 zJ7Toj;wGU~G3zEln8&SgP$i)~vmWcD576`3FGEoM_*|5Am1Pel`A(0>GdfpEnL1@Vqfx&=Zk^V;tiMpV|5sE zvE64A2AOv=ZmK2tc_FJ8wbB7s^nddnUceEEzhT&O5WUKBmF%cq#)E|5TUl~_Ui{;E zeRDP6X7?IeEsK96rph;11u17ao{M#Ii-fi2w`^&+X^5W?hVwvPO0!>4;BZBiCW}jc zPWy)e6*Zz!T6x`pUeL^z$ger`wmB#uU|ICFR`@Y~)!pKt)pC^CAZZ*8<6QM0&nYO|?k=-YY2%dT?5Tn-M@$B*rDb}mQx~+h+7<;@r_McJqgf`|Ai~~x z*3>kH>wmyMi>dTAr$8jr|E;jtClD5t8dCSTblfzw2r#5Ho9D{_4B*mrtm^ZYUTouw zykgm+n}kN#;yFsvhpy{2jzdnQn)UAn+g()-Ad5M9S$G0QuKfGUol5Qw4MU7%gRkmMijF zL;+$Y6zQzEy!ysA*}ZZLG^eb0@mNz~1M}j|r_n*sBPboH}8ttc;B64X~sbi1?!^8s5#ra7AAo z6sRBENoCDGo0NbWqf0fA)${LK;RP#C;_`iTj(DZnfOTsZ}=zdAy)MU#X$ZQUw5x zN;kaYz8)(w=FjHGmFQE*9Y~LH4^`{=6zi%n< zY3ntQBg0zXm^@s+OfSUxEtjXjStB&+_ZYdnYGJ6dYk2@AkWwLs?9`gdz?KCMF=}Ww z{MAax>hhpKngex29GW9vEe6%5o8snv#deL+=ry!}l~m%OO4+ClSrf?_eKHDXNnQuh zwZ&N6L#IucNeT5s$n}N+3doE2Z=R@3dbGvud6OnvM@27@gw9hz#xC0+vudndx-^3A zy22bxAb#srnYcJl=A2iqZNV6~O!xXn3xP=IW(j~IkD6(UK(EKcBtcC}%Z`pML%wdu zE)Ycu@R`hOpblk6tA~@yr|)3b=^^Ttzk{o;Emz=cqDeI7M;p@L?Bh(MAL1_v!IC>@RuSO>gc^dQ|Efim2u;S{7g?E%)D?P6@N`lHED&LScM zjx*(O)`x6Rg?M*Ik{|Pn)c^yB>ksgx*yl^()t#sTLF3pmmam86*w_MG@v7VMtPt>s z&=2i=si^9}3jiDmSQakS4+AAkS&;$yXbf5qcj-lA29usk<_NSP)`buo5cX?4Aq4XG zqh0=1k)F5jPQ>6d9lK-T1ycB`a-z)-%EGxUPVyaiK^Hq}6J}L&y)Hp8=@|0N>uMUz zJG{d72YSa+;WO#zFZaGo;WIGYqjKuz({MTmqwBgAY9}hehe?nHMb7J~4L0gMX7&q4 z>-XJcvzII>N&TyDfPOK{;ikSp*at$B`G>7Qp9d(Q(yHp+-FBeUa z$qfiM%8_UiB%^SkXaUXwtuq@j$;z4sKzXcf{`EmPX(?wEWRdZ*0MUPiYw_9};!Ikc zQ-_mgeLV_{mg6kTSLm1DFBQS4p#l;k0wL#zzO1K3o50##q&ZgVR zubijw87&0Tc+wObKXlJFE#-|CmsE#-MkW*op7ipFs^reapDO!||v?h-|Y6piKU zAScpek{=2;b>5O!T*>BRe{-NN%^jE28YqkG)Mt+O zKsugyzZZp*_8g1h{2*=7AAE(PQO(WcPPmp}{k+NZ3!K-mn;Xv-2~-z?D);Kj{A|c) zPaPYn8%o&Y%a8!+r@ggA!>PR0LIh<(VU7$6gI!mx*G-S+6!~hfd(dr6I%MSb`AR4Z zS&q1z$QToWj$5zHV*MP*sbK^=YV~d;v>)&$%o#7G@q3F(8a6|;*Z+rb%0AJ$LmoM$ zV7)`f!lSh6_0snm$#+#j)VhCVWrh-je0rpjQ({3bl$oxXNdv?;^SW{Tw=m(H5{)|=Eh;jPTyckvxt@w@_e()n?aR5MPAgL z8poJpDIl;XT^y^SJwh*Cq|6vt<3ot{bWXvWpd0-8?;cOq(8|!|yOFe#VnAC#FCEe{ zAy~jY0j2O`e6|5IUea?ADKeF~1q#tA=;|!M)2{!R#1j(D=8RX>WXcvHE_c^BE`$sN zkJ<6mZN(c;E8ir~huJcF2hN%Rb;5D^RcG74Jd)%_%xpJU1rgZSjy2Pu%sz2#{>bej zc@S4I?Z15|xzK;BqK2mHWY8e_ZCXSTVm{^4%CSLu-C!{6C*%}WiSwK~A)dDxBky$w zjXM>fzihn=qAkc|tc7C$SJd>2Jo7IB-Txv9?k`{Rx5*dwd&AmOL4L)mcX-l2K8=Db zm|S9iJRY+Sx!&q!cjB1W?_W3U}oxd(k@cCt@lbNA4XD47$;$p zI#%PjoXkY_zFW`*Ax4bk(izGGbHN4gn1vdkGff#|iNhx$H$TL_VpekeM9iJG^;almR zz&g1*Z4^yol`(5K)i_4CZq>hX52rH10y&EK3WFcO#gm`e({aWf0(K)+sLD zyBhk~YH%o#1b9>Ka%x&g=Zk~%MZD2f1!OhF?lh&BbCinh8o385?nh)^2>X~XBR0BH zW^%AjlUraV{W~@90S1kD&QbvjPvdpJy%AHutChbPw(fhEI+*KS0M{`v2~FM+Pu<#I z1CBs*HC(YeV0qC;j~NhL+$$`C7RkCzeN&FTjzS}Ru-DUqNqh3pJiV#??~ z-IteAZAB5o(Ew$N<4O(3!04=M3GG&36!Wr+0cmd}rOcCr9mTxRF#ayB*IC4D$TOLI zUM4)fn8+4B0ZC&;;CTL&CH;Mk1<`)4uE+jK%E*Y_^L9`x-BTfCA|kEfsZFsgBFGPp z_j;iy2FnXu;EFTXbn%%%3t{k>`0`+vP=cAS|2~jGPf%7mT?Pkus;c(6g+v@4nV8?x z+Pdm*K+#yNi~OKswAiAPx)}ota_PZ#y?dn#_V`cHwg5|jp51^+9RZHyZ%Tu7flQ{~cm?+1TDQAt-ZxWW6O%+ak7I6)yEnIpbXCHRX@m&`Tsks_7 z-~(qZ6XR<_pn{4_>~$*PHJc<;blS-esAP#OQ0N*xMTV_r7)uyn-Elw zsL34fQLI`2L}oG?P|eaM?&KAanxuz;f~>cK9zi;jwwCoCjNG0Fh-zdUR0>%yb_Ge9 z>0!3OSq+A68!<9I{)@oexinq_MUC3-cr>6|(SehyGF>tib^6^bM4-HY5Mqgb*^&tuGY~Zj zM&JMb(gtm5HjcNB(UpfL5#X%m`DcTeJ|SsxC9pWb3W$pNpH<5oNG>&wIptIEnBkxb z86&h^6V~s8sA!nsQZoV&w0^isivoZV(SsdIm=vF?qwjM*xDBWp8=oSa%SG2T5Op;i5;)tfPB>>zqgb06ovu|L8ERN<+V>zcTASsqtrpYxl z1Rdl*~0450S?*WWKM{4C!N_+BL z7k`^)j@|5P{16D|t^l6D4(BcMPu=&vj;4o!E%Sc4Ezr(59V593Gb`Od?oxEk z??yC|)bR_Ir$U zNx4E+MO_|`WXVp>ga=L%dc!aPV=Tx6Nsjgf?|WXe{cRC4y@!4^op4sud$VEByzBHk zoBm0}UhEHm7u0)`Y8+|aU|m;{;k#q)l>NpX+ctuyfu@k~`&^I@jV zYIjy(GyDlZ?aFWclb>58<^FPWr2@oDj7zH2dIY*sV;hV!oE+3Pn&<&4i4%AS^RU+X zMzK(52Yxd_(2*HqZfOqMHH{oSh7W;`%mjRF9q7ok7yRv@Dj16q7>fevCI;S`H_aBrL055qxqNH@6d0bofvXj#(Ft~$~ z(ULH&*~y-gKwjHDfB9)W{=ZsD*Niaa+LQl~=D&Y!>HqEevc9MoUq|^*_55$?ApOr( zqR%Z^rv&A#lQInwUVLD_eRvG~S*>Hj#@~9~053BIbi2Y2Oz>`VE+4qbcjx&p^^;qc z{dG$UdF#QAT<33rSFb+?eA|-V28iCoT#vi)(nNuAS@cnGgNuL7Jeu(X;}xU&n_>NI zJ0ZWO{udfOJ~sbq?-5q>8aWWHtlyJa-(FylwN6bBb@5(`9_hRm2hbYsNB_Th1IbHW zB_*ZjA4(Nw4K!UA?~f{JhdJ5IzW>b^{Ybl_1A6)O9@K3IzUTm4&`btA;0;yFR^T~) zboBV`AZYYI0OvPT0&eAo`R9Yc7d`kdJT_NjU#{9QhT*XY4MmA&<8;d%Z+NA_CZX}$c2eIk8Fco~1|j`Sb5^O_vp zeOv0SLEl@JiPbqP))iv#YAIt9Fk(wWUQ~?D&{-N90_Rq8s>oegMEBLX?h3O%c6DpV zVAJ^ixjsLA(vG>TO8InNNZ8qJ9+@%-M|D-tx!-p=@|q7{TKm3n@a}{J_)*Zi`Nu6} z@4t6vt0(po{rnJpo6d>7)-4rJYJ2uyJe=wAtmDubBjbjIY?T;AS`yRV>ApuM@&-#G`K z*`_GpZ-hWp7F^WOv2e6M5oNu@Q6xv_!6-#xi{tvJiCZNkWpT%gC&uB9FQkJfbFMS1 zodZtb4IIXE50Q3=eXlq?bM`^wnOcU~fu@A6-~Z7B!;Ie~{QHT2+5AP-`uFbqhhnQ- zijayrtJwKI_bmRT-K)w&ibYQpoB71L1RYQCN86BV2mBwLsuIqt!tw3uU|t<7^TrA~ z`a_{L22B?f$o!V z9J?NXzbyh_*O&5YA-*quKjR-i&pW*Ml12GjOnXSC$u{`ow$?4~nR^D$RHbX)z8@yb zr)U@E80H0ek=85Nh?&uA1iZzNQ+3g_ksy|HIsyM>Tc6ZKL$_ZE5R39Wjhn6qG>%5t$Nb ze-==gs;J0pL^hL%AtwNcL3ML9cML&qcEstwJiabc$Hi>CH81Kml7Kij20oI z&rNI>X>Vt(_k4{D%P0OTaWZk~Hy=T?C%wbTJ!&tUG3?a|*NyYkte2&o9WZyCD=L@D zt|4z73E!v_fN6t{9W2#je5`mt4p{246l&NRRZ1+WQ6XK0l6UB=V4?kVPvbtyUDZO{ z5?XxxP<8cC^2=3;|X~@bo{x*{8DSMFQVvG)akZZ>-6wce$N^H=r(AU zkiadlq$5=|Uj!PaF=DNka~(Ti#vnH?>)kMRpV=6fc_~2O#>TR6eTLE~uUw%~FS*97 z|3rU5Tg5`}jBpjciJ|BQpV8gSqB}>A;&e+{qD%jwf%L$$3UEiwagm9xPjDmWNvl+y z@U(_o18`D(ggwb4`V5-xwTUh*mIVFv2nwpS2z5sYdvfc=Y~PUnSlMOM#iQ!}0^wy3 ze2t~I*9Il$5Xj2zANyyE)y98PEQi1Qx4xK~_q)Ole#`xbqz4NZpE{m90zrPL{eiG$ zgs-~?18%!IT68K?r=kw&7RlU*`f&3G7K#XVC(YQ_o54kA zaQ7n}&lJ}kt4|m=IP%i$CRFzo^i+gZ)^c|a=A3V4br#s)=5@dz?M)+-=iF#By_?j4Pvg7 zT`8LOuj@6-JK@g;T$V4V)eOEk92#^MLBJPVx8?PdJ0FW^n6snXSVDUvE&Sa1-Ny$7 zS6(j~QL)e^)61`k%YQ6mr(UfTU?KWxG-baJuCV#Jxedp8iN@~-8OX*@MM#! zzjmk3J+(aZ#~IZw%+S*F@(%dZlg8vN8rwxRZa%pFyNlpn+F^?B`2E1-XyMK!$|0Ya z{2p>4>2lX}a(&2zt!$bhma1uFajMEDak>tD55-CE*`^fN5Y04Z`o?c~37y@)CRXa} zaV14p_UtqgV5)o98!4q3T8I1j-|4uebD8t?rF3%6EpF1~m%oWoEv?o`w7m|+e=vUz zG-@@6^Ir>oBK45N+ma^7kitx!%NCR<|Cu4Iv3kQ|u-ZDwwNDoNx{R$|NW8WQfoBEi z_%ohu3lEI+SoeN)E041sqSzQylo^7mbjlAYB@^SndF-FItjo(-9bbuMn}NTsaT}*K z(Yh36Ewm|Rdbh3aVq76%LK8dnm?0khqs%^W^3%*X<69Kfy<7OzMxvJATmxI(R`rw% zw3Li6qry$DXS1hosqqlPuNKF@dhug;vh1wUKQ^!Oy{`cy^|!CP*Z-|A{?hjnG`sxd zR(_gy%snuFX5l+j=8Z{giA|Le>=eb~66$Dk+s@Kwu_c#M8&JbL19juMUt%D9?(%c1 zGc3~BUjk*VxR&G ziT^;}oL#^fg?L$NK$PasGGtBEBXfAEKJx}S(6CVt&Yx&I2xk~)F=Jv}3{_BmGqD#) zqqW%Mq*SKH!JW%u(X+7g`tZ;js|!jG%JvoCPO~sGYsaviCbpH1U9sm| z>@a`A;|ZptMOVDjl@qvXCRHegVUYS}ltbjw4YG`qg{(k39)0N0uz4?bUrta@u)a*;mnCmEOR52*w zCO43>B+o`(##(E(aOWmhMYQRx9!e~_(}35u%83T6Ek=uQjzo4(=cK<(6_4pS^u^jw zhNfn#{P`nqt@<&Dk#-K=>4TrecVoq`KzxeCt*>hPP)-KArcG}c8{=6{rwU?a^^dz$ zyEEh0o1cif=~q_kHg-5oZx*2AOb~_4-u!d|)i9;;;V#x6)p-M|foSbtp<0K*bkt&h zWH&h;qxBc}>!aAIH|zpITW@2DdwSyt`g!0{l4WyP0?CpntH4gFOG~el`56U2Kp|dS z$y1OUNypfY^QEf|bSCzZPXgBo!-F=!b_nXpft{@N)SZG>9)-sZG)vCTpbrjiw;lXz z5Ns+@Hl3QpVaFb7sxwQ|BXD)n&u)~*sz?2-_1j)ZU%T`6j4}yb+kArGhw)P7iJ^ za+!Y=X=Y*@^O}3i=4A2(E1svVzs>#pG|6bh*ak@H%HzJq{r`})yzoB%KQ-t53w}by z;DWkeBTbvDt>Jmg2WD`_(@R}Ek0Vd~ru^L04ZmF*jG(DaFgEJU$fb9~}#k+>c9SrQM@H~;3@!RQaRNW*l z`%HnV=;L64CblGOPsXFXaUS7=p)vb0&Fn(Z&+eqSw_W1hhr+T$`}5whk#4E3+QC*l z>f`=jjlE8rCT%Srws3O~w5;7}gd-k~5_cd`g7s@iAGaolq_qz_^`9 z$kb3*y}7eb(g{~yixad(MbX&bB}m z;0SrRYy@;KpHgCSL<@?bQ16uiINsBkyh){-iP#kWs;~OC^$kU3bDp#n1@UoD9onE3wufz4$5l>9pmRiRlV_jJM7=C0yTjbKdCt zc%3AlP94FjPI2s^U;MtUaXgt`wcK}N!aqx6S9oBvj%=M3Ll?cTAK@jqs6^87Dt`Gk z7}$IHMGgR8j_5+~A)eJZNn1qHbmRu=e3O+?Wjhu_qa%4`&z?^2>>YS+!CN&|+H;iQ zhW=Y@S5kt%+VxLU?4WJ`-WQFO6TpSy&z|fek9h7Td;YZ*inb{On_tUB6WUb;Ah;eF z%>Xrf(K_KskIkrx+*nv=FbrE#ZnYH($&Du=!n^pal^zPx?O-ob7@B@LvjOkaTA+SJ zTNN5p9_w^s=EY%kH!o(Q0uuAFhxF$V?c;+uXU96RX=1YhBSULf;qScxxQff6nRM?} zj~H;QYO%_DPwzgiW55+o_7}}{z|NWi$r(gew(z$?Iz4V#D%e==(!~p&(cXy4Es$bM zp2r-3cVya=%3PsVEK)sJz>gV$OxGwU>?e(BJl}TeeqUjWm6oW?j^&Tk7+Ppb_w=G8tN)t4e#8i3ifTA;L^iRFd zd7bbeLjhjLWIvJOgtsa;p@>ADumgT_4=m)F1?!GY@bQ3&%=a+<;xfWtPbXCL<=m+kL5<>hHLI_K|`N=pBeX8bC|98&FceLD`cb zeW{TvzEIz`hjMQ+`mA^87Hgw6QOs#c5UD9ZI*UZfN5o^zRAWjPX+AY{Z9$JFJfuUX zn?;5(DM)<(Y%vy!PeN1jDiGPT0Yi~NAqdRB{13FL-b4kQ+a-v?16GUjporUiHX52} zQTrpb>t-RX>gkbP*?~`V*bI;@#IVPzEf|*oY3LbEpU%#;P0fDSx(tY`=P~=>i_6^z z*3ry-NtmU`Y>5HeDSr|jnNCWauFTL{k~J=x)zdJ3p)}r^g#@l(iLF`4f8qra!Yn#- zIO@<{ZD`6ktIBoFB|{>9P>0L*RnMww8Cye5S(+|(?#e|p(>7eNv%f>RMNZ@A)NKKW z{CRy|TvkTheatcgFcIAneS4J4o)`IUPTkZHI>o|G>fxmxpvuHgBmYjr?L=}k7sU9^4&<|39D z9)fx7j$hL1OD(|u{gkWd(U?-J_*m73Dl5phL<7^2seAd~>*V&tlY%wVV$7)JYS5b1 z^bK-_rXo#_$AKgv;3VldPhfIsRK30ObXQLG{z}c%A?|Klb9;&sg2I$BL$(9F;rc{jt`=dwMYEZID)eUfJG?tj@@|c%!?hYi6 zh`~t3djkxbC*D0`&ttI(MDIcUWid5KGh^<|wDfnZ?(*g?@{D+Qo=%Yf(gi8XhoxRS zMyearS*SmUlb?nX*=ZTfXp<~}Vxrtd8=x-lPrnsjCvW!zkv=Pa zeU&JgY)3$OG9Pc9vCB>7@NGoYqnv@NVN=tT%rsm5`!(FTx zLt{KE^r&W`dCip13N*yb_dUWRFjqByw=@Wq8R@)MQ@uxZ0PN|0SN2wjQXA_=T4k3s z7Fy5GHfk9)hV6lO=m-e4_OAAm;pqu&m3m!u{7+fekZvQUv!0ArgDY<@4vh+m1(?sW z?K~#2j~lDi*w|nI8#WzeyN6jUc9%4B*;n)_cb$y{73D*C_lc#mbRIsCNozD-^lUeP z4@Y-1&mGo9t#Qj@(LFU=KfN-Yx@sWgj*;rhuZcbDb$HE9VXo*YLAPMFm};Dq^e`%r zwk%>`u#aqhV3zKxYdov^dnX{c=`6jXJ{~3Z^E0vd_HZnm>A6nHLol0n)`?|GSCl@$ zKH1&E?}t-kt@UEUze6q9c{}L#k>hAhyZgwF=KSxKofzmVIy~7~T40O6gEwnzjms{z zq6qRs_8&Tyxr{SD+Y`&UZ?&6w=Wz3QCp>S02&&YQ-R7>8ylM{rG27(@ex#EX?z}#-- z^SA%Yjs)Li8~RdyxJo$G`1ZFYK{he5oiF{AV(SE^e_J2Z`J07{lDHilJ)aIp7x}Ww zvkME8q9~8N8Q>ye{l`E5E9c9BG<6r$rz61})>8R8Q$8W0`m{s`$k)Z<*k<6sFbMpL za-tW;a_#NyPsPNcSdc2))@CDt8`C4;!Bh(dhDHvZbSN&hmCgW9i3MDEJ2`9T-|kaQ;g=Kc;ei!N zb5Rj*tMHzAxJ#jYeR|NmO^_K0{E|viQ&Us=4*}yX>gU`lC*o&4;?1*9&j}vWq7?6> z7$iPp;J1204S|U3norwl@EXj}y>qBd%w4K2#TbQ3TA~+D@Dn(a*VDH0)!HzyCE?oo z`Vt}2)>gV&D36Qp*iQNMm{#ZyVnZ;P)GldD>NfDr0i$b~=u5JEJma`Kc*bp2ebr6} zixoe=1jeXh2#fea$6L2sv^2*0Djrp&X zU^8$sPlD3Fgus7XlokJGy`q7A7fin{V1LeQ_+l)S-AHmead25ADpetTJUVlIyUQ_m zZ=BsqNFXw^$GL{W^#xO}M8CN~TH`PI@p6`~lc zdhGbY;+~$KrMfDD5k`>e>co#@CREdjd%i41y@JS7&GXfLS`j8>V4H+M9487?i8&06 zXreZdNpYx?LRM%0_PAi9FQEqCf2VD!_%IUrj6b*2;OwLAu4n;P2<}FiVV|^(H8ytl zY=viTe0=<=inESVsWcdveFcMW{h%vXkQh-D)7jlOE;1F1g zExBp%`;#e@SBE@F2$#ghq5Z2({~x209Q2w z{QgGy7QRI3^322ZLnr(sW31|WjuZWF0rqNBsiMM)ypyP~mZ^&C4_%S46(%hDxsCaX zW#XhQ`;;;m)-ag|_S@&Nxeqev1I!DkIBaTqa~E*;0n>0*iVLh`O2BaJK#t^DoNx+^k5*j} zq2w~YoT51Ou8j9b|K09`lo$zY7 z>L_)pX5*57+}g*Iv(Cm;IM$x-Q{L5nprzK=U!G;RwNA(_z1(t0l40@f*XkTGZE1bl z*0!#p!Io=c@xx*ZzpWM+zeRq?buZD$Mt}a9*mj{iEpjyAPO)c2o((X^S(L~WxGjY+ zDGzJ`H|Nz~k5h6!{}|F|{A3F#ee5^y+UTBXluZo(W`#!Xk9oQUl`AeT?QZEs4%Rn- zBu@umsGe6lGGo!Xf-k{*@k<$$J=2~h!>B{8=d<)faG9}h zEq$<4o#{y$&{%roGeX?kTBMs~+zwl^JXV-Oj(5>K$P+sBTBw+`b`QG1(o8J5PUt*koWzsAH9 z&m}zw|I67ZBdLH?ngHrOp+8CaeP*)ojnBVhQ$BxyDL{*zsn$bm|* zS0A7E+S~g8eS`dTjJJc!Oza*watDI>!^NNlriJl0awUO$3pwV-=p&O|oKWL>n()?e zPNRERSvU0@BOgDy)z^36zx(dv%Tz&rZ^t+)et2>N z_7U8yde203AJ~YYx`4XoU@-tdOD8{M1INN8V3ue}ojp@^!tTld75G#sysGBI-1T#p zo@Wyb333U?RpGJerm;yhOMcGFyC%Zg;}rsJWwH{BPtLj33&c(_xxs%-P=+F`3T zqMn|ObjH<|`%qs_2i}tfHsY5mC5pl1z_c-Xae^$HDU=;lykoB#%h;{X)7fwB^&Lep zlg|QoE@Jcg=e|L)T&}9$CS3?GwJ#Dpqt_#{{|-R#On6;5qm=Tee>UcbPokiE26%mL zQoR-T)YHpraHF5NgW$)B9CQX}#r%&o?P94XqqhbD22tO#k&H5dd2j}bam+YLkX}QS z0qefyq>VF)t9^cy92~~CT{_-g;9_kgQ%zAdJohI>m_I!)iI0ptJP9C*3OQ`V|Np7! z0-^TBg#XWos}8}^wdI-q@Fct^H9l%>2>>X^apqWEs7NYbT8eB2)|^_V4d6hMw#g=@ zZOnBf@|P+}&o~#%j1xxtw1%0ikh(|^Pr${2wC{MPy}vzeCARgWopCRN5KOp>yQ1P*f^9P!eZ+Uq~Z zvs&1>)WU&*fu&U@(Z(mDXa7vFYl5fe^9y#fi;FI~C-#HPn?nTWdQXyVOM+-IWn#*$ zaBXO4aIl6ySC#m|KykC3Y{8Ax`f&0M2Y{ANX`9VL?V#535rB!Pd^zwEy;px}^m8Na z@OG{KxpU2tGf93bTejF-tn~QjvvgGy+g@B^4G-szRCwJ{{82=@FPHZv<^u2q4*UeR zkg{~U6un;oMF?-goSdDDzhEt68Q1vrUw{4OuSS3I;>FRVb}AJk1!l zZ33L$1MK;`y1F`LqvPzA`ba^yDhR;9rPN{Vm4nSq-%r1(KKjsir<9jveMmZ>Ej`)1 z*gV%EU0zBfBEtyf-_Y#%x23-#Fck!9)$S^)WiHI)Em&2z0)ocQTYT(n{Rs( z;R`MSR?MsS=jVJy%}d}P!rS>q9@YC4snU$XJQgGlCqN=mR#c>OrhJ${4kt>+XR2hg zRn;`N9d>JycDZEfH`UfXRPbA!ZPx~}!Ey!;c(vfq(m;07;ex0UkC|}URQT)Q(t?#g zN3UJG_LN|&1{N}{!BXAgA6L2y(7myn@{GLl6i-|cy$Ku%NV?YVe;sfcSj#SD%f;D} zw#wm1gI7}osOT41xev=C0TmcCWCJ?8F5h#CVn zm%mrr$^_Yi!09jdKhK;`H;fbiDvXJaUKA|{-0{gfTj{;}n7-s!5Yc-udu}?=;2#Kb zR3S#p9sstQBClVe!4hYc-bC<_odP_-CY|Ib#l#QZW8bcmViH!+Q_m}WT8@0#P2g+< z5Iz$48W!`DL@i(PoHdYr6o7jH4pIfSSAajva+#|$KvwK<^>kHZHh*rri^Y!xNozCN zaZdg4b8gOU-gaD&a0Piv{bL%+8t}(n@(!fi;a=peA}cFRNdNaDYzZkJXbk_lg1(>6 zI#Y2yr}s68+((9hHnEiJ#Q9~tts!FY=%@{(}Z|Hw=sf%3J=v%y324`ui z7h3y9&AWCw$s+?vtE{myY)RENRylx=zEO*$O{j|Tm1Od1Xh=vE#65r?Cw#bwo~qW8 zU%ox72OqMG+UMZ?+HC-oG2Yz*#h)@pqpOixcI%a}69;!9jH#gs6ZCAOGEgIP_de?0 zul+4beO3^vtzMZ-4xtr$FqVHhneLYm(*G(WFyA&U7L9a2qYA5_RHclV4tOk2ENnL_ zQk8sWG1Ej*A2P8eCOK}GpcCE>Y6wg?qt(%@Xr;f!n1YW#uTOFFIHjp5$kFk^ER8lR z__R5#-S$IdqVM&io@c`~S>=1h4!eQVuXt!Dwq%ZLgq=EC-mVQT-aDrTwLO(lA}YaYe<46*eV0C9G1_#IcG#sXlAZWtnn5`Ih*$2AyD1K<&b zWV;lpxi2p&bDfNMRW~CS(0WxgVj_4&y@*MfAfQY-esL79Yn4O=W5G^Yb%E+GmcOHQQaGszD zvR@*}mvmK%!sq<+vOv|~>h0H?%_sh>Z2jDuvd&lB&^7RD$-dHA_fiFu;7VK@lt|~_ zrD#*wqBarceY7Ilmz?+@&l?L*}M zylkeLBzd#?YjyUgRog8HOGVm~BI=3an@MsJ zQU0S)u}A?fDh0fPV&O!7?$Ik%tmt5RmZr)Yx!Q1$Dq`PL@IPPTW`%c?DOFg&cOno7 z9#w&r5&y`|H~(L`xdNY?NGE!84F+S{gax(WX{W3U6`+7RWma`uTP~)H2>g4Tb?A%P zxq`LHGxnn>=muJ*`ryZpAKPFB_V!(=zEkYpzBWASkI>L?DwNI#4h90;1~PTUE&y)A z&Xbdq!HJ3Ju!%;K&6_r%KTazEOo<+Fl2vJ$RBab`2ZK*JV-z0e7;;z*KzM*hDiyG$ zG~Kw(0)QlaBQW=}4otuW(Y}@-SpnT5w=s(l_a8*sGml@td213B8tSO(6L~ToBo>9B z7^7_>0i>4&=F^J}Gjo3OwImLwHrQ>mgvpF)D{bJhUAF=)S z9uV8Fjf{*8jgQ~UJLGfM1+zR935?Mz!nn3+ zgMwiIcEKgAww1cFQ)L&k`G5v-{Q@QxX-?$u8TT6-^Wb~t#?q;985FPBH7+)*@8_lY zG3>Z6B!^;a4BD=D+CEY|60pA=CYUH-FTbS73#b6`pV+v|@D-coJ$+8#mpkZ&gDiAV zmIx^A;kT%Ygf-qmZ3CTGX5Qg+H;wM?;!Gzyh}YLSf+3#ZFn8ho<7j63QGr)ZbVhUv zIj3hN7VR%!gd7zJ4yq&g2Q_+4Lr(f%*bASI0kX6OVF-4L`?Wbu(tE((*ef=S)U?`T z(CMVpEuE8Y2f@Qu_6dQ7dO#2LF18wrd7%%~4ux%{@#}4YJ)^3i6tOfi%%%W6*w^$M zqzJc*$yg4-?nzY<0df*`KoI;`s2KY>? z%#78mrj;|6!czs~phR?z|1D~GNBH!sdLU1H!VvEYs(V$KqMX!|Wr5C!q*_)E7h+}| zu_ZNsO#6H@uC?aNCYp@_A*Ej-`XD_2F*~6x^#-nD@y$-Z&K;J*wGbURBY*WhwG&%{ zPgo!Io!rKeSXVa~!-*@kHJwhPDImaRKF9gYtt~V%WCOER*U4rTDsq+ib#2EdI@l=& z_ZiY{T%ASly+Wf7WNna!J3`^VZP{6PgPg<4z^&NIf=x3t`!{Zv- z4@P9MzCSKZY`2e?iVd9QR{#OX4iGz29)aqv)?ZQ(*8Y=;f44dO7|V}HXa>IJAS(Jj zIa?mA@b6ZPpa6$>r2F3XU3#cmeEZZ`b%0~DkYPOt&hk@4 zBJnPXbgZDDAb8-R-IudLuP`wpBs5+AzebM+gZUYX*f6M|5NiRo{Dt!c&Bdi{ZEb+H zJ_UZYXZxQTnfXnqnq+db-%)fO{oDrV=FV1!vayFm|rnj}v{3YYs>$>hMu303Iqna*QYA?y*P zn*n$8aC~$!Sw7v6v^|q)I3rojX22507r*bk(!_P^wI+&UJkZ; zRKi3#ir-Y9|yQ5pp)U&E8v^pFjIO$}M0* zcm}mteYS)<=8j0EPdugzw+ApV)^((#-Y%OL#65uA{o1|#2;91Skn=sNu$li9IVgE4 zw)|a#D7oFps3crY%>MLyXqyJK3vxo_vou9B1ih~1wQ@pf3b|U1S;k1%1FFibE2Ogd zOjIUSAZ@3eqoKg@+Z1|#W(R`x@s_LgA^4IlhT^K}gW%&M zqrdMHxQX)(C_V?ke;;+Aq=Zxj72>~dru)bLy)V`(3XDouzBNd*aLa37$D!47W1VT$ zYT+df{%APEb%eXwjC(BVqUjcA*OZ)UlRoxFQo3n~XI$q3b}HS~))FwG1FF}_?}jz8 zx(=+o#uCgsorYwxt#3pW$+8{@D%iML^G-M*o$)5HCDl$?bQb7OZMD4+ zP>CAO4UHb&1%)i&3a#3J>~xwDiw2=#=JO37oNnhu*Vd{mO#wf-ndx08U20SbGvJ1D zPoz!n9D#x0@rXN5ZQyBP{fW@TeDzRJKY36mXwy~0EvP?^NhWKDd#OTGDy}!k)s^kZ z^z!EOKa=7Z*+;IE-B*BWVzDsGw#Invg;xiRWi4z*9S!AC!ug=cU?RQ!0Eu760=d8< zB@PQbr30|K!*ke2uShk_CmDF$qBitkZ4nd!`V6wSpv)67Y?j*?2*mf*L|MOQ#>!Tn zJh#J6Iol>hk%H&a_(baaH?gvk2DR*_S3#PC=MpvjT)KrCFW|>2Gy(?~-l_+$k1uTn z9Z+v3J^b%}mm)t<$=XiT6nPMbWb;T=(n=~8Vs#f`ryh0I-yrX&cTKl0EC?o&ShQGQ z>`Na(Vg!9feD}W9Hb0d!i>dsrN9=k2;07Rm?TIDL)o$;2HEEnW*J+{`{ymCVyVeUN z#^%?_ef1a&JTcs`4@O7Dc~|YKr!fytjPDYy4ace-7_Q$!dytw6wwmfqJn{b4aAk`t z5fp!F2-5^$J$5F-I$fu)g&UeZ$-@zxax*N?T-|U#KZ{>iC2nbSCoISK?t~6WS;jZW zQAW6MakD=+Fk3!b26~h#It6&V=Dn0orw`Y}GC!RncO#T6Nm=QqEOGchI+&G^djH;y z_20hx|GnG#e#iVKHB-YrVj~*<+MW&?c%N*oTCjLtgn|vnH2CGF7TX{-e}AlQu0Ze= z^plRvns(5Vgc9hX*;|fT5$mlQ$~|iYzgsa?<$A9fF%x9kYF{rHP9-in;`tf z6-|i~fzC;Y@%qYJ=YH!elDk1&grFyk>eW;Y1YY0H5uI`pOih62E$Xb{T8##C_`P`M z6%g}^ra*&4Po+5iv9G(`yPO?$jWMkfR~~!KFTQ;zbh#iyw9N7SEWu7anBq0j7=#=fwp5D}Zg64hFCvQNj~F%zs6? z*WbjNO40yWT&+0A>3*#w~B9R-5W!hCJ! zhRqr~17|A)6Hx(AF!0i)R8+8G(n5UT-7_QWAJ)Y-1C`U7+aX3$XB2ko^?tbyJl{hK zXCKCZ=T$}%$?93Wgs@IGfPZULvJ>W>$4|`pwcx`@GemF-dDQ7u26vxJ%zm1KbIC6g z>LNvkq1g#i&#kPb6KUFA77&|(tO-7erjMV@q@$GiLI2D7N&c(YyyX$=PS~$7I+`^4 z`7@BgR!`gh0PWJ+ZhQ0!k-m-`tg%^)b5Aq|BF)VR>HIM?5#`2=-JsMTmP}dtn=D01 z>FxQN|EV3rS4STPg>6Q48jN+oznLX;z^AEdkp6Z%QrTwL6mmq61$x-;D`!Ee31gYzXK+ z{3H9_dVFGHvygC+YYJ;h7LQYQ4$*CbM{xwoZtq2hR3-?$lq#^cbgs4R-Ao zPvE4vi>a4;FzL|6W(>vfa(cK<3A>VEF{H{|=*+*3lTrl*00V}%f8KzFaQkK6{L5X- z?;0+*9DohidB~f+{XlduuvOF|Wo8v-zL8}#Zm^v_*aE81G$!V7uNV`?(PAQMmou-E z5pyx=jTF@d0d*`QC`i|S1N>+HYL8u=m^`YnJz(PZ%FF5eqD5vW_m!BOT(OZGI2xUG=vx%fZ?X>L8jS-( z0K+7E*GXk1Q>~o8Pft`jf9T)aPygS(`+s^XT=ghAbv-9%r{Hdet55TO_>vk~8%kEr z1B>1{0emXZg7T<(Yv5ISMNq17C*CvwBZKqYZUMmI+&cq=DAX7{$wNQ|hl=YRAX;%S zoL52F6>mcQ9?BEdS`BQ6tWq30ozyaz8?QW;4C_s3L23spYt&1d#KOa%Kfrhv(Q;RtwN$;Ikzt%hA7q0D zVxKUq?oL7UDMU76OYPX(Ur6NMMwf+(p$@oEOHzRpl1JP+oYb_lQKut^e5^t zo=RsUX;-0-sC&-~Xb;6|gbzJ9X}(`(70=xD&?Z#FWlsb&aygw@QMOGqVi%ftxqo%? z*k-d`v59F}S;#@+h1LSM7SM#n%WPaB;D4qq$Cwo{pow%t)D-e6;wGvg4^?^>UqwA$ z^)r5Pq1cCjy~wTyovt}&hGYGJ3rgqvbuq@MR~%C`pT-QdGtE!-cLz@vQ<^LT+nF0*Nc&q+Z~ zzeAZ6mUj z%g)|5)b}eT61n~7zb8RTFMbaEpW2Y8RAVKWVlY)ekWPpM8CwTz%QGmTJFJceq zyx_YBTmpT$C`JKN{fs@=CbGmcZ#sJMdlX+>v%M2;j+y)c>P0lmDi!(H8gcm*QuyAS z&iaok@AXdPphwq@6U>=fRXc!K1yTF_WItWd8HmGacSM^{&GC2F-Nb-Q)K>b}uY zB!oq^8Eh!Kyu0U=&ks!H3~AKFU(~pMP){Y?0-Q=|B{804$?e0bjK#~dz&(P7Xc>5q zoPq77A@;h%9luzbk)LbcKWLPiz`d}1td*Qc#X^s^GW$>COJ1jE5LreBw{R7H%_8Lo z8I)pEV@e@>GOJtaZ&_fHI{K(L9^E$xyzy=7r~r@oZR3~2q;YU$;JXnbIp_l_EGS#u zM+#eP-@}u+q}hhXICo<>bG`x^vSxmEuEVLw8EK@F=o&Qt2=$F zJ%||_y^q&~yb8Iw&{!j3UMalIn?=bvIq`|Ipo{2vud^u1jmmD8laY_WQg<_*<)oy9 zh$;IA*dlHGZ+lGsy)S;L6)C;FdLRw?`kbjP99h9uhw|#d@lntMK-Wo8)(vt*;$DVA zFVLH@=~LrQaAr)AY^8CeIQct32D#cUz6PxyOh7q*7`jNZyswYL7<=(%vOwRpyITv0 z4TP7GXiCP1MD+^6oIVYDJk|N*B2o4}A1N&HYr(*gXF>0{n-)oH3yRsc5```KFcpV` z&|0pndV5mv!XI;E9iFXK_C{<*q(f^#yjOmYr8(e(tj6nVv>+? zk#%KF=$-lRfP00Yj!x^@4_s;XEE=cRi^qx{({6Pue4f6#U}r`-{20!zijgqA5qtw4 z+lfi+!ZSXcqY1yV#0-7as>`rqPz5zx{T;bI+~+$5p?psSODcGtAt6cy8um=80Cz}& z(kW#;(vfUcvnz&ls$ED4mBs>i;8j z{F|*B{#&Dn|MYKAf_M2Fv~KGT0XO`Co1MgZB1U z(14Q(^cah0I4USk+s7(Uc8c*)gLWl;7y&k3g&^+MJ$?*aj-mt)ku{P=z$@{e4}zrm zi3~JP#&^nly@p(#VJER26b8T26|I{fCpMSNw1JA&!mJ&D z{?>J3`{CKO*t_lo*9lhZ3WKu|#j%_bg1XjgHE5*1t8QwoH1UWc*c=_+Tj>$U?KqNa zLXmLmEvQlk*z=R}B6Gl2d<18|c2?+PiYk;n;!HqQ7w9!~*R6%Oua+Teo^4~%z8(fS zY>gI+$!V881bmt#24rq{v@+VCfHhkTz~4qp--)pD_d|?-wZ$iMM4VH9V?mHVONmFj zCk|8m!Hqs<-``u7UAg@3DkwV^f|5V7Mmi&imQAhqb;2#aYwJk-n*`IvTVZf)^7(z0 z*_O+XUk`7EUcc_fK9U~?ikTHZVD1p8*N`|-iMSa%kSSRH)gLryCqq%< zTtVPP={@e8I53h8w!`vLU#`cep{Wl8sqnOJa_C*JMeUAcbdhw4vUa86b6g zN?*1LxR}->u4C!OcfuV$pe8#o2v$nugt@ye&;MLT7N}GZ_rQlsfPn&yOL*yhMuGiS zRYB&-A})Jir~D1t)gddusp%GnXdE;>atr$~BlKfdeBvhCVfheJ*j{O24QoWjO_b-U zCuu>|Yon$VWAMS(@&E=Qt^kU?)n8BmoQB3)pPJZ3vv|Tjw|1!mcIvtyhp7jehb_I-$%T40 z{l}xUz%mo$ZYZ9gF*i`wzEzZLYq{_li4qTslU@_ta22aVm10lgQ-FEHX>DWW8?aL> zRUo|Am~6(thfQTtID8BGN7(7o$1SMx@C;6RtyvN$Fodl&_3$%W(h2juSJ#ljP9b+% zi@m0T%Op&;O96`|To%`~3x)-&A2^J4Z3Qbvf zT)dw?{bYIPF?p~6wacMDefrv>*12ua7%+q@O>kR>-Jes9;ls&1$z3HSOH+#e-$czy z|D*Hu4G;fdokD+p3hwXg*eyX|J8r#;4o z{^-P3eyBchM?^?^#XdKnBUFL|X>|2o3NRD$S!uyeaa`9fkvx8#0kKUw@2WS}36JxW zPLdF^k!paclmP;JoC)eMYj4>|{7=j?FXnSh_~H-pRRiqQZw(V_kV{CWx4uAp^J z7UZF#PRp@;U?S1^5kFbwNFB@tJaE&`E_!jae^ahb;;3Y7epD^#l#qf zWzX>j&q^)e_oM^7_``4hkZ^l4L>Hu|!%3?x{mk!CBKcf1693*zfYKh8fP46Kb(_U? zGt-cHT==>N$&!!1bDca#Sp!i{)~&6A%2cH<>rcnscwnC&w~|=}6r|F3fYghx@?2^p zCtS5GeM)C_5o3f+z|4%9!Q3pjGyxJ2=@>{f{YU7SAb&t+Uc>|{eL5B~<*0@$c@>-V zBnZTXw4LArVtmIC5GAI=2n(Ip*J8ei62zYMR`;8$ooqqDL(1G~%LqoCV(rVIga0*H zfbaeX#|w>bb64t-!a#>a@G;!BLYLn(TEQAkp}^g5lxd}wuf^*v; z(1euD444Qt2pvYc&CF(q%ns+9XqoNUaywsQ9x5Am?SK>H2@^t>e4IC_N#0b49Bk8X zYaeD?0i#Pn(bls6&EC65HFfUm!n7V_Em7&JB64p_6%`d@K&}C;6P~mBY5h{gFhwGX^YdYYmvkm{~6>OFc@uY-S8G<=|?LCCeW>g0Ym7x;l;FD zorUJ|8yf4;LIE|3R{j$Ep7F09C%GkJ`uwGyFVH(5zXm;vez}PKJ!PIXKcgux`zc!h ze~+|SfImuhSP;CEUwm<;Mo>ut)$tg|MpbmAmKJm{`A_ti|AqN@%_{G1ip1#9)Zg#I z#w%=33G2{RiF=*;e*NZ5j6MAG8AD0=^OF{js97K79&uY!DwBaBe$}GoDd9R|{Z7t3 z`CHtJK35#I3@`l>U*si7Try z{?4NO!gFmkVRjDsfLTDMbE#+3aLRY+(_g|}^!m?1z;LGf$F<}i9PWfd;&T=vd+u}U z#h!avh_DCK+;?F{F7w5huu%K6=WA_2jH?xr`FzZ}eSCbyu1_wW+$>`DG&RgXe($Eb zt}RmD(yp^j7y26Iz%Y+lgT(#CiWvrbzF|i2(YI@SwY5MQP+INzk9CO^|J)M9pPv1X zy2Q*jzjsd-M(hNDTJ($soSY&%c?#lX&ja8=!XamB)~j7Ncem#MvSs&zH{in5GCN{H zry;!oX$N?b3;d!6KY_ZZ9Jj+Zju^5nFM0nPA^Cs&;vvwV_Qn5C9Q?i44sTw!Zk>XF%Dr-Io8IdX^3el@fZufWdEkAo zUb64}5qZWzAAFybedg(}Z0ZGk)?$xn=xEy8m$%n)BO?$dMqTfE4)3_K|FWMmzO*?y z+$$~2u4?1jH4!tDMH`y$?At+Y=sIO@Q_{@H^IM&^vMMqDSUULQ4J)2nYj{|+ROyHF z_Dsp^-+k6W3pLP?9qm-}HnkE6txCrpTpvFV=F5g?*eCMDkHLxqx-lnvLkq-XV z=BmU6FK`|A)^c}(>&OMyaR6M$O>iA0;5xob3o`)MaTr|3U2q-W?pS$|@*8Q)e_U>UtMVn_~Z1ChH}_{1|Yoqe+?t~pSIHe-@*U%eSH!o2{ppUG9B`oP8#}-{P15QZl(*fPGi5=#vb}|9ymbZ`z;Tl7Gd6V=k`% z)zt})*4EatzR0?%v0-4Iyumb!W%Soixxd1&M-|{%DI`?_KADKv*tutGek|-dbpjZm z7GS`C*|#I8zbh2Heb;ix6udS4``bT$TmNSt4%J%h0prbDPf1KLL0yKaY*_eOZ-qTa zGZZEUL!XMPW-GHDS0BoDtgk6)c>s)eIxyacZu-?HmNbL6&HwiH-+r40K78}T53knJ zEymsb2grcd(_jDe=KR}A>^a{X3lFE4`FdaGz5lDxn)i&|CrwMID;f!NGYpm@-(I|H z`x^24cfe7-zP-U3xU=oRZ@v3vEqJ^3?{EM3Eil&q_+f;Wa#G`)huCxT^QO|o2S)TO z#_gUj!!Ob(OK;M+s%h*y@IAjSX->}tU;BH%`jXNX@U|uW7WkguJO2IcAHNMR{j(2i zy+?LBYi{lbuaHxJ78mUQ|H;8v+Su>0RO~%B^r=0aeR;^hKyBk~8$HXjktmcyXQb%) z$mzi&fujGY1`~Z^H<*RmG6Vh4}ZWh1q$eHd15Mjgq@l0NO0 z8cy7kGTkRDUVeK}I5q&5$N#5sN&E0t`<=%(v47Z~hAu)m<;X_g{5_31bWr-o`Tyf% zckzc)e}Y6<`_Dc2itxK}c=^gGW4OAq)+I%e`U^daFOHmB5W+9?5!m;y;}8r*HsZVa z^F}0!E?;t;a%Z{+G?qAr5!jK1SL7ah{xynlU0-Ms%2pp4v1R)yZD|$#)hjZke)c}O z2m7rj4Cm7@-95$&&pa4e6TO`}(hK3@@hozWG6#$3#EY5{oh~t7LS2}#z1nd#4)$a_ zhP;d1lhE^m>=3>0`tgp#>k!D7G|Vv2l}0d>?tuu#bRAgs2}%R_%Y3e(xF^}gt9sTs z2&JPnRPsaZEuYubQ9=T5N;|jwiLS8a>ARy|=|+KalZLUSp}j=T_0?<-{K@H@t_E#& zz`g8=pu&;dfv=n9YvP^U#1>4+cb?o4*kd1*;sRYfP(;C8VAutdFu^fLU}rh$N@$S_ zb~hxHt$qgM)VN^yu^H|Q2;K*7f8vwm`F~Sf{QIrHcr*1+S|pBt*17oaJNUKi_r3fc zZIrR`?85wL$7;P%LPemyMa7wA1oGj9&miGS1L$y{TU{oFqK^wk7UF^#R=tQ$KZhK< zuAPe>Y-~C-Hm}}dY*Pw;_M_3fgRd} zbI#lO#t<3hr``((I4>TDW4)OhSW)u{p~B1^{-MIwxJnWQtLjXIyD9=mVhV-)(Gi!7 zX=iQ(-OmuI5d4hs;p)V*VRcNm0UF@F35ZVrnEThRmKj--g9W&^tz@}?gWnIs>NYt7 zbQ#QBJ@9u!%vNnT)8&cm=YJ;ms9O6)V<~9*L&?C;qKtjXLrR4lNX7VIJ_|F_{-x5a0}E3x z<&TSBuMBJ(;4p0e?1s-SORy1JJ@~!0 z2xRqEnOQ5R&342q(#{uKESm5MJNKNBX?LHn1>#I``G%Iga+LB#Q_eG(FLAYleEPE$ zOBhxT9@$YG!)>gxW~D`JH*w!&+i=_b$HlmwyZm z4b2pzi{dLcn7D4TEx#@4HG!qKcJvj+X` zex9R*@fAR8H{z7a_CHEO{k|0}^Z#YHf$qImTb_^1{}1Di`hr^CHp5rHmQ3Vt#ZBo} zRSuUP3pUt(Mq4l;910nPON*buDqngVC8zZp+r)(va6k>@4_}}2V8dn9t7OCJ+^ZC- zFxku#NeVkGON2f6CZub5e1oF`iZbjOH$#qKE5VGPR{qqvu z_P9+A?;4%0&s`sp#5bIZgJF`;OgQ(VAGxvb2IkphxJqT`Cto&vV{E~DR2)-anxym^ zeAyUaKCzz}R9%eq)Iaf=HtI4IHQlj?Kf)A0)1E*>Yz#IkzPUZ274GwFllDk;Im39} zrD}_gQu6rLHyp2hk~=4N!2R#t8FEI=)O#jEYBwNif#J8wtIY)Ym5QCNH$<};m(hr0 z+!0Jk^mp4Rc6uY2_)YW5=Lh62*9&}w!|gB9rC-ml!BDSKuqzWW@CAQ*H@nVD+taR&nVCL7P0Pi`WxlNdkk@KQf<$Q5nJhK*o*n*C?<6DB7EEN%tn zipp5HZr@Phi%ZGeXkrt7aWP=}ka~u0_Y6*8T{s1wQ~K?zT%`6ND`iVj;Zry87?E_@ zL~N_|z-!jyZ+~|`fBf(-?1lEaV{Ah2;hVQ2j_J*{z>`J#2xL`}9dyv6;zeb-1FgCf z<<0h}bCQ&23#bv{306?jT9HR#0o$2z*AjxUU%Pw|s z?fd%N6)O;(Z;~V6$&~pQARrNb11G=lcM28GH4+ev(l8=?y=o;IqH}(zJnh7Vg|aY( zhfC!2bdVYH=2Y@ILIiMYGY`EH$oX(H?lpGkJ+Z(kZ$HxESoJL8FB^Vw_hT0DbrV-6Z)T6@~c=I@i@4@I%PjxpB(EC*GG(HkxvS;N#2zvIA^FarVT_fa{Myl zSOyp>c~D}LHrpTlPQSRtgI_0Wh1HM8tz={i7)ul z##*!Xsu~0h8Mklct@<0d@Fs^OtW9u&A@O7*ZJ2_YO-vbo5ygfv?q=61Sx_S+@E&zk zjP&P{anYMu3wI~ZBP|fe-gN=LVwlHmxEpoO)dCWln`csZoAh82` zrDNB}^ZEyc9qvJXUff`ss6)h3wvlmty?As{aTfO~joenXgmtpMMzKs7l@(LDl?_zu z#rj$-YvdY3TE()9O8E?mY#h@Efn{6rpsmdfOv)ja;iJ^t#t>P=`Z=>ctJws6i@0A@ zAG1p!dr*qs4aLWa2-*EQfwqU1=xhmfDC&N1Z#n}RH|;>{El@mFtGc*scvw=4f?Sj@>{{Rj(kQj5woi)UbxhTv ziwc61ytyKo7l8a9T%UjVA20;byZDMU`*yFy1rbt2uliY)>+YU{;)_)F7 zmq%|xiMG#FMasi}A@e1GDF zEr9G#s6?ANS1Ia+XW0mbtjU=*4Nmh?-YGSa5)4wTp}{AUItb}>HgI&o%kK__LUWl) zCfBuU2UO;0F1e)rGMoF#zdgy|w6DUf9W~T2Ag>eTx7o4=XrdVqtOab`2;5iJrTo%b zY2rbSj(A#$vrWQA6(M|IMnwvTj}L4k6U#*=QQCwdw}d9y>C^U55%+~i;8Y0@JE%uJ zTj4WS)N2&RQUV)JU~{ii*rXBWjjI>2F700Bqc(t7A`f%ZT0LCX?6f`E=9J8iDwRQT}_ZAk5SdAfjR3ls8#~=Cw-jrF_R~qV%}no2=N0# z-%-r@G*1bMQkOQ{5NTX}y#*E@n?e3#(NTCs`HLhTm? zI5}@XhQ*7Rz26|s(N}Vr?7nELkGF5zgsM&2!@X0Mnq=d2W5UliR~E-@>PId{NN@;g zM`?ZK4F`$ADq;O5*=p9z{j^YFuf9BGoaSGkHtitRJ9E9=sb(3JJIzeJAO_@Rs}>VjXbLtf6AN~fq#Ne2+p+&gKZ z(0=}Dc(tuD0_o33U`q-F%~Urh$jt~B+_?fgKVMPk(%K9Q3wa!>G1(qcRSS*L!aF)( zOC;U4gK`_>300_@!M@(_>prq9bW2cW!sq==k@#RLW^#g)@+HI(3rvxEo*-+?Ya||k z?>D0OKdq?~qKjlN&%v?gCL~I0Rw=o3RpyzgR{?2#1ZLIia8WHkKexTF$MqH)3H}A} znE7X`LRUKa1@0=|`jX5MRMW@nD=!foIJ93DdNFx2H49eTE_-7s{8u#eM+gTGH zg|1v{kpfGrvC3&JbGXUwgnucC5FVQ=Llv+C9#G1n4w${XbM+GWtjGE}%Y!e1zLh)~ zBLC%3&_2mS1ew*B#(@d^bZdeQlrsM_@VT*0vTC*TW@S31LEyXzTDc%}oPA!JhZee5 zwkeW$FOmuG*A$~nkTjG8MD$DFxTG2OQt_Yjj_)HH|7g^}ymI6m2=LpGv-hV{&*-3q zuiE;XYcR4>1d`ZSMs`p*$*H92hp7|iZ))FsUz%U4Do@iapFhZHO&#YpQ38Iik9XI|2J4pHayxFJv%DAyhF!g;etbGvq`R&_sD< zE;Nfqi&W%N%T>9bYbUX#i16so1ebSrb7+dQ1bG!{?TAqwk`|x6b?lON6@GijKATmU zYqb3?%aezGzT|9ac24LL%@~r@3mj8Y;MG&xYCL}l{AnvrufzIt{}Rc!P4RBjHiPI| zXmm1b^U-qMqg-@)jWITjhl9Um|_E6 zF^i@dX+5Y}A6xwo8`$3Pzg}4Wtv35!NlkorRjg7-Dgo=0kEu&?O+YURQ_nv_*};^AFKQ9iZ|uma!La9@=HZc^HF6L1HX z`0Ax#UqGt|F3;Q6-u}e$YwP9L2ufaL?f3ywTx2ADX_4)k0&@Wkz7Z{9?ByLXg9ado z%6j?)?ZoX5`8gKH<1}S>FK?nsL<~AVP+JvI<;PYk^(zw(%-_)?lxPUH{4}(gt~2#% zW=SPWlVPs@gxQ<4`x>sYLRUdug5;H&1?bnstmSHT{M^ir>i+LE&kEr9>CrWlXjsl7 z1dRlDbHH6&o7C4gxPd0S>Y02umr70HsN*{i+j*;n0%e|obCG!hARzJvWaq@iRI1bN z-H$#)f3RQ|*J)(UE#Ck#x(4xOzzvT`Z`&T+Sidyvsu~t+h)j9mHmW8-ZZ^&2MJ;zn zd9&vIS>76Mq0$sZHL)&f>`q!{TL93f-~qnYlPEw!1QwD9D1_(RJX0PmvDJ&_tRaK| zup6_TQv`}ZU%rgW@xAouaGrqy_fEZiWD_D;wD}ok4*> z&}q>j3c4?5qfm@((7{P+8?B|O=Mu2aV)GWbhfo(PyfMms29s6--_!~Rz>odqjEleA zcW_>hyophS7TLbOMwx~<$q7oBEPuD(pcQWAiqJie6tY$i^lrX5$5r~2jO2RnCK1oYpNEf2<<7-l7g9*(Ra+6;e{@Tr+ae9lgG|&;l1{PG zF$aNnsu?%fGv6F(E*A()kj3hG4!Lw;l9aVDavHWt!V1nTdQIotKS9Gglj=hWU_dw7 z1`2wmJ`E=?@|Zz1=?_XLA1;@@`n~d?(9`?dqP=xPoGS|<)cg8}MQ*dqIB;^)a(Ap= zVL8P;cFZef28G5Qaei#YZ)i<4!_EE8fU zRu~8g%YAEtNt8Qo)1JvmlaI&_F^8zqf_Q!+Y-5QnDCp0x1f*XoF;LXXRiL5y5HqPk z4(m%us53j?R5p~bf9mqb+Jc*{FGFAtI!DkpIEqV&jovJo7&J8wEy4w>EKU@sIFX7) zrO5;Id#pf%x$23YRLogz{zzG>2ik>eu80@r1|=AD^{hynjXy}oZ0JZ2tVaLUiXvM- z_rA6W3x(qcd2imo$S;k7=9ifUkCaWxga;2T_|oo656N3!Z7=0@C~+B4FHqN!-IG63 zsBS6`cGxUgp|?JtrQe_S*654vyS1_R-I}Uc?YB_+$mNUxzMr;~NkBH^f)!ZSy*&xv zx18BYjLJj1heBhe&ipaKO+{Ph#H_Mp`I|vs*=cI}yrd0VQa0o=81LICNxC0GoqzUq zqvX^BCt^^VY$S;g(R2~?Er{rk|Egu;sf#T}X<0vi+niRekt5+vk^%#Q>*CMP1L2Jd z4e>HH0YM4F0}Yv@19(XTFAYd8pi>6JZOSVAtIebgZIb-yr4>AlMhQA4etDkH z7NL~yn|T`A#zIoC#)YWG!OP{^#vc}XHO^<9IDvRPyFg3JJxrnpZ~JPs+Mc0wt?;QCC>YjPUZ{2CKn;LTnJ+; zC+&d3Lzyq&1GP@-lO&6rNH{=bDin*$D_xdq=o8|iyvU?|eT_4%#^5)=xGv_$*98(r zEX*C2YO1TNmqvkRLs?y2T@n{Rh@BIq>6z0tq8s(MFZ{j#_lM;KFgW{DsyR>fW5<}0 z=daie4Gp{~^+J>yUXV@x5S)^HDd#&q)SOowVCgm?^bZUK5{6co|HxJJ0j72_u#pZ+ zixp7>Ock(f8m|YezSCxW^BOy4*YJnmQqmY9(p+*-*iU3TmGR1bH7A-*VAcKR z;}>Vx=9NIiCQ5m{uV%a`FB7Mfx~io-;O>W|)P`?hlJytZ7Jc92K&AtnxE$PLL^}5a zC!jS7vX1DF0fOteK&_II}g7-;_Z_~;(gyX#Op!u?z zWl_xgp(Zg92zsHLGz~a_kFL?vW8D^ga66w~g*3Zp2y8*Na}WbM3|wo9P;uVIh^OH# z_hHHKtc1BR4MT5$7Ze(LR};N}n=#K1HkZjXu2DmEo9QwLI7j2=dA_b;3!3j=ds`D1 zePv@EyN?1kes!%d7ZdanD*=(BSfhKek9+2?5!Iph5C~IbAxQ*3y+kF;-!6;TaDg!= zogzOYkMo`~F-+d>J)j+I31TNal`;xa*G`)W%l15rfV}MLsY@q6*~-`#z5C!lla-=puP6 z-M3+w`H3B*du={oA5?TZL_a(>Zq&<(U!Js?|E7$?&Pm#7?o^g)RKB+PLf7;8+6aSW zwZB6r8T9EYk~v--23=Wa_9KuDpoqMAkBAUxQ@;bMEel{!se$^)42lxRTOp6rCrocz zHm=%u*RAO*9;kUxg8P(Jfv&z)qt zLPtzZ(kb!Tq92Y`^z#LXhtAN#bg(J1-8>^JBE&s~*t<@f@a~3-`##pa@WB9sxvxV0 zwrzZvXTh|qTbKm(>W33vu=Aek%}`^sv75k9l=_WywExd7XF%nw65iy>v!}s=i=w&{;;wZ;!iY8yrpTy(x0Xw#U#I zcwCeD>uuwK2C(AbwqX2e&;ZtvPDF4vTn_`S4QgrX1W~GMBhn74Pi4mOr$K3!U?a;j zo=WvQ9^KbGG?<#?9bNEnq@Fo4_@i5t%Mn}m*@x(OG_}w?i7+i`JbR$=%3{TS&@JH> z%8lOL(vjMBFPv~GE<2Q|`}FR2F@okZpeEqKc(FXy<|sQf-Cwm8I^qzp%X0@aVYPi^ zuePjX&srhVjWHC|-5N5DUA;bkf=e1%I$@AMS-o=B1RIazOtN-K&P5-fq9dIGcKM=i z2%6uzOUY?ZM$Ho28rX47qG>Z5t<G`#;xv`W1d5b$4N5A^W@lq!=h@CRZcP!vo~o zRGEW%nWdJApJ%aHZF5Q~w^BoOTvPeG>(Mags+)em4k`v^J_7?pwO*n6owW)C)jJn; zupbzgr$;qT;Gt>{ITok^GWNT|~g^I0|QgUchjQ zxR$OVo;Hy(K=w6l4p*LJ6lwVQW9McxZmfr!vS<23K-VJvxAOqKr%aqP0&yWgrk>P3 zMQsRyNq8XY_*;6W@xJnU-DV{%5|~YKX)zTTNFX#-rqQGj5jYxaj5ZU1aQlKQcNF+T zy0|t4^b$sH9142r@_};;L4Z>8)4x858GV5^Gk1&~tC?2k021nQkS-OHB!C}UtdW1j zLayMe9G{9;j16Px%%a|!3+R3}!{GAPliv%0$Yceeb|$%~@>ue`2oVRcfqb}r0sxVF z`q%f9+GiG~2k-$h*WJ6n({N2Y3pLiS6ARInr66i)oicZbx%ZnDbcx)P$5R0Bv{(6sy7+AXqm)R!H&*1DX#PjGMvzC?}=N z6}p=fC(-8Rsv$kqZSk~9vA-E8+8sqL4qbLPw+?PplC^_XpVI(ib{UWfD z0OqD7e@^b{=@GA^Ej~l2mzO3gNqvNOS1rpl8d&>U7X$t=9yv8&nWUw4i01?}!o)zV z&wgo9m={G~kO+b`5+pSlZ^3`5RB1LESY4oASFS-{0`V3x5V`rPil|lBF*#qF2%=eThv+&69SLoWIHy6gD%{;AG3rUSbaRA!iZL|b=_1+W*#d|Ay7){*fZ8B?bN4e3tRpd~;*L@kt2c@cQhY;iiQ6|2 zyi2$L@}*&@y=bGulqjYGH#&sYMdXTICCLOhABUx5rULOX5VTou%s!?1(9$N}Gy(BQ z+9!Gj-weC|BPC}+NGM=A%MO`Mf#fja3u4jJbYIfy7WmCcX7Qs)Ed61W)4YnztRH&w ziuWAb_|PXMkcc1KgkWHt`Q`LdXTcN&W@omZc3D{k4ELtxzeTS z1{>X6bPB=^0>WtADkyaSAm`bYek(I4D_LLgj1Xkg)|E^PfL3YaM^sl`K|`_O+n~)p zJ4qBb;~96vVf+89wH#B-$O@eX+Z?{ zd%5`V2DGuYseKLym#R_jF-6v_4&E$VTGe{V)IAh5ZbTKjnKEI^CqF0i4|0wb&jwT^ zL``e2%^Y!9nvABFo1O%nk`89S?^t>F=#)pL2Dm`BXY7<&Qj-ofau$1$AdPNImEM$q zKEfwH{WRC~LCYA|vndgd(8b}B5Y_SBdd2|9Fmn9hV2sw>Cf^jaz{n17}+c?;wt6Ec(s4Dp;qa!FS{w3GtzK#=w=0(_)k(I_N}1i}4j4YpQY%=*0QURI1EEF5 z#pN1;eoZ~)FtCdr@qpa^0zNKK{Y(nj+$+7oPhDz!v+-bbtg+5olM)x8G}J)O`9N7v z>yQSicH}2bC~dy?E3OkPPp1zqXzEGtCK(3Tm>@YSUslL`fgF`Vrt1D5 zhT_xeIG~H?3&aFipOCG;S@^j+mZFXa zl#EF(kyU=YDt?gvK!Si^9gfFhu{@SCNBw9iOg%}{cNJ+^>3ISHac-u%alS$Q~CpFyO3{vXPLt2F7x zUJwxQf|V`7$AMmuJMoF8x~Or4E`(ANvbhYnMNXt=MQu`#g99WD6WcJ zW}jQ|MBvVfJ87!ec@2bIoAjHN1sV>X&LX0*p56aA`;?0xTx)F{_mIk+fPR{ zwCCGD))Z@jV8d_KOvmKK_pfLIpk>4>NEkc`9!;PV@yHJ(GgZm**$B02na3$`DhC7_ zqW;0H#rdKR?q;_l+sig9*B7OYQ(;DQ$BQh`CEQ3KXZP2&VU{YPm{)q zloHL}Faw23jmIi6!%sdG1$aQbGH;$aP_6Q=7NjUVt0bU^5tuL#$CV4{EY;GY66Bw? z8fJI!41k!hQFDc7G~N^xT;pLbwITW+Q%;aYKlt(7U_DbKdRqn*qUx54)pInT34SGs zYiMLtrjcJg4!o^p8K7_XWz86~j)H2avP+Rd{{hXhAHEL8s_56a4Iq&(tEH1%8i!JK z%YeYc`(m?ZENe4?}Aa`)omujfj#_#d4!am=q~*zwcd3Xkr-TVm_~VDZ?%&gbn(>`vlmthD)&bMh-59s`Z0O=r+mt9r zd&DQmJcW!6XIqe3;kxdAedM~z=T+g41`Loo!Zfrn^D<~>QQh~ncu45=Pv5;nBEH)} z{KC9_rG(3Vv@~>L&@ftQBh9bTlyB$9jYFa2_D5txCo7WZIgh1u=#LPAx{ zVODfF9-iCdhmbz^2>M1gU0{qZlI(4&9oIhA(>gXwJ`3Q*-3+;O+tJs>rD4#9)4>mJ zuM$r#_Q(|rjq`%*l*57UvyH1bc@kB(0|Lu=0)S(UB?gb4b1Du+9NxzbnY&Y8@X*Jl zalvku$vLz#;Tc?5PXCb-p}v&=5!JWQlK-Y>4O8fue^cZc6(Bne2RJT%4WAo#*a(6a zYVdRc7rX3JfIo1olbQBvy2JKKcAuSh;fTHV_8)7rxmW9CBIe{96+ll^qc$S_lR(4X zl{t$+U?+NTgY{>_w?l*V?-~I=ZwQR?p65krTbegoSSW22s+3WJ)ZTPb!qT}okBETB3u z5auVw1iCbORCJ(SK;wKxttIyG#I_obuI)AFY?%lWZt6fC8hRYI8){@KVDed_tNG<_ z0#bIPPbbE_WK4A+Z#UArnL{=!2cz4{oIPpfr_*v$GM|zC)(<3 zzQ#HW@cl@tq7Cgo+kif1-7Z>ED{uItg#{Mi=^`-7(ZWmn0dQlyH5^_a@WuV!X0XPS z?FQ3E+OixTV`%xxYi-u=~3%|xTJ)C^k1RnUlTyF#e z8)*YIK3-OmaZ7J`2*#=CWAOGm3%9Z)&iyIHv9?!Syc%EefKm?8cw zo5$uTcR@i+W1xDuwdl+m3@Re*}h{k(T8{1zfPj)S9O9xW9119TNJ=Ud9pcqJK zkJ5YEQ^l7C8zapF6d7b(`*Q+9nrgQ6RqKIVFr< z$Kbc0*@VqX-Fo<2&)SA_pFx@TtxF!}VNtH5lMM*QT~|KWf+1gggjUaz$9LBJ^d#Os z_@*OT(>M$LLco0}NyC8Py#QwxJoRRs#nd>Z-V_^ZKW zHJ%l(w%Om-50Fj1!DKz?(x1E8r!D&=J79R+xJ~*_wL(7lQ0UcY7qOq)L_WLY zQCwsV7TE0`el{LiF_?dn$63B7L~}<79UI{EkF-{;&%f|NRTb>Je_hnv#RG*s`W6{? zn#h#8WM}Be_=`SDsR%7R#0AYejJZG`3ASkPQ)zi|SmG_Y+zo{QLrzOQUp{(w1JjnAjw5Hnj$+MT7VIq-G2iEA-0NH z*~nvG)H-Bwg2>3vo4GLXNcs*C^Bkqhsvv1UvkZ%JG_+lnS?Rt<)LZ??ir@IDYyO*! z$c&yF0`no9?JCutj%%Zc-Bo+|Vv97a^@uS6jE!_R$W4|XDzIw+P9<2O1Yk0(=twK< zP&|TuITlVe;H;e!eH16h?|(pkY>CY{q0;QkSM(g<2??^Z2rTa`2_SF=f1q$Xcua%Y zj)AaNxSWqNL#j`Y+J!=Ae9gFJ)PVG~c#)YKssIlrLPXsM9Osn(qpWP((lT9FJo&;L zAYyv5vJq0rW`Pm1VOj1R7oX5LxaR?qfB%&`*3_`SpwgkD;6>c|*{YPvIGL%!q*_c?!97Lbe2 zM~g(m?N=C5lzA9wL%+dyyQ)#3 zJIC7i^2V>Uv=UZ-P%_lgiv0`758_yfK9fP+zG~u!6aYu z7QWYs7Y7>^9r{+G&}@G4PN%BRi4P|L#$GyxE@s1Gd`yLl>{5S1X&%4;r1((Tk~cqz zGVwX1Q?a(v3Gs_!2XAi!llE1a!`}O%>dJIiIet>L7GS_%S*x}{w_Fn=U}u+@AlTU^ zAUf~lw^z3?A$}f`mW-X&SKp+m-ORi$^(|0Mk(WuN77xYqufd>Ls*&G!ou{srhs`>K z^Ht4YFfh}(*m#hFCxLp)dj3*|y^LQ}^;er4wqXYsN`ZqC zVfM7b34=c=zOC#b<2MSfUoCU22XMLW!K1rb9JX^D?Lcw3OU}f8zSx;&%-YU9lMtB2 z!ek(4ozD$C%0G$v3As)xuOdp@a$4d3JB{R0vGU~OC8;uoxNmDOehYxZp9dOpVY9Z+ zT2D7lpFD0jzw5RI&kA7AcPs-yAmjE`gs+zcmDHa&I*1|O(oY5 z?j7*!oAl)7^=D6D?M1^B-4MiS*v4$2T|Fn69r8F%C!%dssE6<*77M+@Il zkA*_#cJi0wc|TGLP;#l*$Mh`2v5`5m|%aD@!eF3^?fHtB#Sw6^IMHcl${(hs28X9LdH zL~VwOL{f!wj$VQ{h?uUr`Kp$;e&;1EPYwFpPgSj0?(JIHosC=N>8F=L$1vU^s`M3& zn_z6*4Te*)nR@-fyGEFfvP(}ilWWLUJz+uBs#Q;x{*rzWy@$JBht|f2YGtEIf#xr; zqVBiOv>mmsA(tW!qi6zJ#5UQ@Jon9^P@v=_$fke5Q5$C;*H;ERN~g*4=Ry!r^sb$9 z*vGKuigG7rs}_m}>$XCWCU`000bu>H=*R{%Hyobqd$8)xAq0+&${Q51-4@+ZOS{H9>Iy@h~{`+Qq|8&y$ zj{#s4gN*uHp@r3!&_d8_bEIeiHdwtyGSNgi7Fv|hry^4-CsJp$(%s5IY}y8}(l)b= zt!vreDBfq0OGP43+2~?09FSn{yQ3}NFfb(^_IOeYaNi(I#|@N$K+DsP3I>9kF067q z2@uu}i9V4DK`+7R(n{k7fCzt4CCn;~LohD9_k#0#X_$^f3ea-T6`qDYy5*>IK2zi# z1nUc^H_pY!lx%vPLWNx7BZF|CLEUpA^q_|o%T+I~;wq(iLMl0dZV88vqvbN4%(_(h zPQ}r?)UC`szj{}q1&b(W`w{~>7ec7PB*@ON0T0xXHTfu3yr&> zZlaZ>97Q4FUY3Gk@qj0Y+U-z$kOL?;1!swFTcA1-SjjM$Ov-+2U#A zWdqSi75)e_q;=p!3s99n%2pf6-}}$0v@Gx3!0e6tHfzfpk4T;MXJhpP76C5kD@Qw;q5p#Wl|wcBq9c!oA%wl%q_>dg;J0~d!&@3rpa5`I8M+1o-stj}rd zSfM;{-4hGRQ|jk>z5to%Bu$7PX-cN~pm)o@>MaC5Q}$>SOs?!92B9v4QqG;GfHXal zDClV5%cO9Qepl;_;lMT&;_0`=C!?df-rREbb*fA79Q5^mjeKrqU+)4Xu{Q_@Yrohu z|H@?jExNWsC0ccoYjC<%TJ>0{hjqPOIi$KekEwnJ5BvMH!2aDryq?aHHrZ`cM<<$Z zMbc;H2F9CYf!rKr-v+JQH%F~Bg9ZQjmB0R0pAq zoKasuN!kfT_N06Zubu?6PETiEX@+`LATY>TBk~}6HeHMvz-%~Oc4G-=!kbFk_oUP4 zIc74bW;g#$2TBnN-M8I{1bLPO3&=QLHi?;XGRPk=&j!1_;DDEXnUuIdW>J)8v<#OV z83&Ry$4`;Qy1UoKMd}Fbun}-_jOZ5xL;C&^nSvQbi@Huz{rMFE`|A|55-vJW&H&m| zl2e#%EPA%*vv0f)9~ODEJ**ERqb+|f1Jh1Imb1KHUOwVeash4wvgHz=k{FX2` zsRA3~LZe;J7=>#Ia+4y2?K~2A`k5s*%Jid$qAgmYW>8EcOQA-zTO0Rs}T)XTUndQ>AN` zKa(As=uy^|rJ}tsTY}ZOfwvTbiY#UAtu>=)_h8RK&YQ%>1wPA~n~=9o7a6tFp|XT` z%6|N3csR^m8sBY$%?S6EYgi`;5qL`}uH~EZ+r+*+(wdCBy#^Lq+{%%ldp9~_4f>6b z-ab9Z-%=CLBknSI?&Cz$tC-}-GV>jiwm=KAQF$CXBxFn)oH5{@BIa!t)kqE<;Cf6ty#cPCbVWHp8 zvg6(zf*w;OYk{ZN0=o;|7bCDx0+pNnI9Oqf1!>kfI)>&Fa}w@n${v$@mfz=+ z4c}yA!ns)+ON~}IbOyg{ij-}bv|@F5OWvl=JYKaMEl-H+DU`2?@bnNCG<|V=w$4e~ z27H=y(=mVyD+gnI_GE6Jj=E(HLK-EQdqQ?te0Gg|mJSljHRbG3p*p^|0}O~i9Y!$9 zthl3Lx0S9$mJ%1?bDn;tDa2$rsx7t5{`l(O#?DxMoon=$$W5jl)QDkjW{z7DlEC?>jU-| zY-MI@S<)!S+(L!-JEgXD83Y|PB!Hs+QRE%RWv_wh!EXsUc zv2fyW^F)kmNb{rQk0zU@05Xv1+Zb+FZvmj?r}r&yANSn;lCbURyYD=~<++ksYVDl) z^rer6LB;bjiV1bxhY3GSH>M6Y-^2gqqCshvh?6mxtQy!}da?nv$>N1V%&Dr0i)a*_ z7yN*8R)>#vB07&nbKehMO{Pg-s+`dR1rZB)Mj@}_Z2#oz)+aMiAY_dR3>hGH+AOoWcX{WHr9w6IwwpF4mU#N_pg%nyswX1fg^)kDY z<6^L|vd&Ze9h@&C0-V0wlxc8xVV{>*S~&{Gv6D7n(M9vG10EikFw3OK){3z2XvT{k zu*Zg<(?&b`k@j62M{R%$x{ zrUZ(rZ`1Y3RvSHKms0&FYiG%qw7=VBUg|He8_4fie){QvbqC>xhYYF1Qde#AKx8u? zODmMMIkV`+x^#yQk*xUWOC72cmH<01ac7w=M4%J~wht9lf+#b(p+!sO*@)M8HK3d6 z8vdcBC?vAW_N142SwMNEf4nCgU`u9U_$-@X2Cd#&5cX0)6D8|#+xhbgfms6DC2wX> zg(1uz_R%}?3|`pZjxJI>Oin%-M{IPj5bbTUeRfqdHaVSEK7iCQE!|0;uv^u;>So;Z z3Xx&3_hG&~fYtLHP4z3NMjn4zN7qf_YjMl7xU9Tp+;l+2b1YnDC;LhanQ@hml7nX& z(`TI<8h6X0p3F(EQn(V9J+u~e3-4T>Fp9f?DiHAxEXQkwv{^AN_5X9w_b;9}>17NK zk}0?;5P;bDA+Wczk=6YvI5;0a^p^b522w4Lb4MOv?nltZu;{d_S^RwYiT~N&xrQ}y zW^p{~1FVRpTUG(((ygtxsu&{;mq3(Osfwav5G|0F=uqxR41t6Y!3%9w5HFYr1fdEj zZD@-zM!A?4P;^MZiYW*I7m^VoM2Hv?knF*(eY)B&yI<-)%Y4X_xnw5u&U-m?{^$H& z)QXCqOQ0qj+>=TEz>t}$E9p@TYg}qvk<`?Yao&y&S6f1s37ij?Mb^LT_b$dS`ghNy|uT6jd-v8gp3& z&W~|lL$FCt90$uX@{L+_2=PakoD9sDFIAV3h$$({6e6y;8?3!TvL_jl8HWP!gc>>% z-4;U+O=`C6i)Sy9uO@O6pbpkNRnXMV)Mp}|Bx$+gSm9cMeB)iimB&p_z`fYrq~;k0 zIN^2ZBI<1~fhem-USvR(ya;Cx2f}c)clbiXvnff<>t*(VnDSAdAYG?Tmkb9>lUShR z`k3pI*SLO=LjdgH7sn<3x~6PiZzsGcA~UKlJ74LN@!VC%zvmfX>af`aGAojUJ{dqc znd)d58BVL6fhR$h>-#O{hTwHadntIGFEx1s^bBG{N-1ZdbH?4h9u#__H{!&4 zw}4B&_jV19A`PZG9+KS^ljQ+59@$zZ8~eDLEDizL<*Vux z=X$A`@ihBrI@>k!blX0x`Qa_>9!2U!tte5aHjEUf$Q5_nf)(}hM!i?GH_aEVjO_=v za_K4g^Z_L!5?#VxLY?=h^;G#^bK0rqxXndf)K};IdYyGJ+izP>5lVzv*~I8nH>7^& zeuzDUb_MrU-X8>>cRz~4eVT(a!_&98gob>6A2?!O*-={*Il<0K5-Sc(=^BcwLyg;2 zEYa2X6@9A4G#)=%UXKRo0LC$vO{b}|27nUwP33#s**`I%1J{8yv$iqA?2gQ}NMu^2 z+_bK^d7<6rv=fg}mcSE#=QFzc0tBq7iblkRl>|*EGSRI~j3?CYdZuyW| zhmHjcM*&VJ#|-O7zh6L8k99(xvx}((wA_I3r?n8{cw`&LLN|DxNeVxds5+B4F5I8I zYP3G=oRf1F|K6iuWy49|_k~vFX-V+y!hBWi_{mE#7ecQkO!bH@{y=-9YBq{E*ozl}{}IyEp%Y-YYA0c&TtA$<+$#$T3|a{A^i2{Gd+ z)T>r^GKRmhwP=_6Cp^7(5d=D2N}Gl!w7WVI^he{4 z@Z-GpAgEPGw#IPC0s#)3&^HVjza9j^(PN|2@(%QoFinsd5OfNtypTo}Yi8a7OSP}g zpEL_HOLOxi{hd}le%%y2J-EdN|HB1iAo`GmGaWlFX&WPeUReJ2iE9prvnrDI*%oG- z##_-LA+Ocla}L435tP>LQ9DrJv-F?*e6LDD(aiek$Au%yU|^tY)F(V?>Ps#?4@hJ2 z)${fK%TUGBi@MC(o2uLdm!8VP4m>^31gFA_-$A^^x+1`lRnNp%Oz;Qg)l0J^RWZKk z-3pcr$dn4^PBTSKhBuc9o&7K{T=wp+w$NltD`>@$wtIbYA)|(OQ@=7V@!ih3Tl}%^ z4HOKopi8&`a#j5;w;c0$a#(1-_p(ju#xv*7szRkHSk3G zw(Go*>x5Ds(Kue3DY&&1=JniDXaSIwi>7c5yBbo45KVU|8H&Znnqx3F0r;eu8tLSw&Y7}7%2&G|vnN1=Sp z*R;rl@Mv2^ByE{zR%xsVtYCx1P@)<*ACVYG1-MN|X)*zab0B``i9k4jWB3j9>V%yv zdInMNw=67;4FM7pwWU9AN@gz0n=Afg)VgQptX+q=$I8Vsxr?pg28<(S#L8g;PSH3tPA=K=D z4%CK&r&oc?vTY2dZT#^?hET<~2U{uE z-d;&A5bOSt@!T=1VlxplE z97#N}ER@*4Vnkw0wh{=Lc4o2+}d$)+*cu}pT$xXDpta?Hg|ngsuECc%_v YIcJMc1ni5RFf7-t+4Pan=hO6m0+8y>&;S4c literal 0 HcmV?d00001 diff --git a/examples/cvae-flax/data.py b/examples/cvae-flax/data.py index 09afc2bdc..657097a27 100644 --- a/examples/cvae-flax/data.py +++ b/examples/cvae-flax/data.py @@ -23,7 +23,7 @@ def load_dataset( X = Y.copy() _, m, n = X.shape - X = X[:, (m // 2):, : (n // 2)] + X = X[:, (m // 2) :, : (n // 2)] arrays = (X, Y) def init(): diff --git a/examples/cvae-flax/main.py b/examples/cvae-flax/main.py index 24cd13eb2..10507e82e 100644 --- a/examples/cvae-flax/main.py +++ b/examples/cvae-flax/main.py @@ -9,9 +9,6 @@ from train_baseline import train_baseline from train_cvae import train_cvae -from jax import random -import jax.numpy as jnp - from numpyro.examples.datasets import MNIST From 96dce0422f9f2e5af3e6ac00c169339972f979de Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Wed, 15 Jun 2022 23:56:35 +0200 Subject: [PATCH 5/6] Clean up the code a bit --- examples/cvae-flax/train_baseline.py | 1 - examples/cvae-flax/train_cvae.py | 16 ---------------- 2 files changed, 17 deletions(-) diff --git a/examples/cvae-flax/train_baseline.py b/examples/cvae-flax/train_baseline.py index 3478806e7..9581180ea 100644 --- a/examples/cvae-flax/train_baseline.py +++ b/examples/cvae-flax/train_baseline.py @@ -16,7 +16,6 @@ def create_train_state(model, x, learning_rate_fn): return state -@jax.jit def train_step(state, x_batched, y_batched): def loss_fn(params): y_pred = state.apply_fn(params, x_batched) diff --git a/examples/cvae-flax/train_cvae.py b/examples/cvae-flax/train_cvae.py index 9fb1c3ccb..cd32c7a09 100644 --- a/examples/cvae-flax/train_cvae.py +++ b/examples/cvae-flax/train_cvae.py @@ -1,8 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from models import cross_entropy_loss - from flax import traverse_util import jax from jax import lax, numpy as jnp, random @@ -44,20 +42,6 @@ def create_train_state( return svi, state -@jax.jit -def train_step(state, x_batched, y_batched): - def loss_fn(params): - y_pred = state.apply_fn(params, x_batched) - loss = cross_entropy_loss(y_pred, y_batched) - return jnp.sum(loss) - - grad_fn = jax.value_and_grad(loss_fn) - loss, grads = grad_fn(state.params) - - new_state = state.apply_gradients(grads=grads) - return new_state, loss - - def train_epoch(svi, state, train_fetch, num_train, train_idx, epoch_rng): def _fn(i, val): state, loss_sum = val From 9081ad2ae395d7f62dd059144d39d794adc48a82 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Sat, 18 Jun 2022 00:26:06 +0200 Subject: [PATCH 6/6] Move some files and add entry in sphinx docs --- .../source/_static/img/examples/cvae.png | Bin docs/source/index.rst | 1 + examples/cvae-flax/README.md | 5 +-- examples/cvae.py | 31 ++++++++++++++++++ 4 files changed, 33 insertions(+), 4 deletions(-) rename examples/cvae-flax/assets/cvae_predictions.png => docs/source/_static/img/examples/cvae.png (100%) create mode 100644 examples/cvae.py diff --git a/examples/cvae-flax/assets/cvae_predictions.png b/docs/source/_static/img/examples/cvae.png similarity index 100% rename from examples/cvae-flax/assets/cvae_predictions.png rename to docs/source/_static/img/examples/cvae.png diff --git a/docs/source/index.rst b/docs/source/index.rst index 13b9c1086..438b5706c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -73,6 +73,7 @@ NumPyro documentation examples/holt_winters examples/mortality examples/zero_inflated_poisson + examples/cvae .. nbgallery:: :maxdepth: 1 diff --git a/examples/cvae-flax/README.md b/examples/cvae-flax/README.md index ac022fd19..222e3913c 100644 --- a/examples/cvae-flax/README.md +++ b/examples/cvae-flax/README.md @@ -1,7 +1,6 @@ ## Conditional Variational Autoencoder in Flax Trains a *Conditional Variational Autoencoder* (CVAE) on the MNIST data using Flax' neural network API. -The model is a port of [Pyro's CVAE example](https://pyro.ai/examples/cvae.html) which describes the model as well as the data. The model first trains a baseline to predict an entire MNIST image from a single quadrant of it (i.e., input is one quadrant of an image, output is the entire image (not the other three quadrants)). Then, in a second model, the generation/prior/recognition nets of the CVAE are trained while keeping the model parameters of the baseline fixed/frozen. @@ -9,7 +8,5 @@ We use Optax' `multi_transform` to apply different gradient transformations to t Running `main.py` trains the model(s) and plots a figure in the end comparing the baseline prediction with the CVAE prediction like this one: -![CVAE prediction](./assets/cvae_predictions.png) - - +![CVAE prediction](https://github.com/pyro-ppl/numpyro/tree/master/docs/source/_static/img/examples/cvae.png) diff --git a/examples/cvae.py b/examples/cvae.py new file mode 100644 index 000000000..d7a340498 --- /dev/null +++ b/examples/cvae.py @@ -0,0 +1,31 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +""" +Example: Conditional Variational Autoencoder in Flax +==================================================== + +This example trains a *Conditional Variational Autoencoder* (CVAE) [1] on the MNIST data +using Flax' neural network API. The implementation can be found here: +https://github.com/pyro-ppl/numpyro/tree/master/examples/cvae-flax + +The model is a port of Pyro's excellent CVAE example which describes the model as well as the data in detail: +https://pyro.ai/examples/cvae.html + +The model first trains a baseline to predict an entire MNIST image from a single quadrant of it +(i.e., input is one quadrant of an image, output is the entire image (not the other three quadrants)). +Then, in a second model, the generation/prior/recognition nets of the CVAE are trained while keeping the model +parameters of the baseline fixed/frozen. We use Optax' `multi_transform` to apply different gradient transformations +to the trainable parameters and the frozen parameters. + + +.. image:: ../_static/img/examples/cvae.png + :align: center + +**References:** + + 1. Kihyuk Sohn, Xinchen Yan, Honglak Lee (2015), "Learning Structured Output Representation using Deep + Conditional Generative Models + (https://papers.nips.cc/paper/5775-learning-structured-output-representation-using-deep-conditional-generative-models) + +"""