{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "P26jqMgJh3Q9" }, "source": [ "# Tutorial 2: Data Parallel and Fully Sharded Data Parallel Training\n", "\n", "[![Open in GitHub](https://img.shields.io/badge/Open%20in-GitHub-181717?style=flat-square&logo=github)](https://github.com/sshkhr/MinText/blob/main/docs/tutorials/2_Data_Parallel_and_FSDP.ipynb)\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sshkhr/MinText/blob/main/docs/tutorials/2_Data_Parallel_and_FSDP.ipynb)" ] }, { "cell_type": "markdown", "metadata": { "id": "Y589kcC_h84I" }, "source": [ "In the previous tutorial, we explored the basics of JAX parallelization, including device meshes, sharded matrices, and collective operations. In this tutorial, we'll build on those concepts to explore the first parallelism strategy used in scaling models: data parallelism (DP). After data parallelism, we will learn about the second parallelism strategy used in scaling models: Fully Sharded Data Parallelism (FSDP).\n", "\n", "**Learning objectives:**\n", "- Understand data parallelism\n", "- Learn about implementing neural networks in Flax\n", "- Learn about the Fully Sharded Data Parallel (FSDP) strategy\n", "\n", "**Prerequisites (covered in Tutorial 1):**\n", "- Basic familiarity with JAX\n", "- Understanding of JAX sharding\n", "- Understanding of sharded matrix operations\n" ] }, { "cell_type": "markdown", "metadata": { "id": "BZry_w4N15Vy" }, "source": [ "## 0. Background" ] }, { "cell_type": "markdown", "metadata": { "id": "C_wGnB2mh3Q9" }, "source": [ "### 0.1 Setup\n", "\n", "Let's start by importing the necessary libraries and initializing our environment." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "si6y3Ea7h3Q-" }, "outputs": [], "source": [ "import os\n", "# Force JAX to use 8 GPU devices for this tutorial (only use if not using TPU runtime)\n", "#os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax.sharding import Mesh, PartitionSpec as P, NamedSharding\n", "from jax.experimental import mesh_utils\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import time\n", "from functools import partial\n", "import dataclasses" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "IVxCl7LUigdF", "outputId": "b97acd48-a735-48ea-c3d9-6e7395efb741" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "JAX version: 0.5.2\n", "Available devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1)]...\n", "Number of devices: 8\n" ] } ], "source": [ "# Check available devices\n", "print(f\"JAX version: {jax.__version__}\")\n", "print(f\"Available devices: {jax.devices()[:4]}...\")\n", "print(f\"Number of devices: {jax.device_count()}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "3kAeWl08h3Q-" }, "source": [ "### 0.2 Model Representation\n", "\n", "For simplicity, we will start with a simple feed-forward model. The model consists of two fully-connected (or dense) layers:\n", "\n", "- **Win**: `bf16[D, F]` (up-projection)\n", "- **Wout**: `bf16[F, D]` (down-projection)\n", "\n", "And the input and output are defined as:\n", "- **Input**: `bf16[B, D]`\n", "- **Out**: `bf16[B, D]`\n", "\n", "Where:\n", "- **D** = dmodel (input/output dimension)\n", "- **F** = dff (feed-forward or hidden dimension)\n", "- **B** = batch size (total tokens)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ackK14iBw8Mb" }, "source": [ "![MLP](https://github.com/jax-ml/scaling-book/blob/main/assets/img/simple-transformer.png?raw=true)\n", "\n", " Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) " ] }, { "cell_type": "markdown", "metadata": { "id": "k7P0-zkOuMNM" }, "source": [ "### 0.3 Why Parallelism?\n", "\n", "As we scale up our models and datasets, we need to distribute the computation across multiple devices in order to train these models in reasonable periods of time. For example, 2048 A100 GPUs were used to train LLaMa. If they were to train LLama 3 8B model (the smallest Llama 3 model) on a single A100 GPU, it would take approximately 1.46 million hours (or 166 years) to complete the training. This is why we need parallelism.\n", "\n", "We also want to optimize the time spent doing various computations in the training process - since some operations are independent of each other, we can spread them across multiple devices to speed up the training process. We will see more examples of this when we discuss each training strategy.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "HYoUSMnyxS2A" }, "source": [ "### 0.4 Communication vs Computation Trade-offs\n", "\n", "The goal of scaling is to achieve **strong scaling**, a linear increase in throughput with more chips. In order to achieve good scaling performance, parallel algorithms are designed to hide inter-chip communication by overlapping it with useful FLOPs." ] }, { "cell_type": "markdown", "metadata": { "id": "NBN_ggN1IsbS" }, "source": [ "Let $T_\\text{ops}$ represent the time spent on computation (FLOPs) and $T_{\\text{comms}}$ represent the time spent on communication (data transfer). For the purpose of this tutorial, let us ignore intra-chip communication cost and focus on inter-chip communication costs (although in the real world, you would consider both - but let's assume that intra-chip communication and computation overlap have already been optimized).\n", "\n", "The ratio of these two times ($T_\\text{ops}$ and $T_{\\text{comms}}$) is crucial for determining the efficiency of our parallel training:\n", "\n", "An algorithm become **compute-bound** when:\n", "\n", "$$\\frac{T_{\\text{ops}}}{T_{\\text{comms}}} > 1$$\n", "\n", "This means that the time spent on computation is greater than the time spent on communication, and we can achieve good scaling performance by hiding communication time between the time spend on computation. Let us see what this means in practice." ] }, { "cell_type": "markdown", "metadata": { "id": "Wk_MulhTxMes" }, "source": [ "## 1. Data Parallel" ] }, { "cell_type": "markdown", "metadata": { "id": "CC5U0fvpIsbS" }, "source": [ "Data parallelism (DP) is a parallelization strategy where we split the input data batch across multiple devices, allowing each device to compute on a different subset of the data. This is particularly useful for large datasets that cannot fit into the memory of a single device. When the model is small enough to fit on a single device, we can use data parallelism to scale up the training by distributing the data across multiple devices." ] }, { "cell_type": "markdown", "metadata": { "id": "5BybVQZ6h3Q-" }, "source": [ "### 1.1 Data Parallelism Theory\n", "\n", "**Sharding**: Input data and model activations are sharded along batch dimension across devices, model parameters are replicated on each device.\n", "\n", "**Equation** (for our MLP example):\n", "$$\\text{In}[B_X, D] \\cdot_D W_{\\text{in}}[D, F] \\cdot_F W_{\\text{out}}[F, D] \\rightarrow \\text{Out}[B_X, D]$$\n", "\n", "where $B_X$ indicates the batch is sharded across $X$ devices." ] }, { "cell_type": "markdown", "metadata": { "id": "_Kb3_BRYIsbS" }, "source": [ "![Data Parallelism](https://github.com/jax-ml/scaling-book/blob/main/assets/img/data-parallelism.png?raw=true)\n", "\n", " Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Data Parallel Algorithm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One of the key advantages of data parallelism is that there is no communication between devices during the forward pass. Each device computes its own forward pass independently on the sharded batch without moving any data around." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 Forward pass\n", "\n", "1. Tmp[BX, F] = In[BX, D] ×D Win[D, F]\n", "2. Out[BX, D] = Tmp[BX, F] ×F Wout[F, D]\n", "3. Loss[BX] = ..." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the backward pass, we need the gradients from the full batch across all devices. In order to do so, we first compute the gradients for each device (Steps 2 and 5 below). However, these are only the gradients for the sharded batch on each device. To compute the gradients for the full batch, we need to aggregate the gradients across all devices (Steps 3 and 6 below). This is done with `AllReduce` operation on the gradients from all devices.\n", "\n", "One important thing to note is that the backward pass computation for the current iteration does not involve the collected gradients from this iteration. We say that the gradient accumulation in the backward passs is not on a **critical path** - the rest of the backward pass can continue while the gradients are being communicated. This means that we can overlap the communication of the gradients with the computation on each device. This is crucial for achieving good scaling performance." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 Backward pass\n", "\n", "1. dOut[BX, D] = ...\n", "2. dWout[F, D] {UX} = Tmp[BX, F] \\*B dOut[BX, D]\n", "3. dWout[F, D] = **AllReduce**(dWout[F, D] {UX}) (*not on critical path, can be done async*)\n", "4. dTmp[BX, F] = dOut[BX, D] \\*D Wout[F, D]\n", "5. dWin[D, F] {UX} = In[BX, D] \\*B dTmp[BX, F]\n", "6. dWin[D, F] = **AllReduce**(dWin[D, F] {UX}) (*not on critical path, can be done async*)\n", "7. dIn[BX, D] = dTmp[BX, F] \\*F Win[D, F] (*needed for previous layers*)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.3 Compute Bound versus Communication Bound\n", "\n", "As we can see above, we have two AllReduces per layer, each of size $$2DF$$ (for bf16 weights). When does data parallelism make us communication bound?\n", "\n", "Let $C$ = per-chip FLOPs, $W_{\\text{ici}}$ = **bidirectional** network bandwidth, and $X$ = number of shards across which the batch is partitioned. Let's calculate the time required to perform the relevant matmuls (compute time), $T_\\text{math}$, and the required communication time $T_\\text{comms}$. Since this parallelism scheme requires no communication in the forward pass, we only need to calculate these quantities for the backwards pass.\n", "\n", "#### Communication time\n", "\n", "Time required to perform an AllReduce in a 1D mesh depends only on the total bytes of the array being AllReduced and the ICI bandwidth $W_\\text{ici}$; specifically the AllReduce time is $2 \\cdot \\text{total bytes} / W_\\text{ici}$. Since we need to AllReduce for both $W_\\text{in}$ and $W_\\text{out}$, we have 2 AllReduces per layer. Each AllReduce is for a weight matrix, i.e. an array of $DF$ parameters, or $2DF$ bytes. Putting this all together, the total time for the AllReduce in a single layer is\n", "\n", "$$\\begin{align}\n", "T_\\text{comms} &= \\frac{2 \\cdot 2 \\cdot 2 \\cdot D \\cdot F}{W_\\text{ici}}. \\\\\n", "\\end{align}$$\n", "\n", "#### Computation time\n", "\n", "Each layer comprises two matmuls in the forward pass, or four matmuls in the backwards pass, each of which requires $2(B/X)DF$ FLOPs. Thus, for a single layer in the backward pass, we have\n", "\n", "$$\\begin{align}\n", "T_\\text{math} &= \\frac{2 \\cdot 2 \\cdot 2 \\cdot B \\cdot D \\cdot F}{X \\cdot C} \\\\\n", "\\end{align}$$\n", "\n", "Since we overlap, the total time per layer is the max of these two quantities:\n", "\n", "$$\\begin{aligned}\n", "T &\\approx \\max(\\frac{8 \\cdot B \\cdot D \\cdot F}{X \\cdot C}, \\frac{8 \\cdot D \\cdot F}{W_\\text{ici}}) \\\\\n", "T &\\approx 8 \\cdot D \\cdot F \\cdot \\max(\\frac{B}{X \\cdot C}, \\frac{1}{W_\\text{ici}})\n", "\\end{aligned}$$\n", "\n", "We become compute-bound when $T_\\text{math}/T_\\text{comms} > 1$, or when $\\frac{B}{X} > \\frac{C}{W_\\text{ici}}.$\n", "\n", "The upshot is that, to remain compute-bound with data parallelism, we need the per-device batch size $$B / X$$ to exceed the ICI operational intensity, $C / W_\\text{ici}$. This is ultimately a consequence of the fact that the computation time scales with the per-device batch size, while the communication time is independent of this quantity (since we are transferring model weights).\n", "\n", "For TPU v2 pods, we follow the same analytical framework to determine when data parallelism becomes communication-bound.\n", "\n", "**TPU v2 Specifications:**\n", "- Per-chip compute: C = 4.5e13 FLOPS (45 TFLOPS at bfloat16)\n", "- ICI bandwidth: W_ici = 2.48e11 bytes/s (4 links @ 496 Gbits/s per direction = 248 GB/s unidirectional)\n", "- 2D torus topology connecting up to 256 chips\n", "\n", "**Critical Condition:**\n", "Data parallelism remains compute-bound when the per-device batch size exceeds:\n", "\n", "$$\\frac{B}{X} > \\frac{C}{W_{ici}} = \\frac{4.5 \\times 10^{13}}{2.48 \\times 10^{11}} \\approx 181$$\n", "\n", "**Key Results:**\n", "- **Minimum batch size per chip: ~181 tokens** to avoid communication bottleneck\n", "- For a full 256-chip TPU v2 pod, this translates to a minimum global batch size of ~46K tokens\n", "\n", "\n", "**Practical Implications:**\n", "The relatively low threshold of 181 tokens per chip makes TPU v2 pods well-suited for typical training workloads. **This tells us that it's fairly hard to become bottlenecked by pure data parallelism!** Most production models use batch sizes well above this threshold, ensuring efficient utilization without communication bottlenecks during data-parallel training." ] }, { "cell_type": "markdown", "metadata": { "id": "LmEAlrsuh3Q-" }, "source": [ "## 3. Example: 8-way Data Parallel Training with Plain JAX" ] }, { "cell_type": "markdown", "metadata": { "id": "Z1skQ8igh3Q-" }, "source": [ "### 3.1 Create dataset and define model\n", "\n", "First, let's generate our synthetic dataset and simple feed-forward neural network." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "aka1dFaWh3Q-" }, "outputs": [], "source": [ "def get_linear_layer(key, dim_in, dim_hidden):\n", " k1, k2 = jax.random.split(key)\n", " W = jax.random.normal(k1, (dim_in, dim_hidden)) / jnp.sqrt(dim_in)\n", " b = jax.random.normal(k2, (dim_hidden,))\n", " return W, b\n", "\n", "def get_model_and_data(key, layer_sizes, batch_size):\n", " keys, *keys = jax.random.split(key, len(layer_sizes))\n", "\n", " model = list(map(get_linear_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", " # The model is just a list of linear layers. A more readable version of the above is:\n", " # model = [\n", " # get_linear_layer(k, in_dim, out_dim)\n", " # for k, in_dim, out_dim in zip(keys, layer_sizes[:-1], layer_sizes[1:])\n", " #]\n", "\n", " keys, *keys = jax.random.split(key, 2)\n", " input_data = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n", " target_data = jax.random.normal(keys[0], (batch_size, layer_sizes[-1]))\n", "\n", " # data is just random numbers for both inputs and outputs\n", "\n", " return model, (input_data, target_data)" ] }, { "cell_type": "markdown", "metadata": { "id": "b93qN9SxIsbT" }, "source": [ "We will use a simple feed-forward neural network with two linear layers, as described in the background section. The input and output dimensions are set to 128, which is a common choice of hidden dimension in transformer models. The batch size is set to 8192." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "y_EO4Cc99Lo6" }, "outputs": [], "source": [ "# A simple sqeuence-modelling architecture: 128 -> 2048 -> 2048 -> 128\n", "layer_sizes = [128, 2048, 2048, 128]\n", "batch_size = 8192" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "rUkwAjZr-LnO" }, "outputs": [], "source": [ "model, batch = get_model_and_data(jax.random.key(0), layer_sizes, batch_size)" ] }, { "cell_type": "markdown", "metadata": { "id": "_lKR_AJvIsbT" }, "source": [ "Now we define our model's forward pass and loss function. We are using JAX for now, so this might seem a bit verbose if you are coming from PyTorch or TensorFlow, but we will see how to simplify the neural network modelling with Flax in the next section." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "ftK4v74m-ZP9" }, "outputs": [], "source": [ "def predict(model, inputs):\n", " for W, b in model:\n", " outputs = jnp.dot(inputs, W) + b\n", " inputs = jnp.maximum(outputs, 0) # ReLU activation\n", " return outputs\n", "\n", "def loss(model, batch):\n", " inputs, targets = batch\n", " predictions = predict(model, inputs)\n", " return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))" ] }, { "cell_type": "markdown", "metadata": { "id": "Y0M8bdIkIsbT" }, "source": [ "Finally, we will compile our model using `jax.jit` to utilize JAX compiler to optimize the performance of our forward and backward pass." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "CdizzgTi_olY" }, "outputs": [], "source": [ "loss_jit = jax.jit(loss)\n", "gradfun = jax.jit(jax.grad(loss))" ] }, { "cell_type": "markdown", "metadata": { "id": "nZyEdC3Zh3Q-" }, "source": [ "### 3.2 Single-Device Baseline\n", "\n", "Let's first establish a baseline by training on a single device." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "KFq25YzKh3Q-" }, "outputs": [], "source": [ "batch_single = jax.device_put(batch, jax.devices()[0])\n", "params_single = jax.device_put(model, jax.devices()[0])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "a1wHGqYh_0gR", "outputId": "de8e7e2b-9511-49a9-dcde-c0eb27861cd7" }, "outputs": [ { "data": { "text/plain": [ "Array(435.76917, dtype=float32)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss_jit(params_single, batch_single)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bUvLRRplAUK7", "outputId": "6977aaf2-2d10-4674-a92a-569bfafa6ab4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The slowest run took 20.17 times longer than the fastest. This could mean that an intermediate result is being cached.\n", "58.1 ms ± 92 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], "source": [ "%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()" ] }, { "cell_type": "markdown", "metadata": { "id": "wEhK17TPh3Q_" }, "source": [ "### 3.3 8-way Data Parallel Training\n", "\n", "Now let's implement 8-way data parallel training where we'll shard the batch across 8 devices." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mhH5pGm_h3Q_", "outputId": "364e6466-34d3-4f85-b41a-7f885156e5d1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mesh shape: OrderedDict([('batch', 8)])\n", "Mesh axis names: ('batch',)\n" ] } ], "source": [ "# Create an 8-device mesh for data parallelism\n", "mesh = jax.make_mesh((8,), ('batch',))\n", "print(f\"Mesh shape: {mesh.shape}\")\n", "print(f\"Mesh axis names: {mesh.axis_names}\")\n", "\n", "# Create sharding specifications\n", "\n", "## Shard data along the batch dimension\n", "batch_sharding = NamedSharding(mesh, P('batch'))\n", "## Replicate parameters across all devices\n", "replicated_sharding = NamedSharding(mesh, P())" ] }, { "cell_type": "markdown", "metadata": { "id": "zxvHebcwIsbU" }, "source": [ "We shard our batch along `batch` dimension, while our model is replicated across all devices." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "6I47F3-tADf6" }, "outputs": [], "source": [ "batch = jax.device_put(batch, batch_sharding)\n", "params = jax.device_put(model, replicated_sharding)" ] }, { "cell_type": "markdown", "metadata": { "id": "dtIz6yahIsbU" }, "source": [ "Let's visualize how the batch is sharded across devices." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 460 }, "id": "dfGRyGI3IsbU", "outputId": "f9cfad37-7edb-4d7c-da8c-389dcc77ff1f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Visualizing batch sharding across 8 devices:\n", "Original batch shape: (8192, 128)\n", "TPU_0(process=0,(0,0,0,0)) slice(0, 1024, None) (1024, 128)\n", "TPU_1(process=0,(0,0,0,1)) slice(1024, 2048, None) (1024, 128)\n", "TPU_2(process=0,(1,0,0,0)) slice(2048, 3072, None) (1024, 128)\n", "TPU_3(process=0,(1,0,0,1)) slice(3072, 4096, None) (1024, 128)\n", "TPU_6(process=0,(1,1,0,0)) slice(4096, 5120, None) (1024, 128)\n", "TPU_7(process=0,(1,1,0,1)) slice(5120, 6144, None) (1024, 128)\n", "TPU_4(process=0,(0,1,0,0)) slice(6144, 7168, None) (1024, 128)\n", "TPU_5(process=0,(0,1,0,1)) slice(7168, 8192, None) (1024, 128)\n" ] }, { "data": { "text/html": [ "
  TPU 0  \n",
              "         \n",
              "  TPU 1  \n",
              "         \n",
              "  TPU 2  \n",
              "         \n",
              "  TPU 3  \n",
              "         \n",
              "  TPU 6  \n",
              "         \n",
              "  TPU 7  \n",
              "         \n",
              "  TPU 4  \n",
              "         \n",
              "  TPU 5  \n",
              "         \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(\"Visualizing batch sharding across 8 devices:\")\n", "print(\"Original batch shape:\", batch[0].shape)\n", "\n", "for shard in batch[0].addressable_shards:\n", " print(shard.device, shard.index[0], shard.data.shape)\n", "\n", "jax.debug.visualize_array_sharding(batch[0])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HZ4VRe8bh3Q_", "outputId": "2ce6a772-5f97-4c7a-d720-efc3e78db7fe" }, "outputs": [ { "data": { "text/plain": [ "Array(435.7692, dtype=float32)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss_jit(params, batch)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tdS8RpmEAJWh", "outputId": "1712638b-8830-4ff1-9a92-945d5e8bed51" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The slowest run took 82.19 times longer than the fastest. This could mean that an intermediate result is being cached.\n", "60 ms ± 113 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], "source": [ "%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3kmUnp4fRw8i", "outputId": "9ddb9bb4-a9b6-4069-8809-3c013a762955" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 0, loss: 409.836945\n", "Step 5, loss: 312.069946\n", "Step 10, loss: 251.479156\n", "Step 15, loss: 213.499405\n", "Step 20, loss: 189.462906\n", "Step 25, loss: 174.309052\n", "\n", "Final loss: 166.210480\n" ] } ], "source": [ "step_size = 1e-5\n", "losses = []\n", "\n", "for step in range(30):\n", " grads = gradfun(params, batch)\n", " params = [(W - step_size * dW, b - step_size * db)\n", " for (W, b), (dW, db) in zip(params, grads)]\n", "\n", " current_loss = loss_jit(params, batch)\n", " losses.append(float(current_loss))\n", "\n", " if step % 5 == 0:\n", " print(f\"Step {step}, loss: {current_loss:.6f}\")\n", "\n", "print(f\"\\nFinal loss: {loss_jit(params, batch):.6f}\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 564 }, "id": "aa2z1Y9GIsbV", "outputId": "87cd814b-ef1e-4278-cdde-c8fb49e88bc2" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAArcAAAIjCAYAAAAZajMiAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAZ5VJREFUeJzt3Xd4VGXCxuFnZtJIJYUkBEKABAihdwJSpAQQEQQVCwJ2IRZE0cUVpOiiWNaGiA1QxIILKAhCpEuRDqGFToCQhJZCQvp8f0Tm2yxVTHKSye++rrl0zpyZ85y8GXx8OcVktVqtAgAAAOyA2egAAAAAQHGh3AIAAMBuUG4BAABgNyi3AAAAsBuUWwAAANgNyi0AAADsBuUWAAAAdoNyCwAAALtBuQUAAIDdoNwCqNCGDh2qmjVr3tR7x40bJ5PJVLyBAAB/C+UWQJlkMplu6LFy5Uqjoxpi6NChcnd3NzrGDZs3b5569eolPz8/OTk5KSgoSPfcc4+WL19udDQAdsZktVqtRocAgP81a9asIs+/+uorxcTE6Ouvvy6yvHv37goICLjp7eTm5qqgoEDOzs5/+b15eXnKy8uTi4vLTW//Zg0dOlQ//vijLly4UOrb/iusVqsefvhhzZgxQ82aNdNdd92lwMBAnTp1SvPmzdOWLVu0du1atWvXzuioAOyEg9EBAOBKBg0aVOT5hg0bFBMTc9ny/5WZmSlXV9cb3o6jo+NN5ZMkBwcHOTjwx+i1vPPOO5oxY4ZGjBihd999t8hhHP/85z/19ddfF8vP0Gq1KisrS5UqVfrbnwWgfOOwBADlVufOndWwYUNt2bJFHTt2lKurq15++WVJ0k8//aTevXsrKChIzs7OCg0N1cSJE5Wfn1/kM/73mNujR4/KZDLp7bff1qeffqrQ0FA5OzurVatW2rRpU5H3XumYW5PJpKeeekrz589Xw4YN5ezsrAYNGujXX3+9LP/KlSvVsmVLubi4KDQ0VNOmTSv243jnzJmjFi1aqFKlSvLz89OgQYN08uTJIuskJibqoYceUvXq1eXs7KyqVauqb9++Onr0qG2dzZs3q0ePHvLz81OlSpVUq1YtPfzww9fc9sWLFzVp0iSFh4fr7bffvuJ+Pfjgg2rdurWkqx/DPGPGDJlMpiJ5atasqdtvv11LlixRy5YtValSJU2bNk0NGzbUrbfeetlnFBQUqFq1arrrrruKLHvvvffUoEEDubi4KCAgQE888YTOnz9/zf0CULYx5QCgXDt79qx69eqle++9V4MGDbIdojBjxgy5u7tr5MiRcnd31/LlyzV27FilpaXprbfeuu7nzp49W+np6XriiSdkMpk0efJk9e/fX4cPH77ubO/vv/+uuXPnavjw4fLw8NAHH3ygAQMGKD4+Xr6+vpKkbdu2qWfPnqpatarGjx+v/Px8TZgwQVWqVPn7P5Q/zZgxQw899JBatWqlSZMmKSkpSe+//77Wrl2rbdu2qXLlypKkAQMGaPfu3Xr66adVs2ZNJScnKyYmRvHx8bbnUVFRqlKliv7xj3+ocuXKOnr0qObOnXvdn8O5c+c0YsQIWSyWYtuvS+Li4nTffffpiSee0GOPPaZ69epp4MCBGjdunBITExUYGFgkS0JCgu69917bsieeeML2M3rmmWd05MgRffTRR9q2bZvWrl37t2b1ARjICgDlQHR0tPV//8jq1KmTVZL1k08+uWz9zMzMy5Y98cQTVldXV2tWVpZt2ZAhQ6whISG250eOHLFKsvr6+lrPnTtnW/7TTz9ZJVkXLFhgW/bqq69elkmS1cnJyXrw4EHbsh07dlglWT/88EPbsj59+lhdXV2tJ0+etC07cOCA1cHB4bLPvJIhQ4ZY3dzcrvp6Tk6O1d/f39qwYUPrxYsXbcsXLlxolWQdO3as1Wq1Ws+fP2+VZH3rrbeu+lnz5s2zSrJu2rTpurn+2/vvv2+VZJ03b94NrX+ln6fVarVOnz7dKsl65MgR27KQkBCrJOuvv/5aZN24uLjLftZWq9U6fPhwq7u7u+33Ys2aNVZJ1m+++abIer/++usVlwMoPzgsAUC55uzsrIceeuiy5f997GV6errOnDmjDh06KDMzU/v27bvu5w4cOFDe3t625x06dJAkHT58+Lrv7datm0JDQ23PGzduLE9PT9t78/Pz9dtvv6lfv34KCgqyrRcWFqZevXpd9/NvxObNm5WcnKzhw4cXOeGtd+/eCg8P1y+//CKp8Ofk5OSklStXXvWv4y/N8C5cuFC5ubk3nCEtLU2S5OHhcZN7cW21atVSjx49iiyrW7eumjZtqu+//962LD8/Xz/++KP69Olj+72YM2eOvLy81L17d505c8b2aNGihdzd3bVixYoSyQyg5FFuAZRr1apVk5OT02XLd+/erTvvvFNeXl7y9PRUlSpVbCejpaamXvdza9SoUeT5paJ7I8dj/u97L73/0nuTk5N18eJFhYWFXbbelZbdjGPHjkmS6tWrd9lr4eHhttednZ315ptvavHixQoICFDHjh01efJkJSYm2tbv1KmTBgwYoPHjx8vPz099+/bV9OnTlZ2dfc0Mnp6ekgr/56Ik1KpV64rLBw4cqLVr19qOLV65cqWSk5M1cOBA2zoHDhxQamqq/P39VaVKlSKPCxcuKDk5uUQyAyh5lFsA5dqVzo5PSUlRp06dtGPHDk2YMEELFixQTEyM3nzzTUmFJxJdz9WOEbXewNUT/857jTBixAjt379fkyZNkouLi8aMGaP69etr27ZtkgpPkvvxxx+1fv16PfXUUzp58qQefvhhtWjR4pqXIgsPD5ckxcbG3lCOq51I978nAV5ytSsjDBw4UFarVXPmzJEk/fDDD/Ly8lLPnj1t6xQUFMjf318xMTFXfEyYMOGGMgMoeyi3AOzOypUrdfbsWc2YMUPPPvusbr/9dnXr1q3IYQZG8vf3l4uLiw4ePHjZa1dadjNCQkIkFZ509b/i4uJsr18SGhqq559/XkuXLtWuXbuUk5Ojd955p8g6bdu21euvv67Nmzfrm2++0e7du/Xdd99dNcMtt9wib29vffvtt1ctqP/t0vikpKQUWX5plvlG1apVS61bt9b333+vvLw8zZ07V/369StyLePQ0FCdPXtW7du3V7du3S57NGnS5C9tE0DZQbkFYHcuzZz+90xpTk6OPv74Y6MiFWGxWNStWzfNnz9fCQkJtuUHDx7U4sWLi2UbLVu2lL+/vz755JMihw8sXrxYe/fuVe/evSUVXhc4KyuryHtDQ0Pl4eFhe9/58+cvm3Vu2rSpJF3z0ARXV1e99NJL2rt3r1566aUrzlzPmjVLGzdutG1XklavXm17PSMjQzNnzrzR3bYZOHCgNmzYoC+//FJnzpwpckiCJN1zzz3Kz8/XxIkTL3tvXl7eZQUbQPnBpcAA2J127drJ29tbQ4YM0TPPPCOTyaSvv/66TB0WMG7cOC1dulTt27fXsGHDlJ+fr48++kgNGzbU9u3bb+gzcnNz9dprr1223MfHR8OHD9ebb76phx56SJ06ddJ9991nuxRYzZo19dxzz0mS9u/fr65du+qee+5RRESEHBwcNG/ePCUlJdkumzVz5kx9/PHHuvPOOxUaGqr09HR99tln8vT01G233XbNjKNGjdLu3bv1zjvvaMWKFbY7lCUmJmr+/PnauHGj1q1bJ0mKiopSjRo19Mgjj2jUqFGyWCz68ssvVaVKFcXHx/+Fn25heX3hhRf0wgsvyMfHR926dSvyeqdOnfTEE09o0qRJ2r59u6KiouTo6KgDBw5ozpw5ev/994tcExdA+UG5BWB3fH19tXDhQj3//PN65ZVX5O3trUGDBqlr166XnV1vlBYtWmjx4sV64YUXNGbMGAUHB2vChAnau3fvDV3NQSqcjR4zZsxly0NDQzV8+HANHTpUrq6ueuONN/TSSy/Jzc1Nd955p958803bFRCCg4N13333admyZba7hYWHh+uHH37QgAEDJBUWwY0bN+q7775TUlKSvLy81Lp1a33zzTdXPanrErPZrK+++kp9+/bVp59+qrfffltpaWmqUqWK7eS1yMhISYV3i5s3b56GDx+uMWPGKDAwUCNGjJC3t/cVr4hxLdWrV1e7du20du1aPfroo1e8Zu0nn3yiFi1aaNq0aXr55Zfl4OCgmjVratCgQWrfvv1f2h6AssNkLUtTGQBQwfXr10+7d+/WgQMHjI4CAOUSx9wCgEEuXrxY5PmBAwe0aNEide7c2ZhAAGAHmLkFAINUrVpVQ4cOVe3atXXs2DFNnTpV2dnZ2rZtm+rUqWN0PAAolzjmFgAM0rNnT3377bdKTEyUs7OzIiMj9a9//YtiCwB/AzO3AAAAsBsccwsAAAC7QbkFAACA3eCYWxXeYzwhIUEeHh5Xvbc5AAAAjGO1WpWenq6goCCZzVefn6XcSkpISFBwcLDRMQAAAHAdx48fV/Xq1a/6OuVWkoeHh6TCH5anp2eJby83N1dLly613e4RpY8xMB5jYDzGwHiMQdnAOBjvRsYgLS1NwcHBtt52NZRbyXYogqenZ6mVW1dXV3l6evIlMghjYDzGwHiMgfEYg7KBcTDeXxmD6x1CygllAAAAsBuUWwAAANgNyi0AAADsBuUWAAAAdoNyCwAAALtBuQUAAIDdoNwCAADAblBuAQAAYDcotwAAALAblFsAAADYDcotAAAA7AblFgAAAHaDcgsAAAC7QbkFAACA3aDcAgAAwG5QbgEAAGA3KLelLL/AqkWxiTqcZnQSAAAA+0O5LWUfrzioZ3/YqYXxFqOjAAAA2B3KbSm7u2WwHC0mHUo3adPR80bHAQAAsCuU21IW6OWi/s2qSZI+WX3Y4DQAAAD2hXJrgMc71JRJVq0+cFaxJ1KNjgMAAGA3KLcGqOHjquZ+VknSxysPGpwGAADAflBuDdK9WoEk6dfdiTqYnG5wGgAAAPtAuTVIVVepe31/Wa3SxysPGR0HAADALlBuDTSsUy1J0k/bE3T8XKbBaQAAAMo/yq2BGlXzUoc6fsovsOqTVczeAgAA/F2UW4NF3xomSZqz+YSS07IMTgMAAFC+UW4N1qaWj1qGeCsnv0CfreG6twAAAH8H5dZgJpPJNnv7zR/xOp+RY3AiAACA8otyWwZ0rldFDYI8lZmTr+nrjhodBwAAoNyi3JYB/z17O2PtEV3IzjM4EQAAQPlUZsrtG2+8IZPJpBEjRtiWZWVlKTo6Wr6+vnJ3d9eAAQOUlJRU5H3x8fHq3bu3XF1d5e/vr1GjRikvr/yVwx4NAlW7ipvSsvI0a8Mxo+MAAACUS2Wi3G7atEnTpk1T48aNiyx/7rnntGDBAs2ZM0erVq1SQkKC+vfvb3s9Pz9fvXv3Vk5OjtatW6eZM2dqxowZGjt2bGnvwt9mMZs0vHPh7O3na44oKzff4EQAAADlj+Hl9sKFC3rggQf02Wefydvb27Y8NTVVX3zxhd5991116dJFLVq00PTp07Vu3Tpt2LBBkrR06VLt2bNHs2bNUtOmTdWrVy9NnDhRU6ZMUU5O+Tsxq2/TIFWrXElnLmTrh83HjY4DAABQ7jgYHSA6Olq9e/dWt27d9Nprr9mWb9myRbm5uerWrZttWXh4uGrUqKH169erbdu2Wr9+vRo1aqSAgADbOj169NCwYcO0e/duNWvW7IrbzM7OVnZ2tu15WlqaJCk3N1e5ubnFvYuXubSNK23rsVtCNG7hPn2y8pDualZVjhbD///DLl1rDFA6GAPjMQbGYwzKBsbBeDcyBjc6PoaW2++++05bt27Vpk2bLnstMTFRTk5Oqly5cpHlAQEBSkxMtK3z38X20uuXXruaSZMmafz48ZctX7p0qVxdXf/qbty0mJiYy5a55UsejhYlpGbpta+XqI2/tdTyVERXGgOULsbAeIyB8RiDsoFxMN61xiAzM/OGPsOwcnv8+HE9++yziomJkYuLS6lue/To0Ro5cqTteVpamoKDgxUVFSVPT88S335ubq5iYmLUvXt3OTo6Xvb6ae8jmrzkgNaneGjs4PaymE0lnqmiud4YoOQxBsZjDIzHGJQNjIPxbmQMLv1N+/UYVm63bNmi5ORkNW/e3LYsPz9fq1ev1kcffaQlS5YoJydHKSkpRWZvk5KSFBgYKEkKDAzUxo0bi3zupaspXFrnSpydneXs7HzZckdHx1L9pb7a9ga3q61pq4/qyNlM/RZ3Rrc3Diq1TBVNaY85LscYGI8xMB5jUDYwDsa71hjc6NgYdkBn165dFRsbq+3bt9seLVu21AMPPGD7d0dHRy1btsz2nri4OMXHxysyMlKSFBkZqdjYWCUnJ9vWiYmJkaenpyIiIkp9n4qLu7ODhrarKUmasuKQrFYOTQAAALgRhs3cenh4qGHDhkWWubm5ydfX17b8kUce0ciRI+Xj4yNPT089/fTTioyMVNu2bSVJUVFRioiI0IMPPqjJkycrMTFRr7zyiqKjo684M1uePNS+pj5bc1h7T6VpRVyyuoQHXP9NAAAAFVyZPhX/3//+t26//XYNGDBAHTt2VGBgoObOnWt73WKxaOHChbJYLIqMjNSgQYM0ePBgTZgwwcDUxaOyq5MGtQ2RJH20/CCztwAAADfA8EuB/beVK1cWee7i4qIpU6ZoypQpV31PSEiIFi1aVMLJjPHoLbU0Y91RbY1P0YbD5xQZ6mt0JAAAgDKtTM/cVnT+ni4a2DJYkvTxyoMGpwEAACj7KLdl3OMda8tiNmnNgTPacTzF6DgAAABlGuW2jAv2cVW/ptUkSR+tYPYWAADgWii35cCwzqEymaSYPUmKS0w3Og4AAECZRbktB8L83dWrYeFNKTj2FgAA4Ooot+XE8M5hkqQFOxJ07GyGwWkAAADKJsptOdGwmpc616uiAqv0yapDRscBAAAokyi35chTtxbO3v645YQSU7MMTgMAAFD2UG7LkZY1fdS6lo9y8636dPVho+MAAACUOZTbcubS7O3sjcd09kK2wWkAAADKFsptOdOhjp8aV/dSVm6Bpq89anQcAACAMoVyW86YTCbblRNmrj+qtKxcgxMBAACUHZTbcigqIkB1/N2VnpWnr9cfMzoOAABAmUG5LYfMZpOG3xoqSfry9yO6mJNvcCIAAICygXJbTvVpHKQaPq46m5GjbzfGGx0HAACgTKDcllMOFrOe7FQ4e/vp6sPKzmP2FgAAgHJbjg1oUU0Bns5KTMvSdxuPGx0HAADAcJTbcszZwaKnu9SRJH24/KAysvMMTgQAAGAsym05N7BVsEJ8XXXmQramrz1idBwAAABDUW7LOUeLWSO715UkTVt9WCmZOQYnAgAAMA7l1g70aRyk8EAPpWflaeqqQ0bHAQAAMAzl1g6YzSa92LOeJGnG2qNKTM0yOBEAAIAxKLd24tZ6/moZ4q3svAJ9sPyA0XEAAAAMQbm1EyaTSS/2DJck/bDpuI6eyTA4EQAAQOmj3NqR1rV8dGu9KsorsOrdmP1GxwEAACh1lFs780KPwmNvf96RoN0JqQanAQAAKF2UWzvTIMhLfZoESZLeXhJncBoAAIDSRbm1Q893rysHs0kr4k5r09FzRscBAAAoNZRbO1TTz033tAqWJL25eJ+sVqvBiQAAAEoH5dZOPdOljpwdzNp87LxWxCUbHQcAAKBUUG7tVKCXi4a2qylJemvJfhUUMHsLAADsH+XWjg3rHCoPFwftPZWmBTsTjI4DAABQ4ii3dqyyq5Oe6FhbkvRuzH7l5hcYnAgAAKBkUW7t3EPta8nP3UnHzmbq+03HjY4DAABQoii3ds7N2UFP3RomSfpg2QFdzMk3OBEAAEDJodxWAPe1qaHq3pWUnJ6tmeuPGh0HAACgxFBuKwBnB4ue61ZXkjR15SGlXsw1OBEAAEDJoNxWEP2aVVPdAHelXszVp6sPGR0HAACgRFBuKwiL2aTno+pJkr78/aiS07MMTgQAAFD8KLcVSFREgJoGV9bF3HxNWX7Q6DgAAADFjnJbgZhMJr3Ys3D2dvbGeB0/l2lwIgAAgOJFua1g2oX6qUMdP+XmW/XvmP1GxwEAAChWlNsKaFSPwtnbedtPKi4x3eA0AAAAxYdyWwE1rl5ZtzUKlNUqvb00zug4AAAAxYZyW0GN7F5PZpMUsydJW46dNzoOAABAsaDcVlBh/u66q0V1SdJbS/bJarUanAgAAODvo9xWYM92qysnB7M2HD6nNQfOGB0HAADgb6PcVmDVKlfSg21DJElvLYlTQQGztwAAoHyj3FZwwzuHys3JotiTqVq8K9HoOAAAAH8L5baC83V31qMdakuS3omJU15+gcGJAAAAbh7lFnq0Qy35uDnp8OkM/WfrCaPjAAAA3DTKLeTh4qjhnUMlSe/9dkBZufkGJwIAALg5lFtIkga1DVGQl4tOpWbp6/XHjI4DAABwUyi3kCS5OFo0ontdSdIHyw7ozIVsgxMBAAD8dZRb2AxoXl0Nq3kqPTtPby/htrwAAKD8odzCxmI2aVyfBpKk7zcfV+yJVIMTAQAA/DWUWxTRsqaP+jUNktUqjVuwm9vyAgCAcoVyi8v8o1d9uTpZtOXYef20PcHoOAAAADeMcovLBHq5KPrWMEnSpMV7lZGdZ3AiAACAG0O5xRU9ckst1fBxVVJatqasOGh0HAAAgBtCucUVuTha9Erv+pKkz9cc0bGzGQYnAgAAuD7KLa6qe0SAOtTxU05+gV77Za/RcQAAAK6LcourMplMerVPhBzMJsXsSdLq/aeNjgQAAHBNlFtcU5i/hwZH1pQkTVi4R7n5BcYGAgAAuAbKLa7r2W515OvmpIPJF/TV+mNGxwEAALgqyi2uy6uSo0b1qCdJei9mv85cyDY4EQAAwJVRbnFD7m4ZrIbVPJWenae3l8QZHQcAAOCKKLe4IRazSeP6NJAkfb/5uGJPpBqcCAAA4HKUW9ywljV91K9pkKxWadyC3bJarUZHAgAAKIJyi7/kH73qy9XJoi3Hzuun7QlGxwEAACiCcou/JNDLRdG3hkmSJi3eq4zsPIMTAQAA/D/KLf6yR26ppRo+rkpKy9aUFQeNjgMAAGBDucVf5uJo0Su960uSPl9zRMfOZhicCAAAoBDlFjele0SAOtTxU05+gV77Za/RcQAAACRRbnGTTCaTXu0TIQezSTF7krR6/2mjIwEAABhbbqdOnarGjRvL09NTnp6eioyM1OLFi22vd+7cWSaTqcjjySefLPIZ8fHx6t27t1xdXeXv769Ro0YpL4+TnEpDmL+HBkfWlCRNWLhHufkFxgYCAAAVnqHltnr16nrjjTe0ZcsWbd68WV26dFHfvn21e/du2zqPPfaYTp06ZXtMnjzZ9lp+fr569+6tnJwcrVu3TjNnztSMGTM0duxYI3anQnq2Wx35ujnpYPIFfbX+mNFxAABABWdoue3Tp49uu+021alTR3Xr1tXrr78ud3d3bdiwwbaOq6urAgMDbQ9PT0/ba0uXLtWePXs0a9YsNW3aVL169dLEiRM1ZcoU5eTkGLFLFY5XJUeN6lFPkvRezH6duZBtcCIAAFCRORgd4JL8/HzNmTNHGRkZioyMtC3/5ptvNGvWLAUGBqpPnz4aM2aMXF1dJUnr169Xo0aNFBAQYFu/R48eGjZsmHbv3q1mzZpdcVvZ2dnKzv7/EpaWliZJys3NVW5ubknsXhGXtlEa2yoN/ZoE6usNR7U7IV2TF+/V6/0aGB3puuxtDMojxsB4jIHxGIOygXEw3o2MwY2Oj+HlNjY2VpGRkcrKypK7u7vmzZuniIgISdL999+vkJAQBQUFaefOnXrppZcUFxenuXPnSpISExOLFFtJtueJiYlX3eakSZM0fvz4y5YvXbrUVpxLQ0xMTKltq6R185Z2JzhozpYTqpFzTMHuRie6MfY0BuUVY2A8xsB4jEHZwDgY71pjkJmZeUOfYXi5rVevnrZv367U1FT9+OOPGjJkiFatWqWIiAg9/vjjtvUaNWqkqlWrqmvXrjp06JBCQ0NvepujR4/WyJEjbc/T0tIUHBysqKioIoc9lJTc3FzFxMSoe/fucnR0LPHtlZYjc2L1885TWp7qp+/ubiWTyWR0pKuy1zEoTxgD4zEGxmMMygbGwXg3MgaX/qb9egwvt05OTgoLK7yda4sWLbRp0ya9//77mjZt2mXrtmnTRpJ08OBBhYaGKjAwUBs3biyyTlJSkiQpMDDwqtt0dnaWs7PzZcsdHR1L9Ze6tLdX0l7uHaHf9iVra3yKFu0+rX7Nqhkd6brsbQzKI8bAeIyB8RiDsoFxMN61xuBGx6bMXee2oKCgyPGw/2379u2SpKpVq0qSIiMjFRsbq+TkZNs6MTEx8vT0tB3agNIT6OWi6FsL/0dl0uK9ysjmkmwAAKB0GVpuR48erdWrV+vo0aOKjY3V6NGjtXLlSj3wwAM6dOiQJk6cqC1btujo0aP6+eefNXjwYHXs2FGNGzeWJEVFRSkiIkIPPvigduzYoSVLluiVV15RdHT0FWdmUfIeuaWWavi4KiktW1NWHDQ6DgAAqGAMLbfJyckaPHiw6tWrp65du2rTpk1asmSJunfvLicnJ/3222+KiopSeHi4nn/+eQ0YMEALFiywvd9isWjhwoWyWCyKjIzUoEGDNHjwYE2YMMHAvarYXBwteqV3fUnS52uO6NjZDIMTAQCAisTQY26/+OKLq74WHBysVatWXfczQkJCtGjRouKMhb+pe0SAOtTx05oDZzRx4V59PqSl0ZEAAEAFUeaOuUX5ZzKZ9GqfCDmYTfptb5KW7r76ZdkAAACKE+UWJSLM30OPdawtSRrz0y6lZXFhbAAAUPIotygxz3ato5q+hSeXTf51n9FxAABABUC5RYlxcbToX/0bSZJmbYjXpqPnDE4EAADsHeUWJapdqJ/ubRUsSfrHf3YqKzff4EQAAMCeUW5R4kb3qq8qHs46dDpDH3PtWwAAUIIotyhxXq6OGn9HA0nS1FWHFJeYbnAiAABgryi3KBW9Ggaqe0SAcvOteuk/O5VfYDU6EgAAsEOUW5QKk8mkiX0bysPZQduPp+jr9UeNjgQAAOwQ5RalJtDLRS/1CpckTV4Sp5MpFw1OBAAA7A3lFqXq/tY11KqmtzJz8vXKvFhZrRyeAAAAig/lFqXKbDZpUv/GcrKYtSLutBbsPGV0JAAAYEcotyh1Yf7ueqpLmCRp/M+7dT4jx+BEAADAXlBuYYgnO4WqboC7zmbk6LVf9hodBwAA2AnKLQzh5GDWGwMay2SS/rP1hNYcOG10JAAAYAcotzBM8xreGhJZU5L08rxYZebkGRsIAACUe5RbGOqFHvUU5OWi4+cu6r3fDhgdBwAAlHOUWxjK3dlBr93ZUJL0+ZrDij2RanAiAABQnlFuYbgu4QG6o0mQCqzSS//Zqdz8AqMjAQCAcopyizJhbJ8IVXZ11J5Tafp8zRGj4wAAgHKKcosywc/dWWN6R0iS3vttv46eyTA4EQAAKI8otygz+jevpg51/JSdV6DRc7k1LwAA+OsotygzTCaTXu/XSC6OZq0/fFZzNp8wOhIAAChnKLcoU2r4uur57vUkSa/9skfJ6VkGJwIAAOUJ5RZlzkPta6pRNS+lZeVp/M97jI4DAADKEcotyhwHi1lvDGgki9mkX2JPKWZPktGRAABAOUG5RZnUIMhLj3WoLUkaM3+X0rNyDU4EAADKA8otyqwR3eooxNdViWlZmvxrnNFxAABAOUC5RZnl4mjRpP6NJElfbzimzUfPGZwIAACUdZRblGntQv00sGWwpMJb82bn5RucCAAAlGWUW5R5L99WX37uzjp0OkPv/XbA6DgAAKAMo9yizPNyddRr/RpKkj5ZdUh/HD5rcCIAAFBWUW5RLvRsGKi7W1SX1SqN/GGHUi9y9QQAAHA5yi3KjVfvaKAQX1edTLmosT/tMjoOAAAogyi3KDfcnR3074FNZTGb9NP2BP20/aTRkQAAQBlDuUW50ryGt57uEiZJemXeLp04n2lwIgAAUJZQblHuPHVrmJrXqKz07DyN/H6H8gusRkcCAABlBOUW5Y6Dxaz3BjaTu7ODNh49p09WHTI6EgAAKCMotyiXavi6atwdDSRJ/47Zr50nUowNBAAAygTKLcqtAc2rqXejqsorsGrEd9uVmZNndCQAAGAwyi3KLZPJpNfvbKhATxcdPpOhiQv3Gh0JAAAYjHKLcq2yq5PevaeJTCbp243xWro70ehIAADAQJRblHvtwvz0WIfakqR/zI1VcnqWwYkAAIBRKLewC89H1VVEVU+dy8jRqDk7ZbVyeTAAACoiyi3sgrODRe/f21TODmat2n9aM9cdNToSAAAwAOUWdqNOgIdevq2+JOlfi/dpf1K6wYkAAEBpo9zCrgyODFHnelWUk1egZ77dpuy8fKMjAQCAUkS5hV0xmUyafFdj+bg5aV9iut5eEmd0JAAAUIoot7A7/h4umjygsSTpszVHtPbgGYMTAQCA0kK5hV3qFhGg+9vUkCQ9/8MOnc/IMTgRAAAoDZRb2K1XetdXbT83JaZl6eV5sVweDACACoByC7vl6uSg9+9tJgezSYt3JWrOlhNGRwIAACWMcgu71qi6l57rXleSNP7n3Tp2NsPgRAAAoCRRbmH3nuwUqta1fJSRk68R329XXn6B0ZEAAEAJodzC7lnMJr17TxN5uDhoW3yKPlx+0OhIAACghFBuUSFU93bVa/0aSpI+XH5A2+JTjA0EAABKBOUWFUbfptXUt2mQCqzS8z/GKoublwEAYHcot6hQJvRtqGqVK+n4+Yv68Qi//gAA2Bv+644KxauSo/49sKnMJmnTabN+2MzlwQAAsCeUW1Q4rWv5aETXMEnS+F/2KfZEqsGJAABAcaHcokJ6okMtNfQuUE5egZ6ctYXb8wIAYCcot6iQzGaTHggrUA2fSjqZclHPfr9d+QXcnhcAgPKOcosKy9VBmnJfU7k4mrV6/2m9v+yA0ZEAAMDfRLlFhRYe6KF/3dlIkvTBsgNavi/J4EQAAODvoNyiwuvfvLoebBsiSRrx3XbFn800OBEAALhZlFtA0iu311fT4MpKy8rTk7O2KCuXOzwAAFAeUW4BSc4OFk0d1Fy+bk7acypNr8zfJauVE8wAAChvKLfAn6p6VdKH9zWT2ST9uOWEvt143OhIAADgL6LcAv+lXZifXuhRT5I07ufd2nE8xdhAAADgL6HcAv9jWKdQRUUEKCe/QMNmbdE5bvAAAEC5QbkF/ofJZNLb9zRRLT83JaRm6dnvtnGDBwAAygnKLXAFni6OmjqouSo5WrTmwBm999t+oyMBAIAbQLkFriI80FNvDCi8wcOHyw/qtz3c4AEAgLKOcgtcQ9+m1TS0XU1J0nM/bNfRMxnGBgIAANdEuQWu4+Xb6qt5jcpK//MGDxdzuMEDAABllaHldurUqWrcuLE8PT3l6empyMhILV682PZ6VlaWoqOj5evrK3d3dw0YMEBJSUX/ajg+Pl69e/eWq6ur/P39NWrUKOXl5ZX2rsCOOTmY9fEDLeTn7qR9ien65/xYbvAAAEAZZWi5rV69ut544w1t2bJFmzdvVpcuXdS3b1/t3r1bkvTcc89pwYIFmjNnjlatWqWEhAT179/f9v78/Hz17t1bOTk5WrdunWbOnKkZM2Zo7NixRu0S7FSgl4s+vK+5LGaT5m49qW/+iDc6EgAAuAJDy22fPn102223qU6dOqpbt65ef/11ubu7a8OGDUpNTdUXX3yhd999V126dFGLFi00ffp0rVu3Ths2bJAkLV26VHv27NGsWbPUtGlT9erVSxMnTtSUKVOUk8O1SVG8IkN99eKfN3gYv2C3tsWfNzgRAAD4Xw5GB7gkPz9fc+bMUUZGhiIjI7Vlyxbl5uaqW7dutnXCw8NVo0YNrV+/Xm3bttX69evVqFEjBQQE2Nbp0aOHhg0bpt27d6tZs2ZX3FZ2drays7Ntz9PS0iRJubm5ys3NLaE9/H+XtlEa28KV3ewYPBQZrC3HzmnpnmQNm7VF84dHytfNqSQi2j2+B8ZjDIzHGJQNjIPxbmQMbnR8DC+3sbGxioyMVFZWltzd3TVv3jxFRERo+/btcnJyUuXKlYusHxAQoMTERElSYmJikWJ76fVLr13NpEmTNH78+MuWL126VK6urn9zj25cTExMqW0LV3YzY9DVTdruYlFiWrYGT12uYfULZDaVQLgKgu+B8RgD4zEGZQPjYLxrjUFmZuYNfYbh5bZevXravn27UlNT9eOPP2rIkCFatWpViW5z9OjRGjlypO15WlqagoODFRUVJU9PzxLdtlT4fx4xMTHq3r27HB0dS3x7uNzfHYNGbS5owLQN2p8qxTmF6vnudUogpX3je2A8xsB4jEHZwDgY70bG4NLftF+P4eXWyclJYWFhkqQWLVpo06ZNev/99zVw4EDl5OQoJSWlyOxtUlKSAgMDJUmBgYHauHFjkc+7dDWFS+tcibOzs5ydnS9b7ujoWKq/1KW9PVzuZscgorq33ryriZ75dps+WX1EzUN8FNXg6r9zuDq+B8ZjDIzHGJQNjIPxrjUGNzo2Ze46twUFBcrOzlaLFi3k6OioZcuW2V6Li4tTfHy8IiMjJUmRkZGKjY1VcnKybZ2YmBh5enoqIiKi1LOjYrmjSZAeal9TkvT8Dzt06PQFYwMBAABjZ25Hjx6tXr16qUaNGkpPT9fs2bO1cuVKLVmyRF5eXnrkkUc0cuRI+fj4yNPTU08//bQiIyPVtm1bSVJUVJQiIiL04IMPavLkyUpMTNQrr7yi6OjoK87MAsXt5dvqa9fJVG06el6PzNikecPby5sTzAAAMIyhM7fJyckaPHiw6tWrp65du2rTpk1asmSJunfvLkn697//rdtvv10DBgxQx44dFRgYqLlz59reb7FYtHDhQlksFkVGRmrQoEEaPHiwJkyYYNQuoYJxtBTe4KFa5Uo6ejZTT87aopy8AqNjAQBQYRk6c/vFF19c83UXFxdNmTJFU6ZMueo6ISEhWrRoUXFHA25YFQ9nfTm0lQZMXac/jpzT6LmxevvuxjKZuIQCAAClrcwdcwuUR/UCPTTlgcI7mP1n6wl9vPKQ0ZEAAKiQKLdAMelUt4rG9Sk8kfGtJXFaFHvK4EQAAFQ8N1Vujx8/rhMnTtieb9y4USNGjNCnn35abMGA8ujByJoa2q6mJOm577dr+/EUQ/MAAFDR3FS5vf/++7VixQpJhXcC6969uzZu3Kh//vOfnMyFCm/M7RHqEu6v7LwCPTpzs06mXDQ6EgAAFcZNldtdu3apdevWkqQffvhBDRs21Lp16/TNN99oxowZxZkPKHcsZpM+uK+ZwgM9dOZCth6ZsUkXsvOMjgUAQIVwU+U2NzfXdh3Z3377TXfccYckKTw8XKdOcZwh4O7soC+GtpKfu7P2Jabr6dlblZfPJcIAAChpN1VuGzRooE8++URr1qxRTEyMevbsKUlKSEiQr69vsQYEyqtqlSvpiyEt5eJo1oq403rtl71GRwIAwO7dVLl98803NW3aNHXu3Fn33XefmjRpIkn6+eefbYcrAJCaBFfWu/c0lSTNWHdUX68/amgeAADs3U3dxKFz5846c+aM0tLS5O3tbVv++OOPy9XVtdjCAfbgtkZVNapHPb21JE7jFuxRsI+rOtfzNzoWAAB26aZmbi9evKjs7GxbsT127Jjee+89xcXFyd+f/2gD/2t451Dd1aK68gusemr2NsUlphsdCQAAu3RT5bZv37766quvJEkpKSlq06aN3nnnHfXr109Tp04t1oCAPTCZTPrXnY3UppaPLmTn6eEZm3Q6PdvoWAAA2J2bKrdbt25Vhw4dJEk//vijAgICdOzYMX311Vf64IMPijUgYC+cHMz6ZFAL1fJz08mUi3r8683Kys03OhYAAHblpsptZmamPDw8JElLly5V//79ZTab1bZtWx07dqxYAwL2xNvNSV8MaSmvSo7aFp+iF+bsUEGB1ehYAADYjZsqt2FhYZo/f76OHz+uJUuWKCoqSpKUnJwsT0/PYg0I2JvaVdz1yaAWcrSYtHDnKb33236jIwEAYDduqtyOHTtWL7zwgmrWrKnWrVsrMjJSUuEsbrNmzYo1IGCPIkN99fqdjSRJHyw/qHnbThicCAAA+3BTlwK76667dMstt+jUqVO2a9xKUteuXXXnnXcWWzjAnt3TMlhHzmRo6spDeunHWFX3dlWrmj5GxwIAoFy7qZlbSQoMDFSzZs2UkJCgEycKZ51at26t8PDwYgsH2LtRUfXUq2GgcvIL9PhXm3XsbIbRkQAAKNduqtwWFBRowoQJ8vLyUkhIiEJCQlS5cmVNnDhRBQUFxZ0RsFtms0nv3tNUjat76Xxmrh6esUmpF3ONjgUAQLl1U+X2n//8pz766CO98cYb2rZtm7Zt26Z//etf+vDDDzVmzJjizgjYtUpOFn0+uKWqerno0OkMDf9mi3Lz+Z9EAABuxk2V25kzZ+rzzz/XsGHD1LhxYzVu3FjDhw/XZ599phkzZhRzRMD++Xu66IshreTmZNHag2f18txYWa1cIgwAgL/qpsrtuXPnrnhsbXh4uM6dO/e3QwEVUUSQpz64r5nMJmnOlhOauHAvBRcAgL/opsptkyZN9NFHH122/KOPPlLjxo3/diigoupaP0BvDij8Dn259oje++2AwYkAAChfbupSYJMnT1bv3r3122+/2a5xu379eh0/flyLFi0q1oBARXN3y2BlZOdp3II9en/ZAbk7O+ixjrWNjgUAQLlwUzO3nTp10v79+3XnnXcqJSVFKSkp6t+/v3bv3q2vv/66uDMCFc7Q9rU0qkc9SdLri/Zq9h/xBicCAKB8uKmZW0kKCgrS66+/XmTZjh079MUXX+jTTz/928GAii761jBdyM7T1JWH9M/5sXJztqhv02pGxwIAoEy76Zs4ACh5L/aopwfbhshqlUb+sENLdycaHQkAgDKNcguUYSaTSePvaKD+zaopv8Cqp2Zv0+8HzhgdCwCAMotyC5RxZrNJk+9qrB4NApSTX6DHvtqsLce45B4AAFfyl4657d+//zVfT0lJ+TtZAFyFg8WsD+5rpse+2qLV+09r6PRN+vaxtmpYzcvoaAAAlCl/aebWy8vrmo+QkBANHjy4pLICFZqzg0XTBrVQ65o+Ss/K0+AvN+pgcrrRsQAAKFP+0szt9OnTSyoHgBtQycmiz4e21AOf/aHYk6ka9PlGzXkyUsE+rkZHAwCgTOCYW6Cc8XRx1MyHW6uOv7sS07L0wOd/KCkty+hYAACUCZRboBzycXPSrEfbqIaPq+LPZeqBz//QuYwco2MBAGA4yi1QTgV4uuibR9so0NNFB5MvaPCXfygtK9foWAAAGIpyC5RjwT6umvVoG/m6OWnXyTQ9PH2TMnPyjI4FAIBhKLdAORfm766vHmktDxcHbT52Xk98vUXZeflGxwIAwBCUW8AONAjy0oyHWsvVyaI1B87omW+3KS+/wOhYAACUOsotYCdahHjrs8Et5WQxa8nuJI36cacKCqxGxwIAoFRRbgE70j7MT1MeaC6L2aR5205q7M+7ZLVScAEAFQflFrAz3SMC9O49TWQySbM2xOuNX/dRcAEAFQblFrBDfZtW07/ubCRJmrbqsN78NY6CCwCoECi3gJ26r3UNjb09QpL0yapDGr9gDwUXAGD3KLeAHXv4llp6rV9DSdKMdUf18rxYTjIDANg1yi1g5wa1DdFbdzWW2SR9u/G4Xpizg8uEAQDsFuUWqADubhms9+5tJovZpLnbTurZ77Yrl4ILALBDlFuggrijSZCm3N9cjhaTfok9pWGztnInMwCA3aHcAhVIz4aB+nRwSzk7mPXb3iQ9OnOzLuZQcAEA9oNyC1Qwt9bz1/ShrVTJsfBWvQ/N2KiM7DyjYwEAUCwot0AF1C7MT18/0lruzg7acPicHvziD6Vl5RodCwCAv41yC1RQLWv66JtH28irkqO2xqfogc/+0PmMHKNjAQDwt1BugQqsSXBlfftYW/m4OSn2ZKru+2yDzlzINjoWAAA3jXILVHARQZ76/vG2quLhrH2J6Ro4bb0SU7OMjgUAwE2h3AJQnQAP/fBEpIK8XHTodIYGfrpeJ85nGh0LAIC/jHILQJJUy89N3z8RqWCfSjp2NlMDp23Q0TMZRscCAOAvodwCsAn2cdWcJ9qptp+bTqZc1D3T1utgcrrRsQAAuGGUWwBFBHq56PsnIlUvwEPJ6dkaOG2D9p5KMzoWAAA3hHIL4DJVPJz17eNt1SDIU2czcnTfZxu080SK0bEAALguyi2AK/Jxc9Lsx9qqWY3KSsnM1QOf/aEtx84ZHQsAgGui3AK4Kq9Kjvr6kTZqXctH6dl5evCLjVp38IzRsQAAuCrKLYBrcnd20MyHWqtDHT9l5uRr6PRNWrAjwehYAABcEeUWwHVVcrLos8EtdVujQOXkF+jpb7fp8zWHjY4FAMBlKLcAboiLo0Uf3tdcQ9vVlCS99stevbZwjwoKrMYGAwDgv1BuAdwwi9mkV/tEaHSvcEnS578f0bPfb1d2Xr7ByQAAKES5BfCXmEwmPdEpVP8e2EQOZpMW7EjQkC83Ki0r1+hoAABQbgHcnDubVdf0h1rJzcmiDYfP6Z5P1isxNcvoWACACo5yC+CmdahTRd8/EakqHs7al5iu/h+v1YEkbtcLADAO5RbA39KwmpfmDmun2lXclJCapQFT12nTUW72AAAwBuUWwN8W7OOq/zzZTs1rVFZaVp4e+PwP/brrlNGxAAAVEOUWQLHwdnPSN4+2Vbf6AcrJK9Cwb7Zq5rqjRscCAFQwlFsAxaaSk0WfDGqu+9vUkNUqvfrzbr356z5ZrVwLFwBQOii3AIqVg8Ws1/s11PPd60qSpq48pOd/2KGcvAKDkwEAKgLKLYBiZzKZ9HTXOpp8V2NZzCbN3XZSj8zcpAvZeUZHAwDYOcotgBJzT8tgfT6kpSo5WrTmwBkNnLZeyelcCxcAUHIotwBK1K31/PXd423l6+ak3Qlp6v/xOh0+fcHoWAAAO0W5BVDimgRX1tzh7RTi66oT5y9qwNR12nY8xehYAAA7ZGi5nTRpklq1aiUPDw/5+/urX79+iouLK7JO586dZTKZijyefPLJIuvEx8erd+/ecnV1lb+/v0aNGqW8PI7tA8qSEF83/WdYOzWu7qXzmbkaPH2zYs+ZjI4FALAzhpbbVatWKTo6Whs2bFBMTIxyc3MVFRWljIyMIus99thjOnXqlO0xefJk22v5+fnq3bu3cnJytG7dOs2cOVMzZszQ2LFjS3t3AFyHn7uzvn2srTrXq6Ks3AJ9EWfWtNVHuFQYAKDYOBi58V9//bXI8xkzZsjf319btmxRx44dbctdXV0VGBh4xc9YunSp9uzZo99++00BAQFq2rSpJk6cqJdeeknjxo2Tk5NTie4DgL/GzdlBnw1uqVfmxer7zSf0dswB7UlM11t3NZGbs6F/JAEA7ECZ+i9JamqqJMnHx6fI8m+++UazZs1SYGCg+vTpozFjxsjV1VWStH79ejVq1EgBAQG29Xv06KFhw4Zp9+7datas2WXbyc7OVnZ2tu15WlqaJCk3N1e5ubnFvl//69I2SmNbuDLGwHiv3lZHOndM/znqoEWxiTqQlK6P72+qmr5uRkerMPgeGI8xKBsYB+PdyBjc6PiYrGXk7wMLCgp0xx13KCUlRb///rtt+aeffqqQkBAFBQVp586deumll9S6dWvNnTtXkvT444/r2LFjWrJkie09mZmZcnNz06JFi9SrV6/LtjVu3DiNHz/+suWzZ8+2lWYApeNIuvRlnEVpuSZVslj1YJ0CNfAuE38sAQDKkMzMTN1///1KTU2Vp6fnVdcrMzO30dHR2rVrV5FiKxWW10saNWqkqlWrqmvXrjp06JBCQ0NvalujR4/WyJEjbc/T0tIUHBysqKioa/6wiktubq5iYmLUvXt3OTo6lvj2cDnGwHiXxuDx/t11d1aBnv5uh7bGp+izOIueuTVUwzvVltnMCWclie+B8RiDsoFxMN6NjMGlv2m/njJRbp966iktXLhQq1evVvXq1a+5bps2bSRJBw8eVGhoqAIDA7Vx48Yi6yQlJUnSVY/TdXZ2lrOz82XLHR0dS/WXurS3h8sxBsZzdHRUNVdHffd4pCYs3K1ZG+L1/vJD2pN4Qe/e00QeLoxPSeN7YDzGoGxgHIx3rTG40bEx9GoJVqtVTz31lObNm6fly5erVq1a133P9u3bJUlVq1aVJEVGRio2NlbJycm2dWJiYuTp6amIiIgSyQ2g+Dk5mPVav0Z6c0AjOVnMitmTpH5T1upgMjd8AADcOEPLbXR0tGbNmqXZs2fLw8NDiYmJSkxM1MWLFyVJhw4d0sSJE7VlyxYdPXpUP//8swYPHqyOHTuqcePGkqSoqChFRETowQcf1I4dO7RkyRK98sorio6OvuLsLICybWCrGvrhyUgFerro0OkM9ZuyVkt3JxodCwBQThhabqdOnarU1FR17txZVatWtT2+//57SZKTk5N+++03RUVFKTw8XM8//7wGDBigBQsW2D7DYrFo4cKFslgsioyM1KBBgzR48GBNmDDBqN0C8Dc1Da6sBU/fota1fHQhO0+Pf71F78bsV0EBJ5oBAK7N0GNur3ehhuDgYK1ateq6nxMSEqJFixYVVywAZUAVD2d982gbvf7LXs1Yd1QfLDug3SdT9e7ApvKqxDFxAIArM3TmFgCuxdFi1rg7Guidu5vI2cGsZfuS1W/KWh1ISjc6GgCgjKLcAijzBrSorh+fbKdqlSvpyJnC43AXx54yOhYAoAyi3AIoFxpV99LPT7VXZG1fZeTka9g3WzX5133K5zhcAMB/odwCKDd83Z319SOt9egthZcN/HjlIT00Y5NSMnMMTgYAKCsotwDKFQeLWa/cHqH3720qF0ezVu8/rTs+Wqu9p27szjUAAPtGuQVQLvVtWk1zh7VXde9Kij+Xqf4fr9O8bSeMjgUAMBjlFkC5FRHkqQVP3aIOdfx0MTdfz32/QyN/2K6M7DyjowEADEK5BVCuebs5acZDrfVct7oym6S5W0/q9g9/166TqUZHAwAYgHILoNyzmE16tlsdffd4pKp6uejImQz1/3idvvz9yHVvFgMAsC+UWwB2o3UtHy1+toOiIgKUk1+gCQv36NGZm3Uug6spAEBFQbkFYFcquzpp2oMtNKFvAzn9eVezXu+v1vpDZ42OBgAoBZRbAHbHZDJpcGRNzR/eXqFV3JSUlq37P9+gd5fGKS+/wOh4AIASRLkFYLcigjy14OlbNLBlsKxW6YPlB3XfZxt0MuWi0dEAACWEcgvArrk6OejNuxrrg/uayd3ZQZuOntdt76/Rr7sSjY4GACgBlFsAFcIdTYK06JkOahJcWakXc/XkrC0aM3+XsnLzjY4GAChGlFsAFUYNX1fNeSJST3SqLUn6esMx9ZuyVgeT0w1OBgAoLpRbABWKk4NZo3vV18yHW8vP3Un7EtN1+4e/67uN8VwTFwDsAOUWQIXUqW4VLXq2gzrU8VNWboH+MTdWT3+7TWlZuUZHAwD8DZRbABWWv4eLZj7UWv/oFS4Hs0kLd55S7w/WaFv8eaOjAQBuEuUWQIVmNpv0ZKdQzXkyUtW9K+n4uYu6+5P1mrLioPILOEwBAMobyi0ASGpWw1uLnu2g2xtXVV6BVW8tidNdn6zTodMXjI4GAPgLKLcA8CdPF0d9eF8zvXVXY3k4O2hbfIpue3+NPl9zWAXM4gJAuUC5BYD/YjKZdHfLYC15rqM61PFTdl6BXvtlr+79dIOOnc0wOh4A4DootwBwBUGVK+mrh1vrX3c2kpuTRRuPnlPP99boq/VHmcUFgDKMcgsAV2EymXR/mxr6dURHRdb21cXcfI39abcGffGHjp/LNDoeAOAKKLcAcB3BPq765tE2Gn9HA1VytGjdobPq+d5qzf6DGz8AQFlDuQWAG2A2mzSkXU0tfraDWtX0VkZOvl6eF6vBX25UQspFo+MBAP5EuQWAv6Cmn5u+ezxSr/SuL2cHs9YcOKMe/16tOZuPM4sLAGUA5RYA/iKL2aRHO9TWomc7qGlwZaVn52nUjzv16MzNSkrLMjoeAFRolFsAuEmhVdz145OReqlnuJwsZi3bl6yof6/W/G0nmcUFAINQbgHgb3CwmDWsc6gWPnOLGlXzUurFXI34fruenLVFp9OzjY4HABUO5RYAikHdAA/NHd5Oz3evK0eLSUt2Jynq36u0cGeC0dEAoEKh3AJAMXG0mPV01zr6KfoW1a/qqfOZuXpq9jZFz97KLC4AlBLKLQAUs4ggT/0U3V7PdAmTxWzSLztPqcs7K/XV+qPK5+5mAFCiKLcAUAKcHMwaGVVP84e3V8NqnkrPytPYn3ar75TftS3+vNHxAMBuUW4BoAQ1qu6ln6Jv0cS+DeTh4qBdJ9PUf+o6jZ4bq/MZOUbHAwC7Q7kFgBJmMZv0YGRNLX++s/o3ryarVfp2Y7y6vLNS32+KVwGHKgBAsaHcAkApqeLhrHfvaaofnohUvQAPnc/M1Uv/idVdn6zT7oRUo+MBgF2g3AJAKWtdy0cLn7lF/7ytvtycLNoan6I+H/6ucT/vVlpWrtHxAKBco9wCgAEcLWY91rG2lj3fWb0bV1WBVZqx7qi6vrOKO5wBwN9AuQUAAwV6uWjK/c319SOtVdvPTafTszXi++2677MNOpCUbnQ8ACh3KLcAUAZ0qFNFi0d00Kge9eTiaNaGw+fU6/01mrR4rzKy84yOBwDlBuUWAMoIZweLom8NU8xzndStfoDyCqyatuqwur27SotjT3GoAgDcAMotAJQxwT6u+nxIS30+uKWqe1fSqdQsDftmq4ZO36SjZzKMjgcAZRrlFgDKqG4RAYp5rpOe7hImJ4tZq/afVtR7q/X2kjgOVQCAq6DcAkAZVsnJouej6unXER3UoY6fcvIK9NGKg+r89kp9tzFe+dwAAgCKoNwCQDlQu4q7vnq4tT4Z1Fwhvq46nZ6tf8yNVe8P1mjNgdNGxwOAMoNyCwDlhMlkUs+GVRXzXCe90ru+PF0ctC8xXQ9+sVEPTd/IpcMAQJRbACh3nBzMerRDba0adaseal9TDmaTVsSdVs/31+iV+bE6cyHb6IgAYBjKLQCUU95uTnq1TwMtfa6joiIClF9g1awN8br1rZWauvKQsnLzjY4IAKWOcgsA5VztKu76dHBLffd4WzWs5qn07Dy9+es+dX1nlX7ekcD1cQFUKJRbALATbWv76ufoW/TO3U0U6OmikykX9cy329R/6jptOXbe6HgAUCootwBgR8xmkwa0qK4VL3TWyO515epk0bb4FA2Yuk7Rs7fq+LlMoyMCQImi3AKAHarkZNEzXeto5QudNbBlsEwm6Zedp9T1nVWatGiv0rJyjY4IACWCcgsAdszf00Vv3tVYvzzdQe3DfJWTX6Bpqw+r81srNeuPeOUXGJ0QAIoX5RYAKoCIIE/NeqSNvhzaUqFV3HQuI0fjF+7TGzssWhSbqALudAbATlBuAaCCMJlM6hIeoF9HdNTEvg3k7eqo5CyTnv1hp3q9v0a/7jpFyQVQ7lFuAaCCcbSY9WBkTS177hb1rJ4vd2cHxSWl68lZW3X7h79r6e5ELh8GoNyi3AJABeXh4qhewVatfL6DnukSJndnB+05labHv96iOz5aq+X7kii5AModyi0AVHBelRw1Mqqe1rx4q4Z3DpWrk0WxJ1P18IzN6vfxOq3af5qSC6DcoNwCACQV3s73xZ7hWvPirXqiU21VcrRox/EUDflyo+76ZL1+P3CGkgugzKPcAgCK8HV31uhe9bX6xVv16C215Oxg1pZj5zXoiz80cNoGrT901uiIAHBVlFsAwBVV8XDWK7dHaM2Lt2pou5pycjBr49Fzuu+zDbrv0w3aeOSc0REB4DKUWwDANfl7umjcHQ20alRnPdg2RI4Wk9YfPqt7pq3XoM//0JZj542OCAA2lFsAwA2p6lVJE/s11MpRt+r+NjXkYDbp94NnNGDqOg35cqO2H08xOiIAUG4BAH9NtcqV9K87G2nFC501sGWwLGaTVu0/rX5T1mrwlxu17hAnngEwDuUWAHBTgn1c9eZdjbX8+U66q0V1mU3S6v2ndf9nf6jflLVaFHtK+dzxDEApo9wCAP6WEF83vX13E6184VY92DZEzg5m7TiRquHfbFWXd1Zq1oZjysrNNzomgAqCcgsAKBY1fF01sV9DrftHFz3TtY4quzrq2NlMvTJ/l255c7k+Wn5AqZm5RscEYOcotwCAYuXr7qyR3etq7Utd9GqfCFWrXElnLuTo7aX7FfnGMk1cuEcJKReNjgnATlFuAQAlws3ZQQ+1r6WVozrrvYFNFR7oocycfH3x+xF1nLxCI3/YrrjEdKNjArAzDkYHAADYN0eLWf2aVVPfpkFatf+0pq06rPWHz2ru1pOau/WkuoT764mOtdW6lo9MJpPRcQGUc5RbAECpMJlM6lzPX53r+WvH8RRNW31Ii3clavm+ZC3fl6ymwZX1ZKdQRUUEyGym5AK4OZRbAECpaxJcWR8/0EJHzmToszWH9eOWE9p+PEVPztqi2n5uerxjbfVrVk0ujhajowIoZzjmFgBgmFp+bvrXnY209qUuir41VJ4uDjp8JkP/mBurW95coXeXxikxNcvomADKEcotAMBwVTycNapHuNaN7qpXetdXVS8XnbmQrQ+WH1T7N5cr+put+uPwWe58BuC6DC23kyZNUqtWreTh4SF/f3/169dPcXFxRdbJyspSdHS0fH195e7urgEDBigpKanIOvHx8erdu7dcXV3l7++vUaNGKS8vrzR3BQBQDNydHfRoh9pa/eKt+uj+Zmpd00f5BVb9EntKAz/doF7vr9HsP+KVmcOf8QCuzNByu2rVKkVHR2vDhg2KiYlRbm6uoqKilJGRYVvnueee04IFCzRnzhytWrVKCQkJ6t+/v+31/Px89e7dWzk5OVq3bp1mzpypGTNmaOzYsUbsEgCgGDhazLq9cZB+eDJSi5/toPta15CLo1n7EtP18rxYtflX4fVyj57JuP6HAahQDD2h7Ndffy3yfMaMGfL399eWLVvUsWNHpaam6osvvtDs2bPVpUsXSdL06dNVv359bdiwQW3bttXSpUu1Z88e/fbbbwoICFDTpk01ceJEvfTSSxo3bpycnJwu2252drays7Ntz9PS0iRJubm5ys0t+bvnXNpGaWwLV8YYGI8xMF55GYMwv0qa0Cdcz3cL1X+2ntSsP47r+PmL+uL3I/ri9yPqVMdPg9oGq2OYX7m7ykJ5GQN7xzgY70bG4EbHx2QtQwcwHTx4UHXq1FFsbKwaNmyo5cuXq2vXrjp//rwqV65sWy8kJEQjRozQc889p7Fjx+rnn3/W9u3bba8fOXJEtWvX1tatW9WsWbPLtjNu3DiNHz/+suWzZ8+Wq6trSewaAKCYFFilvSkmrUk0aW/K//8FpJ+zVbcEFqiNv1WuXAsIsDuZmZm6//77lZqaKk9Pz6uuV2a+/gUFBRoxYoTat2+vhg0bSpISExPl5ORUpNhKUkBAgBITE23rBAQEXPb6pdeuZPTo0Ro5cqTteVpamoKDgxUVFXXNH1Zxyc3NVUxMjLp37y5HR8cS3x4uxxgYjzEwXnkeg9sljZJ07Gymvtl4XD9uPakzWXmaf8yiXxPM6tukqh5oXUP1q3oYHfWayvMY2BPGwXg3MgaX/qb9espMuY2OjtauXbv0+++/l/i2nJ2d5ezsfNlyR0fHUv2lLu3t4XKMgfEYA+OV5zEIC/TSq3d4aVTPcM3flqCv1h/VvsR0fb/5pL7ffFKta/pocLsQ9WgQKEdL2b1AUHkeA3vCOBjvWmNwo2NTJsrtU089pYULF2r16tWqXr26bXlgYKBycnKUkpJSZPY2KSlJgYGBtnU2btxY5PMuXU3h0joAAPvm6uSg+9vU0H2tg7XxyDl9teGYft2VqI1Hz2nj0XPy93DWwFbBuqdlsIJ9OPwMsGeG/m+s1WrVU089pXnz5mn58uWqVatWkddbtGghR0dHLVu2zLYsLi5O8fHxioyMlCRFRkYqNjZWycnJtnViYmLk6empiIiI0tkRAECZYDKZ1Ka2r6bc31xrX+qiZ7rWkZ+7s5LTs/Xh8oPqMHmFHvh8g37aflJZuflGxwVQAgyduY2Ojtbs2bP1008/ycPDw3aMrJeXlypVqiQvLy898sgjGjlypHx8fOTp6amnn35akZGRatu2rSQpKipKERERevDBBzV58mQlJibqlVdeUXR09BUPPQAAVAyBXi4a2b2unro1TEt2J+qHzce15sAZrT14VmsPnpVXJUfd2aya7mkZrIigkj/fAkDpMLTcTp06VZLUuXPnIsunT5+uoUOHSpL+/e9/y2w2a8CAAcrOzlaPHj308ccf29a1WCxauHChhg0bpsjISLm5uWnIkCGaMGFCae0GAKAMc3Iwq0+TIPVpEqTj5zI1Z8sJ/bj5uBJSszRj3VHNWHdUjat76Z6WwbqjaZA8XTjmEijPDC23N3IVMhcXF02ZMkVTpky56johISFatGhRcUYDANihYB9XjexeV892raM1B07rh83HFbMnSTtPpGrniVS99sse3daoqu5tVUOtanrLZCpf180FUEZOKAMAoDRZzCZ1ruevzvX8dfZCtuZtO6nvNx3XgeQLmrv1pOZuPanafm66p1Ww+jevJn8PF6MjA7hBlFsAQIXm6+6sRzvU1iO31NLW+BT9sOm4FuxM0OEzGXpj8T69tSROXcL9dW+rYHWqW0UOZfiSYgAotwAASCq80kKLEG+1CPHWmD4R+mVngr7fdFxb41MUsydJMXuSFODprLtaVNfdLYJV08/N6MgAroByCwDA/3B3dtDAVjU0sFUNHUhK1/ebjmvutpNKSsvWlBWHNGXFITUJrqy+TYJ0e5OqHLYAlCGUWwAArqFOgIdeuT1CL/YM1297k/T9puNac+C0dhxP0Y7jKXrtlz1qH+anO5oEqUfDQK62ABiMcgsAwA1wcjDrtkZVdVujqjqdnq1fdibopx0J2hafojUHzmjNgTP65/xd6hrur75Ng9S5nr9cHC1GxwYqHMotAAB/URUPZw1tX0tD29fSsbMZ+nl7guZvP6lDpzO0eFeiFu9KlIeLg3o1DFTfptXUtravLGYuKwaUBsotAAB/Q4ivm57uWkdPdQnTnlNp+nl7gn7ekaBTqVn6YfMJ/bD5hPw9nNWnSZD6Ng1So2peXD8XKEGUWwAAioHJZFKDIC81CPLSSz3DtenoOc3fnqBFsaeUnJ6tL34/oi9+P6Jafm6648+iG1yZ28QDxY1yCwBAMTObTWpT21dtavtq/B0NtHr/af20I0ExexJ15EyG3l92QO8vO6BG1TwV6mBSk5SLqlmFE9GA4kC5BQCgBDk5mNUtIkDdIgKUkZ2npXsS9dP2BK05cEaxJ9MUK4vmv7NGDat5qmeDQPVoEKgwf3cOXQBuEuUWAIBS4ubsoDubVdedzarr7IVsLdh+Ql+v2qsjF0zadTJNu06m6e2l+1Xbz01RDQLVs2GgGlfzkpmT0YAbRrkFAMAAvu7OeqBNDXmf3aU2Hbtq5YFzWrI7UWsPntXhMxn6ZNUhfbLqkAI9XRTVIEA9GwSqdS0fbv8LXAflFgAAg/m6O+ve1jV0b+saSs/K1Yq401qyO1Er9yUrMS1LX60/pq/WH1NlV0d1DQ9Qz4aB6lDHj+voAldAuQUAoAzxcHHUHU2CdEeTIGXl5mvdoTP6dVeiftubrHMZOfrP1hP6z9YTcnWyqFPdKurRIFC3hvvLqxInpAES5RYAgDLLxdGiLuEB6hIeoLz8Am0+dl6/7krU0t2JSkjNst0wwtFiUmSon3o0CFD3+gHy93QxOjpgGMotAADlgIPFrLa1fdW2tq9e7ROhXSfTtGR3on7dnaiDyRe0ev9prd5/Wv+ct0uNq3vp1nr+6lrfXw2DOCENFQvlFgCAcsZkMqlRdS81qu6lF3rU06HTF7Rkd6KW7E7SjuMp2nkiVTtPpOr9ZQdUxcNZXer569Zwf3Wo4yc3Z/7TD/vGbzgAAOVcaBV3De8cpuGdw5ScnqWVcae1fG+y1hw4rdPp2fp+83F9v/m4nCxmtanto67h/uoSHqAavq5GRweKHeUWAAA74u/hontaBuuelsHKzsvXpiPntWxfkpbtTVb8uUytOXBGaw6c0bgFe1TH311dwv3VJdxfLUK8ucwY7ALlFgAAO+XsYNEtdfx0Sx0/jb09QodOZ2jFvmQt25ekTUfP60DyBR1IvqBpqw/L08VBnesVFt1OdavI283J6PjATaHcAgBQAZhMJoX5uyvM312Pdayt1Iu5Wr3/tFbsS9aKuGSdz8zVzzsS9POOBJlNUosQ78LjdMOqqEGQJyelodyg3AIAUAF5VXJUnyZB6tMkSPkFVm0/fl7L9iZr+b5k7UtM16aj57Xp6HlNVpy8XR3VLrRwBviWMD8F+3CsLsouyi0AABWcxWxSixAftQjx0Ys9w3Uy5aKW70vWqrhkbTh8Tuczc/VL7Cn9EntKklTDx9VWdNuF+qqyK4cwoOyg3AIAgCKqVa6kB9uG6MG2IcrNL9DOEyn6/cBZ/X7wtLbFpyj+XKZm/xGv2X/Ey2SSGlXzUvswP3UI81PzEG9uCwxDUW4BAMBVOVrMtlndZ7vV0YXsPP1x+Kx+P3hGaw+e0f6kC7br6k5deUjODma1ruWjW8L81D7MTxFVOV4XpYtyCwAAbpi7s4O61g9Q1/oBkqSktCytPXhGvx84o98PnlFyerbtcmOS5OPmpHahvrolzE+Rob6q4eMqk4myi5JDuQUAADctwNNF/ZtXV//m1WW1WnUw+YLWHCic1d1w+KzOZeRo4c5TWriz8HjdQE8Xta7lo9a1fNSmlo/C/N0puyhWlFsAAFAsTCaT6gR4qE6Ahx6+pZZy8wu0/XiKfv+z7O44kaLEtCzbJcekwpnd1jV9bIW3flVPWTiMAX8D5RYAAJQIR4tZrWr6qFVNHz3Xva4u5uRr2/Hz2njknDYeOaet8ed1LiNHv+5O1K+7EyVJHs4OalnTW21q+6p1LR81quYlR+6chr+AcgsAAEpFJSeL2oX6qV2onyQpJ69AsSdT9ceRs9p45Jw2Hz2v9Ow8rYg7rRVxpwvf42hR85DKalOrsOw2Da7M1RhwTZRbAABgCCcHs1qEeKtFiLeGd5byC6zaeypNfxw5p41/Ft7zmblae/Cs1h48W/gei1lNgr3UqqaPmtfwVrMaleXr7mzsjqBModwCAIAywWI2qWE1LzWs5qVHbqmlggKrDp6+8GfZPac/Dp9Vcnq27e5pl4T4utqKbvMa3qoX6MGhDBUY5RYAAJRJZrNJdQM8VDfAQw+2DZHValX8uUz9cficNh87p23xKTqQfEHHzmbq2NlMzdt2UpLk4mhW42qV1SykspoFe6t5SGX5e7gYvDcoLZRbAABQLphMJoX4uinE1033tAqWJKVezNWO4ynaGn9e2+JTtC3+vNKy8rTx6DltPHrO9t5qlSupeYi3mgVXVrMaldUgyEtODszu2iPKLQAAKLe8KjmqY90q6li3iiSpoMCqw2cyipTduKR0nUy5qJMpF7Xgz0uQOTmY1TDIU81reKtxNQ+dy5asVquRu4JiQrkFAAB2w2w2KczfXWH+7rqnZeHsbnpWrnaeSNW2+PPa+mfhPZ+Zq63xKdoan/LnOx30YdxKNa5eWY2realR9cpqUt1L/p4czlDeUG4BAIBd83BxVPswP7UPK7wEmdVq1dGzmX+W3fPaeuy89iWm6VxGrlbGndbKPy9DJkkBns5qVK2w6Daq7qXG1SvLx83JqF3BDaDcAgCACsVkMqmWn5tq+bmpf/Pqys3N1U8LF6lm0/bak3hBO0+kKvZEqg4kpyspLVtJaUn6bW+S7f3VvSupcXUvW+ltUM1LXpUcDdwj/DfKLQAAqPAczVKT6l5qWcvPtiwjO097TqVp54lU7TyRotgTqTp8JkMnzl/UifMXtSg20bZuLT83Narm9Wfp9VL9IE95ulB4jUC5BQAAuAI3Zwfb7YMvScvK1a4Tqdp5snB2d8eJFJ04f1FHzmToyJkM/fznCWuSFOxTSRFVPRVR1UsRQZ6KCPJUkJeLTCaTEbtTYVBuAQAAbpCni6PahfmpXdj/z/Cey8hR7MlU7Tyeoh0nUrX3VJpOplzU8XOFjyW7//+QBq9KjoWFN8jT9s/QKu5clqwYUW4BAAD+Bh83J3WqW0Wd/rwcmSSlZOZoz6k07UlIs/3zYPIFpV7M1frDZ7X+8Fnbuk4Ws8L83YsU3vpVPTmO9yZRbgEAAIpZZVcntQv1U7vQ/5/hzc7L14GkC7ayu/dUYfFNzyo8tnfPqbQin1Hdu/CwhnqBhXdpqxfooVp+btxa+DootwAAAKXA2cGihtW81LCal22Z1WrVifMXL5vlPZly0Xbi2tI9/39Yg6PFpNp+7qob6KG6/oX/rBfgoWAfV1nMHMsrUW4BAAAMYzKZFOzjqmAfV/VoEGhbnpqZqz2nCmd39yel//m4oAvZeYpLSldcUnqRz3FxNKuOv4fqBLirXoCHrfRWrYAnsFFuAQAAyhgvV0dFhvoqMtTXtsxqtSohNUv7EwvL7aXSeyDpgrJyCxR7MlWxJ1OLfI6Hs0Nh4f3z0IY6/h4K9XdToKf9ll7KLQAAQDlgMplUrXIlVatcSbeG+9uW5xdYFX8uU3GJ6f81y5uuw6czlJ6d9z+3GS7k5mRR7Sruql3FTaFV3Asf/m6q6esmF0dLKe9Z8aLcAgAAlGMW8//fca1nw/8/tCEnr0BHzmTYyu6+xHQdOn1B8WczlZGTf8WZXpOp8EQ2W+Gt4q7QKm6qXcVdfu5O5WK2l3ILAABgh5wczKoXWHiVhf+Wm1+g+HOZOpR8QYdOZ+jQ6QuFj+QLSsvKs12fd2Xc6SLv83RxUKh/YeG9NOPbNLiyAjxdSnO3rotyCwAAUIE4Wsy2Wdn/ZrVadTYj57LSe/h0ho6fz1RaVp62xado238d4jD29gg9fEutUt6Da6PcAgAAQCaTSX7uzvJzd1ab2r5FXsvKzdfRsxk6lJzxZ+EtLMDh/zMrXBZQbgEAAHBNLo4WhQd6KjzQ0+go18UtLgAAAGA3KLcAAACwG5RbAAAA2A3KLQAAAOwG5RYAAAB2g3ILAAAAu0G5BQAAgN2g3AIAAMBuUG4BAABgNyi3AAAAsBuUWwAAANgNyi0AAADsBuUWAAAAdoNyCwAAALtBuQUAAIDdoNwCAADAblBuAQAAYDcotwAAALAbDkYHKAusVqskKS0trVS2l5ubq8zMTKWlpcnR0bFUtomiGAPjMQbGYwyMxxiUDYyD8W5kDC71tEu97Woot5LS09MlScHBwQYnAQAAwLWkp6fLy8vrqq+brNervxVAQUGBEhIS5OHhIZPJVOLbS0tLU3BwsI4fPy5PT88S3x4uxxgYjzEwHmNgPMagbGAcjHcjY2C1WpWenq6goCCZzVc/spaZW0lms1nVq1cv9e16enryJTIYY2A8xsB4jIHxGIOygXEw3vXG4FoztpdwQhkAAADsBuUWAAAAdoNyawBnZ2e9+uqrcnZ2NjpKhcUYGI8xMB5jYDzGoGxgHIxXnGPACWUAAACwG8zcAgAAwG5QbgEAAGA3KLcAAACwG5RbAAAA2A3KbSmbMmWKatasKRcXF7Vp00YbN240OlKFMm7cOJlMpiKP8PBwo2PZtdWrV6tPnz4KCgqSyWTS/Pnzi7xutVo1duxYVa1aVZUqVVK3bt104MABY8LaqeuNwdChQy/7XvTs2dOYsHZq0qRJatWqlTw8POTv769+/fopLi6uyDpZWVmKjo6Wr6+v3N3dNWDAACUlJRmU2P7cyBh07tz5su/Ck08+aVBi+zN16lQ1btzYdqOGyMhILV682PZ6cX0HKLel6Pvvv9fIkSP16quvauvWrWrSpIl69Oih5ORko6NVKA0aNNCpU6dsj99//93oSHYtIyNDTZo00ZQpU674+uTJk/XBBx/ok08+0R9//CE3Nzf16NFDWVlZpZzUfl1vDCSpZ8+eRb4X3377bSkmtH+rVq1SdHS0NmzYoJiYGOXm5ioqKkoZGRm2dZ577jktWLBAc+bM0apVq5SQkKD+/fsbmNq+3MgYSNJjjz1W5LswefJkgxLbn+rVq+uNN97Qli1btHnzZnXp0kV9+/bV7t27JRXjd8CKUtO6dWtrdHS07Xl+fr41KCjIOmnSJANTVSyvvvqqtUmTJkbHqLAkWefNm2d7XlBQYA0MDLS+9dZbtmUpKSlWZ2dn67fffmtAQvv3v2NgtVqtQ4YMsfbt29eQPBVVcnKyVZJ11apVVqu18Pfe0dHROmfOHNs6e/futUqyrl+/3qiYdu1/x8BqtVo7depkffbZZ40LVQF5e3tbP//882L9DjBzW0pycnK0ZcsWdevWzbbMbDarW7duWr9+vYHJKp4DBw4oKChItWvX1gMPPKD4+HijI1VYR44cUWJiYpHvhZeXl9q0acP3opStXLlS/v7+qlevnoYNG6azZ88aHcmupaamSpJ8fHwkSVu2bFFubm6R70J4eLhq1KjBd6GE/O8YXPLNN9/Iz89PDRs21OjRo5WZmWlEPLuXn5+v7777ThkZGYqMjCzW74BDcYfFlZ05c0b5+fkKCAgosjwgIED79u0zKFXF06ZNG82YMUP16tXTqVOnNH78eHXo0EG7du2Sh4eH0fEqnMTEREm64vfi0msoeT179lT//v1Vq1YtHTp0SC+//LJ69eql9evXy2KxGB3P7hQUFGjEiBFq3769GjZsKKnwu+Dk5KTKlSsXWZfvQsm40hhI0v3336+QkBAFBQVp586deumllxQXF6e5c+camNa+xMbGKjIyUllZWXJ3d9e8efMUERGh7du3F9t3gHKLCqVXr162f2/cuLHatGmjkJAQ/fDDD3rkkUcMTAYY595777X9e6NGjdS4cWOFhoZq5cqV6tq1q4HJ7FN0dLR27drF8f4GutoYPP7447Z/b9SokapWraquXbvq0KFDCg0NLe2YdqlevXravn27UlNT9eOPP2rIkCFatWpVsW6DwxJKiZ+fnywWy2Vn/SUlJSkwMNCgVKhcubLq1q2rgwcPGh2lQrr0u8/3omypXbu2/Pz8+F6UgKeeekoLFy7UihUrVL16ddvywMBA5eTkKCUlpcj6fBeK39XG4EratGkjSXwXipGTk5PCwsLUokULTZo0SU2aNNH7779frN8Bym0pcXJyUosWLbRs2TLbsoKCAi1btkyRkZEGJqvYLly4oEOHDqlq1apGR6mQatWqpcDAwCLfi7S0NP3xxx98Lwx04sQJnT17lu9FMbJarXrqqac0b948LV++XLVq1SryeosWLeTo6FjkuxAXF6f4+Hi+C8XkemNwJdu3b5ckvgslqKCgQNnZ2cX6HeCwhFI0cuRIDRkyRC1btlTr1q313nvvKSMjQw899JDR0SqMF154QX369FFISIgSEhL06quvymKx6L777jM6mt26cOFCkVmPI0eOaPv27fLx8VGNGjU0YsQIvfbaa6pTp45q1aqlMWPGKCgoSP369TMutJ251hj4+Pho/PjxGjBggAIDA3Xo0CG9+OKLCgsLU48ePQxMbV+io6M1e/Zs/fTTT/Lw8LAdQ+jl5aVKlSrJy8tLjzzyiEaOHCkfHx95enrq6aefVmRkpNq2bWtwevtwvTE4dOiQZs+erdtuu02+vr7auXOnnnvuOXXs2FGNGzc2OL19GD16tHr16qUaNWooPT1ds2fP1sqVK7VkyZLi/Q4U7wUdcD0ffvihtUaNGlYnJydr69atrRs2bDA6UoUycOBAa9WqVa1OTk7WatWqWQcOHGg9ePCg0bHs2ooVK6ySLnsMGTLEarUWXg5szJgx1oCAAKuzs7O1a9eu1ri4OGND25lrjUFmZqY1KirKWqVKFaujo6M1JCTE+thjj1kTExONjm1XrvTzl2SdPn26bZ2LFy9ahw8fbvX29ra6urpa77zzTuupU6eMC21nrjcG8fHx1o4dO1p9fHyszs7O1rCwMOuoUaOsqampxga3Iw8//LA1JCTE6uTkZK1SpYq1a9eu1qVLl9peL67vgMlqtVr/bhMHAAAAygKOuQUAAIDdoNwCAADAblBuAQAAYDcotwAAALAblFsAAADYDcotAAAA7AblFgAAAHaDcgsAAAC7QbkFAACA3aDcAkAZd/r0aQ0bNkw1atSQs7OzAgMD1aNHD61du1aSZDKZNH/+fGNDAkAZ4WB0AADAtQ0YMEA5OTmaOXOmateuraSkJC1btkxnz541OhoAlDkmq9VqNToEAODKUlJS5O3trZUrV6pTp06XvV6zZk0dO3bM9jwkJERHjx6VJP30008aP3689uzZo6CgIA0ZMkT//Oc/5eBQOK9hMpn08ccf6+eff9bKlStVtWpVTZ48WXfddVep7BsAlAQOSwCAMszd3V3u7u6aP3++srOzL3t906ZNkqTp06fr1KlTtudr1qzR4MGD9eyzz2rPnj2aNm2aZsyYoddff73I+8eMGaMBAwZox44deuCBB3Tvvfdq7969Jb9jAFBCmLkFgDLuP//5jx577DFdvHhRzZs3V6dOnXTvvfeqcePGkgpnYOfNm6d+/frZ3tOtWzd17dpVo0ePti2bNWuWXnzxRSUkJNje9+STT2rq1Km2ddq2bavmzZvr448/Lp2dA4BixswtAJRxAwYMUEJCgn7++Wf17NlTK1euVPPmzTVjxoyrvmfHjh2aMGGCbebX3d1djz32mE6dOqXMzEzbepGRkUXeFxkZycwtgHKNE8oAoBxwcXFR9+7d1b17d40ZM0aPPvqoXn31VQ0dOvSK61+4cEHjx49X//79r/hZAGCvmLkFgHIoIiJCGRkZkiRHR0fl5+cXeb158+aKi4tTWFjYZQ+z+f//6N+wYUOR923YsEH169cv+R0AgBLCzC0AlGFnz57V3XffrYcffliNGzeWh4eHNm/erMmTJ6tv376SCq+YsGzZMrVv317Ozs7y9vbW2LFjdfvtt6tGjRq66667ZDabtWPHDu3atUuvvfaa7fPnzJmjli1b6pZbbtE333yjjRs36osvvjBqdwHgb+OEMgAow7KzszVu3DgtXbpUhw4dUm5uroKDg3X33Xfr5ZdfVqVKlbRgwQKNHDlSR48eVbVq1WyXAluyZIkmTJigbdu2ydHRUeHh4Xr00Uf12GOPSSo8oWzKlCmaP3++Vq9erapVq+rNN9/UPffcY+AeA8DfQ7kFgArqSldZAIDyjmNuAQAAYDcotwAAALAbnFAGABUUR6UBsEfM3AIAAMBuUG4BAABgNyi3AAAAsBuUWwAAANgNyi0AAADsBuUWAAAAdoNyCwAAALtBuQUAAIDd+D8z4oVHOI9fgAAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Visualize the loss curve\n", "plt.figure(figsize=(8, 6))\n", "plt.plot(losses)\n", "plt.xlabel('Step')\n", "plt.ylabel('Loss')\n", "plt.title('Training Loss Curve')\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "rLuPnH66h3Q_" }, "source": [ "## 4. Data Parallel Training with Flax NNX\n", "\n", "While JAX has its own advantages, it is often too low level of an API when implementing neural networks. For people familiar with PyTorch, the JAX ecosystem offers [Flax](https://flax.readthedocs.io/en/latest/) - a neural network modelling library with an API much more similar to that offered by PyTorch. Let's implement the same 8-way data parallel training using the Flax NNX API, which provides higher-level abstractions." ] }, { "cell_type": "markdown", "metadata": { "id": "nJ-NDbFy2VKK" }, "source": [ "### 4.1 Import Flax" ] }, { "cell_type": "markdown", "metadata": { "id": "NnarairyIsbb" }, "source": [ "We need to import (or install) the `flax` library, which provides a high-level API for building neural networks in JAX called `nnx` (analogous to `torch.nn` API in PyTorch).\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "4E8YR_rlh3Q_" }, "outputs": [], "source": [ "# Import Flax NNX\n", "try:\n", " import flax.nnx as nnx\n", " import optax\n", "except ImportError:\n", " !pip install -q flax optax\n", " import flax.nnx as nnx\n", " import optax" ] }, { "cell_type": "markdown", "metadata": { "id": "Snm3tRbC2ZFA" }, "source": [ "### 4.2 Define Model using Flax NNX API" ] }, { "cell_type": "markdown", "metadata": { "id": "qgCRtNjPIsbb" }, "source": [ "We define our two layer model as before, but using the flax `nnx.Module` API." ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "id": "K9Wlxs67h3Q_" }, "outputs": [], "source": [ "class MLP(nnx.Module):\n", " def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs):\n", " self.linear1 = nnx.Linear(din, dmid, rngs=rngs)\n", " self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)\n", "\n", " def __call__(self, x):\n", " x = nnx.relu(self.linear1(x))\n", " return self.linear2(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "sthOrJSt2rpe" }, "source": [ "### 4.3 Replicate Model and Optimizer State across devices" ] }, { "cell_type": "markdown", "metadata": { "id": "eooWeVsUIsbb" }, "source": [ "We replicated the model and optimizer state across the replicated sharding we defined earlier." ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "id": "Ittq3UFPZ_zY" }, "outputs": [], "source": [ "model = MLP(128, 2048, 128, rngs=nnx.Rngs(0))\n", "optimizer = nnx.Optimizer(model, optax.adamw(1e-2))\n", "\n", "# replicate model and optimizer states across shards\n", "state = nnx.state((model, optimizer))\n", "state = jax.device_put(state, replicated_sharding)\n", "nnx.update((model, optimizer), state)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 216 }, "id": "4LM_wq6kaLtU", "outputId": "7400b6da-5217-4ca4-ebb5-4072fd2e6ed3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model sharding\n" ] }, { "data": { "text/html": [ "
                                                                                \n",
              "                                                                                \n",
              "                                                                                \n",
              "                                                                                \n",
              "                                                                                \n",
              "                              TPU 0,1,2,3,4,5,6,7                               \n",
              "                                                                                \n",
              "                                                                                \n",
              "                                                                                \n",
              "                                                                                \n",
              "                                                                                \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0,1,2,3,4,5,6,7\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# visualize model sharding\n", "print('model sharding')\n", "jax.debug.visualize_array_sharding(model.linear1.kernel.value)" ] }, { "cell_type": "markdown", "metadata": { "id": "sIclsMKG2yVJ" }, "source": [ "### 4.4 Define Train Step" ] }, { "cell_type": "markdown", "metadata": { "id": "ivxVqggDIsbb" }, "source": [ "Simple MSE loss" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "id": "v1mGmZygaS6V" }, "outputs": [], "source": [ "@nnx.jit\n", "def train_step(model: MLP, optimizer: nnx.Optimizer, x, y):\n", " def loss_fn(model: MLP):\n", " y_pred = model(x)\n", " return jnp.mean((y - y_pred) ** 2)\n", "\n", " loss, grads = nnx.value_and_grad(loss_fn)(model)\n", " optimizer.update(grads)\n", " return loss" ] }, { "cell_type": "markdown", "metadata": { "id": "071bAWOL2_l1" }, "source": [ "### 4.5 Distributed Training" ] }, { "cell_type": "markdown", "metadata": { "id": "g2y6UawuIsbb" }, "source": [ "In order to simulate sequence modelling, we will use a synthetic dataset of sequences which are random weighted sums of periodic functions (sin and cosine)." ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "id": "NVnPMia93CXI" }, "outputs": [], "source": [ "def dataset(steps, batch_size):\n", " \"\"\"Generate 128D sequence data with underlying pattern.\"\"\"\n", " for _ in range(steps):\n", " # Generate input sequences\n", " # Create a pattern where the output is a transformed version of input\n", " # First few dimensions have strong pattern, rest have weaker signal\n", "\n", " # Generate base signal\n", " t = np.linspace(0, 4*np.pi, 128)\n", " base_patterns = np.array([\n", " np.sin(t + np.random.uniform(0, 2*np.pi)),\n", " np.cos(2*t + np.random.uniform(0, 2*np.pi)),\n", " np.sin(3*t + np.random.uniform(0, 2*np.pi))\n", " ])\n", "\n", " # Create batch of sequences\n", " x = np.zeros((batch_size, 128))\n", " y = np.zeros((batch_size, 128))\n", "\n", " for i in range(batch_size):\n", " # Mix base patterns with random weights\n", " weights = np.random.randn(3)\n", " signal = np.sum(weights[:, np.newaxis] * base_patterns, axis=0)\n", "\n", " # Add noise\n", " x[i] = signal + np.random.normal(0, 0.1, 128)\n", "\n", " # Output is a non-linear transformation of input\n", " # Shift and apply non-linear transformation\n", " y[i] = np.roll(x[i], 5) * 0.8 + 0.1 * x[i]**2\n", " y[i] += np.random.normal(0, 0.05, 128)\n", "\n", " yield x.astype(np.float32), y.astype(np.float32)" ] }, { "cell_type": "markdown", "metadata": { "id": "smYBaqIvIsbc" }, "source": [ "We can visualize a sample from the dataset to see what it looks like." ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 290 }, "id": "4FHz_U88Isbc", "outputId": "cb8a0897-79e5-4008-9e94-d6df15385cde" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA/QAAAERCAYAAADVM+x6AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAA2tdJREFUeJzsnXd4HNXVh9/Zpt57syXbkiX3insBbNwA0wyhhwRI8oVAcEInlBRKiCEkQCCUkNC7KTbGBdsY927LlmzJalbvvW2Z74+7K8m2urZI8n2fR8+MVrMzV6vRzJx7zvn9FFVVVSQSiUQikUgkEolEIpEMKDSuHoBEIpFIJBKJRCKRSCSSniMDeolEIpFIJBKJRCKRSAYgMqCXSCQSiUQikUgkEolkACIDeolEIpFIJBKJRCKRSAYgMqCXSCQSiUQikUgkEolkACIDeolEIpFIJBKJRCKRSAYgMqCXSCQSiUQikUgkEolkACIDeolEIpFIJBKJRCKRSAYgOlcPoL9jsVjIz8/Hx8cHRVFcPRyJRCKRSCQSiUQikQxyVFWlpqaGyMhINJqO8/AyoO+C/Px8YmJiXD0MiUQikUgkEolEIpGcZ5w+fZro6OgOfy4D+i7w8fEBxAfp6+vr4tF0jMVioaSkhJCQkE5ncCTnD/KckJyNPCck7SHPC8nZyHNCcjbynJCcjTwnHE91dTUxMTEt8WhHyIC+C2xl9r6+vv0+oG9sbMTX11f+U0kAeU5IzkWeE5L2kOeF5GzkOSE5G3lOSM5GnhPOo6u2b/npSyQSiUQikUgkEolEMgCRAb1EIpFIJBKJRCKRSCQDEBnQSyQSiUQikUgkEolEMgCRAb1EIpFIJBKJRCKRSCQDEBnQSyQSiUQikUgkEolEMgCRAb1EIpFIJBK7UFDVwJ+/OU5KQbWrhyKRSCQSyXmBDOglEolEIpHYhf/uyGZ3ZjkvbkzDYlFdPRyJRCKRSAY9Ayqg/+GHH7jsssuIjIxEURRWr17d5Xu2bNnCpEmTcHNzY8SIEbz99tsOH6dEIpFIJOcbxTWN7DxVCkBeZQPb0ktdPCKJRCKRSAY/Ayqgr6urY/z48bz88svd2j4zM5Nly5Zx4YUXcujQIX77299y++2389133zl4pBKJRCKRnF+sOVKARQW9VgHgo705MksvkUgkEomD0bl6AD1hyZIlLFmypNvbv/rqq8TFxbFq1SoAkpKS+PHHH3nhhRdYtGiRo4YpkUgkEsl5RaPRzPpjRQD85qJ4XvvhFKfLG/gxvZS5CSEuHp1EIpFIJIOXAZWh7yk7d+5kwYIFZ7y2aNEidu7c6aIRSSQSiUQy+NicWkxtk4lwP3fmJYSwfEIUAB/tPS2z9BKJRCKROJABlaHvKYWFhYSFhZ3xWlhYGNXV1TQ0NODh4XHOe5qammhqamr5vrpaKPVaLBYsFotjB9wHLBYLqqr26zFKnIs8JyRnI88JSXv09bywWFS+PJSHisqlY8MBlWVjw1l9MI/s8jq2pZUwJz7YvoOWOBR5rZCcjTwnJGcjzwnH093PdlAH9L3h6aef5sknnzzn9ZKSEhobG10wou5xsriOzKJKFo5W0WgGdeGFpJtYLBaqqqpQVXlOSATynJC0R1/Pi6P5tWSV1OCu1zAuWKG4uBiAC4d5s/poKf/7MZ14XzMaRbH30CUOQl4rJGcjzwnJ2QzkcyKrvJHYQHdXD6NLampqurXdoA7ow8PDKSoqOuO1oqIifH19283OAzz00EOsXLmy5fvq6mpiYmIICQnB19fXoePtLVUNRt5Yk01xVT31Ol9+PjsOnXZg/WNJ7I/FYkFRFEJCQgbchVbiGOQ5IWmPvp4XP+4uRafXsXR8JEOjIlpev35WAN+fqqGo3sypGi2zRsgs/UBBXiskZyPPCcnZDNRzYn92BX/elMbMYUE8sHgkSj+ebHZ3796kw6AO6GfMmMHatWvPeG3Dhg3MmDGjw/e4ubnh5uZ2zusajabfnqx+HgYWJIXy/q4s1iQXkl3ewAOLEwnwMrh6aBIXoyhKvz53Jc5HnhOS9ujteXG6vJ4DOZVoFIXLx0ed8X4/TzcunxDFR3tP89G+XGaNCEGj6b8PTpIzkdcKydnIc0JyNgPtnDBbVN7ekY2CQqivO1qt1tVD6pTufq4D49O3Ultby6FDhzh06BAgbOkOHTpETk4OILLrt9xyS8v2v/zlL8nIyOD+++8nNTWVV155hY8//ph7773XFcN3GBqNwk3Th3L33Gg89FqO5Vfz248OkVpY7eqhSSQSiWQQ882RAgAuiA0k3O/cTMLyCZF46LVkl9WzK6PM2cOTSCQSiaSFDccLySmvx9tNx3VTY1w9HLsxoAL6ffv2MXHiRCZOnAjAypUrmThxIo899hgABQUFLcE9QFxcHGvWrGHDhg2MHz+eVatW8cYbbwxay7pJ0T48v2I8MYEelNc18+BnR1mXXODqYUkkEolkAPLloTwe+vwoR3Ir2/15TaORTSmire3yCZHtbuPjruey8aIM/0OpeC+RSCQSF9HQbOa93SJO/MkFMfi46108IvsxoEru58+fj6p2/DDw9ttvt/uegwcPOnBU/YuoAA9WrZjA3zeeZMepMl7efIq0olp+MW84Bt2Amr+RSCQSiYvYeaqMN7ZlAvDIF1UsHhPObbNi8TS0PjZsOF5Ek8lCbLAXY6P8OtzX8olRfH24gMzSOnZnljNjeJDDxy+RSCQSSVs+3X+aynojEX7uLB0b0fUbBhAywhuEeBi0PLgkkVtmDEVRYP3xIp5dl+rqYUkkEolkAJBX2cALG08CMCzEC4B1yYX8+r0D7M8uB0Qfoq3c/vLxkZ2KCvm667m0JUuf0+nEvEQikUgk9qakpokvDuYB8NNZsegHmXj44PptJC0oisKKKTE8ftloNArsySzndHm9q4clkUgkkn5Mo9HMU2tTaGg2MyrCl1UrxvOXK8cQ5utOaW0zT3x1nL9vPMnGlCJKaprw9dAxLyGky/0unxCFu15DRkkdb2zLlKX3EolEIukW+ZUN3P3BQd7ZmYW5l/eOd3ZmYTSrjI70ZcawwVclJgP6Qc7koQFMjQ0E4LtjhS4ejUQikUj6K6qq8vLmdHLK6vH31PPAkkR0Wg3jov156YaJLJ8QiaLAppRiXvo+HYDFYyK61c7l56Hn57OHAfDV4Xz+8X1arx/MJBKJRHL+sCm1mMzSOj7el8sjXxylrLapR+9PK6ph84kSAH4+O65f29T1FhnQnwcsGhMOwPepxTSbLC4ejUQikUj6I2uPFrLlRAkaBR5YnEhgG+tTd72W2+cM45mrxhHpL9TsNRqFpdb7S3dYPCacexfGo7FOCjy7LlXekyQSiUTSKUdOV7asH8uv5p4PD3GozWudoaoqb20XejAXjgwhPszHASN0PTKgPw+YPCSAIG8DNY0mdkrbIIlEIpGcRXppA29aH3pumxXHmA5E7kZF+vKP6yfyi3nDeHRZEkHebj06zkWJYTy4JAmdVmHnqTL++M0xGprNfR6/RCKRSAYfDc1mThbXAvDk8tHEBntR1WDksS+T+WBPTpftW7syyknOq0avVbh5RqwTRuwaZEB/HqDRKFwySmRRZNm9RCKRSNpSWd/Myz/mYrKozBwRxPIOLOhsuOm0XDousqWdq6fMGB7EE5eNxl2v4fDpKv7wZTI1jcZe7atdakvg3Wtg7X1gkRUAEolEMlA5XlCFxaIS5uvGpCEB/G3FOC4ZFYaqwvu7c3ji62NU1bd//zCaLby9Q0xUXzkxihCfnk1ADyRkQH+esHBUGBoFjuZWkVfZ4OrhSCQSiaQfYLaoPLf+JBX1JqL9PfjtxQlO6S8cH+PPX64ci7ebjhOFNTz4+VHK65rts/Pd/4K6Eji9B1K/ts8+JRKJROJ0Dp+uAmBctD8gJpR/c3E8v10Qj0Gn4WBOJf/3/n7+sDqZFzem8e6ubNYlF7Avq5yP950mv7IRf08910yOceFv4XgGlA+9pBcYG2HrM4SEjmLS0NHsy6pg/bFCbpsV5+qRSSQSicTFfH4gl6N5VbjrNTy0JBEPg9Zpx04I8+GZq8fyhy+PkVNWz+NfHePF6yag0XRvQsFktmCyqLjr24y54Aikb2r9fve/IXYOePaumkAikUgkruNIbiUA46LPbAO7OCmMEaHePPNtKrkVDZ321N84bYhT722uQGboBzuZW+HUZtj9KstGCCGjTSnFGM29LENUVcj8AXL3y1JGiUQiGcAUVzfy4d7TANw0OYyYQE+nj2FokBfPXTMOLzctWaV17GpP56Wh8pyXVFXlz2tSuPnN3RRVN4oXLRbY8Q+xPnIphCRCcy3sfMlxv4BEIpFIHEJNo5GM0jqgNUPflqFBXrz4k4n8+Yox/HZBPDdPH8riMeFMjQ0kLtgLPw89k4cGsHBU98VbByoyQz/YObVZLC1mJjXuJtBrBOV1zezOKGd2fHDP93fsC9j+olj3DILhF0H8JRAcD4PQBkIikUgGK69vy6DZZGFMpB+z4toXwXMGYb7uLBsXycd7T/Pp/lxmDA8SZf8NlbBtlZhEHnU5zF7Zcp85eLqS/dkVAOzLqmDZuAhI/QZK08DgDdPuFL30X/xCZOwTlkDMVJf9jhKJRCLpGUfzqlBViAn0OMN1pS0GnYbxMf7OHVg/RGboBzNNNZC7t+VbzYk1LEgKBWDdsYKe7688A3b9S6zrPaC+DI5+Ap/fAZ/cCgfegepe7FcikUgkTmVfVjm7MsrRaBR+OW+Yy315Lx8XiUGnIa24liO5VZCzGz69TQTzAMe/ggP/A0R2/v3dOS3vTSmohsZq2PuGeGHyT8EjAEISYMxV4rUfnwdTz7yLJRKJROI6juSe2T8v6RgZ0A9msneAxQR+0aD3hKpclgYXoyhCZKKgqgfieKYm2PgkmJthyHS45StY9BQMvxC0BqjIFg9TH/xEZFSa6x33e0kkEomk1zSZzLy6NQOA5eMjGeKCUvuz8fPUs3BUGAa1iaJvn4Fv74f6cggYChNuFBvtewtOfMuBnApOFNa0vDeloBoO/Bcaq8T2o69s3fGUn4NXCFTnt0wISCQSiaT/c9QW0HdgoyppRQb0gxlbuf2IBaI0Hgg6vYFJQwIA2HC8qPv72vUvqMgSWY95D4DOALGzYMETcMtqmP8gRE0S2x7/SmRWcvfb7VeRSCQSiX34bH8eRdWNBHkbuP6CIa4eTgvXxNTw+9rniCvaQKPJAmOvgateF+Xz1qBe/eE5fty8FoBFo4V7i1KZhfHIZ2InM+8GbZtuQoMnzPyNWD/8obiPSSQSiaRfU1HXTE55PYoCY6JlQN8VMqAfrLQttx82H5IuFesZW1gSL7IxG44XYeqOOF7WdtE7D3DhI+eqBRu8YOQSuPQFuPR58ImAmkJYs1Jm6yUSiaQfkV/ZwKf7hRDe7bOH9Q/lX1WFA/8jeOO9xOkrqNL483HYb0UgrrP6Bk+9HeIXUt/YzMW5rxBLPjdNH8rQQE+uavicxmYjxM6G6Cnn7j9uLgydJSrWtq2yr6BrxlZ4cxFk77TfPiUSieQ850ieyM7HBXvh66538Wj6PzKgH6xkbRcPLwGxEBgn1H4Dh4G5mSnGA/h76qmsN7Ins7zz/dSWwNZnxPq467oWFYqaDNe8BaOvEN/LbL1EIpH0C1RV5d8/ZGA0q0yI8WfWiCBXD0lwfDXsfRNUC+6JC/mr9wN8VhhGXmWbtjCNBnXu/Rwyx+GmNvGA+hb+5nIWeKYRbzpJvUmB6f/X/v4VBWbdDTp3YWt3cp0dx/4lmBohfYP99imRSCTnOUesNnSyf757yIB+sJKxRSyHXyiWitKSpdeeXMNCqzje+s7K7i0W2PKUEBsKjhcZku5g8ITZ94qM/RnZ+ufBbOzlLySRSCSSvrDzVBn7syvQaRV+OX+4y4XwAChOhZ0vi/UL7sT/sj8zelgUFhVWH8w7Y9O9p2v5u+YWinWRhOvr4dv7mVnyIQA/elwMflEdH8cnHKbcJtZ3vQINFX0fu6kZCo+K9dKTfd+fRCKRSAA4bO2fHy/L7buFDOgHI2eX29sYsVAI2JWdYkl4NQAHcipaPXzP5siHkHdAZDUufkz0zfeEqElnZeu/hO8eBmMPxPgkEolE0mcajWZe3yaE8K6aGEWUv4eLR4SYLN74hJjojZsDE24A4JrJ0QBsTCmivK4ZENUFH+zJoVHxIGf6E+h8QqEiGz9TGdUaPz40zaHRaO78eGOugaDh4h65/g99bwcrOiqEYgGqcmV7mUQikdiBoupGiqob0SgwOlIG9N1BBvSDEVu5fWCcKLm34e4Lw+YBEJK7gfExfqgqfLz39Ln7KE5ttQCadTf4twonpRfXUlnf3L2x2LL1S54VEwOn98Ca34kHOYlEIpE4hQ/35FBa20yojxsrpsS4ejiib37LM1BTAL6RQmzVWjEwOtKPxHAfTGaVrw6JLP2ezHLSi2tx12tYOm2cuKcYvNBrFLYEXEUj7pwsqunsiEIsb/7Dwqe+8CisvQ+a63r/O+QdOPP3KT/V+31JJBKJBGi1q0sI8+kfOi8DABnQD0Zs5fZts/M2EpeJZfomfjIhBEURZfc7TpW2blN5GtY/Chaz2MfIpS0/2pRSxL0fHeLPa1J6NqYh04VgnpsPFB2Dr34j+vMlEolE4lBKa5tYfSgfgDvnDsNd3w8ekI58DNnbQauHBU+Ke0MbbFn6tcmF1DaZ+GCP8J1fNjYCP0+9yLQvfxkW/hFz3MWA1b6uK4JHwLJV1ntRsphgbupiIqAj8qzaMFqrYJMsu5dIJJI+czS3EoBxsty+28iAfrDRUbm9jYgJwpfeWM+YxgNcOVH0HP5zUzqltU1QniGC7boSkZWf87uWrElaUQ0vb04H4ERhDTWNPeyHDxsNl/8TvIKFddBXd4nJA4lEIpE4jE0pRZgtKkkRPkwb1g+E8AqPwu5XxfrM30BIwjmbTI0NZEigJw3NZp5am8Kpkjo89FqunBTdulFgHAybR5K1JDOloJuBeWii0Hhx94XilN5VjTXVQMkJsT5ioViWpvdsHxKJRCI5A1VVW/rnpSBe95EB/WAju4NyexuK0pqlT/2Gm6YPJT7Um9omE+989R3qV3cLsaCgEXDZi+KBB6isb+Yva1MwmtWWXR3P70XZfGAcXP6SmFSoKRRBfYnMakgkEokjsFhUNljFTxePCXfxaBD3l41PgmqBERdD0uXtbqbRKC1Z+qPWh7tLx0fg53GufVFShMjupxRUY7Go5/y8XYLj4dK/g7ufCMy/uRcaq7r/e+QfEr+DfwwMnSlekxl6iUQi6RN5lQ2U1zWj1yokRvh0/QYJIAP6QYfSWbm9jYTFoNFC0TH0VVn8btFIEslkYcYzVFSUQegokb2w+s0bzRaeXptKWW0z0QEezI4PBuB4d8ob28M3QmTqg+OhoRK+vkc8HEkkEonErhzKraSougkvNy2zRgS7djAWC3z/lzYVYL9vqQBrjznxwYT4CB96D72WKya2r2IfF+yNu15DfbOZnPIeCNMFDYfL/g4eAVCWLoL67qrf28rtoyZDsLXCoCJTKN9LJBKJpFfY+ucTI3xx0/WD9rABggzoBxFKc23rQ0ZnAb1nYGtGIXUNUbXHeUR5Cze1iT1NQzg55cmWzDzAG9syOV5QjYdByyPLkpgyNADoZYa+7RguexEixoOxHjY/JR72JBKJRNIhzSYLnx/I7did5Cy+Sy4EYP7IUNc/HB16T7SE6dxg4ZNCNLUTdFoNN04TgqzXTo3B1/3c7DyAVqOQECYyOamFPbwvBQ4TQb1nIJSdEkF9d9Tq2wb03qGiJ99iFkG9RCKRSHrFYWv/vLSr6xkyoB9EGAr2dl5u35bEy8QydS2sewA/vYXqkEm85vlL/vp9Dg3Nwv5n/bFC1h4tQFHg95eMJDrAs8VCIq24liZTFzZBnQ7YC5b8VUwe1BZBzs7e70sikUjOA75PLeI/27P40zfHuywvr6xvZldmOQCLRru43L6mCA78T6zPvlcE0t3g4qQw3rtjGldP6sRjHkiKEJPQvZpoDoiFy/4hgvryTEj7rvPta0ugMgcUjdClUZTWLL0su5dIJJJeYbGoLS1Wsn++Z8iAfhBhOL1drAy7sOuNo6eKrIKxHsxGlLi5jLzlHwT6elNU3ci/tp4itbCaf20VNjw3XDCEC+JECX6YrxuBXgbMFpWThbV9G7TeHUZae/qPr+7bviQSiWSQk1UmssfZZfVsSi3udNtNKcVYLCrxYd7EBXs5Y3gds/cN4dkeOUG0ffUAX3c9Siel+dAmoO+uMN7Z+MfAhBvF+rHVwoauI/KtdnXBCa3VbC0BfVrvji+RSCTnOVllddQ0mnDXa4gP9Xb1cAYUMqAfLDRVoy8+ItY7K7e3odHA6KvEesIiWPAE3p4e/O6SkWgU2JxazGNfHsNkVpkxPIhr2/gWK4rCqEjbw1MPRIQ6IukykeE4vQeqcvu+P4lEIhmk5Fc2tKy/uyubRmP7VVKqqvLdMVFuv9jV2fmSE5C2XqxP/79O++Z7S2K4D4oCRdWNVNT1so89YTHo3IULS8HhjrdrW25vIzheLGVAL5FIJL3iaJ6IKUZH+qHTyhC1J8hPa7CQtR1FtZXbD+3ee8ZdB9d/CPMfEiJ5wKhIX66bKnoWG5rNDAn05N4FCWg0Zz6AjbYG9Mf60kdvwy8Koi8Q68e/6t57cvfDqe/7fmyJRCIZQORViIBer1Uor2vmy0N57W53NK+KgqpGPPRa5sSH2HcQZhN8+4BV+6SLtitVhV2viPX4SyBkpH3HYsXLTcfQIFGF0C0/+vZw84Z4qwXdsS/a30ZVOwjorRn6svSuPxOJRCKRnMPh07Zye9k/31NkQD9IUDK3AqDGze/+mzQaoTh/VrbkuqkxzBweRKS/Ow8vS8LDcK6Q0ihreWNqQQ3m7toEdcboK8XyxFowNXW+bUU2fHufsD6SPvYSieQ8oclkpqRWXB9/OisOgM/251FZf25G2padnzcypN1reJ8oSYWcXXDyu9ZgvSNydgoXE60Bpt5u33Gchc2+rtcOLACjrhDLrG1QV3ruzytzxOtaA4SPaX3dNwr0nqKtoDKn98eXSCSS85Cs0jqO5lUCsn++N8iAfjBgaoLCHpTbd4FWo/DQ0iRevWkyUf4e7W4TG+SFp0FLg9FMZmkf++gBYqaBTwQ01XSeeVdV2P5iawak6Fjfjy2RSCQDgMKqRlQVvNy0XDo2gvhQbxqMZt7fc2YAWd1oZMepMgAuGRVm/4GUpLauH/1UiKu2h8UMu/4l1sdeAz4OGEsbksJtrWB9COiDR4hA3WKGlK/P/bktOx8+Rqj129BohA0eyLJ7iUQi6QEHcyq4/9MjNBotDAvxYpirNV8GIDKgHwzo3FBv+ISa6Q8Ib1870ZkIkUajtIgQ2aXsXqOBUZeL9WOrO94uY0vrAxVAUXLfjy2RSCQDAFu5faS/BxqNws9miyz9d8mFnG7jv745tRiTWWVYiBcjHCEsZAvofSLE8sfnobCda3HqNyJb7e7bKjjnQGzaLqdK6jrUFugWtoqx1G9Ee0Fb2iu3tyGV7iUSiaRHbDhexBNfH6fBaGZMlC9/vmLMOW2+kq6RAf1gweBFc9Q0px6yRRjPHgE9wMgloNWLh8Xi1HN/3lwPO18W6yGJYlmcYp9jSyQSST8n1yqIF22tnBoT5ccFcYFYVPjvjizgTDG8RaPDu1SH7xW26+7seyFuDpiNsOEPws7NRnM97PuPWJ98m+hPdzChPm4EeBmwWFTSi/tQORY3DzwCRGl99o+tr1vMon0AOgjobcJ4/Segt1hUssvqUDtT7ZdIJBIno6oq7+7K5h+b0rBYVOYlhPDk5WPwcdc7/uCmZkjfJJaDhAEX0L/88svExsbi7u7OtGnT2LNnT4fbvv322yiKcsaXu7u7E0c7uBkd2VreaJeHBY+AVsu941+e+/OD70BdicgKXfyYeK08Qzw4SiQSySDHpnAf2aYV6qczY9EosDuznOS8KlIKajhd3oCbTsO8BDuL4YFoi7K5kYSMhPkPCzHW+nJY/2irBsrh96GhAvyiIely+4+jHRRFae2j78tEs1YPiVY71bYVY6UnobkWDN4Q3I643xnCeJbeH9+OvL0ji7veP8j640WuHopEIpEAYDRbeGHDST7aK3Swrp0Sze8uScCgc1JYmrEFNv0RvvqNc47nBAZUQP/RRx+xcuVKHn/8cQ4cOMD48eNZtGgRxcUde/H6+vpSUFDQ8pWdne3EEQ9u4kN90GsVKuuN5Fc12meno5aLZfpGaGzzQFaRDUc+EuszfyOU8b1DQbVA6Qn7HFsikUj6MbaS+6iA1oA+JtCTS6y2dG9tz2SdNTs/Jz4ELzed/QdRYs0++0SAhz8YPGHRU+DmI6qrfvibyNQf+VhsN+0XoHXAODrAJtiaUtjHyrFRy0HRQP5BKM8Ur+XuE8vICaJN7Gz8hwqxvOY6qMnv2/HtQH2ziXXJ4nxYc6TAxaORSCQSqG0y8fhXx9h8ogSNAnddNIKbZ8Q6ppqsI2xJw9hZzjumgxlQAf3zzz/PHXfcwW233caoUaN49dVX8fT05K233urwPYqiEB4e3vIVFuZYUZ7zCYNOQ0KYyIYcy7ODHz1A2GgIGiGUgk+uE6+1FcIbOrP1HzB0lFgWHbfPsSUSiaQfk19lDejPEiu94YIhuOs1pBXVsjlVTHBfMtpB9zpb/3xoYutrvpGw8EkRAKeth6/vFpn68LEQO8cx4+iAtg4slr44sHiHivsNtD785R8Qy/bK7UFMXAQOE+v9QBhvc2oJDVYtgczSOjJL61w8IolEcj5TXN3IA58e4WhuFR56LY9dNppF1glpp1F2SuhvabSQeKlzj+1ABkxA39zczP79+1mwYEHLaxqNhgULFrBz584O31dbW8vQoUOJiYlh+fLlHDsmVdHtyahIO6gKt0VRYPQVYv3YalG2aBPC0xpEdt5G2GixlEr3EolkkFPTaKS6QQi0RfidGdAHeBm4elJ0y/dDAj1JDPdxzEBsAX1I4pmvR02GmXeJ9Wprdnr6/51ji+po4oK9cNNpqG0ykWutaOg1o68Sy5PfifYBm/BfRwE9tOmjd21Ar6oqa4+KrLyHXtgWbkqRZfcSicQ1pBfX8LtPDpNTXk+Qt4Fnrh7L5KEBzh/I8dViGTsHPAOdf3wH4bw6uD5SWlqK2Ww+J8MeFhZGamo7AmrAyJEjeeuttxg3bhxVVVX87W9/Y+bMmRw7dozo6Oh239PU1ERTU6sPenW1CFQtFguWftIT1x4WiwVVVZ0+xqRwH1RUjuVV2e/Ywy5C2fkKVOehZv6AsvMlANTx14N3eGtvYkgiCkDxcVSz2ekPjv0dV50Tkv6LPCcGLqfL61FRCfJyw02nnPM3vHx8BGuPFlBe38zCUaGoqtptbZOenBeKVRBPDR55bp940hVQkoZy8lsYfjFqSKLTe8k1CsSHenM0v4pj+ZVEB/RBNydiPIpftNAM+GGVqBzzCkb1je749wqKF/el0pOoLvw/S86rIru8Djedll/Mi+OFjWlsOVHMLdOHoNN2ncuR1wrJ2chzQnI23T0n9mSW89z6kzSZzMQGefHYpUkEe7s5/1wy1qOkbQBATbq832iddEZ3P6MBE9D3hhkzZjBjxoyW72fOnElSUhKvvfYaf/rTn9p9z9NPP82TTz55zuslJSU0NtqpT9wBWCwWqqqqUFUVTXu9fQ4iSGvGbDJxuqyWk9n5+HvY55TyjJyNx6k1qOufQLE0Y/YMpTJyAbTVS7D4E2hWUWqKqchKxuIl2yna4qpzQtJ/kefEwCUluwqT0USgm6FD3ZhfTQ8lubCOKaHaTrVlzqa754XSWElgVT6gUGbxP/N6bCPxFvSBkzAGJ7b/cycQ5Q0HjSb2pRcyMaRv57l71IV4lb0F6d8D0OQ3ktqSkg631ymB+JmMWPKPUVFU5LKJ5k9352Eympg91JuRfioeWpXS6ga+P5LJhKiuqzfktUJyNvKckJxNd86J79MqeHdfIRYVRod7cdecMCz1VRS7QM/aLeM7vBuqMXtHUamNdNk9qifU1NR0a7sBE9AHBwej1WopKjqzZKyoqIjw8O71X+j1eiZOnEh6enqH2zz00EOsXLmy5fvq6mpiYmIICQnB19e3d4N3AhaLBUVRCAkJcfqFdkRYEZlldRQ1G0gYGmyfnV5wI0r2ekAFjR7t/N8TGnFuVYUSngQlqQSbiyF0rH2OPUhw5Tkh6Z/Ic2LgUpvRgE6vY3hEIKGhoe1uExoK05J6vu9unxfZ6Sg6PQTEEho1tOPtwpzcE3kWFyToWXeyivQKE/6BwX1TTvZbgXLy4xb1fm3CXDw7+PwBCPRD+dENLPWEemvAywFOA11QXtfM4cJ0dHod10wfQWSwFwvH1PPV4XwOFBm5ZGIn47di12tFU7WwMBw6E6Kn9m1fEpch7x+Ss+nsnLBYVP67M5svDpWi0elYlBTGr+YN61aFkENQVZRtW0CnRzvhGkIHiKZad93ZBkxAbzAYmDx5Mps2beKKK64AxIm0adMm7rrrrm7tw2w2c/ToUZYuXdrhNm5ubri5uZ3zukaj6fcXMEVRXDLOMVF+ZJXVk1pYw7yRXT8odIvAWNGnmLcfhs5EiZvd/nZho6EkFaUkBRIW2ufYgwhXnROS/os8JwYmBVVNKCjEBHo65G/XrfOizKpwH5KI0o/Pn3HR/gR4Giiva2bN0UKuntx+i1238PCD+IWQ8g0ASvSU9hXubRg8ICAWyjNQytLBx/kPjRtTirGokBTuy4hQkY1fkBTG14cL2JtVQV2zuVtez3a5VlgssOVpyNkF6RtgxX/B2/mTHBL7IO8fkrNp75wwmi2s2pDG9vRSFBRunj6UFVOinatkfzaFycLqWueGMnJJ59fxfkR3/9cGxm9jZeXKlbz++uv897//JSUlhV/96lfU1dVx2223AXDLLbfw0EMPtWz/xz/+kfXr15ORkcGBAwe46aabyM7O5vbbb3fVrzAosQnjHeuL7297zL4Xxl8Pc+/reJswq9J9sVS6l0gkg5e8djzonU6xTRCvHQ/2foS7Xstts2IB+Gjvacrrmvu2w9FXCVHWsDHg1Y0qtBZhvJN9O24vMFvUFuvCpeMiWl4fFuJNbLAXJrPKtrRS5w3o0LsimAdh5/fjC8K5RnJeY7GofLz3NNvSOm5fkQxcPtp7mu3ppWg1CisvSeDaqTGuDeah1a1k+MXg3n8rrnvLgAror7vuOv72t7/x2GOPMWHCBA4dOsS6detahPJycnIoKGj1Wq2oqOCOO+4gKSmJpUuXUl1dzY4dOxg1apSrfoVByehIPwCyyuqoazLZb8f+MTD9l52rUIZale5L08DUx4c2iUQi6YdYLCr5rg7oVbWNZV0v6vqdzPyEUBLCfGgwmnl7R1bfdhY0HK79Hyx5pnvbByeIpQuU7ndnlFFW24yfh55Zw8+cfLg4UVTQbUpxUt9o7j7YZ7UVnnADaHSQvV0410jOaw7lVvLOrmz+uu4E29OdOMHUHnKCya6cLq/n0/25APzukgQutFflbl9oqISMzWJ91OUuHYqjGFABPcBdd91FdnY2TU1N7N69m2nTprX8bMuWLbz99tst37/wwgst2xYWFrJmzRomTpzoglEPbgK9DIT7uaOqkFpo5yx9V/iEg0cAWEwuyYZIJIOFtKIafkwrxdwX726JQyira6bZZEGjUQjzObclzCnUFEJjlfDuDRzumjH0AI1G4ZfzhCf85tTivt+bfCPArZtWgLYMfZnzA/q1ySKpccnosHO0A+aPDEGjwMmiGnIrHKxIVVsMm54UwVLipTDtFzDxRvGz7S9Co5OfFST9iu1tqkRWrT/BicLuCX91RnddPc5gw+Pw/rXi+ibpM6qq8sqWdMwWlamxgcweYSddrb5ych2YjWKy9WzL1UHCgAvoJf2T0TY/enuX3XeForT60cuye4mkVxzJreT+z47w7LpU7v7gIPuzK1w9JEkbbOX2Eb7urhMUsmXnA4eDzuCaMfSQ+DAfFiSJCr5/b83A4qzJqqARYllbLPzrncTp8noOn65CUWDx6HOFCf09DUweKirevk91YJbebBSBUmO1mNyYdY94fcJNEDBUfCa7XnHc8SX9GpPZws6MMgCiAzwwmlX+vOY4RdU9c5IymS2kFFTz8d7TPLr6KCte3cmq9Se6v4PaYlEtUlsM21bJTL0d2JhSTHJeNW46Db+cN8z1ZfYgdDyOfyXWRy0ftBbXMqCX2IVREQ7qo+8OodYWiqJjzj+2RDLAOVVSy5+/ScFkVtEokFNezxNfHeOJr45xutwFvjKSc8irEAF9VIAL++dLrA/K/bx//mxunTkUD4OWtOJaNjkyiG2LwQv8rEJ8pR276tibb63Z+amxgYT6tq+MfHGSKH/9PrXYcRMcO18WE+xuPrDwj60TQDoDzL1fPFCf+FaU5EvOO47mVVHTaMLXQ8dzK8YTF+xFZb2RP359vMu2zaLqRj4/kMsTXx3jhtd3c/+nR3hnVzaHT1fRZLKw5UQJOWXdvG9l/di6fnoPpK3vw28lqao38taPmQDcMG1Ih9cgp5O3H6rzwOANIy529WgchgzoJXZhdJTooz9ZVEOzyeLcg0thPImkVxRUNfDEV8doMJoZE+XH27ddwPIJkWg0CvuzK7jr/QO8uvUU1Y1GVw/1vMbl/fMAJSliOQD659vi72ng+gtiAPjfziz76rx0hq2P3kll941GMxutvfFLx0Z0uN3U2EC83XSU1TZzJK/K/gNJ2wjHvhDrFz4MvpFn/jx8DIy+Uqz/8DcwNth/DJJ+ja1nfubwYLzddDx22SgCvQzklNfzzLepmMznPkOW1jbx8uZ07nxnP//ZnsX+7AoajGa83XTMHB7EL+YNY9IQfwC+OZrfvYFkbxdL/yFiueOfUF/e11/vvOWt7ZnUNpmIDfbi8vGRXb/BWRxfLZYJl4DehfdQByMDeoldiPRzx99Tj9Gsklbc916oHhE8EhSNKJuqlYqpEkl3qKxv5rEvj1FZbyQu2ItHlyUR4GXg9jnDeOXGSUyLC8SiwpojBdz5v33scLVw0XmMreQ+yt9FGQ+LBUpaLesGGpeOiyTK34PKeiMf7MlxzkGdrHS/5UQJDc1mwv3cmRjj3+F2Bp2GOQmir/X7lCL7DqI8E354TqxPvEn4zrfH1DvAOxRqClpF8yTnBWaL2lJuP3N4EADB3m784dJRuOk0HDpdyWs/ZLT0w1fWN/PGtgzu/N8+1iUXYrGojIv24/Y5cbz4kwm8d/s0HlqaxKXjIrlmspi425xa3PXEXVMN5B8U6wv/KP5fm2qEvkNfqCmCslN928cA5EhuFd+nFqMo8OsLh7uuNexsaksge4dYTxqcYng2+sknLhnoKIrSUna/L8vJ/bcGTwgU4kcySy+RdE19s4nHvzpGYVUjYb7uPHn5aLzcdC0/j/L34NFLR/HnK8YQF+xFXZOZlzanY2wncyJxPLaAPjrA0zUDqMoBYz3o3ITH+gBDr9Vwx9w4AL4+UuCcVhInKt2rqsrao6LcfunYcDSazntEL04UugI7TpXR0Gy21yDg+z+BqRGiJsOUn3e8rcET5vxerB/9tNUOUTLoOZpXRXWDCR93HeOi/VteHxHqzX2LRqIosC65kI/2nuadXdnc8b99fHkoH6NZZUyUL89cPZa/XDmW5ROiGBbifca5PibKlyGBnjQaLV2315zeDRaz0HQIjIN5D4jEUMYWyPyhd7/cqe/ho5vg8zuc2mrjaoxmC69sEZMYi8eEkxjejyzhUr8G1QIR48XfeRAjA3qJ3ZiXEAKIjF5lvZMt5GTZvUTSLYxmC39Zm0pGSR3+nnr+uHw0AV7ti5yNj/HnhesmEOhloKbRxN5MWY7obIxmC8VWsSiXldzb+ueD44XK/QBk8tBALogLxGJReX1bRu8UsXuCTRivKheaah16qPTiWjJL69BrlRYRwM5ICPMmyt+DJpOFH+1VeVOdLzKTGh1c9Chouni8HDIN4heKh+2tzwohPcmgx1ZuP2NYENqzJp6mDQvi57NF0PXe7hw+3nuaRqOFEaHePLl8NE9dObbFJrk9FEXh0nGi3WTNkfzONSKyrOX2Q2eLZXC8sFYE+PGFnrkwqCoc+B9sfBLMzWKi4PD73X//AGfN8TLyqxrw99Rzy4xY5w+gtkRMDO55XbTxrP8DfHU3fHwrHPpAbDNqufPH5WRkQC+xGzOGBzEi1JsGo5lP9uU69+A2P3opjCeRdIjFovLajnyO5lXhodfy+GWjuwwStRqFC0eKybqNjvavbqiAjK1gdlKf8wCgsKoRiwoeei0BnnrXDKLY2j8fMrD658/m57Pj0GkVDuZUsivDwZNTHv6twnh7/u3QQ9lKmKcNC8LHvetzRFEULmoRx7NT2X3hUbEMSQTPwO69Z8Zd4O4H5Rmw+1X7jEPSbzFbVHaeEufqrPj27cwuHx/JMmtQPiTQk4eWJvL8teOZNCSgW4rp80eG4mHQkl/ZyKHcyg4GYhQZeoDYWa2vT7pV/M/Wl3f/fDQ1w+anYO+b4vvhF4nlqc1ikmuQk1fZwDfHxN/0jjnD8G5T6ecULGZY94DQPzj4LqR8LSosCg5DRZaYYPGNhLi5zh2XC5ABvcRuKIrCrTNjAeGF21MLkj5hy9CXnJDBgETSDnmVDTz+9XH2na5Br9HwyLIkRoR6d+u9F1uzfvuzy6moc1D1TUUWfH4nbHgMtv/dMccYgOS1COK5u84CqEXhfuD1z7cl0t+DKyZEAfCvraeocbTYo82u7fiX4gHfQey2Vs5cENfNQBq4KDEURYHkvGq+OJjb94qFwiNiGT62++/x8Ie5bUrvU9f2bQySfs2x/CqqGox4u+kYF9V+pl1RFH4xdxiv3TyZf14/kZnDg3t03fMwaFlgnaxac6Sg/Y3yD0FznZh4ajtJqTOI0nuA1DWQu7/zgzVUwpqVQh1f0cCclbDgcYiZJipPjnzU7XEPRCwWlVe2nMJkUZk8JIA5HUzSOJRjX4jKIDcfIbY5+afiunvxY7Dsebj6TVjxNmhdNBnuRGRAL7ErE2L8GR/jh8ms8v5uJ4kPAfhGi39oczOUn3+CJBJJRzSbLLy/O4ffvH+Aw7mV6LUKKxfGM74T4ayziQn0ZGS4DxYVtpx0QJa+MBm+vEsIW4KYZZeWVkCrZZ3Lyu3NRiiz9oMOMMu69vjJBTFE+rtTUdfM6z9kOPZgMRfAhBvF+g/POSRjV1TdSE5ZPRoFpgwN6Pb7gr3duHKimNx468csXvsho282dr0J6EFkzib/VKxvWwUFR3o/Bkm/xtbeMX1YUKeiaYqiEOnv0aUWREfYXB72ZpW3n1jKttrVDZ15bmtIxLg2LgzPdezCUJEFq38lKlMM3rD0uday7gnXi2Xq2kGtmv/J/tMczatCrxWTME6fcK4vbxXVvOBOmP1bmHIbjLlK2NNFT4bgEUL75TxABvQSu3OrtYdm84lissvqnHNQjUb60UskZ3HodCW/+eAAH+zJwWhWmRjjz1+WDmPWiJ7PpNuyHhuPF9u3/zh7h8hyNNWI/+GRS8TrW58VWZTznHxXC+KVZ4qJUjef1hLyAYybTsu9CxPQKLD5RAm7rOXqDmPKzyBsjDiXNz5p915xW3Z+VKRvt8rt2/LTmbH8bHYsILKZf1mbQqOxFyJ59eVQeVqs9zSgB1HqPGweWEyw4Q9QU9jzfUj6NZY25faz44MceqzoAE8mDvFHVWkRi2xBVc/tnz+bC+4E7zDhwvDZ7bD6/2D1r8Wk81d3w9f3iNeq88EnApa/BNFTWt8fMUHcy8zNkPyZQ35HV3PodCXvWZN2t0wNJ9zPBQ4su/4lrqshiZB4qfOP38+QAb3E7sSH+TBzeBCqCu/szHbegW3+yFIYT3KeU1nfzKr1J/jD6mTyKxvx99Rz/+KRPHHZKEJ92hfA64rZ8SHotQo55fWkF9tJ5Ct1LXz3CJiaYMh0uPR5US7nGymy9Ttfsc9xBjBtS+5dQolVgTw4AVxV8m9nEsN9W7LTL29Op6rBgaX3Wp0o/3TzEZ/l7tfsuvs9mSJI6km5vQ1FUbhyYjQPLklEr1XYk1nOw58f7XlbTVGyWAbEgnsvFK41Gpj/kBASbKgU14RmJzgRSJzGsfxqKutFuf34Nur2jmKZNUu//lgRTaY2k1QlJ6CuRPiRR01u/80Gz9ZWkKpckSQqShbZ+ILDrSX74WPhyn+dq56uKK1Z+uNfDrqJ6dLaJv723QlUFRYmhTFnmL/zB5F/yNrqoMDse7sW4TwPkJ+AxCHcNH0oGkVkD1IKeqAW2hfCxohlkQzoJecvKQXV/OrdA2w5UYKiwLJxEbx602TmxIf0qSTO203HDKtvcJ/F8VQVDr4nsvCqBRIWwyV/EQ9Zeo82fYzfQM7uvh1rgNPqQe8qhXtrQB86sAXxzuaGaUMZEuhJZb2R17Y6uE3LJ0wErABHP2nNEPaR+mYTyXni/npBXO+znrNGBPOXK8fi464jrbiW+z493DNrP5sgXsS4Xo8BvQcsego8AkSLx5anwSJtMgcL2091r9zeXkyNDSTM143aJhM/nGzj5GArt4+eKnrmOyLmArj6DVj8tDgvL/mz8Ktf+KSYoFv8DFz6gjhf22PobPCPEZVnKV/b7xdzMSazhb+uS6WqwUhcsBd3znWBFZzZ1Kqzk3gphA5sbRd7IQN6iUOICfRsEdL6744sx1sEQes/dXWeUMsegPSph1Fy3nM0t4rHvkymtslEXLAXq1aM55fzhp/hMd8XbJZYP5wsodnUy4dtiwV2vtSq/D3hBpj/oMhk2oicAGOvEes/PCceis5D6ppMVNaL7HFUgIsD+kHQP98Wg07DbxfEo1FgW1ppi52Ww4idBWNXiPUtT0NN39XlD2RXYraoRPl79HnCJynCl+dWjCfcz52i6iYe+OwoqUXdzCza+t7D2w/oTWYLm1KKKKjqoB/Zhk+YCJy0eqFUfeDt7v8Ckn6LxaK2/H/NGuHYcnsbGo3CkjEiS//NkfzWZ1DbZFrsnK53Ehwv+uxjZ0HcHNEWMmy+6M8eOqNzoTWNBsZbbfCOfiLU8AcB/9uZTUpBDR4GLQ8sScRN5wIb02NfiFYwNx+44A7nH7+fIgN6icO4/oIh6LUKx/KrOZDjhADbzQcChop1m83SAOKjvTn85N+7+Gy/HRSHJecdB3IqePyrZBqNFsbH+PHXa8YRH+Zj12OMj/YnyNtAbZOJPb3xpDebRDBz9FPx/Yy7YNov2i/lnnqH6NmuK4GdL/dt4AMUW/+8v6ceT4OT7YAAjI3iwQkGvGVde8SH+XDNlBgAXtmSTmW9gx+6p/1C9Hs21cCmP/bZkaUv5fbtEeXvwd+uGU9iuA+1zSae23yajSldTDwYG6D0pFhvp3/ebFFZteEkf9+Yxp/XpHR9bwsfA3Os5c77/+tQdwCJczheIMrtvdy0PRJj7SsLR4eh1ypklNSRWlgjet7LM4Qi/ZDpjh9A/ELwCoG6Ukjf4PjjOZgdp0r54mAeAL+9ON41VWN1Za1CeNN+IWwvJYAM6CUOJMTHjWXjIgH4745s52SfnehHv+VEMXd/cJANx4v6HIBX1DXz0d7TNBjNvL0ji6fWplDXJO33JN1jd0YZf/rmOEazypTYAB67dDTuevvPnGs0ChcnWsXxunrQPxtTs7Cks1n8XPQojFvR8fZ6d5G5VxQ48S1k7+zDyAcmeS2CeC7KzpeliZYIz0DwcoElkRP4ydQYYoO9qG4w8a8tpxw7marVC1srg5foye2DP73ZorI3S0yU2yugB/Dz1PPnK8cwe0QwZovKP75P5z/bMzu+fxcdF+eIdyj4hJ/xI4tF5cWNJ/kxTWRnc8rquze5P3IxjLtOrG95unVSqa801Q66fuaBgE3dflpcEHonlNvb8HXXMy+hjYWdLTsfMa53Wg89RauHcdeK9UPvD+gWkoKqBv6+MQ2A5RMimdkLYV27sPtfYKwXLWAjl7lmDP0UGdBLHMqKKdF4GLRkltaxzdEljdDa53lsNex5XcyMOoCKumZe2XyKzNI6/rEpjT98mUxhVTv2KN1k9aE8jGaVEB83dFqFXRnl3PvRITJLB/jDR2EyfPlrOLne1SMZtPyYVspT36ZiMqvMHB7Ew0uTMOgcd2m/yFp2fzCngrLapu69qbkevr0fsreD1gCL/iKyF10RPhbGWh+IfngOGp2kx9FPaBHE83NRQF9sK7dPGjSCeGej11pL7zUKO06VsS3Nwfcp30iYd79YP/IRHPmkV7tJKaimtsmEt5uOpAj7BiduOi2/X5jA8jHiof3zA3k8tTaFhuZ2FPA7sKtTVZVXtqSz+UQJGgVGR4ox2jJ8XTLtl0I53NQEm54Uy75QVwof3wzvXCX0O+zsNiBpH4tFZUeLur3zg8Bl40TZ/Y/ppTSlbxUvdqRu7wgSLxXVo1W5kLXNece1I00mM0+vTaWh2UxiuA8/nRnrmoHkH4K0DeJeNOu3UgjvLOSnIXEovu56rp4k1ITf2ZmNyezgGcqhs0SZbnMtHHwX3r8Ovv+LUDa1I+/syqbBaCbUxw29VuHw6Sruev8AXx7K63ElQk2jkW+PCpueX80fzrNXjyPEx42CqkZ+/8lhNqc6wPfbGZRnwroHRVC/5WnI2eXqEQ06NqcW89x3qVgsKvMSQrh/caLDMyBR/h4kRVg96U+UdP2Gxir45l7IPwh6T+HXO3Rm9w849edCXKi+DH74q8Mm6fojLvegH6T982czPMSb66yl9//acorqRgcHe8Pmw9TbxfrOl+Dkdz3exd4s0fIyNTYAbS/9ujtDo1G4clwIKxckoNMq7M4s54HPjlB69iSeTRCvTf+8qqr8+4cMvjtWhEaB310ykpVWq8DDp6vIKOmGS4ZGAxc+IkTHyjPF59RbVBW2/lXY65kaRWXEp7dB7v7e71PSLVIKq6moa8bDoHWKuv3ZjAj1JjHcBzdTLWWn9qMCxDoxoDd4wugrxPqh98W5OMD4cM9pMkvr8PPQ88CSRKeIGp6D2QQ/viDWky6TQnjtIAN6icO5fHwUvh46iqobOZBT6diDeQXBte8INdLwscLXNm09fH4nfPUbyPqxz4c4VVLbUm78+0UjeemGSYyJ8qPJZOGNbZn8/tPDZJd1P7O+5kgBDUYzscFeTBkaQEKYDy9cN4GJQ/xpNll4fsNJXt6c3nsRMldQVwrfPiB6RfWeoiRz45NQmu7qkQ0aNh4v4oWNJ7GosHBUGCsXJjjkwb49bIKX25NPoR5bLSZr6trx864tEb69JamixPHSF4TgXU/QuQmFcEUDmdvgvWvg698K5eBBnrG39dC7RBBPVYVFEwhP5UHOtVOiGRLoSW2TyTmTqBNvai3H3fJMj+9NuzOsAb0dy+3bY/7IEJ66cix+HnoyS+tY+fFh0oqsIpUWc6tNrDVDr6oqb+/I4psjBSgK3LMgnrkJIYT6ujPLWqa7urtZes9A0ZoDcPwryNjSu18i5Ws4vVtUB037pZgkqDwNa1bChsfFdUriEGztFtPjAh1aOdYZN04fymjzceoam0lpDsHkFebcAYy+Spx7JakiS19bIu6X9eXCprGxul+L5u2wOhTcMXcYwd5urhlE6jdQkSV65qdKIbz2kAG9xOF4GLTMt/YxbT3phAcljUaokS5/Ca58TZT2arRCife7R4T3dS9RVZU3tmWgqjA3IZikCF8i/T146sox/PrCEXgYtKQV1XLPh4f4eO/pLvsxG5rNfHkoH4AVk6NbbMX8PPQ8cdlorr9gCIoC65ILufuDg3y2P/fcDMlZNJss7Egv5ZlvU3liXSY5PbEfsgdNtaK8urZIVEtc9y5ETRJ9T+selA9PdqCqwcirW0+hqqKk8K4LR6BxUjAPMCc+GINOw9TCD2jaskpM3rx7FfzvClh7v2h3ObleTKJVZIn+68v+0ftZ9bDRcMmfhGCWqops/w9/g3euhHUPQ/rGfv1A1BtUVSW/UrTxuER8qDxDCBJqDX2zIxsg6LQaLhvf6l3tcGFSRYHp/wcjl7ROeOYf7NZb8yobyKtsQKNRmDSkA9ssO5IU4cuqa8czJNCTirpmHvjsCH9YnczqjVuora2hWeuJGhALwPt7cvj8gAjY/2/+cC5KbA2erpwoqvW2ppV2eR9rIXoKTLhRrG99DqoLejb4qjzY9YpYv+BO4Q9+3bsw5moxSZixRZTiH/pAluHbmTVHClh7VPy9ZseHuGwcE2L8uT0mH0WBzY0JPPNtqnMTJJ6BkGjt917/BzEp/e5V4v71v+Xw38vgP0v6pUVrcU0j+ZWNaBSYMtTx15p2MRtFdQPA5Fudo38wAJEBvcQpzBspLua7M8rb78NzFKGJYob/+o9g5FLx2uEPei1OsuNUGcl51Rh0Gn46s9V/U1EUFo8J55UbJzEtLhCzReWdXdl8d6yw0/2tP15IbZOJCD93Zp8lMqLRKNwwbQiPXzYKbzcdeZUNvL0ji5+9vZeHvzjKhuNFLcJ5FovK4dOVvLgxjZvf3M3T36ayI6OMrPJG/rruBI3Gbn7mpWnwwfXiJvPOlfDuNfDetfD+T+DDG2H1ryF9U8efn9kohM/KToksyNK/iaqJhX8UDgR1JSKob3byJMMgY82RAppMFoaHePGLucOcGswDeBp0zI3zYazxCNUNRvCNEg/HDRUiE3bwXdj8F6gpED+7/CUI7KNfbexsWP4y3PCRULcNGiEqcLK3w6Y/wed3iD7FQUJ5XTMNRjMaBcL93J0/gNN7xDJyoqiSOA+YmxCCm05DTnm9UMV2NIoCc+8T57a5WUxOdaM9bK/VYWJslK/dLCm7IszXnedWjGPy0ACMZpVDpytJP7KTgqpGvi0N4cY39/L7Tw7z4Z7TgMjmLbbahtmID/NhTJQvFovK14fzu3/wKT8Tk3rNtfD9n7rvDmCxwJanhBJ/xHgRxAO4ecOsu+Gq1yFsjPj57lf7JFIoacViEYmPV7eewqIKu1OXBYMApiYia44S4edBqttYdmeW8/S3Kc4N6ifcIPQztAaRYFLOCr8sJkj5ynnj6SaHT1cB4n/XWdeac0hbLxJEnoFCk0DSLjKglziF+FBvwv3caTJZ2J3ZTmmuo/EOgZm/EeXflTmQ1/PeuWaThf9sF2q7V02KIsTn3IfcYG83HlmWxM3ThX3eaz9kcLKo/QfDZpOlJZNx9eToDoOyyUMD+fctk/n1hcMZHemLqgq/8X9sEsH7U1/s4bb/7OHR1clsTCmivtlMkLeBKydE4eeuI6einjd/7IZKsMUC254X1i4NlaIcrK5EXEhrCkSwVJQs7JY+v10oxrbNYlkssPVZ8dnqPWHJX8HX+kDn5gOLnxVBflm6EDmyOHFiZxDRaDS3PAxfMzmmparD2SwLzMVNbSLX6E3mgtfYd9F7bEl6ko1BN7BZuYA9tSHsbo6jbvELreeBPfAJFw9H17wJK96GSTeL86oiC774Zb/McvQGW3Y+1NfdqcrQLeRYXQWGTHP+sV2Ep0HXIty1/ljffeK7hUYLFz8uJk6M9bD2PqjI7vQtu60B/QVxzvH0tuFp0PH4ZaN44boJ/PrC4VzkX4SbXkO2fhg1jSZOWCdBbp0Zy+XjI9vdx5UTowH4NrmQ+uZuBuZaHVz0BzB4Cwcbm21VVxz5SGi46D1F287ZIlrBI+Dyf8L0X4nvMwemaFl/oqHZzF/WprRUHt48Yyh3X+zcCrJzyNsPxga8AsL5+RWLMeg07Muq4M9rjtNkctJziHcoXP8B3L4B7vge7twMd26BOzbD1W+IbU7v6XfJjiO5lQCMj3aRPZzFLBIEAOOvP28ml3uDDOglTkFRFOYliCz91pN9L7lOK6phf3YPfbANnpBozdInf9bjY64+lEdRdRNB3gaunhQNJ9bBx7eKrHYbFEVhxZRoZg4PwmRWeWptClX155byfZ9aTHldM0HeBi4cGdrpsX3c9SweE8EzV4/jjVuncPOMocQEejCq4QBXH/s11xevIlTfyOIx4Tx91VjeunUqt82K5c6ZkSgorEsubOll65BTm0Q/pN4TrngFrn5TZDCufFVkRi//p8iUGLxEBv67h2H1r4SwkKrC3teFAqlGCwufhJCEM/fvGwGLnxYz1Dm7YMc/B6RAjKtZf7yopapj5nDnPtC3ZXj1PvRaDQe047n7w8M8+W0Gqw7reLFgNM83LedPht/yZ27nuW2ljrOsDIwT4mJXvykybU01sO4BOPDOgD+38irFg51Lyu2b68TkHUDM+RPQA1wySlivbUsr6X7A2Vd0VueHkJFCRHLt7ztsTappNHI8X2TNpjm4f749FEVhRKg3i0eHM8lwmiEBnvzqJ1fy/LXj+fWFw3loSSLXTI7u8P1ThgYQ5e9BQ7OZDcd7MGniGwHz7hPrh96D03s7377sFOx7U6zP/E3Hk4oajRDZUhQxcV3fw+cKSQtltU08+PkR9mSWo9cq3L94JNdOcd2kMyBaNLa/KNZjZzNhSABPXDYad72GgzmV/PHr492vYLQ3iiLOv6ARopLN3Ay5e1wzlnZQVVGJAzA+xt81g0jfJJJM7n7i/1TSITKglzgNW0B/IKeSqobe96o1NJt5ZHUyT3x1vOdB/egrxUU0Z6cQxekm5XXNfLJPbH/rjFjc1SahuluRBdtWnVOCrigK9yyIJ8rfg7LaZv76XSrmNkGN2aLy2QFRHnzlxKgeicWE+bpz7ZQYXl4ewyO+awjx0jPbO4/Xfd/k1xP0jInya5kNHx3u1eIy8M/v0yiu7sBaz9gAu18T6xNvFOWNwSNEUB6aJHqXI8aJ/qXrPxQ9jTp3KE4RwkKf/by1x2nufRBzQfvHCU0SLRCKAse+gKOfdvv3loDJbGkRlLpqUpTrsh6mJpScHfh76jmsn4CHXktcsBczhwdx5cQofjlvOCsXJmDQadifXcFb2+3kI90RXkFw2d/FDV9VYe8bovWjs2xHPw/486wZepd40OfuE5kRv2jxdR6RFOFDdIAHTSYLP9hh8rnbGLxgybPgPwRqi+Hg/9rdbH92BRYVhgR6EubrglYMG1W5osVGa0AflkR8mA+Lx0R06U+t0ShcYe2l//JQfs+cb4bNh1GXi/XNf+nY8cJshM1PieXQWUKnoDMMXhBgbQmyTWRJekRGSS2/++QwGSVCDf0vV45ljgv75gFxjn59N9QUiuvYxJsBGBvtxxOXj8ZDr+VIbhVPfn38jOczp6MoEDdXrGf+4LpxnMXp8gYq643otQqJ4S7oW7dYWq+D464FvYvcXgYIMqCXOI2YQE+GhXgJX9I+eNLvzCht6cN/ZfOpbvfkq6pKuS4UYqaLF4590e1jvrMzm0ajhYQwHzExkfKVyAaCCGpPfX/OezwNOh5emoS7XsOR3Cre3dVaRvljeimFVY34uOtYNDq82+NowWJB2fosHpYGAmOS8A6KQVOTLzzfz7LiueGCGEaG+1DfbOa57060f+M6/KEor/cJb/X97gh3X5h2pygfG3M1aPUiGwIiW9rVw9OweUJpGGDXy1L5vgdsSyulpKYJf0/9GWJTTuf0HjDW4x8Sxaq7ruejX0znH9dP5KGlSfxsdhzLxkVwYWIo9y4QVRpfHsrvWTauN2j1MPf34kujEw9Gq38lJu7qy8X/xZFPhH3V6v+D/yyFL37V/X5cJ+NSy7rT1raFIdOdf2wXoygKl4wW/1tOK7u34REg/JUBTm1uV+hxj7Xcftow52fnz8BmVxcyUlQY9ICLEkPx99RTUtPE9lM9bMGbcZeozGmogPdWCF2XPa+L/2+bV/3+t0Vrl7ufuB50J0McNlosi471bDwSThbV8OBnRymrbSY6wIO/rRhPUoSLhcsqT8PX94jJMf8hcNmLYuLXyuhIP55cLoL65LwqDuRUuHCwQNwcscze2W8EXg9by+1HRfq6xqEgc4v4O7r5iGScpFNkQC9xKvYou/++jaVQcU0T7+3uvN8QRDC/av1Jbn1rD1+apgsv0hPfitLSLkgvrmFTqniwu2NuHBpLswiAQWScAfa81vow0YYhQZ7cfXE8AJ/uz2XHKVF+/LE12798QiTuem2XYziH46tFFk1rgIsfgyv/1VpyvPb3wuLHik6r4b5FI/EwaEktrOH9sz+v2hIhFAgi0O7uw5lnoBAWuu49MXs6/VfCiqk7jLtOlPKqamvwIOkUi0Xl0/2iquPy8ZEuswACWu2j4uZh0Os6LKmcHR/M9RcMAeDlzekcs5YKO5Sky8TDm2eQqKD5+GYh8LhmpaiqSV0jHtqN9aLFpLRrETJXYCu5d3pAr6qtgnjnWbm9jYtGhqHVKKQV13bPM92eRE4U/bZNNULwsQ0ms4X92SLwmBrbzYA+/6BQ0O+G2F6PKDwilhHje/xWg07D0rGiBP6LA7k9cxTQucGCJyEgVrgDFCWLHts1K+HtZcJZw1YtNud34j7VHVoCepmh7yn/2Z5Jg9HMmCg//nrNONeIeLalIktk5utKxXly2YvCaeUskiJ8uShJtDvu7OnEkr0JSRJjNNZD/gHXjsXKYWu5/YQYFwgaWiyidQ5E4sjg5fwxDDB69URoMpnYuHEjr732GjU1IkuZn59Pba2Tb3ySAcfchBBRbZ1fTXFNB+XfnVBS08SRXBEU/PrC4QB8fTi/1Re3Az7Zl9syifBGZhBpTYGoxno4ua7T96mqyr9/EDZ180eGiLKj1DUiO+ATDstWgXeYmAU+8nG7+5gTH8LyCUIg6O8b0vjycB45ZfV46LUtDzU9ojJHKPICTP+lUI/3CBAe3/ELxUPOtlWw82WxjijT/81FI8RnsT+35UINCGVfU5PwER42v+fj8QmDGb+G8T/pXiYExHa2snxbpkfSKfuyK8gpF+fNkt6cN/bC1AzZO8R6N86Xn0yNYeaIIMwWlafXpnbc9mFPwse0KlirqlAU9osWWZBJt8CCx4WVIgg7y36G0WyhsFpMEDq9h95mV6dzg4gJzj12P8HPU8/0YSKbt97RlSVWzBaV0tom0krqyAmeQ7PZgnri2zO2OZZfTX2zGT8PPSPDfLreacFhYSl56nuRrbSnYKTtum31n+8pS8dGoNcqnCqpIzmvumdvDhgK1/5XtH/NewDiLwGvEFFiX3BE3PfiLxHVYN3FFtCXnJT2dT3gWH4VyXnVaDUKv7skAR93vWsHVJ4BX/9WVGUFDRetWJ1M6tj+z/dkljtO66U7aDQQa83S94Oye7NF5UieeNZ2iSBe9nbxt9R7trpTSDqlxx4E2dnZLF68mJycHJqamli4cCE+Pj48++yzNDU18eqrrzpinJJBQrC3G6Mj/UjOq2LbyVKu7kQ8pz02nyhGVWFMlB+Lx0SQnFfN1pMl/OP7dF64djy6dtSgd2eU8Y613P3CxFC2nizhs4Zp3Nr4BWFHP0c76spz1W8R/pv/2nKKlIIaDDoNt86MFTd6WzZ7/PVi1nDancI269B7wmu0nZvHT2fGcqqkluS8at76MQuApWPDe37zs5hFb6CpCaImw6g2ZUg6A1z4CPjFwL63UJI/xacoHZY8CR7+zIkP4VBOJeuPF7Fqw0n++ZOJ+NWcFJYgiiKEg5wpXmN7ECxKFrOx7fwNJK18Zs3OLxkbjrer7GNAiPYY68UDdOioLjfXaBTuXZBAYdURMkrq+NOaFP569Tg8DL2oTOkJXkGw/CXRP+kZeK46bm0J5B2wBibXO3YsPcT2cOnvqSfIq2flzH0mZ5dYRk7qcSn1YOKS0WFsTy9ly4libpsVi5uu/fP1RGENXxzMY0KMP/NHhnSr4iqztI7vU4vJq2igvK6JsrpmqhqMLbIOIeZoHqqph/LN/LdiC4GhkQwN9GxxTJkaG9i1fkZpGqx7SNwr3HxExv+7h2Du/TBycefvVVUozwRjB9fk+nLRn6worYFwD/Hz0HNxUhjrkgv54mAeY3sTNPhGiK/EpWLMVbmiIqG+TFSB9WhAMa2fU1l6a/WdpFM+2SfuSwuSQgn2drECeWm6qNRorILgeJFwce/8vBoT6Yu3m46qBiPHC6oZE+UiNXcQE87HvoCsbaK6ROPge2QnpBfX0tBsxstNy/AQb+ceXFXhgLV3fsxV0ne+m/T4Cfqee+5hypQpVFRU4OHRmjm48sor2bRpk10HJxmc9LbsXlVVvk8R5fYXJYoyqdvnxOHtpiOrtI4vrGJhbckpq2fV+pMALBsXwcqFCTy6LInDHlMpa9ZTkJ1GfeauM95jsah8eSiPX793gH1ZFWg1Cr+YO0zcrNLWi2y8Z1Crr/3wi0VgY2wQYlztoNNquH9RIgHWh3O9VmH5BCEMREOlEIf77HZ4awn8+PeOlXYPviN69t182rfhURQhXHfxY6DVYyjYi/LJT0U/pqpyx9xhxAR6UFHXzD82nYQdL4n3JSwWvZDOJGiEEDlpqoHKLOcee4BxPL+a4wXV6LRKh3ZQTuPUZrEcNr/bkzDuei2PLhuFv6eerNI6nt9wwjnZEEURD/ztWd1EjBPLwiPniFq6mrVHCwC4ZFSY84UPbS0wHQlbnidMiPYn1MeNuiYzOzooxz1RWMMfViezPb2Ulzen89P/7OGNbRkUVDWcs22Tycz3qUXc98lh7v7gIKsP5rE3q5xTJXVU1otgXqNAkLcBj5BYcvRxoFoILPiBzanFvL0jq2UcU+O6KIGtzBGtV811oiT++g9Fxtpihi1PixL1jsrcC4/C1/egfPYz/DfeK4Lbc7axVrUEDhP3ol5yxcQoFAX2ZpVzuryPdl2KAv4xQjRvym3C1aan7w8bI9ZlH323SC+uYX92BRqFHidn7E7ZKfjmtyKYD0mEZc93GcyDeDabGiv+n3ZluLjsPmKCCF4bq0V1jQuxVXGOi/Z3zT2o9KQQXh67wrnHHsD0OKDftm0bjz76KAbDmTP3sbGx5OWdG1BJJGcza0QQWo1CZmkdOWXdv4mnFdeSV9mAQadhtlVJ19/TwO1zhDrtB3tyyK9sfZCqaTTy5zXHrb1dvtw+W2w3NTaQP1w5hQMeM2gwmtn79etU1AkRkoySWn7/yWHe2JZJo9HCqAhf/nn9RC4ZHW71w3xP7Hz89a3ZK0URJecAJ9Z2KPIW4GXgoSWJBHkb+MnkCAJK9sB3j8C7VwsLt9I0kfk89gV8cL0Q+mlq00pQnAr7/2v9EH8L3p0oyI64GPWyf2D2iRbtARufgPWP4t5Uzv2LEtEoYDq5AVPhMRFUT729238Hu6HRtmZ4Zdl9p9gcES4aGUqQK7MgPSy3b0uIjxuPLEtCr1XYlVHOJ/u77zLhEILixQNDP5tQyq2o50huFYpC7wQz+0Jzbev/4nkoiNcWjUZhwaiOxfHSi2t47MtkGoxmRoR6E+brTl2TmS8P5fOLd/bzxFfH2J9dTk5ZPW9sy+Cnb+3lhQ1ppBbWoNEozBoRzK8vHM5jl43i7z+ZwDs/v4Av/m8Wb992AS/fOIkLL7uZoUGe3BKUwk3TYpgTH8yQQE/Gx/gxeWgnAX1NEaz5nZgoDhkprELdvMUE8HhrJcqe12H738+cyCo5CWvvhy/vagkmtPXFKF/dBRlbzzxGH8vtbUT5ezBlqKho23Gq90K5dkP20fcIW3Z+bkIIEX4uVCCvyBaZ+aYaUVmx7G89yupOt9q/7soo65meg73RaGHobLHu4rJ7myDeOGeX27fNzo9aDh7+zj3+AKbHdZsWiwWz+VxV8dzcXHx8ej9TKzl/8HHXM3loAHsyy9l6spibZ8R2632brNn5mcODzijXvSgxlM0nijl8uoqXN6fz5yvGYFHhue9OUFDVSKiPGw8uTjqjHH90pB/eV92J+f0tRNYc4ZkPNxCfkMTXh/OxqOBh0PKzWbFcMiq8dXby1GaozrP6YV565uDCx8Dwi0Sv4q6XxexwO+XrST5NvD1yD5zYAIcqW38QkggJi8A3UgTtxcdFFuX4l8IiLnGZsOlRLTD8QhhxcdcfWEgilRf/jdD8DSiH3oesHyH/ILHTfkl8YCiXZX5Ng96Mz7Qb2xWMcQrhYyFvPxQmi4u35BxyyurZk1mOosBVrs6C9LDc/mwSw335xbzhvPR9OmuPFrrWo1irg7BRouy+4IjINvYD1iUXAjBlaCChzrYly90vrjH+MeJadJ6zICmMD/bkkJxXRV5lQ4ueQUZJLX9YfYz6ZjOjInx54vLRuOk0HMip4JsjBezPrmj5akuojxuLRoezcFRYS7VWR2hGXIRh50sEN+dzXWwDXJDY9YDry0VgY1P2XvLXVjEpjUZorngFC3HIY6tFafrEm8W9xhZAKBpIXIqatBzj5ufRVaQIC8hJt8Dk28R+bLoT4eN68nG2y/gYP/ZmlZNa2LkOjlMIs17TZIa+S06X17PTmtFeMTnGdQOpyhPnfEOlKLNf+lyPq0YmDQnAoNNQVN1EZmkdw5xdYt6WuLkiMZT1I8y82yWtiE0mMykFQtdigrP95/MOiP8/rUHoMkm6TY/PlEsuuYS///3vLd8rikJtbS2PP/44S5cutefYJIOYtmX33ZkRbW7jCWwrt7ehKAq/vnAEBp2wh9uYIsoTD+ZU4qbT8MiyJPw8z+1VHxoXT+jYC9FrNSSWbeDLQyKYnzkiiH/dOInFYyJag3mLRZS7gygBas8P84I7xUUo70BrFtNGfbnIwn94Axz9RNx8PALEBWvFf+Cq10Sv0JDpcMUrcMmfhfBPU40QwHv3alFG6RkEs+/tfq+71iAewq5+XcxcN9fBtlXcXfokfpZKyhV/oVDvKmwPhDJD3yG27PyMYUHOF0g7m16U25/N/JEh6LQK5XXN5Fc5QSCvM8LblN33AxqNZjamiGzw0rFOzs5Dm3L78zs7byPEx41JQ0Q2fMMxMdGSU1bPH75MprbJxMhwH+FnbdCi0ShMiQ3kictH8+rNk1k+IRJPgxaNAhfEBfLE5aN4/ZYpXDs1pstgHhBBic3KqgvxVsDqcHKf6CP3CReTyu1lt8ZeAxc/LmweM7fB53eKYF5RRFn+de/A3PsgaATVs/+AOuYa8b4D/4P1jwrlcFsZvh0Cepu/9YnCGtdmR0EojSsaMSFS23snnvOBT/adRlVh+rBAhgT1sL3BXtQUiWC+rlRYGS79W69aQNz1WiZaA9ddGR20OzqLqMlCCK6uBEpSXTKElIIajGaVIG+D8585DlirUJMu7b5DhQToRUC/atUqtm/fzqhRo2hsbOSGG25oKbd/9tlnHTFGySDkgrhA3PViRrQ7M/P7ssqpbTIR6GVgfLT/OT+P8PPgBqs91qtbT7Ha2k//2wUJnc62ek+6jugAD+aynyHeFh5dlsRDS5LOLWvO2iasUAzeHfth+ka09vvsekX4WzdWwa5XRQn90U/B3Cx6Ghc/Azd9Jqzezs4MKop4kLvmbZj/oFDRN1t9Sec90K2+sHMIHAbLXxEevjp3AhGzr2s8lrffX+wsQq0PUDUF8gGqHYprGtlinchyeY9iH8rt2+Km05IYLh66jlrL+lyGLSApONJxT7ET2ZZWSl2TmTDf1kDSaagqSu75bVfXHpdYy+43pRaTXVbHI6uPUt1gIj7UuyWYP5sofw9unzOMd34+jQ/unM4fLh3F5KHdELI7mwSreF36xs69qY2NsO5BEWh7BFjdVzppyRp+oQh+DNZ7Y9xcuOY/cNEjwg3ChkYr2snmPyQmh7O3w6c/E1UcPhGdH6ObDAvxQq9VqGk0uX6Cz+DZej8ulln6jiiqbmzRQLp2iouy83VlIpivKRTnbEcTWN1kRpuye5eiM7S2O7mo7L5t/7xTK+hK00W7j0YH429w3nEHCT0O6KOjozl8+DAPP/ww9957LxMnTuSZZ57h4MGDhIaGdr2DPvLyyy8TGxuLu7s706ZNY8+ePZ1u/8knn5CYmIi7uztjx45l7dq1Dh+jpGvc9VpmWO1CuiOOZ/Oenz8ypMOHoismRhEX7EWzSfQFXjs1htnxXZSSR01CFxTHMD8NL00uZJp1TGegqq3Z+TFXin7Ejph4o3igqsqF7x6G938iVPFNjaJEednzwhN16IyuFUw1Ghi5BK57F+bdDwuegCF9eNDWaGDcCljxNtqEhWxzm8u6hiRqm0y932dfMXgKcTzoN1nS/sR3yYVYLCpjo/1I6I5NlSPpY7l9W8ZG+QO0WFC6jLBR4v+wrkQ8GLqYb61ieItGh9tfiCj/UKde5NqqLFGCrXPvlbf4YGVqXCD+nnoq642s/PgwlfVG4oK9eHL56C7dJgw6DZ6GPjhSRE0RJfJNNZCzs/1tLBbY8pRoW3LzEcG8Xzcm/yIniGz8de/CJX8SGc6OGLlY3Lc8g8QkNfS5f96GXqtpUdE+2S/K7m199DKg74jPDuRhUWHiEH/iXXFfaqgQwXxVrphYuvTvfc7mTo0LRKMIB4oiZ1irdkbcXLHM2uaSieZW/3kn98+fWCOWsbPsMll4vtGrmkmdTsdNN93EX//6V1555RVuv/32MxTvHcVHH33EypUrefzxxzlw4ADjx49n0aJFFBcXt7v9jh07uP766/n5z3/OwYMHueKKK7jiiitITpaCJ/2BeSPFP+yPaaWYO1G8rqo3stfai3hxYliH22k1CvcsiMfDoGVOfDA3WjP2naIoMOZqFEA5/CEcel/cyM1tgtzTe4Rgnd4DbOWHHWHwgik/s75vtwiAguNFRv6KVyB6cs+t4XQG0UM//MKeva8jfCPwWPwEu8JvREUhtaCHHsD2xvZgKMvuz2FvljjvF47q+Lx3GhlbxLIP5fY2bEI7R/OqXFtmq/eA4ASx7uLzL724hrTiWrQaxf5/75zdwof8i1+eK3BmxVB4QKxEnd92dWej12pa2ryaTRaGBHnypyvGOMdvW6OB+EVi/eR37W+z703xN9XoRKtW0PDu798zUOgldIewUXDVv1sn8+womjjSWrHTP/ropdJ9Z1TUG1v0jFySnW+shjW/FxWTXiFw6Qt2Cf583fWMtlrW7ezA1cJpxEwTFTFVucKL3YnUNpk4VVILiAy90zA1QdoGsZ54aefbStqlx1PH//vf/zr9+S233NLrwXTF888/zx133MFtt90GwKuvvsqaNWt46623ePDBB8/Z/sUXX2Tx4sXcd999APzpT39iw4YNvPTSS7z66qsOG6eke4yP9sfXQ/h/Hjpd2aFy7w9pJVgsKiNCvbvs1Roe4s2Hd0xHUeh+qVD8QtEfWFcCu18Tr+ncxUx95AQhTgLdV9xMvFQE83WlMOEGiJ3TLz3WR0f6UVDVyLH8aqbEurBXKXwsJH/m8oCqv1FZ30xmaR1AS3+fyzA1Q9Z2sd6HcnsbCWE+6LUKlfVGcisaiAl0UQ8miLL74hQoPAwJl7hsGGuPigqBWSOC8Pe0Y0DdUCGsykCUSm/6I2j/LKqE2qAvPChWZLn9OSweE86aIwWE+brzlyvG4OfhhGDeRsIiOPSeyNDXl5+ZiTz5nRC0A1HFFTnBsWPxCoblL0F1fveqALpJYrgvX5JPaqGLJ5ehNUNfelJc9+Tk1hmsSy3HaBEOQKMjXeAPvunJ1taSS18QrY52YsawII7mVrEro4wrJkbZbb89xuAJ0VNFi0vmDz2bpOsjR3OrsKiibSjYmY46mdtEJZJ3mKhMkvSYHgf099xzzxnfG41G6uvrMRgMeHp6Oiygb25uZv/+/Tz00EMtr2k0GhYsWMDOne2Xou3cuZOVK1ee8dqiRYtYvXp1h8dpamqiqamp5fvqanGDsVgsWPqZV3FbLBYLqqr26zGejUaBWcODWJtcyJcH80gK98Zdf24Z+qaUIlRULhwZ0u3fT1XpfuZP6yb6yzM2oxQcEQ/2TTVCfT1vv3UbvRAH6u7nu/BPZ37vgr9LV+dEUoQ3G1IKSc6rcu15EzoaBaA8A7WxplWV+Txnf3Y5KirDQ7zxddfZ5W/U6+tEzm4UYz14BaOGJPb5fNZpIDHchyN5VRw+XUGUv5PV3NsSNhaFj6DgCKqL/g9qm0xsPVGMisri0WH2+39UVZQtfxVBfUAsBMRBxmbY8AfURU8LASbA0lCFriwVdBos0VNdcr3qz4T5uPHmrZNx12vRazXOvV76xaCEJkFxCmrahladloIjKFv/CoA64UYYsdCuf7eOrxUK+EbZbrJ2OVZ8qBcqKpmlddQ3Gdt9DnAa3uEo7v7QWIlaktqasZdQVd/E5rQKULSsmByNqqrOrbAqPIqSuw80etSlq8R5aMdz/oLYAF77QeVYfjXltY32nVjtKbGzUawBvTrpVqcd9tDpClRUxsf4des6Z6/YQ0n9BgDVphsi70EtdPez7XFAX1FRcc5raWlp/OpXv2rJhDuC0tJSzGYzYWFnliKGhYWRmtq+EmRhYWG72xcWdtwr+fTTT/Pkk0+e83pJSQmNjS7uq+kEi8VCVZUoX9X0w2xwR0wM0/HNIRN7Mkr4xX8r+fm0CBLDWgO6vKomUvMr0SiQ6K922F5hF8LmiS/VgrY6F31JMrqyFHQVp2gcdgmNtWahfjtA6OqcCDMYMRlNpORVkFdQiF7ruvPG3y0QbV0R1Se2Ywyb4LJx9Cd+TMnDZDQxwl9rt/O+t9cJ7+Q1uJmMNIRMob7EPn7Rsb4aDmSZ2HWygClhrnuAV7ThBJqMUHqK8tNpqG5O7hsENpwop66xmSg/N4K1jRQXN3X9pm7glvEd3hlbUTV6qsb9ErNvDD41lRgK9qCuuZ/q2Y9hCk5Cl/MjXiYjZu8hVDVooWHgXOecSYOLjusWOh3v/COYjn5JVdg8NLUF+G1+EE1zA01RM6gdcinY+d7o7GcKH70o595z4jSJoS6s2AF8fGIx1O6lLm0XjYrjtaEGCp8dLqah2UhcsIFo92bHPo+1g8/ONzGYjDTGzqfO5GX3cx4g2kdHVnkjGw5lMW+Ev933310UzwQCTGaU4hNUZBzG4m2/SoTO2HuqCJPRxBCv7j1v2+M6oaktJCBnD6BQETQVi5PPq/5OTU33WpH6oNbSSnx8PM888ww33XRTh8H1QOGhhx46I6tfXV1NTEwMISEh+Pq6oLyom1gsFhRFISQkZEAF9KGh8CdPX/65+RSltU387Yd8lo4J59YZsXgYtKxNz0Kn1zEtLpARQ5zoixwWDvGtZT8GoP/+9dunq3MiJEQlxK+AivpmKlUPRoc6P5CxocRMhvQNBDTlQqjryp77CxaLSlq5OPfnjo4h1E5/m15dJ8zNKKWHQKfHe9wyvO0kfjoryZ2vUirIqDQRHNyx2KXjCUUJiYeKLEJMhRAT79Sjq6rKjo156PQ6rpg89JxJ6F5TkY1y/F3Q6VGn/x9BCdZS+sueQVn/KOTuJWjPs6jLnketTcOs1aIdMdcp4raSHuJ3Bcrxd9HV5eFmykPZ9wJYGiFiDNqlf8JTZ/8KF2c/U4yNKWdHRhklzTrmuvocjJ2KUnIIv8bT+Lp6LP2EhmYzP2RloNXquGnmcMLCnCxaVp6JUnoY9Aa8Zv4MLz/H/F3mj2ri3d05pJSbWOHSv30oytCpkHeA4JoUGOZ4odLS2iaK6y0YDHpmjx7SLZ0Qu1wnsr9C0ekheirBsaN7t49BjLt7967vdgnoQQjl5efn22t35xAcHIxWq6WoqOiM14uKiggPb9+vNzw8vEfbA7i5ueHmdm7fiEaj6feBsqIoA2KcZzM5NohXbvTjP9uzWJdcyLfJRezPruTXF41g68lSFBQuTgwbcL9Xf6Crc2JUpC870stIKahlbLSTbbLaEjEO0jegFCX3S70BZ5NZVktVgwlPvY5RkX52Pfd7fJ1I3wrGBvAKQQkbY7e/T0K4L+46LTWNJnIrG4kNdmGrRfhYqMhCKU6G4fOceuijuVXkVjTgoddy8Sg7XedMzbD5z8LuMnoKytgVrX83jbsQT1v3AOQfQvn2flQULAooQ6bL62x/xMMf4mbDqc0o6x4Uf1evEFj0NIrBcdlsZz5TJEX6sjOjnJNFta4/B8OtZfbFx4UWT0d6PKrac5HbAcr3JwqpazYT5mNg5vBg5/+NjnwolnFzUQKGOuwwM4eH8N7u0xzOraLJpLZrS+k0hs2HvAMoWdtgouNt3I7l16CgMCLEGz/P7vfP9+k6YTHDyXViPelSFFf/7/dDuvu59viT++qrr874+vLLL3n11Ve56aabmDVrVo8H2l0MBgOTJ09m06ZNLa9ZLBY2bdrEjBkz2n3PjBkzztgeYMOGDR1uL3EdngYdv75wBH+6Ygxhvm4U1zTx+JfHKK9rxttN51rRtkHM6EiR+T3eX5Tui4+f6TBwnnLA6uowNtrPpa0QWMxCMBJg9JV2nWzRazWMsooqHc1zsX2dzaatwP7Wicl5Vfzu48P87bsT7EgvpdFoPuPna5OFVd38kaF9szhry763hHCUu6/wED/776Z3h0VPi/7gphpoqkbVutvNikziAGy9peZm4c6w+BnwasdmdYAyMkxcC1ILa1zrfAEQkmi1syztuM2usRo++Sl8/gswG506PGdjsah8c0Qk7C4ZGej8aqqaQjhlfZZ3sD95TKAHkf7umMwq+7LLHXqsLhk6W0wYFR+HqjyHH+5Qi/+8E6s1T+8WdqnufjDUcTHk+UCPnx6uuOKKM763lVpcdNFFrFq1yl7japeVK1dy6623MmXKFC644AL+/ve/U1dX16J6f8sttxAVFcXTTwtF33vuuYd58+axatUqli1bxocffsi+ffv497//7dBxSnrPhBh//nn9JP67M4s1R8SD7tyEEAw6OWvnCGwqtccLqrFYVNeVPfsPFR7KTTVQlgahSa4ZRz/hQE4lIHx+XUraBmGd4+4rAno7MybKj4M5lRzNq+Ky8U5sqTmb8HFiWXoSmuuFynAfUVWVNUcLeH1bJhaLysmiGraeLEGvVZg0JICZI4KID/Vhh9UiacmYjivHekTufjj8gVife79QJm8PgycseQa+WQmlJzGGjkOnlYre/ZboqcJzu7YILvoDBI9w9YjsyvBQLzQa4XxRXNNEmK8LhTL17hA0AkpOQFEy+LTTBrP9RWGdBiLDmHSZU4foTA7kVJBf2YinXsusOBe05h35WEwuR02G0ESHHkpRFKYPC+LzA3nsyihjTrwL/dC9giD6AhH0Hv0EZv/WYYdSVZXDuZUAjHemq06q1Xs+YRFonegeMgjpcUDvSjXs6667jpKSEh577DEKCwuZMGEC69ata+k5zMnJOaM0YebMmbz//vs8+uijPPzww8THx7N69WrGjJGqpf0ZD4OWX84bzuwRwezPruCqSS60DxnkxAV54WHQ0tBsJrOsjuEh3q4ZiEYjgqrs7cK+7jwO6BuazaRYKyYmDnFhG0Tb7Pz46+0S5J5Nix99bpVrJ5R8woRdTm2RsLCLntyn3TWbLLy69RQbjouWr9nxwYR4u7HjVBlF1Y3szixnd2Zr9icx3Idh9vjfa6xqtahLugzi5nS+vZsPLFuFmvI1db5j8Oj7CCSOQqMVlnHNtcKxYJDhptMyPNiLtOJaUgtrXBvQA4SOEgF98XEYcfGZPzv1PaRvbP3+wDuigmKQBiRfHxbZ+QWjwnDXOzm50lDZGvRNcHzZOcCM4SKg35tVgdFscW2V3PifiID+xFqY/NPuWSf3gvTiWspqm9FrlZbKOYdTVwbZO8T6yKXOOeYgZsClPe+66y6ys7Npampi9+7dTJvW6pm7ZcsW3n777TO2X7FiBSdOnKCpqYnk5GSWLpUnzUBhTJQft86M7ZYwh6R3aDQKoyLExftYfj8puz/P/eiP5lVhtqiE+boR6efCh9q09VCdJ0rhRl3hkEOMCPHGQ6+ltslEVlmdQ47RbSKsWfrCw33aTXldM498cZQNx4tQFPjpzFjuXzSSn82O4/VbJvPiTyZw/QVDGBrUOkFit+qEXf+CuhLhET7j1917j7svjL8ei5edxPgkjsMreFAG8zZGhvsAcKI/+NHb+ugLk898va4Utj0v1sf/BDyDxETgibXOHZ+TOF1ez4GcShQFLh3rHKX1Mzj2BZgaITi+xWbT0SSE+hDgZaCh2cwRa9baZUROhJCRYGoSn4WdUVWVDceLePgL8dw1LtofN52TdANOrgPVIlq/AuOcc8xBTLcy9Gd7uXfG888/3+vBSCQS5zMqwpf92RUcy6/icpeWPdseoI6eV2JDZ3MgR/TPTxwSIASZXIHFLLJO4LDsPIDO2ke/P7uCo3lV9slS95bwcaLFoA999CcKa/jL2hQq6prxctNy36KRTB7aqv+hKArDQrwZFuLNDdOGkF/ZQHldM2Oi7FDG2lDZmjWc/5Dos5ZIBhCJEb58c6SA1MLu2TQ5lFCr2nZZmgimdG7ivrT1r6I1LDgBpt4B3qGw/R/WLP0S0A2utpVvrK2PU2MDCfdzp7jYiZMtxgZI/kysT7jBac8EGo3CtLhA1iUXsiuj/IxruNNRFBh3HWz6owjox18vWkLsQG2TiVc2p7MtTVjRjo324+6LneTyoqqtk2CJy5xzzEFOtwL6gwcPdmtnLnv4lEgkvcZWXnU8vxpVVV33fxw8ErQGaKgQfdv+Ma4Zh4s5aA3oJ7my3P7kdyI77+EPo69w6KHGRvmxP7uCI7lVLJ/gwvaaM4QZjT0un92cWsw/vk/DZFYZEujJw8uSiPLvPKiO9Pcgsottuk3aBjHu4ITWyTGJZACRaM3QZ5TU0WyyuFY7xyccPAOhvlyU3keMg5SvRPmz1gAXPgxaHSReBoc+EJUxJ9Y4RGvEVdQ2mfg+VbQNuWSyP3WNmDzxjYK4+U499MzhQaxLLmRbWgk/mxXnerX7Pa9DTQGc/NYu59ix/CpWrT9JSU0TGo3CTdOGcPWkaOe1vRUcFs95ek/x+0n6TLcC+s2bNzt6HBKJxEUkhPmg0woxooKqRvsFGD1FZxClZYVHxdd5GNAXVTeSX9mIRoHxMS4QHwLhMnCwTXbewZleWx99cp6L++gDYkX5eWO1EMcL674fbk2jsSWYnxYXyO8uGencB0BVhdSvxXrSpc47rkRiR0J93PD31FNZb+RUSS1JEU7q5W0PRRHXgMxtUHQMPAJg5yviZxfc2VoirDPAxBvhx7/DwXdh5LJBk6XfeLyIRqOFIYGejIv2c677gNkkxPAAxl/ndDvb8dH+RPq7k1/ZyPrjha6dbNZoYdy1QojxyMeQtLzXn4fZovLR3tN8tDcHiwrhfu7ct2gkCWE+dh50F9h0EUZc5LAKwPONAddDL5FI7ItBpyE+VJQ6u76P3tbHfH720dvs6hLDfe1nYdZT0tZDdb54gB213OGHGxbijYdBS32zmYzSWocfr0MUpfX862HZ/bH8akxmlUh/dx5emuT8bE5RMlRkg84dRixw7rElEjuhKAojrYFFan/oow9r0wa25RnRyx05EcZcfeZ2I5eJ0vu60taJtQFOW6u6y8ZHOL9y79T3QpvAI0C0MjgZjUbhyonRAHxxMA+j2XWC4ACMXCJETKvzIXNrr3fzzLcpfLBHBPMXJobyj59MdH4w31gNGVvEeqKcgLYXvQro9+3bx/33389PfvITrrrqqjO+JBLJwKPFj97lAb1NGM/+fuADgYNWH9hJQ/1dMwCz6Uxleyf0YWs1CmOs59+RXBf70fdyQik5T4x7XLS/ayoMUr4Ry+EXgcHL+ceXSOyETRivX/TR26p0sreLSTODl9CnODs7qjPAxJvE+sH3RM/9AGdPVjlF1U14u+mYPzLUuQe3WODQe2J97DUuq3i4KDEUf089ZbXN/HCyxCVjaEHv0Vpqf/hDUZXVQ4prGtmVUY5Ggd9dksDKhQmuaSVI3wjmZggcBiGOtSE8n+hxQP/hhx8yc+ZMUlJS+OKLLzAajRw7dozvv/8ePz8XlYhKJJI+YeujP5bv4oDK9gBVlSt66c8jTGYLh2wBvav659O+E316HgEw6nKnHdZWdu/ygL5F6f6IeKjsJrZxj7WHuF1PaaqBDGtbnBQXkgxwEsPFvehEfwjog0eCpk2l1Kx72vekB2G75R0G9WVw/CvnjM+B2KzqLhkdhrveyUHf6d1QkSX6q51QJdYRBp2mpdT+8wN5WCxObDloj9FXCv2GklTRg95DjlrvUyNCfZw/SWOjtgT2vy3WE5edt+LHjqDHAf1TTz3FCy+8wNdff43BYODFF18kNTWVa6+9liFDhjhijBKJxMEkRfiiKFBQ1UhFXbPrBuLu22rLdLZd0CDnRFENDc1mfNx1DHeF2rvZ2KpsP+EGp6qkj41urRAxu/KhKShelK031UBlVrfeUtNobLHcc0lAn75RZAQDYnvU9y+R9Efiw7zRKFBW20xprYsz3TZdF4C4uRB/ScfbavUw6Waxfug9MDY6fnwOIrusjiO5VWgUWDbOBVZ1aevFMnGZKDN3IUvGhONh0JJTXs/erHKXjgXPQBi5WKwf/qDj7fIOCBG903vOmJi2TTzbJtCdjsUMm56ExiphQ5jkvKTB+UCPA/pTp06xbJnIAhgMBurq6lAUhXvvvZd///vfdh+gRCJxPN5uOoYGiVJd1/fRn59+9AdyKgGYEOOisu2TbbLzTr7RxgV54e2mo8FoJr3YhX30Wl1rUNzNPnrhDgHRAR4EeDm5NFRVW8vtky6T2Q7JgMddr225F/WLLP30/4OxK2DufV3/fyUsAZ8IUV2WMnCz9Lbs/PRhQYT62McirduYjSJDDzD8Quceux283HQsGRMOwGcHcl08GmDsteI8zNkF5Zln/qwqD757BL65Vwg0rr0P3l8Bu/+NWpHFUWtr2FhXBfR73xDPdQYvWPDkoBGP7C/0OKAPCAigpkZcZKOiokhOFlm0yspK6uvr7Ts6iUTiNEbb7OsK+kkfc/Z2yNgqbIPOAw628Z93Om17FifcaDef2+6i0SiMiRLn35HcSqce+xxsE0qp30BNUZeb2x6S7OIl31NKTkBZuijD7Cx7KJEMIBIjRFY2paAfCOOFj4GZd4nqsa7Q6s7K0jc4dmwOoLrRyOYTol/8MldY1RUcgeY6MbEckuT847fD8glR6LQKKQU1rtcZ8o+B2Dli/chHYtlcD7tfg09uhawfQdHA0FmiuqGuFA69h+nDW7ix4GlmGbczKtAFE7/ZO+DQ+2J93gPg50LXgEFKtwN6W+A+d+5cNmzYAMCKFSu45557uOOOO7j++uu5+OKLHTNKiUTicEZF2ProXXzDipwgbkhVubDhMXjnSvjoJtj6HJxcDzWFrh2fA6huNLZkpicO8Xf+AAoOCfVcg5fI9LoAW0BsE5hzGSMuFmX3pWnwyU8hdW2nAkQtWQ9XBPSp1ux83JzuBRwSyQDA5kffLzL0PSV+EfhGQkMl7Pl3j7Q4+gPrjxXRbLIQF+zVMsnvVLJ/FMuhM51uVdcRgV4GLrL2nPeLLP3468UybQMc+QQ+vEEEy2YjRE+Ba96CxU/BTZ/Dwj/C0FnUGy0MMWVzY/NnuH95p1CadxY1RbD5KbE+5ioYNs95xz6P6PZ/y7hx45g2bRpjx45lxYoVADzyyCOsXLmSoqIirr76at58802HDVQikTgWmzBeVmkd9c0m1w3EOxSWPifEcGxev5WnRfCy+S/w/nXiJjaIOJRTiarCkCBPgr3dnD+Ak+vEcvhFTs/O2xgX7Q/A8YJqTK60CPIfAle/ISyrjPWw9VlY9xDUlZ2zaU2jkcxS0T/fYYa+vhy+/zPk7LbvOJvrIX2TWJfWP5JBxEirMN6pklrX24X1FK0OpvxcrCd/Dt/8FmqLXTqknvB9qqhKumx8pPOt6lQVsneK9aGznHvsLrhqcjSKAnsyy8kpc3E1ctgoIeBqMcHOl0SLh180LHoKlv6t9blJZxDB8+KneG/483zpcQUa72BhB3j8S+eM1WwUffNNNULRftqvnHPc85BuB/Rbt25l9OjRPP300yQlJXHrrbeyfft2HnzwQb766itWrVpFQICLlJklEkmfCfZ2I8zXDYvaD0odo6fAnJWw4m249Wtxoxr/EyGkApD8Wa9sW/orB6398y5Rt2+uF60NAAmLnX98K0MDPfFx19FotLi+SsQ/Bi7/J0z7pRC7ytkpyhnTN55x3tn656P8PQjsqH/+xxdEJmX9o1B03H5jzNgiJhz8ooU3tkQySIj0c8fbTYfRrLZMmA0o4hfAhY8IlfaCw/Dpz1qvsf2YyvpmTpc3oCgwfVig8wdQniF0XLQGiJrs/ON3QpS/BzOGBQH9JEs/8WbRS6/3hOm/gmv+A7Gz2tV5UFWVPcWw1e1CGifdLl5M/gxMThBA3vNvKDomyv8XPCH75h1ItwP6OXPm8NZbb1FQUMA///lPsrKymDdvHgkJCTz77LMUFg6+MliJ5HxjVH/xo2+Lu6+4UU3/FVz+kriB1RQIX+BBgKqqHGjpn/d3/gAytoCpUQSGLlRJ12gUJsT4A/Dk18d4e3smdU0urBTRaGDC9XDV60LpuqkGNv0JNj4uJkGga5GhnF2Q+YNYNzfDdw93qy+/W9jK7aX1j2SQoShK//Kj7w0Jl4hKn9Akce3Y8Bj88Fy/7qu3Xc9ig7zwcdc7fwDZO8QyeorLKsU64+rJ0QBsOVniegeGmAtEEH/DhyLZ0UmgnF/VSFltMzqtQvSUZcJesaGitTLPUWRugyMfi/X5D4KvCxwTziN63KDi5eXFbbfdxtatWzl58iQrVqzg5ZdfZsiQIVx+ubQgkEgGMuOsZcPb08tQ+2MGXO8u7IOg1dpmgJNZWkd5XTN6reKankXbTX3kEpcHhrfNimNMlB9Gs8pnB/L4xTv7WXOkoNdWdmaL2vcHr8A4WP4KTLkNNFqRadv6LKhq5/3zxkb48e9ifdRyCBouHqK+e6hlQqDXlGeIrIdG69KqConEUSRZhfFSXV0t1hf8osQk9IQbxbU15Rv4/E4oTXf1yNqlVeDTRXoc2dvFcuhM1xy/CxLCfBgT5YfForL6YJ6rhyPuTe5da7cctQrNJob74GZwg3HXih8c+chxGg+1xbDlGbE+7lqIne2Y40ha6JPixIgRI3j44Yd59NFH8fHxYc2aNfYal0QicQGzRgTjodeSV9lAcl4/fZCyqXmf2iz6swY4646J6qapsYG46bTOPXhVnigJVRQh5uRiQnzceOrKMTy6LIlIf3eqGoy8uvUUd71/gD2Z5T2eZHp7Rxa3/Wcvm1L6mBXX6mDyT+HSF0Cjg4wtNBz8uPP++YPviEoS71BRur/oaaHcXHZK9NR39SBVeRoKk0V272xsVnVDZwlvYolkkGHroz9eUE1x9cD1dEerg2l3wrJV4BkElTmw+lf9Mqg/Zr3nu8Sxo64MilPE+pD+GdADXDNZqLOvP1ZETePAeP6w+c+PjfIXLyQuEyXwVbmQ9YNjDpr8OTTXir75C37hmGNIzqDXAf0PP/zAT3/6U8LDw7nvvvu46qqr2L59uz3HJpFInIyHQcu8kSEAfJtc4OLRdEDkRPFg1FQDp/e4ejR9oqHZzJZUYRG0ZKwLytHSvhPLqCngHeL847eDoihMGxbEyzdM4hfzhuHjriO3ooE/fXOcZ75NxdLNbL3ForI5VYhRvbEtk6p6Ozx8RYwXrR+AeccrDDFmtd8/X5kDhz8U6zN/AwZP8AkTWhBag8hE7X61/WNUZIny3I9ugi9/DW9fCu9cBd+shO3/gONftVanSDE8ySAlIcwbjUahrLaZn/93H3d/cJB3d2WTXlzTP6vHuiJqslAfjxgv2m9slmP9hMr6ZnLKReWQSwL6HGu5fWgSeAU5//jdZNKQAGKDvWgwmnnki+T+1Z7YDmqbSrJxttYwvQeMvlKsH/rA/npEpmY4sVasT7pFTGpJHE6PAvr8/HyeeuopEhISmD9/Punp6fzjH/8gPz+f119/nenTpztqnBKJxEksHhMOwI5TZVTWO0E0padoNMJaDAZ82f2WE8U0GM1E+rszvqM+bEdhscBJa0DfD8u2dVoNl46L5PVbpnDVpCi0GoUdp8pIzu+erV1KYTVVDSKIr20y8Z8dmfYZ2JirYdh8GpqauLX+P0wOP+s2qqpCCM9igiHTWz2DQagTz39QrB/5CFLbVLVVF4gSxU9uE2X9igJe1kmW+jLI2y+EjLatEpNZ3mEQPdU+v5NE0s/wNOj43cIExkT5olFEa9JHe09z70eHue3tvbyyJZ2igZa59/CH6f8n1jM2CweMfoKtIm9okCe+ruyf72fq9mejKAq/mDsMLzctmaV1PPDZEZ5ff4Lyuo6flU6V1PLy5nRueH0X//7hlBNHC7kVDVTWG9FrFRLCfFp/MOYqMblckgr5B+170Kxt0Fgl7l9DZFzoLLo9bbJkyRI2btxIcHAwt9xyCz/72c8YOXKkI8cmkUhcwPAQb+LDvEkrqmVTSnGLEEy/Iv4SIbaSvQOaasHN29Uj6jGqqrI2WZTbLx0b4XyLoIJDUFMovOf7cX+bl5uO22bFUd1gYmNKEdvTy1os7jpjV4Z4WB4W4kVGSR2bUopZkBTW9+yTosC8+8k/vB9/SwFLS94Eyz9aPZPTN0HeAfGwNOuec3UJRlwsMvj73xbBud4DCo5AytdiEgDE32PqzyFwGDTXQUW2yNzbvmqLYNLN/canWSJxBHMTQpibEEJVg5H92eXszijnQE4FZbXNfHu0kAPZlfzz+ol4GJzcqtQXQhNFGXJJqshiTrzJ1SMCaJkobVcPxNEYGyF3n1jvp/3zbRkT5cdrN03hnV1ZrD9exOYTJezKKOeGaUO4dFwEOq2GJpOZbSdL+Ta5kJNFrW1Ta44Wct3UIfh5OGfSxJadT4rwxaBrc7/wCIDEpXBsNRz+AKIm2e+gKV+JZeIyofMicQrdfhrQ6/V8+umn5Obm8uyzz8pgXiIZxCwZI8q/v00u7HaJs1MJGgEBsaJ0MdNBPWAOJrWwhqzSOgw6DRclhjp/AC3e8xf2S0Xhs5kdL8owd5wq7fKcVFWV3RnCN/7aKTEsGh0GwL+2nLKLx32NxcAr+lswKnrCqo6IfnkQmfOdL4n1SbeAb2T7O5j8UxHYW8yw8Uk49oUI5qMmw5WvwqK/iGAexIRL2Cjx8DXj/2DpX+Ha/8KIBX3+PSSSgYCfh56LEsN4aGkS790+nccvG0WIjxtF1Y28td1OlTfOxFbufPwrx4mS9ZCjuS4M6PP2i3u5T0Trda+f4+ep566L4vnbivHEh3rTYDTz5o+Z3P3hQV7deopb39rDi5vSOFlUg1ajMCc+mOgADywWlR/TSp02ziOd/V3HXQeKRrQu2kvToTIH8g+J/cqWMKfS7YD+q6++Yvny5Wi1crZFIhnszIkPxsOgpai6kcNWhdR+haK0BjTpG1w7ll6y9qjQKJgbH+J8iyBjfetESMIS5x67l4yL9sfLTUtlvZHjXShfny5voKCqEb1WYdKQAG6dGYuvh46c8nq+PJTf57Ecz68mXxPJ98E3o9MosP8/kLsf9r4plOz9Y8TDUkcoCsx7AMLGiO/DRsOlz4uv0KQ+j08iGawYdBqmxAby2wXxAKxLLmRPZv8pXe8Wwy8Sdqy1Ra294y6kqt7Y0j8/2hUBfVt1+wFmwZkQ5sPfVoznNxeNwNdDx+nyBtYcKaCuyUyYrxu3zBjK27dN5f7FiS3tjFtOFDtlbKqqktyZtapvJAybL9YPf2Cfg6Z8LZZDZvQbXZ7zBVmvJ5FIzsFdr23JGq+zloX3O2wBff5BqC1x7Vh6SFW9kR/TxSz90rHhzh9Axlbhh+xi7/meoNdqmD5MZOltn11H7MoU2flx0f54GLT4uOu5bWYcAB/syaG4pm+9t7YyRmXkIpGFUFXhT3/8S7HB7Hs79QUGQOcmVPOveQuWvyyy8xKJpFuMi/Zn+QRRAfPP79PsI3rpLHQGGLlMrB/7wrVjAY5Zy+2HBHk6rRS8BYtlwPTPd4RGo3DJ6HBevWkyV0+K4sKRITxx+Sj+ffMUVkyJwd9T3AvmxIegUUR1XmGV4/UfcsrrqWow4qbTnNk/35YJN4jlqe+FjktfMDXDiW/FetJlfduXpMfIgF4ikbTLEuts8q6MMsr66uXtCHwjIGKcCKbSN7p6ND1iY0oRJrPKiFBv4ju60ToQJa2NGN4AyojMGhEMCMHGzsrud1nL7acPa7V0uzgplNGRvjSZLLz+Q0afxpHc1n9+1t2iBaSpBlSLmGjqbnCuMwh/+gH0N5BI+gu3zIhlSKAnlfVGXtmSPrDU70ctF//3ufuERaULOZLnwnL7khRR1WTwEvfzAYyPu56fzopj5SUjmTw0EI3mzOt6oJehRf/FGVn6tv3zem0H4V5wPERPEfeuox/37YCZW1sFW2Om9W1fkh4jA3qJRNIuQ4O8SIrwwaKKALRfMmKhWA6gsnuLReXbNmJ4zkZTW9TqPZ/geu/5njAhxh9Pg5aKumZSCtsvuy+rbSKtqBaAC+Ja7Y8UReH/5o9Ao1HYlVHeYZlubZOJvVnlHWb8aptMZLT1n9e5wcI/Cl9fd99WFWuJROJQDDoNKy9JQGN1wNjspFJmu+AbIcqSAY6vdulQjroyoLdl52OmgdYF6vpO5sJEUYa++USxwyegWnQRunLQGW/N0qeugYbK3h/wDDE8GV46G/mJSySSDrGJ4313rKh/iuMNmw8aHZSdEl8DgIOnKyiqbsTLTcuc+GCnH98tZ4tYiZoM3i4Q4+sDeq2Gaday++0dlN3vzRKBekKYzzn+8EOCPLnCWqb77x9O0Wg0A5BX2cDqg3k89PlRbnx9F3/8+jh3f3iQtDbqxDaO5VWhqpzpP+8XBde9A9f+r197KEskg43hId7ceMEQAF7dmtHndhqnMuoKsTyxTrRAuYCqBiM5Zdb++Uhf5w8g60exHKDl9j1lxrBgDDoN+ZWNpBfXOuw4FovaIog3rquAPmoSBCeAqQn2vgHmXrSvVGQJtxZFAyOX9vz9kj4jA3qJRNIhM0cE4e2mo6SmiQM5Fa4ezrm4+7b6nA6Qsvu1R0V2/uLEMNz1ThYZVS24ZW8W6wNEDO9sZg23BfTtl93b7Oraltu35foLhlgVspv44zfH+eU7+/nlO/t588dMkvOqsKjgoddSXtfMA58d4YeTZ+ozHO1IZMgjQHxJJBKncvXkaEaG+9DQbObvG9P65+Rze0RPFTomzbWQ5poqs2PW69mQQM+WXm+nUZUnAkFFAzEXOPfYLsLDoG25NzmyoiSrrI7aJhMeei0jQrqw9VWUVvvElK/hs58L69WekPqNWA6dKcXwXIQM6CUSSYe46bRcnCSyuN/2V3G8+EvEMm1Dv7EA6ojimkb2WTPINsVbu6OqcORj2P4P2PkK7H4N9rwO+/4DO19GW18svM/7sfd8Z0wcEoCHQQTcJ87KoDc0m1tcGWwCemfjrtdy51xhjXQ0t4q8ygY0GoXxMX7cMXcYr98yhbd/NpXJQwMwmlWe++4E7+7KbgkSkl1ZniqRSM5Bq1FYuTABN52Go7lVfH2k704WTkGjEb30IMruXaABYJugHOPKcvuI8WJy/jzhwpHimeqHk6V2sVFtD9vfdVSkL7qO+ufbMmwezH8IPPyhIhu+uRc2/RHqumGxZ25GSVsv1pMu7/2gJX1C5+oBSCSS/s2i0eF8eSiffVnllNQ0EeLj5uohncmQGWDwhroSKDwMkRNdPaIO+S65EIsqsrsxgZ6OOUj+Qdj5crs/skn0qMMuQhkA3vPtYdBpmB4XyOYTJWxPLyUpovVBcH92BSazSqS/O9EBHh3uY/qwIG6ePpT8qgYmDw1g0pAAvNzOvB0+duko3t6RxRcH8/ho72lOl9dzx9xhZ/bPSySSfkGkvwe3z4nj5c2n+O+OLNx0Gi5OCutYDKy/kLBYlDmXnYLCo04XhkvOF1okLpmgzBnY6va9ZUKMP34eeqoajBzOrWTy0ParyUBMUhstFnx7aG3bqf98R4xcLDLs+94Sji3pmyB7J0z+KYy5GrTth4yG3J1CDM8nXFSdSFxCP7/SSSQSVxMT6MmYKD8sKqw/3g+z9DqDmF0Gl5Utdgej2cL640JccOkYB4rhFR0Ty6ARMP56GLsCRl8Joy5HHbmUxmFLxA16ADPTqnb/Y3rpGeW1u612ddPiglC6UI6/dmoMv12QwJz4kHOCeRBWRD+bHcc9F8ejtYpu3fvRoXP75yUSSb9g0ehwpsYGYjSrvLz5FHf+bx9rjxbQbOrHlVvuvq0WrE62sKtuNJLVMkHp5Ax5fbkQZwURRJ5H6LSaFv2cLSc6ttwtrW3i1+8f4I7/7qOouvvaEBZLq/98l/3zZ+PuC7N/C1e+JixtjfWw6xVRhn/8K2g6t+/fPdPqmpN4qRTDcyHyk5dIJF1is7Bbf6wIc3/sT4y3qt1nbHG5BVBH7Mooo7LeiL+nvsP+brtQkiqWCYth+i9h5l3iBj3ndzD3Puom3gFezhfjsyeThgTgoddSVtvMyWJRdm8yW1oE8Toqt+8NC0aF8dSVY/Hz0FNpVb7vUjVYIpE4HUVReHBJInfMHUaAl4HS2mb+teUUd76zj2+O5PffwH70lWKZ+YMIdJ2ELeiLCfRwbv+8xQJbngGLGUIShajoecZ8a9n9zlNlNDSbz/l5Q7OZP359nJKaJuqbzfxvZ1a3951RWkd9sxkPg5ZhXfXPd0RIAlz+Esx7ANz9hNbBtlXwzpWw6U+Qu1/8HSuy0JelgqKVYnguRgb0EomkS6YPC8LPQ095XTP7s/uhOF74eOF92lwHH90En/4cDrwDlTmuHhkAjUYzn+3PBUQWqVs9bb2lOEUsQ0Y67hguxqDTMDVOCND9mCZ6/I7lV1PXZMbPQ09iuI9djzcq0pdV145naJBok5gW58AJGYlE0msMOg2Xj4/kjVumcOfcYQR6GSirbea1rRnc8b99bE7th9Z2wfEQNgYsJiFK5iSO5Ylye6e3Dx39GE7vBq0B5j/g3GP3ExLCvInwc6fJZGFXRtkZPzNbVJ5dl0pmaR3ebjoURfTbnyg813WlPY7mVQIwJtIPrabzSrVO0WggcSlc9y5M/xUExIK5WQgQr1kJH1yHsvVZse3QmdLhxcXIgF4ikXSJQadhboLI6u441Q2RFGej0cAlfxZKuRotlKWLvsSPboZPboP9/4W6sq734wAajWb++M1xTpXU4WHQOk4MD6C2BOrLhGpwcLzjjtMPmDXCdj6Woapqy0PR1NhANH15iOmAMF93XrhuAq/cOInJQ6WavUTSnzHoNFw2PpLXb5nCL+YNI8jbQHldM89vOOlQu7BeM/oKsUz52mnieC2CeJFODOj/v737Do+iWh84/t2SSjopm4SEhJpQpVcFBAVULCheuEhRighcRMTCRdpVpPzEgnrBcgUU7CJiAUFK6AECCS0klIRQEhJI78nu/P4YsrIQIECSTXk/z5OH3Z0zM2c2h+y+c855z8VjsPdT9XHXf4FHg8o7dxWi0WjMvfRbr8p2rygKn247TcSZNGx0GmY/2pz7Q9Ryn28/fcu16zPzi/glUk0K2TqgnH6v9i7QejAMWq4OxW/2GNg5Q3ayeUSgEjqgfM4l7li1CehTU1MZOnQoLi4uuLm5MWrUKLKzb/5HuWfPnmg0GoufcePGVVKNhahZujRQA6jw06lVc9i9VxN46P9g2M/qMLGATmpwn3paTfKy9l9gLK7UKhUUG3nr92McPpeBg42O/zzWHE+nCkwqmHKld94jWM1kX4O1q++OvY2WlKwCYi9mEx538+XqyoONTkuAh+Mt5+cLIaoGW72WR1r58emw9nS9suTlV7cxfLnSBPcAvb2a3DUtrsJPl5VfRPxldf58pSXEK8hSM6ebjNCwF9TyILBnU3V5t8iz6aTlFAKwNuoCfxxORKOBqQ82panBmWc618dOr+V4UhY7T964Y0JRFD7efJLL2YX4uznQt3k5dx5oNOAdAvdOgWdWQ++ZENiV/KA+4N+ufM8lblu1CeiHDh3K0aNH2bhxI7/99hvbtm1j7Nixt9xvzJgxJCYmmn8WLlxYCbUVouZp5ueCi4Oe7IJi89y7KsneVR0m9tBCGLbmyhwwF8g8D2d2Vlo1CotNzP09mqizajA/57HmhBgqOPFQSoz6r1doxZ6nCrDT6+gQpAbvK/ecISWrAFu9ltYBbtatmBCiyrHVaxnZLQitVsOBhHQOXVnessrQ2/6d4f521wC/A0cvZKIoUM/dAffKSPCpKLDtHchKBGdfuHeqGiDWYn5uDjQ1OGNSYNuJFHadusT/dqg3c57tFmRO/urpZMfAtvUAWL4r7oa5IDYcu8iuU5fRaTVM7dsUextdxVVebwuNeqP0nUtOu/HqqEBhVdXiNxAdHc369ev5/PPP6dSpE927d+fDDz/k22+/5cKFm6836ujoiMFgMP+4uNSetS6FKE86rYZOwWoPx+7T1hm+ftvsXdTgPuRKT0AlzU8sLDbx9h/RHExIx95Gy6xHm1ksr1ZhasH8+at1v/KFJ/JsOgBtA90q9kuMEKLa8nV1oN+VXssVu87ccvhypfNrq/5bCQH9kcpefz76VzVprVYHfWaB3R0ma6thSnrpf426wKINsSgK9G9p4PF7LBMFDmzrj3sdWy5mFvBr1PVxz7m0XD7bdhqAYZ3r08hb3t/aplqsQ797927c3Nxo3769+bU+ffqg1WoJDw/niSeeuOG+q1atYuXKlRgMBgYMGMCMGTNwdLzx+s8FBQUUFBSYn2dmqklDTCYTJlMVzZCKWj9FUap0HUXlqog20SnYnQ3Hkth96hKjr/R2VAtNH0IT9TWc24eSfg5c/CrsVIXFJuavO87+hDTs9DpmPBxKqMG54v9vKiY0JfPZvELUDLTXqGl/J9oEuGKr11JQrGYJ7hjkXmOurTLVtHYh7l5NbRNPt/fnr+iLxFzMZPepS+W6IsZd870HDcCFgyjFRWrwW0GizqajoNDCr+yfTXfcJlJPo9m1GAClw2jwbFrq51Nt1K1hXT7bdpqkK8vStQ90Z0z3YBRFsbjhZKvTMKxTAB9sPsl3+xK4P8QLVwd1bfoio4l3/owhv9hIK39XHmvtW2n/b2vq34mqpKzvbbUI6JOSkvD29rZ4Ta/X4+HhQVLSjdfF/uc//0n9+vXx8/Pj0KFDvPbaa8TExLB69eob7jNv3jzmzJlz3espKSnk55d9HcjKZjKZyMjIQFEUtLIOpKBi2oSfnQk9Ji6m5xJ+PIGGntVlnrYeZ4/m2F48SN6+b8htOaxCzlJsVPhoxzkiz2djo9Mwsasf3jYFJCdXfGZlbdZ53HPTUbS2pBY5QinnrIl/J5p52bIvIQutBoLqGCvlva5pamK7EHenJreJXg2c+O3oZT4PiyWoTgO0VWXot+KKu8YWbV4G6bHhGD0aVchpcgqMnLyYgaKAj21Rmf9m3lGbKM7HbfO/0RXkUujTliyfnqV+NtVmoV52RJ7PJtDdjpFtPbh8qfS16Zt7gJ+TjoS0Aj7fEs2w9upok+8OJnP8Qjp1bHU8c487l26wf0WoyX8nqoqsrLKtbmDVgP71119nwYIFNy0THR19x8e/eo59y5Yt8fX1pXfv3pw6dYqGDRuWus+0adOYMmWK+XlmZiYBAQF4eXlV6eH6JpMJjUaDl5eX/KcSQMW1iS6NMth+8hIx6Qpdmnnfeoeqos3TaP46gnPiTpx6TgKdTbmfYvWB8xxJzsfR3pYZD4dW7nzujEg0ehvwaYa3ofQRCDXx78TDbfQcTIymTYAbDQMrbuRFTVYT24W4OzW5TYy4z4OdZyJIzi3maCr0Dq06n2OawI5wZgd18+PBu2uFnGNvXCo6vR4/Vwea1C/738zbbhOKAtsWosm7CC4GdP3/g4OD251XvIb614PObIpO5uFWvrg73jyfwQu97Zjxy1G2x2fzdGcnUnML+etkJnobPVP6hhASVLkjTmry34mqwt7evkzlrBrQv/zyy4wcOfKmZRo0aIDBYLjuDmJxcTGpqakYDGXP4tipUycATp48ecOA3s7ODju767NQa7XaKt9YNRpNtainqDwV0Sa6NvJkx8nL7DmdyrPdgqtPxu+gblDHE3IuoUnYCQ3vL/dT7D+ThgYNz3YLpk39Sl6r/NKVhHjeoWhu8vuuaX8nOjf0ZP7AVgS4O9aYa7KGmtYuxN2rqW3C2cGWQe0DWLYznm/2naNHUx9s9VXkGuu1hTM70CQehLbPlPvhFUXhl6gLaNDQJtD9tn+3t9UmIr+B2PVqwrT730BTp5I/E6sJf/c6DO8aXKaybQI96BRcl71xqXy6PY7z6XkA9GtuoFtjr4qs5g3V1L8TVUVZ31erBvReXl54ed26AXbp0oX09HQiIiJo105dGmHz5s2YTCZzkF4WkZGRAPj6+t5RfYUQ0L6+BzY6DYkZ+SSk5lK/bh1rV6lsdHpo+hAc+BKOrS33gD6v0Eh0kjo0qmOwFb641KIM99eqtMROQoga4eFWvqyNukBKVgHrjiTy2DVJyKymJDFe0mEoLlSziZejHScvceR8JrZ6LU+2rcBrPh0G4UvVx10ngn/bijtXLfNstyD2n0nj0Dk1saG/mwOj721g5VoJa6sWt1NCQ0Pp168fY8aMYe/evezcuZOJEycyePBg/PzU4ULnz58nJCSEvXv3AnDq1CnefPNNIiIiiI+PZ+3atQwfPpz77ruPVq1aWfNyhKjWHGx13BPgDsCuU9Uk232JkEfU3oILByH9bLke+uiFDEwmBR8XO3xcyjZEqtwYi+DSCfVxLclwL4QQd8pOr2NIx0AAvt9/ltzCYivX6Ar3IHBwh+ICSD5WrofOLzLyxZVl0Z5sWw/vivqcSj4OW+aqj5s/AS2erJjz1FL13B15qIU6OrlSlqgT1UK1COhBzVYfEhJC7969eeihh+jevTuffvqpeXtRURExMTHk5uYCYGtry19//cWDDz5ISEgIL7/8Mk8++SS//lo5y1YJUZN1aXhl+brqFtA7+0BgZ/VxOS9hV7J02j3WWAc9NQ6MhWDnDK71Kv/8QghRzfQJ9cHfzYHMvGLWHLz5EsiVRqP5uzf7QvkuX/fTgXNcyi7Ey9mOgRXVO591Eda/rt6QCOwMXf9VMeep5YZ2rk+vEG+mPNBElqgTQDXJcg/g4eHB119/fcPtQUFBFks8BAQEEBYWVhlVE6LW6RjsgVYDcZdySMrIx+BayT3SdyN0AJzZBTF/QIfR5TakMerK8LdKTYRXIqVk/fkQ9QuhEEKIm9JpNQzvUp95646z5uB5HmppwO0WSckqhX87OLlJXY++/XPlcsiLmfn8FHEOgFHdgyumR7cwB9a/Bnlp4NEAes+q0KX3ajMnOz1THmhi7WqIKqTa9NALIaoOVwcbml+Zt7z79CUr1+Y2BXQGJ28oyIK48rnpl55bSPylHABa+buVyzFvS7K6/rwMtxdCiLLr0rAujb2dyCsysvrAeWtXR1Uyjz75GBTllcshv9gRR5FRoYW/K10bVkAmdJMR/pqjjhZz9IB+88HWsfzPI4QolQT0Qog70rW6DrvXatW59ADRa29cLjcV0hPKdMiS5DTBnnVwdSz/5fBuqaSH3rtZ5Z9bCCGqKY1Gw+Nt1OHnh89nWLk2V7j4grOvGiQnHrrrwx06l86uU5fRauD5+xqU/8o0igK7FsPZcNDbQd956vQ2IUSlkYBeCHFHOjdQA/rjSVmk5hRauTa3KeRhNTle4iG1R+Fq+RmwZyl8/Q/4fkSZvlCVzJ+3ynD7wlxIO6M+9gqp/PMLIUQ11sBLXanlbGouJpNyi9KVxK+N+u+Fg3d1GKNJ4ZNtpwHo39KXIM8KWJXm6Go4ukad7tVrOnjL55AQlU0CeiHEHfF0sqOxjxOKAuGnq1kvfR1PdV16+Ds5XmEuRCyHrwdD1DdqkjnFpL52E4qiEGVOiGeF5dMuxar1rOMFdSpgKKUQQtRgvq4O6HUaCopNpGQXWLs6Kn91iea7TYy37kgiCZdzcbLTM7RTYDlU7Bpn98Guj9THHZ+HBj3K/xxCiFuSgF4Icce6NvQEYHd1C+gBQh9V/z2xASK/gW8Gw/5lUJQLdRtBj9fUhD7nI+Di0RseJikzn+SsAnRaDc39rBDQp1yZPy+9IkIIcdt0Wg3+bg4AJKTmWrk2V5T00F+KhfzMOzpERl4Rq/ao08aGdamPs305TwdLT4C/Zqs3lJv0g9aDy/f4Qogyk4BeCHHHSpavizqXQVZ+kZVrc5v826vzFAuyIHypOtTetR70mQUDP4OQh6BxX7Xsga9ueJiS3vlQX2frrAVbEtB7hVb+uYUQogYI9FATuCVcriIBfZ264F5fnZ+eGHlHh1i55wzZBcUEe9ahX3ND2XZKjlZvcN/qJkJ+JqyfBoXZ4NMC7n1ZVlgRwookoBdC3DF/NwcCPRwxmRT2x6dZuzq3R6uFVoPUx07eao/8019Cw/vVbQBthqpz7RN2Q0psqYeJPHtlubp6bpVQ6VKYM9xLD70QQtyJgCsB/dm0KhLQw9/Z7s/f/rD7mKQsNhxNAmDsfQ3QassQbJtMsOEN9Qb3j8/B2b03KGdUe+YzzqmfnQ++WW7Lvwoh7owE9EKIu1LSS18th903HwiDlsE/Vqk98teumetaDxr1Vh8fvL6X3mRSOHQuHbBSQry8NMhKVB/LknVCCHFH6pf00FeVIfcA/lcC+ttMjFdYbOKDTbGYFOjV1IsW/mWcCnY+AnKuLEObkwJ/vALb371+6bw9H6tl9fZqRntHj9uqnxCi/OmtXYGawmg0UlRkvSHHJpOJoqIi8vPz0WrlPk1tZWNjg05XucO+uzasy3f7zhIel0rU2XTrBLZ3SqMBjwY3L3PPUDj5F8Rtg9TTFuXjLueQlV+Mg42Oxt5O5V+/7BSI+UOdn1jaMkApMeq/bgFgVwHnF0KIWqCkh/5cah6KopT/0m53wvce9TMqLR5yLpc56enX4Wc4m5qHm6MNY+67xefb1U78qf7b9CGwcYAjP8GxX+Dcfug1DbybY3d6A5qjP6vlev0bPBvd1iUJISqGBPR3SVEUkpKSSE9Pt3o9TCYTWVlZVeODSFiNm5sbBkMZ58uVg2DPOtzb2JPtJy4x949oFj7ZqmKWxrEWj2AIvg9Oh8HBldB7pnlTZEI6AC3ruaLXlfONtKI8+GOq+mXu+O/w6GJwvub3mnxl/XmZPy+EEHfM19UenVZDXpGRlOwCvJ3trV0lsHeBuo3VxHgXDkLjPrfcJfZiFj8fPA/AhF6Nyp4IrzBHvWkN0OxR8A6F+t0gbD5knoe1k6Dxg9SJ/gN0WugwSjLaC1GFSEB/l0qCeW9vbxwdHa0WTCuKQnFxMXq9XgL6WkpRFHJzc0lOTgbAx6eUHt0KoNFomNynCak5hRy9kMnsX4/yzqDWeDrZVcr5K0Wb4WpAf2oLtHtW7REHoq4Mt29Vr5yz2ysKbHtHDeYBsi/Cb1NgwAfg5PV3OclwL4QQd02v0+Lv5kBCai5nU3OrRkAP6vJ1l2LV5etuEdAXFpv44K8TmBTo0cSLzg1uYxnT02FQXKB+tpXkY6nXDp5aBrs+hNj1aGLXo1GM0KA3tBl2FxclhChvEtDfBaPRaA7m69a17vrPEtALAAcHdemd5ORkPD09K+28tnot0x8O5bWfDnE2NY9Za4+y8MlW1LGrIX9iPBupvRVndqq99L2mUVhs4ugFNRNwmwD38j3fsTXqMH+NFu5/A/Z+pvaS/P4SDFiszllUFMlwL4QQ5STAw5GE1FwSUnNpV7+KzAv3awNR35QpMd63+xJISM3FzdGGsT1uY6g9/D3cvkk/y2z1dk7qcPug7rDzfYpsPdH1eE2+ZwpRxchk67tQMmfe0dHRyjUR4m8l7bGyczo429swe0Bz3BxtSLicy9w/oikymiq1DhWq7ZUeiRMbIDORmKQsCotNuDnaEODhUH7nSY6GXR+pjzuNU5PyPfKemk04/Sz8PgXy0iErSf1Xq4O6Mo9RCCHuxt9L1+XdomQlMrRU/8ZnJUJm4g2LnbiYxU8R5wAY37MRLrez5nxmIlyIVAP5Rg+UXib4XpQh35N572zQ16DRd0LUEBLQlwO5UymqEmu2R28Xe+Y82hwHGx2Hz2WweNMJTCbFavUpV96hUK8DKCaIXEXkWXWZvnsC3MrvPc/PgI2zwFSszttv9bT6uouvGtQ71oXUOHVu/bl96ra6jWTJICGEuEuBVXHpOltH9bMH1GH3pSgsNvH+laH29zXxNK88w8Vj6k3gWzm5Uf3Xr03pyVdLaDSy1rwQVZQE9EKIctXAy4nXHwpBq9WwNSaFr/acsXaVyk/b4eq/Mes4FR8HlOP68yYTbJ6rzpd3rQc9XrP88uRaTw3qHdzh0gnY8Z76uqw/L4QQdy3wqqXrFKUK3Yj2b6f+e/hHKLz+ZsN3Vw+1v6+h+uLJv2DNC/DTaHW9+BtRFIi9ari9EKJakoBeCFHu2ga6M+l+dRj4jxHn2HjsopVrVE58W4HfPRiNRQSd/QW4jfXniwvUoY2mG0xDOPglnA1XhzM+8J/Sl6Fzrw8PL1KzHytXjiMBvRBC3DVfN3u0GsgrNHI5p9Da1flb6KPqjdzU07D1bYvPkBMXs/jxylD7F3o2xNXBBpKPw9YFaoHifNg6/8afOxePqgG/3h6C7q3oKxFCVBAJ6GuhkSNH8vjjj1f6eZcvX46bm9styxmNRubPn09ISAgODg54eHjQqVMnPv/884qvpCg3vUN9GNIxEIDlu+LILSy2co3KSdvh5BUZ6VK4kzaOyXg5l2E+YXEh/DIBvhkMyx+CNRNg+yJ1jd+LRyF+J0QsV8t2nwJ1G974WHUbwkOLwM5ZTZrn26pcLksIIWozG50WPzc1H0pCahUadl/HE/rOBZ0NxG2HiC8AMJkUPtx8EpMC9zb2pGtDT8i5BH/+G4yF6hB6G0dIOqyuKV+akmR4wfepw/uFENVSDUlBLWqSOXPm8Mknn/DRRx/Rvn17MjMz2b9/P2lpadaumrhN/+gQQFhsMhfS8/n54HmGdqpv7SrdPb+2nKzTjjrpOxmW/w0UP3LrOex7P1WHyYO6vvzFI+rPtUIfgaZlGPbo1QSe+gJyU9Wh+EIIIe5aoIcj59LyOJuaS9vAcl695G74NIf7XoEtb8OBr8A9mA35zYm7lEMdOx3P39dQHQX253TIvQzuQepNgJOb1JvHez+FwE7gFvj3MYsL1aVYQYbbC1HNSQ99OVIUhfwio1V+7ma+V8+ePZk0aRKvvvoqHh4eGAwGZs+ebVFGo9GwZMkS+vfvj4ODAw0aNODHH380b9+6dSsajYb09HTza5GRkWg0GuLj49m6dSvPPvssGRkZaDQaNBrNdecosXbtWsaPH8+gQYMIDg6mdevWjBo1iqlTp5rLmEwm5s2bR3BwMA4ODrRu3dqiPgB//PEHTZo0wcHBgV69erF8+XKLOs6ePZt77rnHYp/333+foKAgi9c+//xzQkNDsbe3JyQkhP/+97/mbfHx8Wg0GlavXk2vXr1wdHSkdevW7N692+IYO3fupGfPnjg6OuLu7k7fvn3NNyjKci3VlU6rYXiXIADWHDxPem4VGsZ4pzQavtIPJFvrhK/xAhxYcfPyZ/fB4R/Ux33fhqdXQO+ZcM8/IaCjugQdgHcz6Ppi2evh5C3rzwshRDkKMGe6r0I99CWa9IXWQwAwbpnH5u3bABjSMRBXBz2ELVCXMrVzVj9rbOtA6ACo117tsb926H3CLijIgjpeam++EKLakh76clRQbGLQ0t23LlgBvn++M/q7SD66YsUKpkyZQnh4OLt372bkyJF069aNBx74ewmTGTNmMH/+fD744AO++uorBg8ezOHDhwkNvfUa2F27duX9999n5syZxMTEAODkVMocYcBgMLB582bGjx+Pl5dXqWXmzZvHypUrWbp0KY0bN2bbtm0888wzeHl50aNHD86ePcvAgQOZMGECY8eOZf/+/bz88su3/b6sWrWKmTNn8tFHH9GmTRsOHjzImDFjqFOnDiNGjDCXmz59Ou+88w6NGzdm+vTpDBkyhJMnT6LX64mMjKR3794899xzfPDBB+j1erZs2YLRaCzTtVR3XRvWpbG3EyeSs/lu31me73GT4eTVQFJGPrEZen5yeJo5tt9B5Nfq3MPSguv8DNg6T33c/HEI6qY+dg9Sl6O7upxNHdDJn2QhhLCWqxPjVUkdx0JaPKnHwng6/xO+rTedh1v6QuQqtTdeq1NzsLj6q+U1GjXB6g8j1eldh76De9SbAsRuUP9t/CBopX9PiOpMvj0KAFq1asWsWbMAaNy4MR999BGbNm2yCOgHDRrE6NGjAXjzzTfZuHEjH374oUWP9Y3Y2tri6uqKRqPBYDDctOy7777LU089hcFgoHnz5nTt2pXHHnuM/v37A1BQUMDbb7/NX3/9RZcuXQBo0KABO3bs4JNPPqFHjx4sWbKEhg0bsmjRIgCaNm3K4cOHWbBgwW29L7NmzWLRokUMHDgQgODgYI4dO8Ynn3xiEdBPnTqVhx9+GFCnDDRv3pyTJ08SEhLCwoULad++vcX71Lx58zJfS3Wn0WgY3jWIGWuOsO5IEo/d44/B1d7a1bpjK3bHA6Bp0AOd82X1S9TWt2Hg55ZD7xUFtr2jDn90C4ROL9z4oPauFVtpIYQQt3RtpvsqtyyxVsuF9lM5HXkYH1MS07RfoY9TYO9n6vZuL4J/W8t9nLyhy0S1B3///yCws/qZc3aPur3Jg5V7DUKIcicBfTmy02v5YVwXq5zbVqcx9/jeiVatLBNr+fr6kpycbPFaScB59fPIyMg7PueNNGvWjCNHjhAREcHOnTvZtm0bAwYMYOTIkXz++eecPHmS3Nxci5sNAIWFhbRpow4bi46OplOnTjet/63k5ORw6tQpRo0axZgxY8yvFxcX4+pqGYBd/f75+voCkJycTEhICJGRkQwaNKjUc5TlWmqCewLcuCfAjciz6XwdfoYpDza1dpXuyJHzGew4cQmtBp7rHgxOL8KFg5B2BvZ/AZ3H/V04Zh3EbQOtHu6fATbV9yaGEELUBn5uDmg1kFtoJDWnkLpOZUh6Wsk+35PMaYcxzDJ+QN3c07DpP+qG5k9As8dK36lpf4gLg4Q96tD7hveDyaiukuIeVGl1F0JUDAnoy5FGo8HeRmeVc9/tmqk2NjYWzzUaDaYbLXNSCu2V4VpX16OoqOiO66PVaunQoQMdOnRg8uTJrFy5kmHDhjF9+nSys7MB+P333/H397fYz86u7B++Wq32uvft6jqXnOezzz677uaATmf5e776/Su5o1/y/jk4ONywDuV1LdXBiK71ifwuna2xKTzRth7BnnWsXaXbYjIpfLrtNAAPNjf8Xf97p6pZhQ99C0HdwdACMs7DrsXq9g6j1CR2QgghqjRbvRaDqz0X0vM5m5ZX5QL6iDNp7ItPRWvjhWO/ubD9DTUw92+r9sLfiEajJtX7YaQ6zz71lPp6k76VUm8hRMWSSTOizPbs2XPd85L58yVz3RMTE83br+29t7W1veNRBM2aNQPUXvNmzZphZ2dHQkICjRo1svgJCAgAIDQ0lL179960/l5eXiQlJVkE9VfX2cfHBz8/P06fPn3deYKDg8tc91atWrFp06YbXtetrqWmaOTtTPfGnigKrNgVb+3q3LYNxy6aMwo/c3W2/qBuaoZgRVHnyxfmwJa5ajZ739bQarD1Ki2EEOK21K+r3qytavPoi40m/rdDvak8oJUvPiFd4YE3ocVA6DPn1jlY6nhC10nqY2OROnqs4f0VXGshRGWQHnpRZj/88APt27ene/furFq1ir179/K///0PwByAzp49m7lz5xIbG2uev14iKCiI7OxsNm3aROvWrXF0dMTR8fp1T5966im6detG165dMRgMxMXFMW3aNJo0aUJISAh6vZ6pU6fy0ksvYTKZ6N69OxkZGezcuRMXFxdGjBjBuHHjWLRoEa+88gqjR48mIiKC5cuXW5ynZ8+epKSksHDhQp566inWr1/PunXrcHFxMZeZM2cOkyZNwtXVlX79+lFQUGBeQm/KlCllet+mTZtGy5YtGT9+POPGjcPW1pYtW7YwaNAgPD09b3ktNckzneuz6+QlIs6kceR8Bi38q8fc8eyCYlbuOQNcySjsaDmiha4T4XwEZJyDn8ZA5nmwdYJe0yXZkBBCVCMB7g7sBs5WckC/5XgyWQXF9GrqhbO9zXXb/ziSxNnUPFwc9AzueGX5uaBufydbLYvGD6hD7+N3QP0u4OBWPpUXQliVfNMUZTZnzhy+/fZbWrVqxZdffsk333xj7jm3sbHhm2++4fjx47Rq1YoFCxbw1ltvWezftWtXxo0bxz/+8Q+8vLxYuHBhqefp27cvv/76KwMGDKBJkyaMGDGCkJAQNmzYgF6v3oN68803mTFjBvPmzSM0NJR+/frx+++/m3vOAwMD+emnn1izZg2tW7dm6dKlvP322xbnCQ0N5b///S8ff/wxrVu3Zu/evRZL4wGMHj2azz//nGXLltGyZUt69OjB8uXLb6uHvkmTJmzYsIGoqCg6duxIly5d+OWXX8p8LTWJv5sDDzZXkyKu2BV/11NFKsu3exPIyCuinruDmlH4WnbO6nBGUIN5gHtfAmefyqukEEKIu2aNpevOp+fx7sZYPtt2mhFf7GXxphOcTsk2b8/IK+LrcPWm8rDO9XGyu8P+OI0Gek6DzuNvb5lUIUSVplGqyzdqK8nMzMTV1ZWMjAyLnluA/Px84uLiCA4Oxt7eugmvFEWhuLgYvV5fIVlZNRoNP//8M48//ni5H7uybN26lV69epGWloabm5u1q1NhStpl/fr1yczMxNvb25zjoCq4nF3A2K8iKCw2Mf3hUDo3qGvtKt3UubRcJnx9EJNJYfajzWlX3/3Ghbe9A9G/qr0g979ReZUsI5PJRHJycpVrE8K6pF2Ia9XmNnE6JZsXv43EyU7P12M6VUqm+/VHEvl4yym0Wg0m099fy5v7ufBIKz+izqWz/kgSwZ51eP8f96DVVn72/drcJkTppE1UvJvFoVeTIfdCiEpV18mOR1v78WPEOb7afYZG3k54VrHEQ1f7fHscJpNChyCPmwfzAN1fgoa9wNC6cionhBCiXPm7q5nuswuKSc8twr2O7a13ukuHzmUA8I/2AbQOcOW3Q4nsOnmJoxcyOXoh01xuzL0NrBLMCyGqNgnohRCV7sl29Vh/JImE1FyeXbaPYM86dAhyp32QB019nKvMF5b98alEnElDp9Uw6t4yTIHQ6sC/XcVXTAghRIWw0+vwcbEnMSOfhNTcCg/oFUXh8Hk1oG9Vz5XmfurPpewC1h9J4s+jSaTnFtG9sSct61WPvDNCiMolAb0ok5owM6Nnz5414jpqAic7Pa/3D+HL3Wc4kZxF3KUc4i7l8P3+czjZ6WlX351+LQxWTZpXZDTx2XY1o/Cjrf3wd7vx8oNCCCFqjgAPRxIz8jmblkvrALcKPde5tDzSc4uw0Wlo4uNsft3TyY5nOtfn6fYBnEjOorG3802OIoSozSSgF0JYResANxYFuJGRW8SBBHVt3QMJaWQXFBMWm8LeuFS+HNURexudVer3S+QFLqTn4+Zowz861KwlBIUQQtxY/bqO7I1L5UwlJMYr6Z0P9XXBVn/9PGRbvZbmftIzL4S4sWqTwWDu3Ll07doVR0fHMic0UxSFmTNn4uvri4ODA3369OHEiRMVW1EhxG1xdbShV4g3r/YLYdXozswb2BJvZzvyioyEx6VWen2MJoVlO+NYsSseUJfaq3OnGYWFEEJUOwHuaqb7c2kVH9CXzJ9vWU2WcRVCVD3VJqAvLCxk0KBBvPDCC2XeZ+HChSxevJilS5cSHh5OnTp16Nu3L/n5+RVYUyHEndJpNbTwd6VXiDcAYTEplXr+9NxC3lhzhNUH1KXnHrvHjwdCZek5IYSoTcxL11XwWvSKonDkSg+9zI8XQtypatPtNGfOHACWL19epvKKovD+++/zxhtv8NhjjwHw5Zdf4uPjw5o1axg8eHBFVVUIcZd6NPHiu31niUhIIzO/CBd7mwo/57ELmcxff5y0nEIcbHRM6t2Y7o09K/y8QgghqpZ67g5oNJCZV0x6biFujhWTGO9sah4ZeUXY6rUW8+eFEOJ2VJuA/nbFxcWRlJREnz59zK+5urrSqVMndu/efcOAvqCggIKCAvPzzEx1uRCTyYTJZLIoazKZUBTF/GNtJXWoCnUR1lPSHkva57Xttjrwd7Mn2NOR05dy2BGbQr8Whgo7l6IorI1KZPmueIyKQqC7I6/3b0o9d8dq+d7dTHVuE6LiSLsQ16rtbcJWp8Hb2Y6kzHzOXM7Bxb5ivi5HnU1DQSHE4IxOQ5V+v2t7mxDXkzZR8cr63tbYgD4pKQkAHx/L4bI+Pj7mbaWZN2+eeTTA1VJSUq4bql9UVITJZKK4uJji4uJyqPWdUxQFo9EIgEZTNZb8EtZRXFyMyWQiNTWVnJwcFEVBq602s2vM2hjsiE3M4M9DZ2nrXTH1zykwsmJfEnsT1Bt3neq78GxHX2yLsklOzq6Qc1qTyWQiIyOj2rYJUTGkXYhrSZsAT3s4d7mYI3FJ+NgU3HqHq+QXmbDTa275fWzPiUSKi4oJctGQnJx8N9WtcNImxLWkTVS8rKysMpWzakD/+uuvs2DBgpuWiY6OJiQkpJJqBNOmTWPKlCnm55mZmQQEBODl5YWLi4tF2fz8fLKystDr9ej1VePeiI1NxQ9NFlWbXq9Hq9Xi4eGBjY0NXl5e1fIP7cOOrqw+ksaptEI0Dq54OduV27FzC4tZG5XImoPnyS0yYmdrw+juwTzU0lCjb4iZTCY0Gk21bROiYki7ENeSNgFN/XM5kpxPpskGb2/vMu2TlV/EFzvj2XQ8mafa1mN4l/o3LGsyKZxOi0dvo6drSD28vV1uWLYqkDYhriVtouLZ29uXqZxVo9CXX36ZkSNH3rRMgwYN7ujYBoM6RPfixYv4+vqaX7948SL33HPPDfezs7PDzu76wEGr1V7XWLVaLRqNxvxjTYqimOtws7rcqp6zZs1i9uzZ5Vm1MtNoNPz88888/vjjNy0XFhbGnDlziIyMJD8/H39/f7p27cpnn32GrW3FzHOrTkraY0n7LK3tVgfeLg608HfhyPlMdpy8zJPt6t31MQuKjfx+KJEfI86Rla+Oqgn2dGJ8z4aE+lbtL1PlpTq3CVFxpF2Ia9X2NlHfsw4aNCSk5pXpPdh18hJLwk6RnluEBg1roy7wRJt6uDqW3tFyJjWH7IJiHGx0NDW4VIv3uba3CXE9aRMVq6zvq1UDei8vL7y8vCrk2MHBwRgMBjZt2mQO4DMzMwkPD7+tTPk1TWJiovnxd999x8yZM4mJiTG/5uTkdFvHKywsrNQg+tixY/Tr149//etfLF68GAcHB06cOMFPP/1knnIgao4eTbw4cj6TsNiUuwroi4wmNhy9yHf7z5KWUwiAn5s9QzvVp3sjT7TamtsrL4QQ4vaVdem6tJxCloadYtepy4CaUE/dL48/jiQypGNgqfuVrD/fzNcFvU6CISHEnas2f0ESEhKIjIwkISEBo9FIZGQkkZGRZGf/Pc81JCSEn3/+GVDvGE2ePJm33nqLtWvXcvjwYYYPH46fn98te4DvmKJAUZ51fsqYCM9gMJh/XF1d0Wg05uc5OTkMHToUHx8fnJyc6NChA3/99ZfF/kFBQbz55psMHz4cFxcXxo4dC8Bnn31GQEAAjo6OPPHEE7z77ru4ublZ7PvLL7/Qtm1b7O3tadCgAXPmzDHnHggKCgLgiSeeQKPRmJ9fa8OGDRgMBhYuXEiLFi1o2LAh/fr147PPPsPBwcFcbseOHdx77704ODgQEBDApEmTyMnJMW9PTk5mwIABODg4EBwczKpVqwgKCuL9998HID4+Ho1GQ2RkpHmf9PR0NBoNW7duNb925MgR+vfvj5OTEz4+PgwbNoxLly6Zt/fs2ZNJkybx6quv4uHhgcFguG4ERHp6Os8//zw+Pj7Y29vTokULfvvttzJfS03W9UqwHXcph4TLd7Z8UHZBMZO/jWRp2CnScgrxdrZjUu/G/HdoO+5r4iXBvBBCiOuULF2XnltERl7RddsVRWHz8YuMX3WAXacuo9XA0+3r8cHgNgy+EsT/cTiRwuLSk1odPpcOQMt6bhVSfyFE7VE1Jn6XwcyZM1mxYoX5eZs2bQDYsmULPXv2BCAmJoaMjAxzmVdffZWcnBzGjh1Leno63bt3Z/369WWej3DbivPhi34Vc+xbeXYdaO5u/nx2djYPPfQQc+fOxc7Oji+//JIBAwYQExNDYODfd5jfeecdZs6cyaxZswDYuXMn48aNY8GCBTz66KP89ddfzJgxw+LY27dvZ/jw4SxevJh7772XU6dOmW8GzJo1i3379uHt7c2yZcvo168fOp2u1DoaDAYSExPZtm0b9913X6llTp06Rb9+/Xjrrbf44osvSElJYeLEiUycOJFly5YBMHLkSC5cuMCWLVuwsbFh0qRJt52QJj09nfvvv5/Ro0fz3nvvkZeXx2uvvcbTTz/N5s2bzeVWrFjBlClTCA8PZ/fu3YwcOZJu3brxwAMPYDKZ6N+/P1lZWaxcuZKGDRty7Ngx8/WX5VpqMhd7G9oFurMvPpWwEykMq3vj+Yg38sP+sySk5uLqYMPgjgE82MyArb7a3MsUQghhBfY2Onxc7LiYWcCMNUew1WsxmRQUwGhSyC8ykpihJktu4FWHF3s3poGXOsqxW8O6LHey5VJ2IVtjknmwueVKLSaTwpHzajLWlv6y/rwQ4u5oFFnj7KYyMzNxdXUlIyOj1KR4cXFxBAcHqzcJivKsFtArz66jWGODXq8v83z+5cuXM3nyZNLT029YpkWLFowbN46JEycCak96mzZtzCMhAAYPHkx2drZFr/IzzzzDb7/9Zj52nz596N27N9OmTTOXWblyJa+++ioXLlwAyjaH3mg0Mnr0aJYvX47BYKBz58707t3bPGIAYPTo0eh0Oj755BPzfjt27KBHjx7k5OSQkJBA06ZN2bt3Lx06dADg+PHjhIaG8t577zF58mTi4+MJDg7m4MGD5ikb6enpuLu7m28ivfXWW2zfvp0///zTfJ5z584REBBATEwMTZo0oWfPnhiNRrZv324u07FjR+6//37mz5/Phg0b6N+/P9HR0TRp0uS6673VtZR2c6qkXdavX5/MzEy8vb2r9dymsNgU3vkzBh8Xez4b3u628lUkZ+YzbmUERUaFGY80o2OwRwXWtOozmUwkJydX+zYhype0C3EtaROq//vzONtiL91wu41Ow5COgQxsWw/dNaO9Vh84x7Kd8QR6OPLRP9tYfHadTsnmxW8jcbDR8c3YztftWxVJmxDXkjZR8W4Wh16t2vTQVwt6e3huvXXOrbODu5xDnp2dzezZs/n9999JTEykuLiYvLw8EhISLMq1b9/e4nlMTAxPPPGExWsdO3a0CPCjoqLYuXMnc+fONb9mNBrJz88nNzcXR0fHMtVRp9OxbNky3nrrLTZv3kx4eDhvv/02CxYsYO/evfj6+hIVFcWhQ4dYtWqVeb+SdTLj4uKIjY1Fr9fTrl078/aQkJDrpgjcSlRUFFu2bCk178CpU6fMAXqrVq0stvn6+ppHA0RGRlKvXr1Sg/mSc9zsWkJDQ2+rztVRp2AP7PRaLmbmE3sxm6YG5zLvu3LPGYqMCi38XekQ5F6BtRRCCFHTTOjViHsbe2FSFHQlCWc1mP8N8HDE06n0FVgebG7g273qCLEDCWm0q//3DWXz/Hk/l2oRzAshqjYJ6MuTRgM2DrcuVxHKYaDF1KlT2bhxI++88w6NGjXCwcGBp556isLCQotyderUue1jZ2dnM2fOHAYOHHjdtjuZAuHv78+wYcMYNmwYb775Jk2aNGHp0qXMmTOH7Oxsnn/+eSZNmnTdfoGBgcTGxt7y+CV3Gq8ewFJUZDmHLjs7mwEDBpS69OLVKytcu5SgRqPBZFLn1F097780t7qW2sDeRkfnBnUJi00hLDa5zAH9yeRstsSkADCqe5DVV6IQQghRvTja6uncoO4d7etkp+fB5j78EnmBNQcvWAT0h86pAb0MtxdClAcJ6IXZzp07GTlypLm3PTs7m/j4+Fvu17RpU/bt22fx2rXP27ZtS0xMDI0aNbrhcWxsbO4oU727uzu+vr7mRHFt27bl2LFjNzxXSEgIxcXFREREmIfcx8TEWEw9KFl9ITEx0Zyv4eoEeSXn+emnnwgKCkKvv7P/Sq1ateLcuXPExsaW2kt/q2upLXo09SIsNoXtJy4xunuDWyayUxSFZTvj1H2beNHIu+y9+kIIIUR5GNDaj1+jLhB5Np34SzkEeda5Mn9eDehb1ZOAXghx92TCgzBr3Lgxq1evJjIykqioKP75z3+ae5Jv5l//+hd//PEH7777LidOnOCTTz5h3bp1Fj2iM2fO5Msvv2TOnDkcPXqU6Ohovv32W9544w1zmaCgIDZt2kRSUhJpaWmlnuuTTz7hhRdeYMOGDZw6dYqjR4/y2muvcfToUQYMGADAa6+9xq5du5g4cSKRkZGcOHGCX375xZwHoGnTpvTr14/nn3+e8PBwIiIiGD16tEVvuYODA507d2b+/PlER0cTFhZmUVeACRMmkJqaypAhQ9i3bx+nTp3izz//5Nlnny3zjYkePXpw33338eSTT7Jx40bi4uJYt24d69evL9O11BZtAtxwtteTnlvEofMZtywfcSaNQ+cy0Os0DOty+4n0hBBCiLvl42JP54ZqD/+ayPMAnL6UTW6hEQdbHQ29bm+pYCGEKI0E9MLs3Xffxd3dna5duzJgwAD69u1L27Ztb7lft27dWLp0Ke+++y6tW7dm/fr1vPTSSxZD6fv27ctvv/3Ghg0b6NChA507d+a9996jfv2/g61FixaxceNGAgICzL3i1+rYsSPZ2dmMGzeO5s2b06NHD/bs2cOaNWvo0aMHoPZ6h4WFERsby7333kubNm2YOXMmfn5+5uMsW7YMPz8/evTowcCBAxk7dize3t4W5/riiy8oLi6mXbt25iUQr+bn58fOnTsxGo08+OCDtGzZksmTJ+Pm5nZbyUF++uknOnTowJAhQ2jWrBmvvvqq+YZAWa6lNtDrtHRr5AlA2JVh9DdiMiks2xkPwIBWfvi4VNCqFkIIIcQtPNHGH4CtMSmk5hSah9u38HOVZVOFEOVCstzfwm1lubciRVEoLi6+rSz3FWnMmDEcP37cIrt7VRcUFMTkyZOZPHmytatyV2palvsSR85nMG31YRxsdawc1emGS89tOJrEh5tP4mSn59Ph7XC2v7vlHGsSyUgrSiPtQlxL2kT5euWHKI4nZfF0hwBOp2SzPz6NUd2DefxKsF8dSJsQ15I2UfHKmuVe3n1RLt555x2ioqI4efIkH374IStWrGDEiBHWrpaoQZr5uuDpZEteoZHNxy9S2r3I/CIjK8PVVRn+0SFAgnkhhBBWVxK4/3EokaMl68/L/HkhRDmRgF6Ui7179/LAAw/QsmVLli5dyuLFixk9erS1qyVqEK1WQ48marLCj7ecYsLXB/gl8jxZ+X+vPrDm4HnScgrxcbHjoZa+NzqUEEIIUWm6NKiLj4sd2QXF5BUZqWOnI7ju7a8YJIQQpZEs96JcfP/999auwl0rS0Z/YV2DOwaSU2hky/Fkzqbm8fn2OFbsiqdbI0+6N/Jk9QE16dDwLkE3HJIvhBBCVCatVsOA1n58vl1dfUXmzwshypN84xVCVBv2Njom9GrEl6M68kLPhgR71qHIqLA1JoW3fo8mr8hIY28n7m3sae2qCiGEEGYPNjPgaKsDZLi9EKJ8SQ99OZC8gqIqqQ3t0dFWz0MtfenfwsDJ5GzWH0li24kUCo0Ko+4NrhKJIYUQQogSDrY6xvVsyLbYFHqFeN96ByGEKCMJ6O+CjY2acCs3N9diDXMhrCk3Nxf4u33WZBqNhsY+zjT2cWb0vQ3IKSzG08nO2tUSQgghrtOrqTe9mkowL4QoXxLQ3wWdToebmxvJyckAODo6Wq1nsKotWycqn6Io5ObmkpycjJubGzqdztpVqlQOtjocbGvXNQshhBBCiNpNAvq7ZDAYAMxBvbUoioLJZEKr1UpAX8u5ublhMBhqxdB7IYQQQgghajMJ6O+SRqPB19cXb29vioqKbr1DBTGZTFy+fJm6deui1Uquw9rKxsbG3DMvAb0QQgghhBA1mwT05USn01l1iLPJZMLGxgZ7e3sJ6IUQQgghhBCiFpDITwghhBBCCCGEqIYkoBdCCCGEEEIIIaohCeiFEEIIIYQQQohqSObQ30JJYrHMzEwr1+TmTCYTWVlZModemEmbENeSNiFKI+1CXEvahLiWtAlxLWkTFa8k/rxVomsJ6G8hKysLgICAACvXRAghhBBCCCFEbZKVlYWrq+sNt2sUWdvqpkwmExcuXMDZ2blKr++emZlJQEAAZ8+excXFxdrVEVWAtAlxLWkTojTSLsS1pE2Ia0mbENeSNlHxFEUhKysLPz+/m46CkB76W9BqtdSrV8/a1SgzFxcX+U8lLEibENeSNiFKI+1CXEvahLiWtAlxLWkTFetmPfMlZMKDEEIIIYQQQghRDUlAL4QQQgghhBBCVEMS0NcQdnZ2zJo1Czs7O2tXRVQR0ibEtaRNiNJIuxDXkjYhriVtQlxL2kTVIUnxhBBCCCGEEEKIakh66IUQQgghhBBCiGpIAnohhBBCCCGEEKIakoBeCCGEEEIIIYSohiSgF0IIIYQQQgghqiEJ6GuIjz/+mKCgIOzt7enUqRN79+61dpVEJZk3bx4dOnTA2dkZb29vHn/8cWJiYizK5OfnM2HCBOrWrYuTkxNPPvkkFy9etFKNRWWaP38+Go2GyZMnm1+T9lA7nT9/nmeeeYa6devi4OBAy5Yt2b9/v3m7oijMnDkTX19fHBwc6NOnDydOnLBijUVFMhqNzJgxg+DgYBwcHGjYsCFvvvkmV+dKljZRs23bto0BAwbg5+eHRqNhzZo1FtvL8vtPTU1l6NChuLi44ObmxqhRo8jOzq7EqxDl6WZtoqioiNdee42WLVtSp04d/Pz8GD58OBcuXLA4hrSJyicBfQ3w3XffMWXKFGbNmsWBAwdo3bo1ffv2JTk52dpVE5UgLCyMCRMmsGfPHjZu3EhRUREPPvggOTk55jIvvfQSv/76Kz/88ANhYWFcuHCBgQMHWrHWojLs27ePTz75hFatWlm8Lu2h9klLS6Nbt27Y2Niwbt06jh07xqJFi3B3dzeXWbhwIYsXL2bp0qWEh4dTp04d+vbtS35+vhVrLirKggULWLJkCR999BHR0dEsWLCAhQsX8uGHH5rLSJuo2XJycmjdujUff/xxqdvL8vsfOnQoR48eZePGjfz2229s27aNsWPHVtYliHJ2szaRm5vLgQMHmDFjBgcOHGD16tXExMTw6KOPWpSTNmEFiqj2OnbsqEyYMMH83Gg0Kn5+fsq8efOsWCthLcnJyQqghIWFKYqiKOnp6YqNjY3yww8/mMtER0crgLJ7925rVVNUsKysLKVx48bKxo0blR49eigvvviioijSHmqr1157TenevfsNt5tMJsVgMCj/93//Z34tPT1dsbOzU7755pvKqKKoZA8//LDy3HPPWbw2cOBAZejQoYqiSJuobQDl559/Nj8vy+//2LFjCqDs27fPXGbdunWKRqNRzp8/X2l1FxXj2jZRmr179yqAcubMGUVRpE1Yi/TQV3OFhYVERETQp08f82tarZY+ffqwe/duK9ZMWEtGRgYAHh4eAERERFBUVGTRRkJCQggMDJQ2UoNNmDCBhx9+2OL3DtIeaqu1a9fSvn17Bg0ahLe3N23atOGzzz4zb4+LiyMpKcmiXbi6utKpUydpFzVU165d2bRpE7GxsQBERUWxY8cO+vfvD0ibqO3K8vvfvXs3bm5utG/f3lymT58+aLVawsPDK73OovJlZGSg0Whwc3MDpE1Yi97aFRB359KlSxiNRnx8fCxe9/Hx4fjx41aqlbAWk8nE5MmT6datGy1atAAgKSkJW1tb8x/bEj4+PiQlJVmhlqKiffvttxw4cIB9+/Zdt03aQ+10+vRplixZwpQpU/j3v//Nvn37mDRpEra2towYMcL8uy/ts0TaRc30+uuvk5mZSUhICDqdDqPRyNy5cxk6dCiAtIlariy//6SkJLy9vS226/V6PDw8pI3UAvn5+bz22msMGTIEFxcXQNqEtUhAL0QNMmHCBI4cOcKOHTusXRVhJWfPnuXFF19k48aN2NvbW7s6ooowmUy0b9+et99+G4A2bdpw5MgRli5dyogRI6xcO2EN33//PatWreLrr7+mefPmREZGMnnyZPz8/KRNCCFuqqioiKeffhpFUViyZIm1q1PryZD7as7T0xOdTnddhuqLFy9iMBisVCthDRMnTuS3335jy5Yt1KtXz/y6wWCgsLCQ9PR0i/LSRmqmiIgIkpOTadu2LXq9Hr1eT1hYGIsXL0av1+Pj4yPtoRby9fWlWbNmFq+FhoaSkJAAYP7dy2dJ7fHKK6/w+uuvM3jwYFq2bMmwYcN46aWXmDdvHiBtorYry+/fYDBcl4C5uLiY1NRUaSM1WEkwf+bMGTZu3GjunQdpE9YiAX01Z2trS7t27di0aZP5NZPJxKZNm+jSpYsVayYqi6IoTJw4kZ9//pnNmzcTHBxssb1du3bY2NhYtJGYmBgSEhKkjdRAvXv35vDhw0RGRpp/2rdvz9ChQ82PpT3UPt26dbtuOcvY2Fjq168PQHBwMAaDwaJdZGZmEh4eLu2ihsrNzUWrtfwaqNPpMJlMgLSJ2q4sv/8uXbqQnp5ORESEuczmzZsxmUx06tSp0ussKl5JMH/ixAn++usv6tata7Fd2oSVWDsrn7h73377rWJnZ6csX75cOXbsmDJ27FjFzc1NSUpKsnbVRCV44YUXFFdXV2Xr1q1KYmKi+Sc3N9dcZty4cUpgYKCyefNmZf/+/UqXLl2ULl26WLHWojJdneVeUaQ91EZ79+5V9Hq9MnfuXOXEiRPKqlWrFEdHR2XlypXmMvPnz1fc3NyUX375RTl06JDy2GOPKcHBwUpeXp4Vay4qyogRIxR/f3/lt99+U+Li4pTVq1crnp6eyquvvmouI22iZsvKylIOHjyoHDx4UAGUd999Vzl48KA5Y3lZfv/9+vVT2rRpo4SHhys7duxQGjdurAwZMsRalyTu0s3aRGFhofLoo48q9erVUyIjIy2+cxYUFJiPIW2i8klAX0N8+OGHSmBgoGJra6t07NhR2bNnj7WrJCoJUOrPsmXLzGXy8vKU8ePHK+7u7oqjo6PyxBNPKImJidartKhU1wb00h5qp19//VVp0aKFYmdnp4SEhCiffvqpxXaTyaTMmDFD8fHxUezs7JTevXsrMTExVqqtqGiZmZnKiy++qAQGBir29vZKgwYNlOnTp1t8MZc2UbNt2bKl1O8PI0aMUBSlbL//y5cvK0OGDFGcnJwUFxcX5dlnn1WysrKscDWiPNysTcTFxd3wO+eWLVvMx5A2Ufk0iqIolTceQAghhBBCCCGEEOVB5tALIYQQQgghhBDVkAT0QgghhBBCCCFENSQBvRBCCCGEEEIIUQ1JQC+EEEIIIYQQQlRDEtALIYQQQgghhBDVkAT0QgghhBBCCCFENSQBvRBCCCGEEEIIUQ1JQC+EEEKIGm/kyJE8/vjjNy2zdetWNBoN6enplVInIYQQ4m5JQC+EEEJUgpSUFF544QUCAwOxs7PDYDDQt29fdu7cae2qVRkajcb84+rqSrdu3di8eXO5HPuDDz5g+fLl5uc9e/Zk8uTJFmW6du1KYmIirq6u5XJOIYQQoqJJQC+EEEJUgieffJKDBw+yYsUKYmNjWbt2LT179uTy5cvWrlqVsmzZMhITE9m5cyeenp488sgjnD59+q6P6+rqipub203L2NraYjAY0Gg0d30+IYQQojJIQC+EEEJUsPT0dLZv386CBQvo1asX9evXp2PHjkybNo1HH33Uotzo0aPx8vLCxcWF+++/n6ioKItjzZ8/Hx8fH5ydnRk1ahSvv/4699xzj3l7aT3Pjz/+OCNHjjQ/LygoYOrUqfj7+1OnTh06derE1q1bzduXL1+Om5sbf/75J6GhoTg5OdGvXz8SExMtjvvFF1/QvHlz7Ozs8PX1ZeLEibd1LaVxc3PDYDDQokULlixZQl5eHhs3bgQgLCyMjh07ms/3+uuvU1xcbN73xx9/pGXLljg4OFC3bl369OlDTk4OYDnkfuTIkYSFhfHBBx+YRwTEx8eXOuT+p59+Ml9jUFAQixYtsqhvUFAQb7/9Ns899xzOzs4EBgby6aef3vI6hRBCiPIgAb0QQghRwZycnHBycmLNmjUUFBTcsNygQYNITk5m3bp1RERE0LZtW3r37k1qaioA33//PbNnz+btt99m//79+Pr68t///ve26zNx4kR2797Nt99+y6FDhxg0aBD9+vXjxIkT5jK5ubm88847fPXVV2zbto2EhASmTp1q3r5kyRImTJjA2LFjOXz4MGvXrqVRo0ZlvpaycHBwAKCwsJDz58/z0EMP0aFDB6KioliyZAn/+9//eOuttwBITExkyJAhPPfcc0RHR7N161YGDhyIoijXHfeDDz6gS5cujBkzhsTERBITEwkICLiuXEREBE8//TSDBw/m8OHDzJ49mxkzZlgM3QdYtGgR7du35+DBg4wfP54XXniBmJiYMl+nEEIIcccUIYQQQlS4H3/8UXF3d1fs7e2Vrl27KtOmTVOioqLM27dv3664uLgo+fn5Fvs1bNhQ+eSTTxRFUZQuXboo48ePt9jeqVMnpXXr1ubnPXr0UF588UWLMo899pgyYsQIRVEU5cyZM4pOp1POnz9vUaZ3797KtGnTFEVRlGXLlimAcvLkSfP2jz/+WPHx8TE/9/PzU6ZPn17qtZblWkoDKD///LOiKIqSk5OjjB8/XtHpdEpUVJTy73//W2natKliMpks6uTk5KQYjUYlIiJCAZT4+PhSjz1ixAjlscceMz8v7X3asmWLAihpaWmKoijKP//5T+WBBx6wKPPKK68ozZo1Mz+vX7++8swzz5ifm0wmxdvbW1myZMkNr1MIIYQoL9JDL4QQQlSCJ598kgsXLrB27Vr69evH1q1badu2rbm3NyoqiuzsbOrWrWvu0XdyciIuLo5Tp04BEB0dTadOnSyO26VLl9uqx+HDhzEajTRp0sTiPGFhYebzADg6OtKwYUPzc19fX5KTkwFITk7mwoUL9O7du9RzlOVabmTIkCE4OTnh7OzMTz/9xP/+9z9atWpFdHQ0Xbp0sZjf3q1bN7Kzszl37hytW7emd+/etGzZkkGDBvHZZ5+RlpZ2W+/NtaKjo+nWrZvFa926dePEiRMYjUbza61atTI/1mg0GAwG83slhBBCVCS9tSsghBBC1Bb29vY88MADPPDAA8yYMYPRo0cza9YsRo4cSXZ2Nr6+vhZz2UvcKpnb1bRa7XXDzIuKisyPs7Oz0el0REREoNPpLMo5OTmZH9vY2Fhs02g05uOWDIW/kbu5lvfee48+ffrg6uqKl5fXTcteTafTsXHjRnbt2sWGDRv48MMPmT59OuHh4QQHB5f5OHeitPfKZDJV6DmFEEIIkDn0QgghhNU0a9bMnLStbdu2JCUlodfradSokcWPp6cnAKGhoYSHh1scY8+ePRbPvby8LJLXGY1Gjhw5Yn7epk0bjEYjycnJ153HYDCUqd7Ozs4EBQWxadOmUreX5VpuxGAw0KhRo+uC+dDQUHbv3m1xs2Lnzp04OztTr149QA2ku3Xrxpw5czh48CC2trb8/PPPpZ7H1tbWope9NKGhodctK7hz506aNGly3c0QIYQQwhokoBdCCCEq2OXLl7n//vtZuXIlhw4dIi4ujh9++IGFCxfy2GOPAdCnTx+6dOnC448/zoYNG4iPj2fXrl1Mnz6d/fv3A/Diiy/yxRdfsGzZMmJjY5k1axZHjx61ONf999/P77//zu+//87x48d54YUXLLK2N2nShKFDhzJ8+HBWr15NXFwce/fuZd68efz+++9lvqbZs2ezaNEiFi9ezIkTJzhw4AAffvhhma/ldo0fP56zZ8/yr3/9i+PHj/PLL78wa9YspkyZglarJTw83JwsMCEhgdWrV5OSkkJoaGipxwsKCiI8PJz4+HguXbpUao/6yy+/zKZNm3jzzTeJjY1lxYoVfPTRRxbJAYUQQghrkiH3QgghRAVzcnKiU6dOvPfee5w6dYqioiICAgIYM2YM//73vwG1d/mPP/5g+vTpPPvss6SkpGAwGLjvvvvw8fEB4B//+AenTp3i1VdfJT8/nyeffJIXXniBP//803yu5557jqioKIYPH45er+ell16iV69eFvVZtmwZb731Fi+//DLnz5/H09OTzp0788gjj5T5mkaMGEF+fj7vvfceU6dOxdPTk6eeeqrM13K7/P39+eOPP3jllVdo3bo1Hh4ejBo1ijfeeAMAFxcXtm3bxvvvv09mZib169dn0aJF9O/fv9TjTZ06lREjRtCsWTPy8vKIi4u7rkzbtm35/vvvmTlzJm+++Sa+vr785z//sVgCUAghhLAmjXLtRDshhBBCVBuzZ89mzZo1REZGWrsqQgghhKhkMuReCCGEEEIIIYSohiSgF0IIIYQQQgghqiEZci+EEEIIIYQQQlRD0kMvhBBCCCGEEEJUQxLQCyGEEEIIIYQQ1ZAE9EIIIYQQQgghRDUkAb0QQgghhBBCCFENSUAvhBBCCCGEEEJUQxLQCyGEEEIIIYQQ1ZAE9EIIIYQQQgghRDUkAb0QQgghhBBCCFENSUAvhBBCCCGEEEJUQ/8PNWHTDC5RfXsAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x, y = next(dataset(1, 1))\n", "\n", "plt.figure(figsize=(12, 6))\n", "plt.subplot(2, 1, 1)\n", "plt.plot(x[0], label='Input Sequence', alpha=0.8)\n", "plt.plot(y[0], label='Target Sequence', alpha=0.8)\n", "plt.xlabel('Sequence Position')\n", "plt.ylabel('Value')\n", "plt.legend()\n", "plt.grid(True, alpha=0.3)" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "id": "yE5UnHOoIsbc" }, "outputs": [], "source": [ "# Training parameters\n", "batch_size = 8192\n", "num_steps = 501\n", "losses = []" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 406 }, "id": "OwSJkENmaCpK", "outputId": "def74c1d-670d-4b39-c1f5-165b133ea216" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "data sharding\n" ] }, { "data": { "text/html": [ "
  TPU 0  \n",
              "         \n",
              "  TPU 1  \n",
              "         \n",
              "  TPU 2  \n",
              "         \n",
              "  TPU 3  \n",
              "         \n",
              "  TPU 6  \n",
              "         \n",
              "  TPU 7  \n",
              "         \n",
              "  TPU 4  \n",
              "         \n",
              "  TPU 5  \n",
              "         \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "step=0, loss=1.8279016017913818\n", "step=100, loss=0.08409085869789124\n", "step=200, loss=0.021983487531542778\n", "step=300, loss=0.01996900513768196\n", "step=400, loss=0.015401830896735191\n", "step=500, loss=0.015486905351281166\n" ] } ], "source": [ "# The training loop will be a bit slow (in actual wall clock time) since the dataset is generating samples for each step on the fly\n", "# It is not slow due to using FSDP 😅\n", "\n", "for step, (x, y) in enumerate(dataset(num_steps, batch_size)):\n", " # shard data\n", " x, y = jax.device_put((x, y), batch_sharding)\n", " # train\n", " loss = train_step(model, optimizer, x, y)\n", "\n", " losses.append(float(loss))\n", "\n", " if step == 0:\n", " print('data sharding')\n", " jax.debug.visualize_array_sharding(x)\n", "\n", " if step % 100 == 0:\n", " print(f'step={step}, loss={loss}')" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "id": "CcuXaSshahKZ" }, "outputs": [], "source": [ "# dereplicate state\n", "state = nnx.state((model, optimizer))\n", "state = jax.device_get(state)\n", "nnx.update((model, optimizer), state)" ] }, { "cell_type": "markdown", "metadata": { "id": "p9cPI47sIsbc" }, "source": [ "We can now visualize the training loss curve, as well as the sequence model learned by our data parallel training." ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 607 }, "id": "Irbzmkq1Isbc", "outputId": "1aa0d6f5-db87-4376-ab9b-e2b069d67251" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAJOCAYAAACqS2TfAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAy9VJREFUeJzs3Xd4W/XZxvFbsmx527Gzd8IIBEgYGWzChlLKni1ltLS0gZcWaN/St2W00EFLt7ug7EIZZQYaKCEhjISEhCSEDLK3d7ytfd4/ZMnnSEe27NixLH8/19Wr1tGRfOzIxree5/f8HIZhGAIAAAAAAD3O2dcXAAAAAABAuiJ0AwAAAADQSwjdAAAAAAD0EkI3AAAAAAC9hNANAAAAAEAvIXQDAAAAANBLCN0AAAAAAPQSQjcAAAAAAL2E0A0AAAAAQC8hdANACnI4HJo1a9Y+PceCBQvkcDh0zz339Mg1AUBvmjVrlhwOxz49x2OPPSaHw6HHHnvMcnz8+PEaP378Pj03AHQXoRsAEnA4HF36Hzo3fvx4ZWdn9/Vl9Ihdu3bpzjvv1NFHH63i4mJlZWVpxIgROu+88/TYY4/J5/P19SWmlNdff13nnXeehg4dqszMTA0ePFiHH364brjhBr3yyit9fXkD0j333BP9/XXHHXckPO9///d/o+fxJh4AdJ2rry8AAFLV3XffHXfsd7/7nerr623v60lr165Vbm7uPj3HjBkztHbtWg0ePLiHrgoRzzzzjL72ta+ptbVVxxxzjL7yla+oqKhI5eXleuedd3T99dfrySef1Lx58/r6UlPCvffeq3vuuUe5ubn64he/qPHjxysQCOizzz7Ts88+q88//1wXXHBBX1/mgOVyufTUU0/pF7/4hVwu65+GgUBATzzxhFwulwKBQB9d4b7jZxFAXyJ0A0ACdhWdxx57TPX19b1e7TnkkEP2+Tlyc3N75HlgNXfuXH3lK19RcXGxXnnlFZ155pmW+w3D0Msvv6yHH364j64wtWzdulU/+clPNGbMGC1evFgjR4603N/a2qqPPvqoj64OknTuuefqtdde05w5c3ThhRda7nvjjTdUXl6uL33pS3r11Vf75gJ7wAEHHNDXlwBgAKO9HAD20datW+VwOHTddddp7dq1uuiii1RaWiqHw6GtW7dKkl566SVdddVVOvDAA5Wbm6uioiKddNJJ+ve//237nHZruq+77jo5HA5t2bJFf/jDH3TIIYfI7XZr3LhxuvfeexUKhSznJ1rTHVnb2NTUpFtvvVUjR46U2+3WlClT9MILLyT8Gq+44gqVlJQoPz9fp5xyihYuXBhtT12wYEF3vnUdam5u1t13361DDjlE2dnZKikp0XnnnacPPvgg7lyPx6MHH3xQU6dOVVFRkfLy8jR+/HhdfvnlWrlyZfS8UCikhx9+WDNmzFBJSYlycnI0evRonX/++Ul9DcFgULNnz1YoFNJzzz0XF7il8L/dRRddpBdffDF6rKPvk90a1I5eU2vWrFFBQUGHIWLKlCnKyclRQ0ND9JhhGHrkkUd0wgknqLCwULm5uZo2bZoeeeSRTr/ufbFkyRKFQiFdfPHFcYFbknJycmznF3T1emtra3XTTTdp2LBhys3N1fTp0/XSSy/Zfn87mndg/t7Hqqys1He/+10deOCBcrvdGjx4sC655BKtXr067tzu/Jz5fD799re/1fTp01VQUKD8/HxNnjxZt912m/bu3dvta+nMxRdfrOLiYtvv7SOPPKJBgwbpoosuSvj41atX6/LLL9fQoUPldrs1YcIEfec731FNTY3t+e+//75OOeUU5eXlqbS0VFdccYV27NiR8Pl74rVrt6bb/HP59NNP68gjj1ROTo5GjBihW2+9Va2trXHPEwgE9POf/1wHHHCAsrOzdeCBB+rnP/+5Nm/enPB1AwBUugGgh2zcuFHHHnusjjjiCF133XWqqalRVlaWJOnOO+9UVlaWTjzxRI0YMUJVVVV69dVXdemll+oPf/iDbrnllqQ/z/e+9z29++67+uIXv6izzz5bL7/8su655x75fD7df//9ST2H3+/XWWedpb179+qSSy5RS0uL/vWvf+nyyy/X3LlzddZZZ0XP3bVrl44//njt2bNH55xzjo466iitX79eZ555pk477bSufZOS5PF4dNppp2nJkiU6+uij9Z3vfEcVFRV69tln9eabb+qZZ57RZZddFj3/2muv1XPPPacpU6bo+uuvl9vt1o4dOzR//nwtXbpUU6dOlRT+d3jggQd0wAEH6Oqrr1ZBQYF27dql999/X2+//Xanw+vmz5+vzZs36/jjj9fpp5/e4blut3ufvw92r6ni4mJdcsklevzxx/Xhhx/q+OOPtzxm5cqV+vTTT3XFFVeosLBQUji0fPnLX9Yzzzyjgw46SFdffbWysrL03//+V1/72te0Zs0a/frXv97n67VTWloqSdqwYUPSj+nq9ba0tGjWrFn69NNPddxxx+mUU07Rjh07dMUVV1hey/ti06ZNmjVrlnbu3KmzzjpLF154oSorK/Xvf/9bb775pubNm6eZM2daHtOVn7PW1ladeeaZ+uCDD3TQQQdFX8cbNmzQ3/72N331q1/VoEGDun0tHcnOztZVV12lhx56SBUVFRo2bJgkqaKiQq+//rq+8Y1vJJzF8P777+vss8+Wz+fTpZdeqvHjx2vRokX6/e9/rzlz5mjx4sWWJS7z5s3TueeeK6fTqSuuuEIjR47UvHnzdMIJJ0S/PrP98dr905/+pLlz5+qCCy7Qaaedprlz5+oPf/iDqqur9c9//tNy7g033KAnn3xSEydO1OzZs+X1evXb3/5WixYt2qdrAJDmDABA0saNG2fE/urcsmWLIcmQZNx11122j9u0aVPcscbGRuOII44wioqKjObmZst9koxTTjnFcuzaa681JBkTJkwwdu/eHT1eVVVlFBcXGwUFBYbX640enz9/viHJuPvuu22/hgsuuMBy/ttvv21IMs4++2zL+V/5ylcMScb9999vOf6Pf/wj+nXPnz/f9uuONW7cOMPtdnd63r333mtIMr785S8boVAoenz58uVGVlaWUVxcbDQ0NBiGYRh1dXWGw+EwjjnmGCMQCFieJxAIGHv37o3eLikpMUaOHBn3/TYMw6ipqen0uu655x5DkvGjH/2o03PN7r777oTfp0cffdSQZDz66KPRY529piL/Vt/61rfi7rv99tsNScacOXOix/7+978bkozrr7/e8Pl80eNer9c4//zzDUnGxx9/3KWvKVmNjY3G2LFjDUnGeeedZzz55JPG+vXrLf+usbp6vZHv74033mh5nrlz50a/j+bvb6KfDcNo/95fe+21luPHH3+8kZGRYcydO9dyfP369UZBQYFxxBFHWI539ecs8u92zTXXxL2O6+rqjMbGxm5fSyKR79szzzxjfPzxx4Yk44EHHoje/8ADDxiSjGXLlhnPPPNM3PcsGAwaBxxwgCEp7lq+973vGZKMG264wXL+xIkTDYfDYbz33nvR46FQyLj66quj/1ZmXX0t2P08GUb432PcuHG2X39RUZGxbt266PGWlhbj4IMPNpxOp7Fr167o8ci/3ZFHHmn5HbJ7925j2LBhtq8bADAMwyB0A0AXdBS6hw8fbvnjOhkPPvigIclYsGCB5XhHofuRRx6Je57IfatWrYoe6yx0b9682fbrKykpid72eDyG2+02hg4dang8Hsu5oVDImDRpUq+E7okTJxqZmZnGjh074u678cYbDUnGE088YRiGYdTX1xuSjBNOOKHDIGcY4dA9fvz4uK8lWTfddJMhyfjrX//apcd1N3Qnek0Fg0Fj1KhRRmlpqSWIBINBY8SIEcaQIUMMv98fPT5lyhQjLy/PaGlpiXuuVatWGZKM22+/vUtfU1csX77cOOyww6KhKhJ0vvjFLxovvvhi3Pldvd4JEyYYWVlZxp49e+LOP/300/c5dC9fvjwuQJrddttthiTj008/jR7rys+Z3+83CgoKjKKiIqO2ttb2c+zLtSRiDt2GEf6+H3roodH7Dz30UGPq1KmGYRi2oXvhwoWGJOPcc8+Ne+7GxkajpKTEyM7Ojr6G3333XUOScf7558edv3XrViMjIyPu92tXXwvdCd12b2xF7nv11Vejx6677jpDku1r9mc/+xmhG0BCtJcDQA+ZOnVqtJ08VmVlpX7xi1/oP//5j7Zt2xa3VnD37t1Jf55jjjkm7tjo0aMlSXV1dUk9R3FxsSZMmGD7POY2yfXr18vr9WratGlx7dIOh0PHH3+81q9fn/S1J6OhoUGbN2/WoYceGv26zE499VQ99NBDWrFiha655hoVFhbqC1/4gt544w0dffTRuuyyyzRr1ixNnz5dmZmZlsdeeeWV+vOf/6zDDz9cV155pU499VQdd9xxysnJ6dGvoackek05nU59+ctf1gMPPKA33ngjOvl73rx52rNnj2655ZboFOqWlhZ9+umnGjlypH75y1/GPZff75ckrVu3rtPrsVsD/Z3vfEfFxcUdPu6oo47Sp59+qkWLFmn+/PlatmyZ3n//fc2ZM0dz5szRl7/8ZT355JNyOBxdvt6GhgZt2bJFkydP1vDhw+POP+mkk/Z5cvXixYslhdut7b4HkWtZt26dDj/88OjxZH/O1q1bp8bGRp1xxhm2LdY9cS3JuOGGG/Sd73wnem1r167V73//+4Tnf/LJJ5JkuywjPz9f06ZN01tvvaX169friCOOiM5XOOmkk+LOHzdunMaMGROdgyH17Gu3I8n+To1c/4knnhh3/gknnLBP1wAgvRG6AaCHRNZBxqqtrdX06dO1fft2nXDCCTrjjDNUXFysjIwMrVixQq+88oq8Xm/SnyeyTtcsErCCwWBSz1FUVGR73OVyWQayRQZxDR061Pb8RF/zvoh8zkTPPWLECMt5kvT888/rZz/7mZ5++mn93//9n6Tw9+n666/Xz372s+j2a7///e81YcIEPfroo7rvvvt03333KTs7W5dffrkefPDBTrdXi4S6Xbt27dsXmaSOvr/XXHONHnjgAT311FPR0P3kk09G74vYu3evDMPQrl27dO+99yZ8vubm5k6vx+7x1113XaehW2p/kyayBt0wDL3yyiv66le/qn/+85+65JJLdNFFF3X5evfHa7S2tlZSeK/x119/vdNrikj256y+vl6SNGrUqF67lmR85Stf0fe///3ogLKsrCx9+ctfTnh+V39WI19nR/9W5tDdk6/djiT7O7WhoUFOp9P290Rv/C4EkD6YXg4APcThcNge/8c//qHt27frpz/9qd5//3398Y9/1E9/+lPdc889OvbYY/fzVXZN5I/RyspK2/srKip67XMmeu7y8nLLeVJ4e7T77rtPmzdv1ubNm/WPf/xDkyZN0u9//3t997vfjZ7ncrl0xx136LPPPtOuXbv09NNP66STTtITTzzRYbiIiFSzulo5dTrD/7m12+c4EkTsJHpNSdLhhx+uI488UnPmzFF9fb1aWlr00ksvadKkSZo+fXr0vMj36ZhjjpERXlZm+7/58+d3+nXYPS52InSyHA6HLrzwwui/zzvvvNOt6+3Oa7Sr/x6Rz/HHP/6xw2u69tpru/Q9iIi8aZHMmzm9eS2lpaW64IIL9Oyzz+rZZ5/VhRdeGB2E19G1JPuzGnkTItl/q5587faEwsJChUIhVVdXd3rtAGBG6AaAXrZp0yZJilYjzd577739fTldMmnSJLndbi1btiyuGm8YRq9M7C0sLNTEiRO1ceNG2xAS2XbryCOPtH38hAkTdMMNN+jdd99Vfn5+wr2FR44cqauuukpz587VgQceqLffftt2iyCzU089VRMnTtSHH37Y6R/65u9XpGXY7uuJtOh2xzXXXCOPx6MXXnhBL730kpqamvSVr3zFck5BQYEOPfRQrV27NunlB/tTfn6+5XZXr7ewsFATJkzQxo0boyHPzO5nrKv/HpFJ4L01oXrSpEkqLCzU0qVL47YG29/XcsMNN6ixsVGNjY264YYbOjz3qKOOkiTbrfCam5v18ccfKycnR5MmTZKk6C4Cdv8m27Zti9s2LNVeu5Hrt9u28MMPP9zflwOgHyF0A0AvGzdunKTw1jpmTz/9tN54442+uKSkud1uXXrppaqoqNDvfvc7y31PPPHEPq+lTOTaa6+V3+/XnXfeKcMwosdXrVqlxx57TEVFRbrwwgslSVVVVbZ7E+/du1derze61ZHX67X9w7i5uVlNTU3KzMyMVkATycjIUFlZmZxOpy6//PJodTbWa6+9pksvvTR6O1J5fuKJJyxtxYsWLYrbkqgrrr76amVkZOjJJ5+MromODd2S9D//8z9qaWnRjTfeaNuKu2XLFktbb09asmSJnnjiCXk8nrj7qqqq9PDDD0uyrpPt6vVec8018vl8uuuuuyznvfXWW7ZdCZMmTVJBQYFeffXVaLu2FK5W3nfffXHnz5gxQzNnztQzzzyjZ599Nu7+UCikd9991+arT47L5dI3v/lN1dfX69Zbb41bJlJfX6+mpqb9ci1nnXWWXn75Zb388su2+9CbnXDCCTrggAP0n//8R2+//bblvvvuu081NTW66qqronMJTjzxRE2YMEFz5syx/D40DEM//OEPbZfH9OVrN1akG+YnP/mJ5Q268vLyDte+AwBrugGgl11zzTX65S9/qVtuuUXz58/XuHHjtHLlSs2bN08XX3yxXnzxxb6+xA79/Oc/19tvv60f/OAHevfdd6P7dM+ZM0fnnHOO5s6d22lYNfP7/bruuusS3v/YY4/p+9//vl5//XU9+eSTWrt2rU4//XRVVlbq2WefVSAQ0EMPPaSCggJJ4WrlUUcdpalTp2rKlCkaNWqUampq9Morr8jv9+uOO+6QFN4H+YQTTtDBBx+sY445RmPHjlVTU5PmzJmj8vJy3XHHHUntrX3OOefoySef1Ne//nWdfvrpmjZtmo477jgVFBSooqJCCxYs0KZNm3TGGWdEH3PsscfqhBNO0DvvvKPjjjtOJ598srZt26ZXXnlF559/vl566aWkv39mw4cP1xlnnKG33npLTqdTJ554om279ze/+U0tXrxYjz/+uD744AOdccYZGjlypCoqKrRu3Tp99NFHevrpp7vdKt6R3bt369prr9XNN9+sk08+WYcccohcLpe2bdumOXPmqKmpSeedd55l3/WuXu/3v/99vfjii3rooYf02Wef6eSTT9aOHTv03HPP6bzzzotb+5yVlaVbbrlFP/vZz3T00UfrggsuUGNjo1577TWdcsop0e4Us2eeeUannnqqrrzySv3ud7/T0UcfrZycHG3fvl2LFi1SVVWV7RsLyfrJT36ixYsX68knn9TixYt17rnnyu12a/PmzZo7d67ef//9aHdHb16L0+m07cpJdO5jjz2ms88+W1/4whd02WWXady4cVq0aJEWLFigAw44QL/4xS8s5//973/XF77wBZ1xxhnRfbrfeecd7dmzR1OmTNGqVassn6MvX7uxzjjjDF199dV6+umndcQRR+jCCy+U1+vVc889p5kzZ+q1117r0u9CAANIL0xEB4C01dGWYR1tFbNixQrjrLPOMgYNGmQUFBQYp5xyivH2228n3N5GHWwZtmXLlrjnt9uSqqMtw2K3zok45ZRT4r4+wzCMzZs3G5dddplRVFRk5ObmGieddJLx7rvvGjfffLMhyfjkk08Sfu2xn1umbaPs/hfR1NRk/PjHPzYOPvjg6N7c5557rmV/X8MwjL179xr33HOPcfLJJxsjRowwsrKyjJEjRxrnnHOO8Z///Cd6ns/nM375y18aZ511ljF69GgjKyvLGDZsmHHyyScbTz/9dKfbjcXauXOn8b//+7/GUUcdZRQWFhoul8sYNmyYcc455xiPPvqoZSsvwzCM6upq46tf/apRUlJi5OTkGMcee6zx5ptvdrhlWDLbDz311FPR793f/va3Ds999tlnjTPOOMMYNGiQkZmZaYwaNcqYNWuW8eCDDxpVVVVd+vqT1dDQYDz11FPGNddcYxx22GFGcXGx4XK5jCFDhhinn3668Y9//CNuX+ruXG9NTY3xjW98wxgyZIiRnZ1tHHPMMcaLL76Y8GcsGAwa99xzjzFmzBgjKyvLOPjgg43f//73xubNmxN+72tra40f/ehHxuGHH27k5OQY+fn5xkEHHWRcffXVcdtIdefnzOPxGL/+9a+NI488Mvr8kydPNm6//XbLfvNdvZZEYrcM64jdlmERq1atMi699FJj8ODBRmZmpjFu3Djj1ltvTfiaWrhwoXHyyScbOTk5RklJiXHZZZcZ27ZtS/h9MYzkXwvd2TIs2a38DCO8vdtPf/rT6DZ1EydONH72s58ZH330kSHJuPXWW22vH8DA5jAMU98eAABdcOKJJ2rRokWqr6+PW5sLpILHHntM119/vR599NEOOyyAffHwww/rxhtv1J///Gd961vf6uvLAZBi6IEBAHRqz549cceeeuqpaLsngRvAQFBeXq7YetWuXbt03333KSMjQ1/84hf76MoApDLWdAMAOnX44YfrqKOO0uTJk6P7iy9YsEAFBQX69a9/3deXBwD7xS9+8Qu9/vrrOumkkzR06FBt375dc+bMUWNjo+655x6NGTOmry8RQAoidAMAOnXTTTfptdde08cff6zm5mYNGTJEV199tX784x/rkEMO6evLA4D94pxzztGaNWv0+uuva+/evcrOztaUKVP07W9/W1dffXVfXx6AFMWabgAAAAAAeglrugEAAAAA6CWEbgAAAAAAeglrursgFApp9+7dKigokMPh6OvLAQAAAAD0EcMw1NjYqJEjR8rpTFzPJnR3we7du5lKCQAAAACI2rFjh0aPHp3wfkJ3FxQUFEgKf1MLCwv7+GoSC4VCqqqq0pAhQzp8xwXYH3g9IpXwekQq4fWIVMNrEqmkP7weGxoaNGbMmGhOTITQ3QWRlvLCwsKUD90ej0eFhYUp+wLFwMHrEamE1yNSCa9HpBpek0gl/en12NnS49S+egAAAAAA+jFCNwAAAAAAvYTQDQAAAABALyF0AwAAAADQSwjdAAAAAAD0EkI3AAAAAAC9hNCdhLKyMk2ePFnTp0/v60sBAAAAAPQjhO4kzJ49W2vWrNHSpUv7+lIAAAAAAP0IoRsAAAAAgF5C6AYAAAAAoJcQugEAAAAA6CWEbgAAAAAAegmhGwAAAACAXkLoBgAAAACglxC6AQAAAADoJYRuAAAAAAB6CaEbAAAAAIBeQugGAAAAAKCXELrTVFWTTyt31MkwjL6+FAAAAAAYsAjdSSgrK9PkyZM1ffr0vr6UpDR6/Lriic900V8W6bVVe/r6cgAAAABgwCJ0J2H27Nlas2aNli5d2teXkpQNlU1q8YUkScu37e3jqwEAAACAgYvQnYZCpo7yYIj2cgAAAADoK4TuNBQyBe0ga7oBAAAAoM8QutNQyBS0Q1S6AQAAAKDPELrTkLmlnPZyAAAAAOg7hO40ZFnTTXs5AAAAAPQZQncaMmgvBwAAAICUQOhOQ+bqdpDMDQAAAAB9htCdhszFbSrdAAAAANB3CN1pKMQgNQAAAABICYTuNBQy2KcbAAAAAFIBoTsN0V4OAAAAAKmB0J2GLO3lVLoBAAAAoM8QutOQpb2cSjcAAAAA9BlCdxoKEroBAAAAICUQupNQVlamyZMna/r06X19KUkxd5QTugEAAACg7xC6kzB79mytWbNGS5cu7etLSYo5aIdY0w0AAAAAfYbQnYZY0w0AAAAAqYHQnYYs7eVkbgAAAADoM4TuNGRpL6fSDQAAAAB9htCdhmgvBwAAAIDUQOhOQ+aczSA1AAAAAOg7hO40RKUbAAAAAFIDoTsNWUI3lW4AAAAA6DOE7jQUCpk/JnQDAAAAQF8hdKchKt0AAAAAkBoI3WnIHLTNVW8AAAAAwP5F6E5D5o7yAKkbAAAAAPoMoTsNmddxB8ncAAAAANBnCN1pyLymm326AQAAAKDvELrTkLm9nH26AQAAAKDvELrTkLm9nC3DAAAAAKDvELrTEFuGAQAAAEBqIHSnIXPQpr0cAAAAAPoOoTsNmYvbDFIDAAAAgL5D6E5DISrdAAAAAJASCN1pyBy0Q4ZkUO0GAAAAgD5B6E5DsRmbYjcAAAAA9A1CdxLKyso0efJkTZ8+va8vJSmxLeW0mAMAAABA3yB0J2H27Nlas2aNli5d2teXkpTY4WmEbgAAAADoG4TuNBSbsdmrGwAAAAD6BqE7DVHpBgAAAIDUQOhOQ7GhO0ToBgAAAIA+QehOQ6GQ9Tbt5QAAAADQNwjdaYhKNwAAAACkBkJ3GorbMoxKNwAAAAD0CUJ3GorN2AxSAwAAAIC+QehOQ/Ht5X10IQAAAAAwwBG601BsOznt5QAAAADQNwjdaYj2cgAAAABIDYTuNBQbsisbPXpx+U7Vtfj66IoAAAAAYGBy9fUFoOfFrum+4bGl8vhDmjVpiB67fkYfXRUAAAAADDxUutNQbOj2+MOT1Basr+qLywEAAACAAYvQnYZYwg0AAAAAqYHQnYZiK90AAAAAgL5B6E5DIUrdAAAAAJASCN1pqKPMTSAHAAAAgP2H0J2GOtqX2xsI7ccrAQAAAICBjdCdhowO1nR7/MH9eCUAAAAAMLARutNQRx3kVLoBAAAAYP8hdKehYAeV7qpGr/6+cJMWb67Zj1cEAAAAAAOTq68vAD2voy3Dfj9vg95eW6HcrAwt/uHpKszO3I9XBgAAAAADC5XuNNTRhPI1u+slSS2+oPbUefbXJQEAAADAgEToTkMdrene2+KPftzkDeyHqwEAAACAgYvQnYY6qnS3mqaXNxO6AQAAAKBXEbrTUEdrus0I3QAAAADQuwjdaaij9nKzRkI3AAAAAPQqQnca6mjLMDMq3QAAAADQuwjdacggdAMAAABASiB0p6FgKLnzmrzBzk8CAAAAAHQboTsNMUgNAAAAAFLDgAzdF110kQYNGqRLL720ry+lVyTbXs4+3QAAAADQuwZk6L711lv1xBNP9PVl9JpgktPLCd0AAAAA0LsGZOieNWuWCgoK+voyek0oyT3DIu3lizfX6Kzfvqvf/Pfz3rwsAAAAABhw+l3oXrhwoc4//3yNHDlSDodDL7/8ctw5ZWVlGj9+vLKzszVz5kwtWbJk/19oH+rKmm7DMHTl3xfr84om/WHeBtZ5AwAAAEAP6nehu7m5WVOnTlVZWZnt/c8++6xuu+023X333Vq+fLmmTp2qs88+W5WVlfv5SvtOkoVuNXkDWr59r+UYoRsAAAAAek6/C93nnnuu7rvvPl100UW29//mN7/RjTfeqOuvv16TJ0/WX//6V+Xm5uqRRx7Zz1fad4JJpu4mb0DPLNlhOeYNJLnfGAAAAACgU66+voCe5PP5tGzZMt15553RY06nU2eccYYWLVrU5efzer3yer3R2w0NDZKkUCikUCh1w2my08srGrx6fdUey7EWrz+lvzb0P6FQSIZh8LpCSuD1iFTC6xGphtckUkl/eD0me21pFbqrq6sVDAY1bNgwy/Fhw4Zp3bp10dtnnHGGVq5cqebmZo0ePVrPP/+8jjvuuLjn+/nPf65777037nhVVZU8Hk/PfwE9JNlKtyS1+oOW23sqq1XkaO3pS8IAFgqFVF9fL8Mw5HT2u+YapBlej0glvB6RanhNIpX0h9djY2NjUuelVehO1ttvv53UeXfeeaduu+226O2GhgaNGTNGQ4YMUWFhYW9d3j7bl/eCcguLNXTooB67FiAUCsnhcGjIkCEp+wsTAwevR6QSXo9INbwmkUr6w+sxOzs7qfPSKnQPHjxYGRkZqqiosByvqKjQ8OHDu/x8brdbbrc77rjT6UzZf3gp+fZyO75A6r6ThP7L4XCk/M8NBg5ej0glvB6RanhNIpWk+usx2etKzavvpqysLB1zzDGaN29e9FgoFNK8efNs28fTVVfay2N5A0Hb48GQsU9hHgAAAAAGon5X6W5qatLGjRujt7ds2aIVK1aopKREY8eO1W233aZrr71W06ZN04wZM/S73/1Ozc3Nuv766/vwqvevfcjc8vjjm9M3VTXpyw99pGGFbj1303FyuzL24eoAAAAAYODod6H7448/1qmnnhq9HVlzfe211+qxxx7TFVdcoaqqKt11110qLy/XkUceqblz58YNV0tX+1qNtqt0v75qj8obPCpv8OjjrXt1woGD9+lzAAAAAMBA0e9C96xZszoNljfffLNuvvnmHvucZWVlKisrUzBo33qdSvaltVyyr3SbJ5y3+lL/ewAAAAAAqSKt1nT3ltmzZ2vNmjVaunRpX19Kp/Yxc9tWuv2B9iAeSOF98gAAAAAg1RC600yoG+3lI4vaR93bVboDpiTvDzJMDQAAAACSRehOM90J3eNK86If21a6g1S6AQAAAKA7CN1ppjtruscPbg/ddpVuc+im0g0AAAAAySN0p5nurOkeX5ob/diu0h0wBe0AoRsAAAAAkkboTkJZWZkmT56s6dOn9/WldKo7W4Z1Wum2rOmmvRwAAAAAkkXoTkJ/ml6eTHu5y+mw3B7XSaXbPL2c0A0AAAAAySN0p5lk2ssH57sttwuyM6Mfe22nl5sHqdFeDgAAAADJInSnmWSmlw8uyLLczna1vwzsKt0+y5puKt0AAAAAkCxCd5pJJnTnu1268aQJKsh26bdXTJU7MyN6n+0+3UwvBwAAAIBucfX1BaBnJdP9nZvl0v+dN1l3nnuonE6HJVR7/J1ML2efbgAAAABIGpXuNBNKInXnZIUr2862gWquDKcy2j72BuJDtc8UytkyDAAAAACSR+hOM8m0l+ea2skjIuu6bSvdIdrLAQAAAKA7CN1J6E/7dCezZVhuVnzojqzrtqt0014OAAAAAN1D6E5Cf9qnO5k13TlZ8Uv5O6p0+xikBgAAAADdQuhOM0YS7eVDC9xxx5KudLNlGAAAAAAkjenlaSaYIHQPK3TrvCNGqrLRo4uPHhV3v7uDSrffUukmdAMAAABAsgjdaSbRkuvMDKfuOn9ywseZK92GYcjhcETvM7eU+5PpXwcAAAAASKK9PO0kml6emdHxP3VkTbcU32JuHp5GezkAAAAAJI/QnWYShW6X02F7PCLbtI1YbOj2B9inGwAAAAC6g9CdZhJtGebqpNLtNle6Y9Z1m1vKaS8HAAAAgOQRutNMokycmdH9Sre5pZz2cgAAAABIHqE7CWVlZZo8ebKmT5/e15fSqURbhnXWXm6udJsnmAdDhiXI014OAAAAAMkjdCdh9uzZWrNmjZYuXdrXl9Kp7raXJ6p0x24R5k80Hh0AAAAAEIfQnWa6216eqNIdG7qpdAMAAABA8gjdaSbx9PLuVbpjQ3ZsCAcAAAAAJEboTjOJ9+nuZqU7FLtnN5VuAAAAAEgWoTvNJMrE3a10+2Mq3UwvBwAAAIDkEbrTTCjhILVOKt2Z9pXu2JAdG8IBAAAAAIkRutNM4vbyTirdrvZKt8efuNLNmm4AAAAASB6hO80k3DKss326TZVub6CD6eWs6QYAAACApBG600zCNd2dDlKzr3QzvRwAAAAAuo/QnYSysjJNnjxZ06dP7+tL6VT3twxLUOmOnV7Omm4AAAAASBqhOwmzZ8/WmjVrtHTp0r6+lE4lDN3drHT7A7Ht5VS6AQAAACBZhO40k6i9vNNBaqZK91/f3aSH39usUMiIW8PtDxoyEgR7AAAAAICVq68vAD0r0ZZhTkfylW5Juu/1tTpkeGFce7kUHtbWWeUcAAAAAEClO+0kai/vZHi5pdId8emuets13EwwBwAAAIDkELrTTKItwzqtdGdmxB2ra/EpYDOtnAnmAAAAAJAcQneaSbTcurNKd65N6C5v8MhnE7CZYA4AAAAAySF0p5lE7eWOTirdg/KydNohQy3Hyus9tgHbbp03AAAAACAeoTvNBBOu6e588NnDX52m9fedo4Ls8Hy9igaP7RZhfirdAAAAAJAUQneaSTTjrLP2cklyOh1yuzI0vDBbkrSn3iOf3SA11nQDAAAAQFII3Wkm4ZZhyaTuNsOLwqHbGwiputEbd/+f3tmo255boSqb+wAAAAAA7dinO80kWtPdFcPaKt2StHNva9z9zy/bKUkaW5Kr75xx8D5/PgAAAABIV1S600x3twwzG24J3S0Jz9tT50n+wgAAAABgACJ0J6GsrEyTJ0/W9OnT+/pSOtXdLcPMhhV1XOmOaPT6k39SAAAAABiACN1JmD17ttasWaOlS5f29aV0KlF7eXcr3bvqOgjdnkDyFwYAAAAAAxChO80k2jKsC5nbEro70uQldAMAAABARwjdaSZxe3nXp5d3hko3AAAAAHSM0J1mEg9SS/45SvOylJnR+QOaCN0AAAAA0CFCd5pJtKbb0YVKt9Pp0NCCzqvdtJcDAAAAQMcI3WkmQaFbBdld25J9SIG703OavIGElXUAAAAAAKE77YRMIfjbsw5QlsupiUPydP7UkV16nuLczKTOa/ZR7QYAAACARLpW/kTKM7eXn3BAqb558gHKc2fIldG191cKs5ML3U2eQNLnAgAAAMBAQ+hOM+Ytw5wOqSjJinWsopwkQzfrugEAAAAgIdrL04x5jpqzKyPLYxTmJPd+TKPH3+3PAQAAAADpjtCdZsyDzbqyN3esZCvd7NUNAAAAAIkRutOMeU13xr5UupNd0017OQAAAAAkROhOM+b28n0odFPpBgAAAIAeQOhOMz3VXl6Y7CC1mNDtDQRlmJK/xx/U3NV7tLuutdvXAgAAAAD9FaE7zVjay/fHmu629vJQyNBv//u5pt77lr76yJJo8P7tfz/XTU8t1+V/W2R5QwAAAAAABgK2DEtCWVmZysrKFAwG+/pSOmUO3fvSXp7smu73NlRpa3WzPq9o1LryxrZj1dpR26qxpbn628LNkqSde1tV0eDRyOKc7l8UAAAAAPQzhO4kzJ49W7Nnz1ZDQ4OKior6+nI6FAq1f7wvg9SSrXR/sr1On2yvizve4o9f693iS/03LQAAAACgJ9FenmbMle59WdOdn71v78d4/SHL2m5JamBPbwAAAAADDKE7zQQtobv7z5PhdKigC8H7smNG65snT4ze9viD0fXeEQ2thG4AAAAAAwuhO82Yi8vOfUndSn5dtyQdNrJQblf7y8kTCKmywWs5p4HtxQAAAAAMMITuNNNTW4ZJ1nXdnT3VgUML5M7MiN72+oOqaowJ3VS6AQAAAAwwhO4001NbhklSYU7y7eUHDs2Pq3RXNcVWugndAAAAAAYWppenmZ7aMkyyVrqNDrbYzne7NKzQreyYSndsO3lDK+3lAAAAAAYWKt1ppqe2DJOSX9M9YXCeHA6HJXR7AiFVNnos51HpBgAAADDQELrTTE9tGSZJee7kGiGGF2VLkqW93G5NdyOD1AAAAAAMMITuNNOT7eW5WRmdnyRpWKFbkqzt5YFQXOiuafLq52+s1W/++7lCoQ761QEAAAAgTbCmO82Ys+y+tpcnG7pHFOVIkrIzTYPUbCrdH26q0YebaiRJhwwv0BeOGLFP1wcAAAAAqY5Kd5rpyS3DcrISvydz4ZEjJUl5WRm6YvoYSZLbZVrTbRO6zeauLt+nawMAAACA/oBKd5qxrunet+fK66DSfd0JE3TBkaM0elCOBudH2svb38Np8gZV2+JL+PiubEcGAAAAAP0VySfNmLf22vdKd+LQne926cgxxZZj5jXdu+taO9xmrCDJyegAAAAA0J/RXp5merK9fEiBO+F9duu9zdPLd+5t6fC5Qx0lcgAAAABIE4TuNGNpL9/Hf93jJpbquImlysnM0OM3zLDcZxe6zZXuHXtbO3zuFm9w3y4OAAAAAPoB2svTTE/u0+1wOPT0jTPlDYQsgVqybz3PNg1S8wVCHT53s5c9uwEAAACkPyrdacayZdi+btStcPCODdySlJUR/9JxZ9q/nIpz49dvN/sI3QAAAADSH6E7zZjXdPdA5k7IYfPk5jXdZpF9vM2aaS8HAAAAMAAQutOM0dZe7nTYB+N9cWXbftznHTHC9n6Hw2EbvEcUZccdo9INAAAAYCBgTXeaiRS6e6PK/bOLjtA1x43TpGEFCc9xu5zyxqzntg3drOkGAAAAMABQ6U4zkfbyfR2iZsfpdOiwkUVy2aznjrBb/z2ymPZyAAAAAAMToTsJZWVlmjx5sqZPn97Xl9KpkKm9vC/YDVMbWUx7OQAAAICBidCdhNmzZ2vNmjVaunRpX19Kp9pDd9+kbvO2YRH2g9QI3QAAAADSH6E7zUTWdPdVpTu2vdzhkIYWuOPO8weNTvfyBgAAAID+jtCdZvq60h07vTw/y6WSvCzbc6l2AwAAAEh3hO40E2ordfdR5o6rdOe5XSrOzdKPzjtUx04s0fjS3Oh9idZ117f49X8vfarHP9zam5cKAAAAAL2OLcPSTKS9PKOv1nTHDFLLc4dD+NdPmqivnzRR//vCKm2taZGUeIL5N5/6WIs310qSZk0aonGleb14xQAAAADQewjdaeba48errtmrgLe1Tz6/O2aQWn52puV2rrv9frtK9/aalmjglqRdda2EbgAAAAD9FqE7zXztxAkKhUKqrKzsk88fu2VYvjsmhLvbX3J2a7r/tnCT5XYL+3kDAAAA6MdY040eFbemO8v6vk5uVuLQXdfi0/PLdlqONTFsDQAAAEA/RuhGj4qbXu52xdw2tZfHVLE3VDbFbSMWG7oNw1Blo6cnLhUAAAAAeh2hGz0qttKdn91BpTtmTbddu3nssZueWqYZ98/T397dFHcuAAAAAKQaQjd6VLYrfsuwRLdjK92tvvj12+ZKdzBk6M3PKiRJP//Pun2+VgAAAADobYRu9Kj4QWqxodvcXh5T6e4kdHv86TFUzTAMrd5Vr5YE+5QDAAAASB+EbvSo7Jg13XlZiSvfse3lrTYh1BzMvTHrvfurPy/YpC/+8X1d8pdFCkU2VgcAAACQlgjd6FHuuDXd1n26O9oyrLNKtzdgvT926Fp/8as310uS1u5p0Pbalj6+GgAAAAC9idCNHpXdyT7duabKd2zIbrEN3e3HvH5ryK5r8XX7OlNFg8ff15cAAAAAoBcRutGjOhuk1lGl2669vMkUSmPby/e29P/A2tDKum4AAAAgnRG60aM6G6Rm3jKsJWZ6uV17uXnCeewgtb1pUOlupNINAAAApDVCN3pUbKU7NnRnuZzKygi/7JriKt2drelOv/by+lZCNwAAAJDOCN3oUbGD1GLbyyUpt22dd+z0cnO7eWTtd0eD1Ozay+etrdDpDy7Qw+9t7uKV9410aJEHAAAAkBihGz3KHbtlmE3ozmtrMW+OaS9vNbWPDylwt50TkGGEt9WKHaRm117+tcc/1qaqZt33+tro41JZOrTIAwAAAEisR0P35s2btXbt2p58SvQz2bFbhtmF7kil2xSoJev08qFtoTsQMqJt5Z6YSnddJ1XiVn98u3oqcDjaP97bTOgGAAAA0lm3Qvcf/vAHXXnllZZj119/vQ466CAdfvjhmjZtmiorK3vkAtG/xG4ZluF0xJ0zrDBbUjgUf7KjLno80l6e5XKq0LS/d6TFPK7S3Ulgja2kp4oC0xsRtJcDAAAA6a1bofvhhx/WsGHDorfffPNNPf744/rGN76hP/7xj9q8ebPuvffeHrtI9B/umEFqdr40dWT046cWb4t+HKlM52VlKD87fmuxrm4Z1mKzBVlPavEFFAiGOj8xRoHpDYV0GAYHAAAAILFuhe5t27bp0EMPjd5+7rnnNGHCBP3lL3/Rt7/9bd1888164403euwi0X/EVrrtnD91pIpywsFzzqo90Yp1pDKdm+WyrAVv9ERCd2x7eceBNXY6ek9avate0+97W6c+uKDL4d5c/WdNNwAAAJDeuhW6YwdUvfXWWzr33HOjt8ePH6/y8vJ9uzL0S7FruhOdc+kxoyVJvkBI/16+U5LU2hZec7IyLC3YiSvdHQfWFpstyHrKm5+Vq9kX1I7aVi3ZUtulxwZD7T8/na1LBwAAANC/dSt0H3zwwXrppZckhVvLd+/ebQndO3fuVHFxcY9cIPoXl9OhI0YVSZKuO358wvMumzY6+vEn2+tkGIZaTO3l5kp3pGLt8XdtkFpzL1a6zVV0u/3FO2IJ3a3+fjFlHQAAAED3xI+WTsIdd9yhq6++WoMGDVJzc7MOPfRQnX322dH733nnHR155JE9dY3oRxwOh/5540yt3FGnmRNKE543ZlBu9OOaZq88/pAi2TMnQeiOrXRHAqvDET+sTerdQWrmoN3VKekBU+gOhgw1eALRdnsAAAAA6aVbofvKK69UaWmp3njjDRUXF+vb3/62XK7wU9XW1qqkpETXXHNNj14o+o/C7EyddNCQDs/JzcpQdqZTHn9Itc0+y7ro3CxXTHt5ONTGTi+PDayhkLVi3NyLg9TMretdbWMPxVS261p8hG4AAAAgTXUrdEvSmWeeqTPPPDPueElJiV588cV9uiikP4fDodI8t3bVtbaF7vbgmhtX6Q63kccOUpOsgdUXM0m8N9vLzdcb2/bemdiJ57XNPo0rzeuR6wIAAACQWrq1pttOS0uLHnnkEf3lL3/Rtm3bOn8ABrySvCxJ4dBpXiOdG7NlWFOk0h2I357LvG1YbOjuagX6zws26rQHF+iddRWdntvq7/6a7piCPMPUAAAAgDTWrUr31772NX300UdavXq1JMnn8+nYY4+N3i4qKtI777yjo446queuFGmnND8cukOGtKe+NXo8N8ulfHf7FPQmj/0gNUnR7cak8CR0s65Uuhs9fj0wd70k6ZtPLtOG+7/Q4fkt+7Smu2tT2AEAAAD0X92qdM+fP18XX3xx9PbTTz+t1atX65///KdWr16t4cOH69577+2xi0R6ilS6JWnnXnPotraXJ9oyTLIG1n0J3Z/tboh+7A92Pk28dV/WdMd8GXupdAMAAABpq1uhu7y8XOPHj4/efvnllzVt2jRdddVVmjx5sm688UZ99NFHPXWNPWrOnDmaNGmSDjroID388MN9fTkDWmkHoTvfvKbblzh0VzZ6ox/Hhe4uhOFPd9Ynfa60j2u6Y1J3HZVuAAAAIG11K3Tn5eWprq5OkhQIBLRgwQLLlmEFBQWqr+9aiNkfAoGAbrvtNr3zzjv65JNP9Ktf/Uo1NTV9fVkDVkmeO/rxjtqW6Mfh9nJT6G5rL/fahNttNe2Pi1/TnXyl+9Nd1terPxgf8K3P3b32csMw4tZ0014OAAAApK9uhe6jjz5aDz30kD755BPdf//9amxs1Pnnnx+9f9OmTRo2bFiPXWRPWbJkiQ477DCNGjVK+fn5Ovfcc/XWW2/19WUNWJE13ZK0Y685dCffXr61ujn6cWylu6kL+3THhu7Ohpu1+ro3SC0Ym7hFezkAAACQzroVuu+//35VVlZq2rRpuvfee3XJJZdoxowZ0ftfeuklnXDCCT12kRELFy7U+eefr5EjR8rhcOjll1+OO6esrEzjx49Xdna2Zs6cqSVLlkTv2717t0aNGhW9PWrUKO3atavHrxPJMbeX76htby/PycpQZoZTblf45RmZbB5p43a7nBqUG94mbFuNKXTHVrqTXNPd4PFriym8Sx1Xnw3DUIu/e5XuoBEfuiOVfAAAAADpp1uhe9q0aVq3bp1efPFFzZ8/X88991z0vrq6On3729/WHXfc0WMXGdHc3KypU6eqrKzM9v5nn31Wt912m+6++24tX75cU6dO1dlnn63KysoevxbsO/MgtfrW9mpvXla4yl3cFqxr2yaURyrZbpczuq/17npPNIzHV7qTC7Ord8UvhahtThy6vYGQzNl5XyvdsdcNAAAAIH10a8swSRoyZIguuOCCuOPFxcW69dZb9+miEjn33HN17rnnJrz/N7/5jW688UZdf/31kqS//vWvev311/XII4/oBz/4gUaOHGmpbO/atctSoY/l9Xrl9bYP6mpoCE+4DoVCCsWOoE4hoVAovHY4ha9RUrRaHSs706lQKKShBdmqaPCquskrnz8gTyDYdn+GxpXmasWOOknStuomHTSsQF6/NWS3+AJJfQ9WtT2PWW2TN+FjmzzWdvBWfzDp77U/EB/QfYHkH98f9ZfXIwYGXo9IJbwekWp4TSKV9IfXY7LX1u3QLUnvvvuuXn/9dW3btk2SNG7cOH3xi1/UySefvC9P2y0+n0/Lli3TnXfeGT3mdDp1xhlnaNGiRZKkGTNmaPXq1dq1a5eKior0n//8Rz/+8Y8TPufPf/5z263Pqqqq5PF4ev6L6CGhUEj19fUyDENOZ7eaGfaLUIIKsaepQZWVQQ1qm7MWMqR1W3dH11G7HIaGZLdXjFdu3qMiR6sqa+osz9Po8Xfa5bClplV/WbAx7vj2ihpVDrX/3u1p8FpuN7Z6k+6mqLdpJW/2+NK6G6O/vB4xMPB6RCrh9YhUw2sSqaQ/vB4bGxuTOq9bodvn8+mqq67Syy+/LMMwVFxcLCncWv7ggw/qoosu0jPPPKPMTPtKZm+orq5WMBiMG+A2bNgwrVu3TpLkcrn04IMP6tRTT1UoFNL3v/99lZaWJnzOO++8U7fddlv0dkNDg8aMGaMhQ4aosLCwd76QHhAKheRwODRkyJCUfYFK4bXRbpczbkDa6OGDNXRogcYOqZI2h1u/A1n5iiydznVnavLYIdLiPZKkukCmhg4dqtxK6/O0+kMaOnRows/f6gvqu4+8q3pPfPj3Z2QnfGy9Yf3h8occHX4es4wmb9yxkMOZ9OP7o/7yesTAwOsRqYTXI1INr0mkkv7weszOzk7qvG6F7nvvvVcvvfSS7rjjDt1+++3RoFtZWakHH3xQv/rVr/STn/xEP/3pT7vz9L3qS1/6kr70pS8lda7b7Zbb7Y477nQ6U/YfPsLhcPSL6yzNy9LuemvXQJ47U06nUyOKc6LHKhu98ra1ZrszMzRhcH70vm21LXI6nQrErJf2+EMy5FCG02H7uRdtrlJ5W9X68FGFuv3MSbr+saWSwtPLzd+7zVVNynI5NXpQrjwB6+dp9QeT/j4bir8WXyCU8v9O+6q/vB4xMPB6RCrh9YhUw2sSqSTVX4/JXle3rv7pp5/WtddeqwceeMBSWR46dKh++ctf6qtf/aqefPLJ7jx1tw0ePFgZGRmqqKiwHK+oqNDw4cP367UgeSWmbcMiItuFDStsf+dod51H/mA47LpdTo1vG6Qmte/VbbelWEd7ddc0t1edr54xTmNKck3P2awnF23V5xWNWrmjTqc9+K5O+dUC7ahtsezRLXVtkFrsGwMSg9QAAACAdNat0L1nzx7NnDkz4f0zZ85UeXl5ty+qO7KysnTMMcdo3rx50WOhUEjz5s3Tcccdt1+vBckrzYvvJMjNypAkDTeF7u217ft4u10ZKs7NVGF2OJxvbds2zC68NnewV3dtc/tAtJK8TMs09bfXVurHr3ymLz/8kT7cVCMpPHl86dZatcYMbPMGQrZTye3YTi8PEroBAACAdNWt0D169GgtWLAg4f3vvvuuRo8e3d1rSqipqUkrVqzQihUrJElbtmzRihUrtH37dknSbbfdpoceekiPP/641q5dq29961tqbm6OTjNH6jHv1S1JDoei+3MPL2oP5Ob9uLMznXI4HJowOFzt3lXXKm8gaB+6O6h0m/fiHpSbpaKcTDliur+rGr2qamyviLf6g3GVbql9D/HO2IVuuwo9AAAAgPTQrTXd1157re6++24VFxfru9/9rg488EA5HA5t2LBBv/vd7/T888/bTv3eVx9//LFOPfXU6O3IkLNrr71Wjz32mK644gpVVVXprrvuUnl5uY488kjNnTs3brhaV5WVlamsrEzBYPJtxEhOSUzozstyydGWfM3t5dtiKt2SNKQgW1K9DENqaA3Ib1Mxbu5gr27zXtyD8rKU4XSoOCdTe1usW4I1edtvt/qCysyIf6+q1R+MtsV3JGjQXg4AAAAMJN0K3T/84Q+1adMm/f3vf9dDDz0UXUAe2Uvt2muv1Q9/+MMevVBJmjVrlgyb0GJ288036+abb+7Rzzt79mzNnj1bDQ0NKioq6tHnHuimTyjRw+9vid4+bGT7VPiC7EzlZWWo2RfU9hpT6M4Mv97y3RnRY83eQJfby+tiKt2R/48P3e3BPWHoTnJdd6JKt2EY0TcbAAAAAKSPboXujIwMPfbYY7rtttv0xhtvWPbp/sIXvqApU6b06EUifZ01eZie++ZxWrunQZJ07uHWoXfDirK1uarZMoAs0n6ea6osN/sCtmujOxqkZq50F+eGt7cblJclVTdbzms07a3d6g/KlaDSnYxEa7/9QUNZLkI3AAAAkG66FbojpkyZYhuw33jjDb388sv6+9//vi9PjwHA4XBoxoQSzZhQYnv/8MJw6DaLtJfnm0O3135Nd1MH7eWRinZhtitavY5UvM0aTKG7xReUy2YLsn2pdEvhYWpZrtTcCgEAAABA9/XKX/mffPKJ/vGPf/TGU2OAMa/rjshuay+PTDmXwu3l9luGdTS9PFzpNq8rL8nLjDuv0dPebu5JMEhtXyvd5jcMQiFD72+ojlb/AQAAAPRf+1TpBnqbXei2rXT7ujZILRAMqb41HKYHmUJ3YbZd6La2lzvtKt1Jhm67fbola+h+a025bnpquVxOh97/39M0vCj+ewAAAACgf6CfFSlteGH8Pt6RNd15lvbyrg1Sq2s17dFtaik3H48wV7pbfEHbVvJk28tDCQYBegPtj1+5s15SOKCvLafaDQAAAPRnhO4klJWVafLkyZo+fXpfX8qAY1flddu0lzd5g10apLY3Zruw6Oezqax7/CHTx0Hb50w2dAeCnVe6zXt++9lODAAAAOjXCN1JmD17ttasWaOlS5f29aUMOOMH58Uds2svb0lQ6U40SM08udy8pvv6E8ZrZAft3K2+fVvTnbjSHbL92O6NBAAAAAD9R9Jrur/0pS8l/aQbN27s1sUAsSYMzlOG02EZQNY+SK395dvksw/dC9ZXqdUXVI6pKi5Je1vitwuTpNJ8txZ+/1Q9tXib7nltTdzzJRrMlnSlu4Pp5RGWSjehGwAAAOjXkg7dq1atksOR/D7CY8eO7dYFAWZuV4bGl+Zqk2nbMPtKt7W9/MgxxVqxo0676lp11yurdeohQ3XCgYNVlBMO2LXN9mu6JcmV4bQEejNPgop20pXuBKHb67evdPsD9ucDAAAA6B+SDt1bt27txcsAEjt4WEFM6I4MUrNuGWaudP/84iP0pT+9L3/Q0PPLdur5ZTt18sFD9MQNMyRZK93mNd0RifbMbvUHZReDuzO93OV0RG+b3zDwmp6L9nIAAACgf2NNN1LeQcMKLLcjg9TM08ubvAFLQJ00rEDfPPkAy+MWbaqW0bamem+CNd3Rz5EgdLf49m2QmrlN3tzybh2kZqp0E7oBAACAfo3QjZR38LB8y+1Ie7k5dLf4gtHg6nI65HQ69N0zD9bPLjoieo4/aEQHqNWaK925Xat02w5S607ozrQP3ebtw+zWqQMAAADoPwjdSHkHDY2pdLcF4txM85Zh7e3lkcCc4XTo6pljdcW0MdHzyhs8kpKpdGfEHZPCIbi5bSJ6YXZ76E+2vTxoml5u3vLMF2x/PJVuAAAAIH0QupPAPt19a0LMtmHZbWHb6XREg2uzNxANqLFV6mGmLcAq2kJ3bUt4kJrDoehwNbNElW5JihSrS/Pd0WNJh+5Qe4jONr1p4PXbTy/3JdjXGwAAAED/QOhOAvt0963YAJzhbJ+iH2kxN7eXZ2VYzx9e2B66y+u9ktor3cU5mZbni0i0ptvMXCFPNNU8lrlwba10J5heTqUbAAAA6NcI3egXzK3c5q3C8tqCq3mQWmxIH17UXpGObS+3m1xu9xx2zKE70f7dscyVbvO2ZNZBaqZ9ulnTDQAAAPRrSYfuyspK+Xy+zk+UVFVVpYULF3b7ooBYz37zOB05pljfmnWAxpTkRo9HKt3N3kC0Qhxb6R5mqnRX1HvkC4TU2LYuO3aP7ohkKt0Fblf0cy3btld/emdDdL13IubCtXl6uTdB6GbLMAAAAKB/Szp0jxgxQi+88EL0dn19vSZPnqyPPvoo7ty33npLp556as9cISDp0BGFenn2Cfrfcw6xHM9rqxYHQkY08MZVus3t5Q0e1ZkmlxcnCN3JVLqLc7PkNJ3267c+15OLt3X4GHOlO/H08s7by/3BkBo9/k6vEQAAAEDfSjp0G4Z1oFMgENC6devU3Nzc4xcFJCvP3R5cIwPOYgNzSV5WtCJd0eCxbBdWkhc/RE1KPL3crDQ/S06HdT34uj0N2lTVpD8v2Kide1viHmPeMizXptJtGIYldPsC1p87wzD0wrKdmnbf25px/zx9urO+0+sEAAAA0HdY041+zbxXd0Rse7nD4dDQwvC67vIGj/Y2t1eI93VN99FjB1mOVTV5ddlfF+mBuet1+3Mr4x4TMO/TnRVf6fbGrOGOrXTf+9oa3fH8StW3+tXqD2rOqt2dXicAAACAvkPoRr+Wl2UTum0Cc6TFvK7Frz31rdHj+7KmuyQvSw9cOkU//EJ7y/uuva2qbRvS9tGW2rjHhDrZp9u8dZgUH7pfXL7TcntbTXw1HQAAAEDqIHQngX26U5ddpTszI/5lbd6re115Y/TjhJVum+eIVZqXpZHFOfrGyQdoTEmOJGlrJyHYUum2WdPtCVinoPtiKt+emFC+tYblHQAAAEAqi08sHWhublZtbbh6F/n/xsbG6McRTU1NPXR5qWH27NmaPXu2GhoaVFRU1NeXAxPzmu6IjirdkrR2T0P040SVbqfTocwMh/xBw/Z+yRrYhxZka0dta8JzI0KW9vL2H79IW3nsft/m6eWhkBE3zXxbTYsMw5DDEb/XOAAAAIC+16XQfdNNN+mmm26yHLv44ovjziMEYH+xXdPdhdCdqNIthYep+YOJtwArtYRud8LzzDqrdHe0pttu+7BWf1CVjV7LtmgAAAAAUkfSofvuu+/uzesAusUudLs7aS+vbjJPL08curNcTslrf1+G06HC7PbJ50MShO5QyJDT6bDcjsi1GaQWW+k2V9pj13tHbK1uJnQDAAAAKYrQjX4tL6vr7eVmidrLpY6HqQ3KzbKE6USVbm8gZJlSnnB6ebDzSrc3Zr13xLaaFs2cWJrwWgEAAAD0HQapoV9LdpDauNLcuGNOh1SQnfh9p462DSuNqZAPLbAP9bGV62CC6eUJ13QHQnHnSFKh6boZpgYAAACkrqRDd3l5uRYuXBg3JM3v9+uuu+7SAQccoNzcXB199NF69dVXe/xCATvJbhk2tMAdVxWPrVbH6rDSnZdpuT2k0L7SHTuNPGhqF8/OtAvd1kq3L2gfuicNL4h+zLZhAAAAQOpKOnT/4he/0GWXXaasLGuF7/bbb9f999+vvXv36rDDDtP69et1ySWXaOHChT1+sUCsZKeXOxwOTRiSZznW0RC1RM8TUZpnDdlD8hOE7pgQba50Zzgd0a3J2gepxa7ptm8vnzg4XxltbxhQ6QYAAABSV9Kh+91339X5559vCd1VVVX685//rEMPPVSbN2/W0qVLtWbNGg0ZMkQPPvhgr1wwYJZvN708wR7bEwbnW253tJ5bCk8vTyR2ANvQBJXu2BAdNK3pdjkd0Wq6r+282JDuD5gGqZkq3Xlul0YPCu8NHtk2DAAAAEDqSTp079ixQ4cddpjl2Jw5cxQKhXTHHXeouLhYkjRu3Dhdf/31+uijj3r0QgE7uUluGSZJEwbHVrozbc+LPk+C8C7Fh+7SPLfsOtXjKt2m0O10OKLXGmkjj59ebqp0m57LnenUuNLw19PkDaim2ScAAAAAqSfp0O3xeJSfb60Uvvfee3I4HDr99NMtxw844ADt3bu3Z64Q6EBJbpZcMWk3UVieGBO6O9ouTAoH20RK862PzXA6NNimxTxukJq50p1hCt0J9um2DlJrfy63y6lhponptYRuAAAAICUlHbonTJigFStWWI7Nnz9f48aN05gxYyzHm5qaVFJS0iMXmArKyso0efJkTZ8+va8vBTFysjJ01mHDLMcSVbonxq7p7qS9vCuVbsm+xbyj0J1hqnQnnF6eYJBalsupAtM+4Y0ef8JrBQAAANB3kg7dF198sR5//HE9++yz2rFjh+6//35t27ZNl19+edy5ixcv1sSJE3v0QvvS7NmztWbNGi1durSvLwU2vjxznOV2otA9vsuV7g7WdNsEdrtharGVa0vothuk1lF7uem53K4MFea0t9Y3tAYSXisAAACAvpN4k+IY3//+9/Xaa6/pqquuksPhkGEYmjRpkv7v//7Pcl5NTY1effVVfe973+vxiwXsHH9AqeV2s9c+gBZmW9dw71OlOz/+scMK4/fq7mif7gynI9rCnqi9PGSEg3qG02EJ5G6XUzJ9PQ1UugEAAICUlHTozsvL05IlS/TSSy9p8+bNGjdunC688EJlZ1uDxq5du3Tvvffq0ksv7fGLBew4HA7dctqB+uM7GyVJk0cWJvW4zA62BJM6XtNtVyW/6KhRenXlbrX42sOxN2aQWiBBpTsQMhQMGXEhXQpXuzOcGTGVbqelot/oodINAAAApKKkQ7ckuVwuXXbZZR2eM2XKFE2ZMmWfLgroqu+ccbCyMzPkdjl17ITShOeNLcnV9toWSdLgzvbp7qDSbVclnzmxVMt+dKZe/3SP7nh+pSTJE7NlWCg2dJuCsy8Qipt2LoWr39mZGZahau7MDGVltA+QI3QDAAAAqSnpNd1AKstwOjT71AP19ZMmymm3d1ebv3/1GI0vzdW5hw/XcQckDudSfKX7uuPHy+V06KoZY5SZIJDnZGUox7QWPLZyHVfpNu0F7guE4vb1ltrXdcdWugtpLwcAAABSXtKV7i996UtdemKHw6FXXnmlyxcE9KZDhhdqwfdOTepcd0yw/tqJE/SDcw9RdgcD1iQp2xTWN1Y26dZ/faIZE0r05ZnjrJVuh8NSTfcGg7aV7vbQbV3TzfRyAAAAIPUlHbrnzJmj7OxsDR8+XIZpGFQiDkfiaiPQH8ROL3e7nJ0GbkmWc577eKckac6qPTrnsOGWSrfL6bRU0xNWugPhx8ROLy/Ibv/xpb0cAAAASE1Jh+5Ro0Zp165dGjx4sK6++mpdeeWVGj58eG9eG9CnYtd0J9qKLJbb5rxgyFBlo1ch0xtWTqe1mp5oTXdkr27zULbwPt2EbgAAACDVJb2me8eOHZo/f76OOuoo/fSnP9WYMWN0xhln6NFHH1VjY2NvXiPQJ2JDttvVeZVbUsJqeJM3oEDQWuk2fw5vIBQ3eE0ybyeWuL28oZX2cgAAACAVdWmQ2imnnKK//e1vKi8v1wsvvKDS0lLdfPPNGjp0qC6++GK98MIL8nq9vXWtwH4VW7FOttKdnWCrsSZPwLJPt9OpuOnlsVuMSfaD1LIzw4E98rmodAMAAACpqVvTyzMzM3XBBRfo2WefVUVFRTSIX3HFFXrggQd6+hqBPmEOxBlOhzI6mIpulqgi3ugNKBizptvcwu4L2le67aeXhz9HpNrNIDUAAAAgNe3TlmFer1dvvvmmXnnlFX3yySfKzs7W+PHje+jSUkdZWZkmT56s6dOn9/WlYD8yh2e7ddoJH5eg0t0cE7qdDsUPUutwTbe1vVxSdF03lW4AAAAgNXU5dIdCIb355pu67rrrNGzYMF111VVqbW3VQw89pMrKSl1zzTW9cZ19avbs2VqzZo2WLl3a15eC/chc6U62tVzqYE23pz10ZzgdcjgcyspoP9cbCCbYpzv8mEj4lmwq3TGBHgAAAEBqSHp6+Ycffqinn35azz//vGpqanTsscfqZz/7mS6//HINHjy4N68R6BPm6nZXKt3ZSbSXZ7RtqRe7ptt2enkgfnp5pEJeaJpg3uQNqCgnUwAAAABSR9Kh+8QTT1ROTo6+8IUv6Kqrroq2kW/fvl3bt2+3fczRRx/dIxcJ9IXuVrozMxxyOqTYwnNspVuyDl1r9Qfl8Xe0prv9vsha8ELTBPNGj5/QDQAAAKSYpEO3JLW2turf//63XnzxxQ7PMwxDDodDwWB8gAD6C3N1O3bP7o44HA5lZ2aoxWd9/Td5/dHp5ZHQbQ7JNU0+BWxaxGMHqWVlOOVsezx7dQMAAACpLenQ/eijj/bmdQApJ8vSXp7cHt3t5zttQnd8pXtQblb0/ooGj+1zte/THYq7LnPoZq9uAAAAIPUkHbqvvfba3rwOIOWYg3ZX2sulyDA1awhutGkvH5TXHrr31NuH7sggtUh7ubkCb20vp9INAAAApJp92jIMSGfubq7pluwnmNtXuttDc3mC0O1rC9uRQWruBJXuRi+VbgAAACDVELqBBLo7vTzR+ZZBam3Ty4tz7Svd5se3V7rbQrcp0BdQ6QYAAABSGqEbSKAoN1PFbZXoCYPzuvTYRJXuQEyluzDbFf14V11r9NwRRdnRj30x08sTVbpZ0w0AAACkni5NLwcGErcrQw9/dZoWb67RVTPGdvGx9pXuzLbjkaDtcDhUnJOpmmaf5dxRg3K0taZFUnh6uWEY0YFqljXdOVS6AQAAgFRG6AY6MG18iaaNL+ny42wr3b6A8h3hHzlXW+iWpOJcm9BdnBP92B8MKRAyovt+mwe8WSrdhG4AAAAg5dBeDvSC7Mz4Hy3DCLeYS4rusy1Ztw2LGFWcG/3YFwhF13NLkjvTfnr5u+srtb68cd8uHAAAAECPInQDvcCu0i2Fg7cUW+m2Cd2DzJVuQ15/+57fidZ076736JzfL9SybXu7fd0AAAAAehahG+gFnU07dzraQ3dJXmbc/eb2cl/QWunOsoTuTMvnMgxp2bbabl0zAAAAgJ5H6AZ6QaJKd4QrI3F7udvl1OD89mP+2PZy05ruDKdDd557iOXxtc1MMQcAAABSBaEb6AWdhW5zpTu2vXxoodtSzfYHQ9HtwqT4Kvp1J0zQf249KXp7b8xQNgAAAAB9h9CdhLKyMk2ePFnTp0/v60tBP5HdSXu5yzJIzdpePrQgW5kZ7Y/3BUPy+s2V7vjnLslrD+61Lfah2x8M6clFWzXrV/N19UOL5TGtEwcAAADQOwjdSZg9e7bWrFmjpUuX9vWloJ9wd1bp7mCQ2tACtzV0B4yY6eXxz21uUberdAdDhi776yL9+JXPtLWmRR9uqtGrK3d3/oUAAAAA2CeEbqAXBCObaifQcaXbrawMa3u5L9BxpTvL5VSBOzzJ3K7Sva2mWSt21FmOvb5qT4fXCAAAAGDfEbqBXuAPhjq8P8McuvNi13RnK9PVfn9na7pjn8eu0t3ii28l/2BjNeu/AQAAgF5G6AZ6ga8Lobs4ptI9JKa93B9MPL3cLBK661r9cZV2u/XbgZChuZ+Vd3idAAAAAPYNoRvoBeZ2cDsZjsRbhg0tcMvldChyii8QU+nOtP+xLWkL74Yh1cW0mHtMg9hOPnhI9GNazAEAAIDeRegGesHgfHf0Y7t2cHOlOzOjfT22FJ5e7nA4otVuX9CwTC83r/c2M7ep740L3e2hffq4QRpWGL6+1bvrk/p6AAAAAHQPoRvoBdcdP16jB+WowO3So9fFbzVnDt2SVJzX3mI+tC0QR8J1XHt5wkq3aduwZr/lPo+pUp6TlaGSvPDnsFvrDQAAAKDnuDo/BUBX5bldWnDHLAVChhyO+PtjQ/eIohztqG1VXlZGNDxnZoTPiR+kZr+muyTfHLoTt5e7MzOUmxV+Dl8gpEAwJFeC6jkAAACAfcNf2kAvcWU4lZ2ZIbcrQ3lZ1qAcG7pvP/NgHTexVPdecHh0D+9Ie7k/ELK0lyeaXm6udMe2l7ea2suzXc5o6JakFpshawAAAAB6BpVuYD8YUuBWc01L9HZs6J45sVTPfKPUciwSuj2Brk0vl+Ir3V5z6DZVuiWp1RdUYbZ1gjoAAACAnkGlG9gPhhZkW25n2PWcx4gMO6tt9qmutT1EJ1zT3UHo9sSF7vb321jXDQAAAPQeQjewHwwpcFtuuzI6D90TBudHP169qyH6caL2cvPWY/94f4uu+cdH+qxtOrl5TXd2plM5pkp3szfQ6bUAAAAA6B5CN7AfxIZuZxKV7olD8qIfr9xZF/14zKBc2/PNlW5Jem9Dtf48f5Mkm0p3pqm9nDXdAAAAQK8hdAP7QVyl25lE6B7cHroNI/z/wwuzLWu3zYpyMuMmpb/+6R5J1i3Dsl0ZynXTXg4AAADsD4RuYD+Iq3QnE7qH5McdO2REQcLzM5yOuIFog/PDnze2vdw6SI32cgAAAKC3ELqB/WBoNyrd40pz4yrXh44o7PAx9a1+y+3I+u/4QWqmLcOodAMAAAC9htAN7AfdqXRnZ2ZoVHGO5dghwxNXuu00esIh3Fzpdmc6lWNa091M6AYAAAB6DaEb2A+6s6Zbim8xn9xJpfvbsw6w3G7yBmQYhryBxFuGRdrLH/9wq05+YL5eWbErqWsDAAAA0DlCN7AflOZZQ3cy+3RL1mFqWS6nJphu27n9rEl6+saZ0Yp4yAhXss3t5TmZGcp1x7eX3/3qZ9pe26Jb/7UiqWsDAAAA0DlCN7AfZMRUtjOcyf3ombcNO3hYvlwZHT8uw+nQ8QcM1vjS9sc1evzR9vIMp0OZGU7rlmG0lwMAAAC9htAN9IEku8stle1DhnfcWm5WkN3ePt7oCUQr3dltg9XM7eUtvqCCIcPyeMOw3gYAAADQPYRuoA80epPbpmvGhBIdNrJQ+W6XrpoxNunnLzBtHdbo8Uf36c5uq3DnZJkHqQXkC4Qsj2/1U/0GAAAAeoKr81NQVlamsrIyBYMEEXRfXlZGdFJ4TZMvqce4XRmac8uJ8gVDcrsyOn9AG3Olu8ETiLaXR0K3dZ9u65pvSWpoDViq4QAAAAC6h0p3EmbPnq01a9Zo6dKlfX0p6MdK8rOiH9c2e5N+nMPh6FLglhK3l7szwz/yeTHt5d6YSnfsft8AAAAAuofQDewnJaYJ5nW9HGoLY9rLvZFKtyu+vdy20u0hdAMAAAA9gdAN7Cc/Pu/Q6MffO3tSr34uc6W7rsUvXzDSXh7+kc9yOaN7hbf4A9E13xENVLoBAACAHsGiTWA/mTa+RE9+bYb8wZCOm1jaq5/LPEituqm9lT3btFVYTlaGGj0BtXiD0Up4BO3lAAAAQM8gdAP70UkHDdkvn8dc6a5qtA/duZHQbTtIjdANAAAA9ATay4E0lDh0t//IR4aptfgCNoPUktvSDAAAAEDHCN1AGjK3l1eZ28td1vZyKbwnN4PUAAAAgN5B6AbSUMJKd5a1vVyS/EFDTV5rZZs13QAAAEDPIHQDacjtciozIzydvNHTHqitle72YL63xRqyk13TvWJHnb7++Md6beXufblcAAAAIG0xSA1IQw6HQwXZmapt9lmOm9d055qGqu2NOS/Z9vIH5q7Th5tq9PbaCk0dXayxpbn7cNUAAABA+qHSDaQpc4t5hGV6udsUulusoTvZQWofbqqJfvzbtz/v6iUCAAAAaY/QDaQp+9BtqnRnJQ7dybaXD8ptH9j28opdWrunoauXCQAAAKQ1QjeQpgrcmXHHrPt0t4fy2Db0ZEO3K6P9V4hhSM8u3dHVywQAAADSGqEbSFO2lW7zIDXLmm5ryG70BhQMGbbPW9fi00ebaxQKGXFbjdXEhHcAAABgoGOQGpCmzHt1R7iTbC+XpEaPX8W5WZZjwZChC8s+0NaaFt1+5sHy+kOW+70xIRwAAAAY6Kh0A2mq80Fq5i3D4kN3g80wtapGr7bWtEiSPtpSK18wJnQHQnGPAQAAAAYyQjeQpgo7C92mj/3B+Fbyept13U3e9iBuF9S9ASrdAAAAgBmhG0hThTk2g9Rc9u3lduz26m42h26b9dtUugEAAAArQjeQpg4aVhB3zFzpzuksdNtUupt97aG71q7S7Sd0AwAAAGaEbiBNHT22WE6H9Zg5aOe5O56jaNde3uxtbx/32ARs2ssBAAAAK0I3kKYKsjN1yPBCyzHzlmF2g9bM6mxDd/xwNTPaywEAAAArQjeQxqaPH2S5nW3aMqzQZksxs892N8QdayJ0AwAAAF1C6AbS2LTxJZbbbtOa7iK7QWuZTuW1taAv2lQtw7BONe+00s0+3QAAAIAFoRtIY9M6qHTnZmXIFbPoO9/t0rETSyVJ1U0+ra9otNxPezkAAADQNYRuII2NKMqx3M7KaP+RdzgccdVutytDxx84OHr7g401lvubvB1Xsr2BUFx1HAAAABjICN1Amvv5xUcoO9OpG06YIIfDWtmODd3ZmU6dcGBp9PaHG6st97f4Oq50S5IvSLUbAAAAiOh4fDGAfu+qGWN1+bQxyojdP0xSoU2le9KwAg3Oz1J1k08fbamVPxhSZluFvLNBalK42u12dbwHOAAAADBQUOkGBgC7wC3Fh+7sTKccDodmtq3rbvIGtLmqOXp/Z2u6JcnHum4AAAAgitANDGB2a7oladKwguixzVVN0Y+bE6zpzjFNRWeYGgAAANCO0A0MYEU51hUmkenmE4fkRY9trm6vdCdqLzeHd7YNAwAAANoRuoEBLFGle+Lg/OixTeZKd4JBapbQTaUbAAAAiCJ0AwOY3fRySZow2FTpTmJNN6EbAAAAsDcgQ/dFF12kQYMG6dJLL+3rSwH6VGG2faU7JytDo4rDe3xvrmqK7r2daE13US7t5QAAAICdARm6b731Vj3xxBN9fRlAn0tU6Zbaq90NnoBqm30Khgy1JgjUVLoBAAAAewMydM+aNUsFBQWdnwikubg13aYp5LHD1BKt5459HkI3AAAA0C7lQvfChQt1/vnna+TIkXI4HHr55ZfjzikrK9P48eOVnZ2tmTNnasmSJfv/QoE0ELdPt6v9V8JEy7rupg736C62hG77angwZOiHL32qbzzxsWqbfd29ZAAAAKBfSbnQ3dzcrKlTp6qsrMz2/meffVa33Xab7r77bi1fvlxTp07V2WefrcrKyug5Rx55pA4//PC4/+3evXt/fRlAv9Bxpbt9gvnmquYOQ7d1Tbd9pfvfy3fq6Y+26601Ffrbwk3dvWQAAACgX3F1fsr+de655+rcc89NeP9vfvMb3Xjjjbr++uslSX/961/1+uuv65FHHtEPfvADSdKKFSt65Fq8Xq+8Xm/0dkNDgyQpFAopFErdFtpQKCTDMFL6GpEaCrIzLLezMhzR182E0tzo8QWfV2nmhBLb58jKcMhtqpB7/AHLay/yenxt5Z7osZc/2aX/PXtSj3wNQFfw+xGphNcjUg2vSaSS/vB6TPbaUi50d8Tn82nZsmW68847o8ecTqfOOOMMLVq0qMc/389//nPde++9ccerqqrk8Xh6/PP1lFAopPr6ehmGIacz5ZoZkEJCbVPJI/yelmjXSIZhaHhBlsobfVpf3qgbHv/Y9jncLqe8ze17eVfvbbB0nkRejztq288ZnOuynCNJgZAhl9Oxz18T0BF+PyKV8HpEquE1iVTSH16PjY2NSZ3Xr0J3dXW1gsGghg0bZjk+bNgwrVu3LunnOeOMM7Ry5Uo1Nzdr9OjRev7553XcccfFnXfnnXfqtttui95uaGjQmDFjNGTIEBUWFnb/C+lloVBIDodDQ4YMSdkXKFLT4EFFGjp0aPR22Zez9I0nl6umgzXYOVkuDSktjt7Oysm1PEfkHcCqpm3tj3FnWc5549M9+t9/f6pzDh+uX106pSe+FMAWvx+RSng9ItXwmkQq6Q+vx+zs7KTO61ehu6e8/fbbSZ3ndrvldrvjjjudzpT9h49wOBz94jqRWnKyXJbXzDHjS/WP66brwrIPEj4mOzND2Vntv0p8gfh3I8sb/ZbtxmqafZZzbn5mhSTp38t36UfnTdagvKx9/VKAhPj9iFTC6xGphtckUkmqvx6Tva7UvPoEBg8erIyMDFVUVFiOV1RUaPjw4X10VUD6yHLF/0o4bGShsjIS/6rIycywrOm2m16+pqLZcru60Rt3TkRLgr3AAQAAgP6oX4XurKwsHXPMMZo3b170WCgU0rx582zbwwF0TSBoxB3LzHDqkBGJ97XPznTK7WofyGa3T/faihbL7UZvQJ4E4bqjKekAAABAf5Ny7eVNTU3auHFj9PaWLVu0YsUKlZSUaOzYsbrtttt07bXXatq0aZoxY4Z+97vfqbm5OTrNHED3Jdpj+/BRRVq1s972PncSle51MaFbkqoavRpTkqtQyBr0Gz2EbgAAAKSPlAvdH3/8sU499dTo7cggs2uvvVaPPfaYrrjiClVVVemuu+5SeXm5jjzySM2dOzduuFpPKisrU1lZmYJB2l6RfpwOKZJ7c7PsfyUcPrIo4eOzMzOUnWkK3TH7dNc2+7S6vCn2YapuCofuBo/fcpxKNwAAANJJyoXuWbNmyTDiW1zNbr75Zt1888376Yqk2bNna/bs2WpoaFBRUeLwAfRHT994rK5/dKnGD87TGYcOtT3niFHW132G06FgW1LP6aS9/PFF2+QNxP9MV7Wt646djE7oBgAAQDrpV2u6AfS8YyeWatmPz9Drt5woV4KBaQcPz7fcLjFNF8/uoL282RvQk4vCW4W5nA7dctqB0fuqm8Jhu6bJGrprmn26+5XV+sV/1sW1ngMAAAD9TcpVugHsf4nayiPMlWxJGpLvjlaqs10ZCSvdLy7fqbrWcPv4+VNHaOro4uh9kcfXNlsnmf/qzfWqb3vMwcPydfHRo7v41QAAAACpg0o3gKT87KIj5HBIsyYN0cji7OjxnKwMuROs6f5sd0P046tnjNXggvZ976ub7NvLI4Fbkuatrey5LwAAAADoA1S6ASTl6pljdd4RI1SY49L//GtF9Lg702nZx9sXbA/djab12cMK3crIaK+IR0J3bUx7uZndvuEAAABAf0LoBpC0otxMSeHhaRHZrgw5nQ5lZTjlC4Ysa7rN238VZGcqx9TGnmiQmllWgjXmAAAAQH/BX7RJKCsr0+TJkzV9+vS+vhQgJZjXgOdkhavXkaq0ub28ybQdWF5WhrIzM1SQHX5stNLdUeim0g0AAIB+jr9okzB79mytWbNGS5cu7etLAVJCdmZ7m3h2WzCOTDA3D1KLVLpzMp3RyehD2tZ1R6aXE7oBAACQzviLFkCX5WaZQndbAG8P3e3t5U1ta7rzTecPzndH72v1BTtsLw8EQwnvAwAAAPoD1nQD6LI8d/uvjty2j91t4duu0m0O6UNNE8x31bXEbRlm1uoPJryvP3pi0Vb9d02FPP6gTj1kqL4968DOHwQAAIB+jUo3gC47a/IwDc7P0qjiHJ1y0BBJpkp325ruUMhor3S720P3oSMKox+v2FHfYXt5qz99Kt1bqpt11yuf6b0N1Vq6da8emLteGyoa+/qyAAAA0MuodAPosjElufrwB6crw+lQhtMhydpebhiGmnztk8vzstrf3ztqbHH04/c2VMkfNBJ+nlZf+lS6y+s9ccfWlTfqoGEFfXA1AAAA2F8I3QC6JXbImdsVrmaHDCkQMizbhZkr3VNGF8vhkAxDmre2ssPP4Umj9nLzWveIrdXNfXAlAAAA2J9oLwfQI9ymvbu9gZCaPOZKd3vozne7NKmtuhtpP0+kP63p9ncy9M281j1iSw2hGwAAIN0RupPAPt1A59ymyrfXH1RjzB7dZuYW8470h/ZywzD07X8u05R73tLc1XsSnmcXurfVtPTmpQEAACAFELqTwD7dQOci7eVSOGA2eu3byyXpqDGDknrO/tBevqW6WW98Wq5Wf1A3PbU84Xlem6+F9nIAAID0R+gG0CMsle5AyLKmO7bSfWRMpfvoscX6whHD456zpR9UunfVtSZ1nsem0l3T7FODqSMAAAAA6YfQDaBHWNd0By1rumMr3QcOydf40lxJ0iHDC/TodTM0KDcr7jn3dU335xWNWrO7YZ+eozPba5NrETdXuvNN+5xT7QYAAEhvTC8H0CPM7eXf+dcKjSnJjd6OrXQ7nQ49+bWZ+nhbrc6aPFx5bpfys+N/He1L6F5X3qBzf/+eDEN66dvH66ixybW0d1XsumzDMORwOOLOM6/pnjS8QMu27ZUUbk+fMrq4V64NAAAAfY9KN4AeYW4vX1feqP+uqYjejg3dUniv74uOGq28tqpvflZ86PYFQgqGEu/j3ZH7X18ro+2hv/nv5916jmRsi5lA3phgInts6G5/PMPUAAAA0hmhG0CPKMzJTHifXeiOO8dt33jT3WFq5rXWdpPDe0psaK5p8tmeZ96nO7JlmkR7OQAAQLojdAPoERceNUrTxtm3cMeu6bY/xz50d7fFvKG1veJckOC5zeatrdDlf12kNz5NvO1XLMMwbEK31/Zcr789+B80LD/6MXt1AwAApDdCdxLYpxvo3KjiHL3wreP1P6cfFHdfMqE7UaU7dq9uw0iu3dw8FTyZ4P7A3PVasrVW97++Nqnnl6SqRm/cc1cnrHS3h+7inCwV54Y7A2qb7c8HAABAeiB0J4F9uoHkjTUNUIvIy+r8V01egmBubi9/7IMtmnLvW/rTOxs6fK66Fp98ppBbnaD6bFbZ6JEklTd4kg72W23WY9c0J6h0m9rL3ZlO5WaGv97YNxUAAACQXgjdAHpUbOjOcDqU7er8V02i9nLzXt33vLZGjZ6Afv3W56prSVwh3lRlbdmuauw8dDd7w58nGDISDkOLtdWmNTzxmu72NwHcLqey29a57+u2aAAAAEhthG4APSo2dOe7XbZbaMVK2F7eFkpjq8+vf7pHS7fWaofNPtmbq5ost/e2+OUPJh6m5g0E5TPdX9fsT3iu2Xa7SncSa7rdrgzltFW6uzsoDgAAAP0D+3QD6FFDC9xyu5zRym6Bzf7bdjobpNYS04b9fy+tliQV52Zq3m2nqDTfreomr/6yYJPmra2Ie56aJp+GF2Xbfo5IlTtib4tPY0vj2+Rj2VW6qxOs0Y5tL4+Ebn/QkD8YUmYG74ECAACkI/7KA9CjnE6Hxpiq3Ykq2LFK87OUa7O1mKctbCcaOFbX4tfCDVWSpAffWq9/vL/Fdq11Ry3mTR5rO3lda3KVbrvnTFjpjmkvzzF9rVS7AQAA0hehG0CPG17YXlFONlDmZrn0+yuP0lePG6dvnjIxejxS6a7pYMr34k21kqSXP9md8JyOhqk1xazh7mi9uN3jXE6HsjPDv06TWdOdleFUdmZ76GZdNwAAQPoidAPocYPysqIf7+3CllhnTh6mn1xwuA4c0r6PdSSQ1iaYCi5Ji7fUqLze02F47bDSHRO6k73myOMKsl0anO+WlPjNAW/btbldTjkcjmh7efi+xOvNAQAA0L8RugH0uJK2PaglqcGT3CRwM3PrdWRLLXMF+drjxumpr83UjPElkqRtNS2as8pa5XY4pBMOLI3eruqg0t0cW+lOsr080paen+1SaVvo3tviUzAUv+VYZAszd9sk9xwq3QAAAAMCg9QA9LghBe59erw5kHr88Wu6Z0wo1YkHDdaSrbVasjXcWv6n+Ruj9z96/XRNGzdIn1c06YONH0rquNIdu0VYXUtyobsxErrdmRrcVt03jHDwjlS+IyLt5e62r83ujQUAAACkHyrdAHrclTPGRoei/frSKV1+vDl0t9gMUitpC7jHTiyJHosEZYdDOnrsIBVkZ2qIKfh2pdK9N4k13eZtxgrcLpXmt7fU263rjkwvj1S6WdMNAAAwMFDpTkJZWZnKysoUDPKHMZCMwfluzbnlRO2p92jm+EGqrq7q0uOzs+IDqXmtdCTgHj12kLJczmjrtiQdPLRARTnh9vbBBe1B2K7S/cHGan26qz5uO7JkKt3miefm9nIpMsG8wHJ+ZN027eUAAAADC6E7CbNnz9bs2bPV0NCgoqKivr4coF+YOCRfE4fkKxTq+pCwztrLI5Xu7MwM3XLqgXrwv59H75s2flD049wsl/LdLjV5A3HTy/c2+3TDY0stU8Ujkplebh6+lu92qdQ0PK6i0RN3frS93BVpL29vNPLQXg4AAJC2aC8HkHIsVWCftdLtcEiDctsD7i2nH6TfX3mkCtwuZWY4dOkxoy3PNbitKh5b6d5c3WwbuCVpbweV7i3VzXpq8TbtqG2NHsvPdmn0oJzo7Z2m+yQpFDKirejuTCrdAAAAAwmVbgApJ8emvTyyZVhxTqYynA7L+RccOUpnHzZcHn9QxaZALoWHum2taVGjJyCPPxhdS93RYLVElW7DMHTdo0u0rabFErIL3C6NHpQbvb1jb4vlcZHALbGmGwAAYKCh0g0g5VgDaTiw1rYNJyvJy0r4mNjALUnDi9rD8e669gp0R4PVGjwBBYLxVfBGb0DbasKBeufe9ucqyHZpbKkpdMdUus37cEfay7NtqvkAAABIP4RuACkn11Tp9viC8viDam4LpqV5XduOzNL2bQrKVQ3x667N6m326q5rtm87z3e7VJidGR3gtr3WWumOTC6X7Aepeah0AwAApC1CN4CUk5nhlKuthbzVH7QdopashKG7g0q3JNW1+mUYhnbVtcowjLZj9m3n+dnhsD2mJPy59tS3ym+qlJvXjtvu051E6N5R26Ly+o7fKAAAAEDqIXQDSEmRSnCLL2DZ97okv6uhu73te1ddewW6ozXdUnhd9y3PfKITfvGO7n1tjaTEA9by3eHxGGNLwp8rZFhb2e0q3db28o4nvM9fX6lTf71As349X9trWjo8FwAAAKmF0A0gJUX26vb4Q6ppbg/IpT1V6e4kdG+qatbrn+6RJD29ZLsaPf6EA9YKssOhe8wg+3XdHn/8ILVkp5c3ePy6/tGlCoQMefwhfbSlpsPrBgAAQGphejmAlBQJpfvaXj6quHuh+/VVe9TWVS5fIKR5aytt13lL7ZXu0SX2E8wt7eWu+PbyjtZ0PzB3neV2C0PXAAAA+hUq3QBSUiR0N3kDlqpxV0N3dmaGhhSEh6/tbAvChmF0uqb73c+rLLfnrNqjvQkq3fnRSnd7wDcPU7O0l9vt050gSNe3+PXMkh2WY42exHuIAwAAIPUQugGkpHFtW3D5AiH99u3Po8fHmKrJyYq0mFc0eOUNBFXf6pc/aHTpORZ+XmWplJsVxKzplsKDzyKsle7k28s/r2xUMGS9zkZPoEvXDQAAgL5F6E5CWVmZJk+erOnTp/f1pQADxu1nTbIEU0k6+eAhOmpMcZefy9xivrvOk7C1fPKIQrUNTY/jC4b0wrKdtvcVtE0vHzUoR462x+8wBXTbfbqz2n/9JgzdFY1xxxoI3QAAAP0KoTsJs2fP1po1a7R06dK+vhRgwJg0vEC/uOSI6O1Dhheo7Oqj5HAkSMUdME8w37m3JWHoPmBovm457SDLsZMOGtzhc2c4Hcpuaxl3uzI0rCBbUmylO356eVaGMxrwE63p3lDRFHeM9nIAAID+hUFqAFLWBUeOkmFIn+6q102nHBCtKHdV7ATz3KwM2/Py3Rn6n9MP0sfbavXBxho5HdL1J4zXexuqEz53vttleSNgXGmuyhs8qm32qa7Fp+LcrJh9usOh2+FwKCczQ82+YMI13Rsr7UI3lW4AAJAaAsGQnA6HnIlaBSGJ0A0gxV141ChdeNSofXoOa+hu0aBc+2Fs+W6XMpwO/eUrx+jh97Zo0rACnXjgEDkcik4yt3uM2UHD8vXRllpJ0ucVTZoxocR2erkUnmDe7AsmbC/fUBluLy/IdkXDNpVuAACQCrbXtOjiv3yowmyXXr3lxLi/idCO9nIAac/cXr5qZ33C9vK8tv9YFGZn6rYzD9Z5U0Yoy+WMtozbiezRHXHwsILox5HQ7PXHt5dL4cnqknUf74j6Vr8qGsLXeeiIwuj69kSV7mDIUJM3/j6PP6jnP96htXsaEn4NAAAAXfW9F1aqusmrzdXNevi9zX19OSmN0A0g7Y0tyVVp21Zj722o1t8W2v+HIS/L/h1ac6U8Vlyle6gpdLetybabXi61TzC3W9O9sbJ9iNpBQ/Oj4d4udHv8QZ3123c17b7/aklblT3igbnr9b0XVunyvy7S3mb7Lc8AAAC6auXOuujH5fWevruQfoDQDSDtZbmcuveCwzo9zx+KrzhL4ankieTHVbrzox9Hpo9b13Rb28ul8PRyI6Z/3TxEzRq649vLX1y+S5uqmuXxh/S9F1ZGjzd7A3ru4/A+343egOavr5Qkfby1Vhf86X39ZcGmhF8XAABARxIVFRCP7w6AAeGLU0bqgiNHWo4Nzndbbtu1eUvWLcdixY4NKc13R6vqn0cr3R23lwdDRty+4evKTZXuYQXRIXLNvmDc3t1bqtsD+raa9qnpr6/aY2k5f3tthSTpT/M3auXOev3qzXWqb2WNOAAA6DpzvcCdaT+kFmGEbgADxn0XHq4rp4/RiQcO1rmHD9cfrjzSEqiLcuyno5vXhMfyBeOD+kFt1e7qJq/2Nvti9umOby+XrHt1//rN9Xp80VbL85nXjjfFtJg3tNqv835m6XbL7XfXV8kbCGpn2x7iIcO6tRkAAEB3ZDC9vEOMmAMwYBRkZ+oXl0yxHPvLV47WFX9brNL8LF01Y4zt4zpqL/cF4kP3wcMKtHhzZIJ5Y+Lp5abQ7fEHVZSTqfc3VOtP8zdGj196zGgNLchWoWm7tAaPX0W57be3m4JzZI35xsomfbK9znJdzb6gFm+uVUWDx/LYw0cVJfz6AAAAOmM3nwbtCN0ABrQpo4v18Y/OkNvllCvDvvmno/Zyu9B9kGWCeZO1vTzTVOk27Rce2at78eaa6LFvzTpAd5w1SZJ1SvqOvS3a2+LTEaOK5HA4tNnUXu4LhGQYhtaVt08rP2JUkT7dVS9Jem3lbsswtu1UugEAwD5q8RK6O0J7OYABL8/tShi4pY5Dd6FNS/pBQ9uHqc1fV5mwvTzbpr3cvJb76hljo+1a5tB99UMf6Ut/+kBPLt6mRk/71mJSuN29rsWvmqb2SeWXTRsdfZ63Piu3XCvt5QAAoKtiiw7NPvulbggjdANAJ8wVaSm8NjwrwymX06G7vjg57vwpo4ui68PnravU65/uid6XqL08ErrXV4Qr1HlZGZawX5AdH+7veuUzbalujjte2ehVjWl7sDEluRpRFN5rvCFmPTiVbgAA0FV1LdZtSFt8VLo7QugGgC46dEShFt15mj78wWmWVvKI3CyXHrh0is0jY9vL2z/2+IJq8ga0ozY85Ozg4QVymoaSFGTbrwZ6e01F3LHKRo9qmtqr34Pz3BpbYj8MLjJUDQAAwKyuxae/LNikJVtq4+6rjQndzV4q3R0hdANAEszbi5XmZak0362hhdkJzz/7sOG64YQJccc7ml4e2ddbkg4Zbg3zdpVuSSqz2Wu7ssGrWlOluyQ/S2MSTGDfubclbgsyAACAP72zUb+cu05fe2ypZQtSSdrbbN1ylEp3xwjdAJCEx66frimji3TTKQdo/OC8pB5z82kHKitmrbj5tnlNd4PHr3V72kP3pGGxodu+0m0XmCsbvZY13aV5WRpbah+6/UHDMs0cAABAkjZVhQe1NnoD2lZjXc62N7bSzZruDjG9HACScPioIr1684ldekxJXpbOPny4Xlu5O3rM4WhvGTevFf/usystj500vNByO1HotlPR4FFNc7i9PC8rQ9mZGRrdwbZn22tbNLKDYXEAAGDgMVevKxo8Omxk+xaj5o46ienlnaHSnYSysjJNnjxZ06dP7+tLAdDPJNr7W7K2l8eKbS8vTNBeHjHBVH2vMg1SK21ri0+0pltigjkAAIhnDt176q1dcXubqXR3BaE7CbNnz9aaNWu0dOnSvr4UAP3McRNLNTg/S5J02Ehr9TpR6M5wOjQoL8tyLLbSfdTYYv30wsN125kH6w9XHaUXv3V89L7d9a2qawmvtSppe54xHYTuf7y/RbvqGKiG/mFXXaue/mi7qk3DAgEAPa/FFKQrYkN3S/yabsNgRkwitJcDQC9yOBx6/qbj9eLynfrS1JGW+8xrw6eOLtKGyia1+IJx50nxg9QmlObpmmPHxZzjUqMnoPWmvb4jgb80L0u5WRnRd61dTocCbevB15U36tRfL9CcW07UwTbT2IFUctOTy/TprnotWF+pv391Wl9fDgCkrdaOKt0xa7qDIUPeQMgyrwbtCN0A0MsmDM7T7WdNijt+6IhC/ebyqapp8uma48ZpT71H72+o0hen2IVu669ru8FoQwvcavQELO1gkUq3w+HQmEG5Wt82If2I0UUqr/dE/yPqC4T06Adb9POL7bc6A1JBMGRo9e56SdKKHXV9ezEAkOZa/O1/T5THDF2NXdMthbcNI3Tbo70cAPrQxUeP1o0nT1R2ZoYmDM7TNceNj2stl6TMmCno42xC9zCbLcxKTVudmVvMRw/K1UvfPkH3XXi48t3hQP/Kit1q9PjjniORT3fW60t/el+3PPOJ1pU3JP04oLvqW/2KdC9WN3kVCIb69oIAII2Z38Qv76TSHXs+rAjdANAP2Q1GG1rgjjtWagrwY0raJ5QPK3BreFG2vnLsOF1wZLiy3uIL6uUVu+OeI5E/zd+gVTvr9drK3Trnd+/pnx9t68qXAHRZbXP7Ou6QIVU3xf/RBwDYd4FgSL5A+xubsZVuu9DNMLXECN0A0A+NLYnfK3xcafyx0vz20G0O6kML2wP61TPHRj9+5qPtSV/DJ9vrLLefWZL8Y7tj594W3f3Kas1fX9mrnwep4+01Fbr8b4s0d/UeSVJts7UTI/aPQABAzzC3lktSoyegZm97qN7bHN8Z18y2YQkRugGgn5g1aYik8HTzwfnxLegzJ5TEHSvJaw/Xp04aqqwMpzKcDs2aNDR6/LCRRZo6Orz35po9Ddpc1dTptZTXe1TZaJ0eva2mpVuTS8vrPfr2P5fpT+9s6PC8O55fqccXbdPN/1wuj5//sA8EX3/iYy3ZUqubnlouKX4NYQWhu0cEQ4bmrt6j5dv39vWlAOgDK3fU6T+f7rEs2Wm1aRWPvNHpDQTV5I2vardQ6U6I0A0A/cT9Fx2hW08/SK/MPkEOhyPu/qPGDlJWzNpvc3v5+MF5+vDO07ToB6fFTSk/b8qI6Mdvrano9Frshlg1egKqb41/5/vTnfX61lPL9NZn5bbP9dd3N+mNT8v167c+19o99mvD61v8Wry5VpLU7AtqWw17i6e72Ddw6lv8XQrdwZChqka2FUvGqyt36aanluvyvy5i+0BggCmv9+iSv3yob/1zuV76ZFf0uN367Mi67roW+/kvVLoTI3QDQD8xqjhH3z3zYB0+qsj2/pysDE0dY71vcL477vZQm4FrZ00eHv04Eo4bPH69vaZC9Tb/cV25sy76sTnYR8JwbbNPy7bVqr7Vr+sfW6r/rC7X//zrE9tBbZ+YAvzKBBOp562zvhGwrabZ9jxIW6ub9bM31mrZtv5dtYz9g29zdVPcGsJEoTsUMnTZXz/U9Pvf1nNLd/TaNaaLlTvCE+EDIUNrdjMUEQPTpqom3fTkMj25eGDNJ1lb3hDdQtT832O7qnUkdDfYvMGe6DEII3QDQBo5dmKp5fagvMwEZ1qNH5ynSW3V7+Xb6/Tcxzt05m/e1def+Fg3PvFxXNXRHI7PN+0rvr22RQ0ev87/4/u65C+LNPXet1TdFK42evwhvbPOuh47EAxpnam6vSam0u0LhLSholFvfLrHcnx7LZXuRO5+9TP9feFmXfX3xf16undNzJC0LdXNcZXu8nr7SvbWmmYtb5s58OrK5IcDDlTmDhW74UjAQPDL/6zT3M/K9eOXVw+oN3bNAdrcHdRRe7n5d0ZxbvvfGc1ML0+I0A0AaWTmBGvodruS3y/zrMOGRT/+/gurVNEQ/o/vkq21+sxU/QqFDH26M1wZG1bo1vTx7WvJt9e26A9vb0jYovr6Kmt43lzdLK9pOqr582yuatI5v1uoM3+7UG+vtYb1LdXNev7jHfqfZz7R5X9dpJdNLXEDxZ76Vu3cG//mw7ufV0mSfMGQ1pU37u/L6jE1zdZAvbkqPnRXNtpXus3LDxKdg3Z1pqC912bv3f7ow03V+t7zK7V6V31fXwr6CfPSqkWbavrwSvav+gSh2y5ARyvdpq614abuuRabdd6xfIGQbnpyma742yLVNA2cJUCuvr4AAEDPOWbcoG4/9uzDhuuP72y0ve8P8zaoNN+tUMiQIUONbf9hnTq62LJn+IL1lXFTzc0WfF6lJm8gujd47B/Ea/c0KBAM6bVVu3XPq2ts14hL0j8/2q5/miatf7qrXl84YoSyXAPjveRNVU0667cLFTIMvfE/J+nQEYWSFFfZXrq1NuFyhFRnV+mOHdwTu2+s+dyIyJtHSKzOUum2/5nrb773/CrtqmvVttoWPffN4/r6cpBCqpu8qm32xc02MVu2ba+unDE24f3pJHGlOz5A76mPr3SPLM6JvsGbTKV7/vpKzW1bxvbCsp365ikHdO/C+xlCNwCkkZysDN1+5sF65IMt+v45h3TpsYePKtKd5x6i+esrNawwW1NGF+uBuevkDYQSDlc7ZtwgjTWF7qVb29cRXzVjjPxBQwXZLrX6gvrX0h3yBUKat7ZCFxw5SpK1si2F1/HO/Nk81ZiqbSV5Wapt9snpCO/NbKfVH9TnFY39NmB21YL1VQq2fTPe/bwqGrpj963+eOteXX/ChP1+fT0httK9qaop7k2VRGu6za2h9a1+efxBZWcm3/Ux0JjnNtSlQXu5LxCKdttsZ+giTOpb/Jr1qwVq8gb00Fen6czJ4Q4v837UkvTRltq+uLw+Yal0N3llGIYcDof9ILWG8M9VQ2t7IB9R1LVK91bTm6LbBtBSMUI3AKSZW04/SDefdqDthPPOfPOUAyzvOn+2u14vLo9v3Xa7nDpvyghdPXOsCrIzVZybaZlmWpqXpbu+eJhyssJBZ9GmGv2rbaDVH+Zt0IwJJVq5o07vb6iOe25z4D77sGH69WVTVdfil8MhfeOJZXHrvs3XOlBCt7mtfLeplT923+olW2ujf0D1N7FvIGytaVZpnnUwYIMnoFZfMPo6az/X+odcVaNXY0z71O9P/mD4jaYJg/M1aXjiylpvq2z0aHCeW05n/GuhLs3WdNe1+mw/Bj7ZsTfaMfPBxupo6I59A297bYv21LdqRFHOfr/G/c0cun2BkBo8ARXlZCaYXh5+M7QhptId8cyS7XI4pG/NOlAlefFbm0qyLD/btXfg7JYwMPrwAGCA6amQdZWpve6AIXl69hvH6smvzdDSH52h31x+pAqywwNUxsUEmvOnjrQEoZkTSqJ7gW+qatYJv3hHNz21XOsr7Nccu11OPfm1GfrbNdNUkJ2pMSW5Gj0o19LKHv462z9evat/T11+f0O1vvbYUs1fX9npuTtq2/9Q2V3X/sdieb31D5iqRm/KDJ17Zsl2Xfn3RVq2zb6CtK68Qd97fqXe2xBekx7bXu7xh2xnBdhVu7fGDEHqy/28n1q8TTc9tVwX/fmDPlsv/fB7mzXj/nm68qHFcUMRQyEjZk13/28vN78B6PGH5PEz3Alh5jeVGj3tVdndNr9bPtqcXtVubyCoDzdWxy3TiV3GFWkxtxukVt3klS8QsjzGvKa72RfUQ+9t0UPvbU54HTtNQXsgbVFI6AYAJDR9fIl+cfER+sbJE/XcN4/TzImlOumgISrMtk5Fj60iXnjUKMttp9OhX182NdoeHNsmHmmPjvjmyRN10kFD4q5nbEzovu748dGPV+/ueGBSg8evS/6ySBc/8qlln/FAMKQ/zNugafe9re8+u8L2D4394c6XVmneukr97wurtKW6WZf/bZHuefWzuJAkdVDptlnjvCQF2iQ9/qDufvUzLd5cq1+/+Xnc/YZh6GuPfaznl+3Ud/61QoZhqLY5ubXYsdV9fzBk+aNO6tt13ZGBTC2+oNaW980bQ/e9vlZS+LWwfLt1K7kmX8Dy85gOle7YNzcSzYbAwFNrelPJvIXlHpvfnYs3dzxM7b0NVfr7wk2WoWKp7K6XP9PVD3+kGx5dajmeKHSbK93Zme2RsbLRY/maze3lEX9ZsCnhdcT+98vuv3HpiNANAOjQlTPG6odfOFSlMXt+m8VOSY9Utc0OGlagH57bvs78EFOr7SVHj9KYknCL2oiibN00y36wyriSPMvtMw8dFq1+R4awJfL0R9v1yY467W7w6Zp/LNHmqiY1eQO68u+L9Zv/fq7qJq9e+mSXrn10iRo9fgVDhjZXNe2XbbfK6z3R6nVlo1fff2Gllmyp1WMfbtXCmBZ8wzAsoXJPvbm9PD5cfry1a/t1B0OGfvvfz/Wbt9b32Ne+q641umZyQ2V8d8NnuxuiFY+aZp8qG72WZQYdia1i79rbGl3vnuic/ck8ST3R4Lee0uILxE14j7392krrDgL1MYPT0mGQWuwbB+nwRgJ6hrmrwxwcd9fHV1w/T9CJJYUrvl97/GP97I11uvqhxXHV41S04PNwF9WSrbWWNxzM67Ol8Lpuybrn9oTB+dGPy+s9lqA+oji+Bd/hkDZUNOpXb66zfB9j//vV4gvqxeW79IN/r7IMwExHrOkGAOyzGRMG6d/Ld0qSLp82OmF7+3UnTNBRYwcpP9uliYPzNH99pfbUe3TpMaM1fXyJXl25W9ccO065Wfb/eRofU+k+auwgHT6ySNtqWuTxh7S5utl2Im0oZOhp07TzZl9Q1z26VCcdNFgfb7OG0iVbavXA3PXKzcrQ3xZu1vTxg/TU12cm3H5td12r3t9QrXOOGB7tADAMQ95AKOnhXZ/EVB/NA+leWLZTpxwcrvpvrmpSZobT8gfe3hZ/dF1zbHu5JK3cWZfUNUTMXV2u38/bIEkqys3S1060DmLb2+xTUU6m7drgRMwt7tVNPjV6/NGlCZL0UsyWb1urm+PWdJsVZLuiraE7Ytrnt9jsr1vZ2DeVbsMwLF97bFW+J+2qa9X5f3xfzd6Afn3ZVJ0/daQkaXnM63vOqj360XmHypURrrvUxYTsuhZfv50DEBH7xkHs14iBy/wmlLm9fE9d/M9mR6+bLdXN0TcSV+9q0Dee+Fj//PrM/fZzs6e+VR9urNEZhw5TUW5mp+e3+AKWjp/NVc2aOqZYUnKV7olD8rS2bZ5KeYPHEtSHFsS/IV+a59btz6/Uqp31en9DtV65+URJ4Z/N2LXitz+/UlK47fypr8/s9Gvpr6h0AwD22flTR+qUg4fopIMG6//Om9zhuVPHFOuAIflyOBw67ZBh+vLMcXK7MjR1TLF+/MXJGj84L+FjY9vLc7IydNio9tb01bvqVdvs0z8/2mapbn6wqTpubfP22pbotmNZLqd+eckRymkLya+t2q2/LQyvSVu6da9+89/4lmgpHOave3SJvv/vVbrmH0sUChny+IP6+uMf69C75uqJRVs7/F5EfGJqd4/15mflqm/x6+8LN+m0B9/VSQ/MjzsnUqUxh7qxbS3/n1c0WioWnXnj0/ZK6E/nrLFUjcvmb9RRP/2vvvXPZV1qCdwZ8703V38DwZBeWbHbcv/Wmubo/q2DcjOjW8xFnGxaejB/fZX1uW2qJZV9VOmuavSq1bSeuDcr3Q+/t1m1zT55AyHd8swnmr8uXNWKfVOpuslrmcwcO2gsEGrfErAjqbxOOrayTehGRDJruiOt1LUddEjEdpB8uKnGsmypJy3dWqvv/OuT6FKhpVtrdfZvF+r251fq/17+NKnniP3v36aqpujHDTGhu7Ix/HvKErpN/102V7ozMxzKzYp/c7m22as1bbuTrK9ojP73wtxaHmvR5pq0XgpC6AYA7LPcLJcev2GGnvzaTBXldP6ue3eNHpSri44apYJslx766jRJ0uEj21vZ564u16V/+VD/99JqffnhjxQIhocomdeX/c/Jo5Xvtv6RcMMJE3TF9LE67dChkuL/SP/7ws16+ZNdcUFz+fa9+rwi/MfLyh11ennFLn3vhfDabMOQfvPfzzsMJ5/trtcrK3Zp6dbE6659gZD+tnCTfvbGuoTnRKo0kUpGgdul4w8olRReP2/emq2uxac7X/xUj36wxTY4b6u1htb/mraL+/Vb6yVJb35WoddWWduUO7IjZo21uY3w/Y3Vqm7yxtzfEv2jdkRRjk47ZKjl/qPHDdJBQ8Ptjsu27bWE6tjJ5ZJU0di9sLuvaw1jr6W3QnezN6AXPt5pOfadZ1eoxRewHVw3x/RvZxdI65r9avYG9Iv/rNOrK3fH3X/bsyt0xD1v6l9Ltsfdt6/+9u4mnfv79/TBxvidDZJlV70HJOugQGt7efhn0+V06KCh4W6p+lZ/wiU2saFb6p2fb8MwdPVDi/Xyit3633+v0upd9frKwx+poe0Ng8jMiEaPX5urmhJew7aa+NBtGIYCwVDcm2zRQWr+9uMTh1hDd+R7V5idaVvdDxnhN/Ck8DDDyP7dsfM2zIIhIzpIMx3RXg4A6Fd+e8WRCoWMaHvzUWOLle92qckbsOwnvrGySX9esEn/WV0ebYsbWuDW5VOHatywEt3+/CpJUnFupr7Vtob8rMnD9LpNmDSMcIh587Ny/faKI6Nt47GB5LbnVlpu17X4NXd1edxgOSlclb/4zx/Kl8S66T93MJRGah9GE1nfPawoW1PHFEe3aVu5o07Tx5dIkn46Z210KYDL6VCrP6jqJp9OO2SoDh9VpLV7rOsYH3l/i845fLjK6z0yZ9Bf/medzpo8TC99skt/emejjh43SDMmlGjx5hodNaZYXz9pYvTc2BZw8z6t734e/0fWp7vqon+wleZn6ZzDh1u+16V54WMb3tkoSXprTYW+cuw4SbJsKedwhP/t7Aapbaxs1JOLtqmiwatBeZn6wbmHWt4w+mBjtW791wqdeGCpfnvFkV1qG61u8uqddZVxbybEtpev3dOguSvKdf0pxSrJjx9GlKwXP9kV94dzfatfn+6s18qd4QGDQwrc2tvsUyBkWIJ4nU1laW+LT08u3qqH3tsiSRpZlK1pba+fmiavXmxbDvDk4m260rTDQTJCIUO+oP3Si7oWn345d51ChvSrN9frhAMHd+m5I2IDkd3XiIEjGDI0b22Fhhdlx1W6I0spor87C7M1OD+81ZVhhH+O7OaZ2IXujirjHbn7ldWav75Kd557iM49YoTlvk1VzfIHw78Lt1Q36+8LN8tr2lO8ptmnNz8r1/8880n0+F1fnKwbYpYFbYtZdvP6qj167uOdKsyOj4K27eWmNd17GjzR6nhhkm+yVzd6le92dVjplqR31lXqi1NGJvWc/Q2hGwDQ75jXExdkZ+reLx0WXRdmZm4Lz8506v6LDpcrQ7rwyJHaVNWs1z/dox+dNzkatk49ZKgyMxzRP3IKsl06+eAh0SD+n9XlavEt0zmHD1cgGNJTi7d1eq1PL9keF7pbfUH9z78+6TBw52Vl6IjRRVqcxLY1v/nv53p5xS55/OHnG16YrSmmYXaR4FXf6o8Gbkn68SufRT/++8LNysxwxA0hW7K1Vsu21catsd5V16r7X1+rZ5fukC8Y3s7rtbZg/PqqPTp63CAdPXaQJGlHzB9a5gqw3XR18/C3wfluzZpknWRflJupsw8brj+2he43PyvXV44dp63VzdHnG1eaK8MIt1XatZff/PQnWldufYPh5xdPiX78h3kbVN3k1csrdmv2qQfqIJtZAYnM/udySwt3hLkKFQiGdMPjH6uiwau3NjToldknJD0DINZTi9pfh5ceM1ovLAv/G7+8Ynd03elJBw7W55WNWr2rQRsqw0ME890u1dsEhdoWXzRwS9KDb32uZ75xrCRZtvnbWt3cpfXfHn9QF/zpA+3c26Knvj5TR7W9PiI+2V4XnaS+ZneDvIFgwlkKHYmtbNNePrC9smKXbntupeV3uxQO4y2+oByO9tfIyOJsDTLtL723xZd06O7OloDl9R493vbz+61/LtdDX50W3Ttckj7cZO34sKsEP7V4myWIv7xiV1zoju26idyuspl3YRe6J5gq3XvqWqNv8kVC96xJQ7RgfeIqdXWTV+MH53W6L/e766ssb6qnE9rLk1BWVqbJkydr+vTpfX0pAAAbFx89Sl+a2v7ueKnpjyZJmjA4T6/MPlGnt7UpOxwOff+cQ/Tu9061/IFTmJ2p4w9or66dc9hw/emqo1R29dHRdcXvfl6lO1/8VD9+5bNoQDhwaH60YjB1dJEevX66Dmj7I2XJllqt2lmnzVVN+s+ne+QNBHX/G2u0uSp+7bE5KE+fUKKHvjpNFx8dXyWPVd7g0Yeb2re3GV6UrYOHFUTXJq5sW2v46opddg+PMv9BOn18eyD6/byNWmUzkO3JxdsSvnHw9EfbdUHZBzr5gflxe6hH9tFu8PijlelDRxRGB+WZ/4AsyctSbpZLg0zDgobku3XYyEKNHhSemvvhphp9vLVWz5jana+aMTa6f2yDJ2DZCm5bTXNc4P738l3RynSjx69lprXQnW1HZ1YTs2barKrJK3/b92tTVXO0Ar++vFE/nbOmW+3sGyubokH46LHF0QFqkqJvgkjhlvwpo4slhSt4q3eFvya7QLqxoslye9Hmmmi3yHrT963ZF7T9oz2RJVtqtb6iUc2+oOXNnwjz99wXDGndnsTTozsSO0itvjW5MBQMGVq5oy6l16uj6yLb5Jl/v0U0ePyWJTQji3NUkmsO3f7oG1dmtpXubuxxv6vOGoZnP73csr48dpmF3e4C62N+l9m1mMdWujtSHTO9PMPpUIG7/XfwhsqmaNdT5L97P/nS4fr2rAP0lWPtO18ib9p21F4uhSv37+3D0pJURuhOwuzZs7VmzRotXbq085MBAPudw+HQLy+Zou+ecbB+cfER+tPVR0fvmzyiUM/fdJwmDU+uUnnBke2h5aKjRsnhcOi8KSP0yHXTo4PWYt125sF6+/ZTtOCOWXrl5hN16qShusrUdvvlhz/Smb9dqG/9c7m+8Pv39NTicDjMznTqi1PC7YSZGQ796LzJ0XB/4f+3d9/hUdXZ/8Dfd3r6pPdKGiEkgYTESIcIBlZRQdFFRVTQFdeC64quK+r+XF3ddVEX29eCuqsC9gZKDSItIQRCCwQChFRCSCV97u+PydzMnZmEgIZMwvv1PDwPmXsz3Bk/iXPu+ZxzkgLholPj5ZuS8PH8NLx1WzIezoju1Wvwc9VBrVRI9e4nq8/hvS1FWL71uHSOKTkZ7uWEJ6cPhWWy8plr46WgdvPh0/jQLJs63WILJAD8cVIkZqcES19/tusU9hTXWDXwAbq2l+ccr5Y+vKWFe9hsoufZudXzkwVXIMLbCdcmBmBYgCsEQcCs5CAAxmDp9vd2Ss3v1EoBNyYHwce1K0NVaVbXbSsj09retXNh69Ez0vZ2AFY3DXqyrYfZvqLY1Ul9v0Ug/78dJ5H5ys9SAzRz24+dwfu/FKGl3ToYNA8YMuP9EenTtQ3UvMt9UrBeNsrPdBPF1tbr1fusSyzeyjKWOFiOUbqQMT/mwYRljSkgD7rNr/FCWY0M62Uw9PQ3+zFj2S+Y+97Oi/p3yT71dGPozuU5ePDTPOlrfzcHWab73Z+LEP/0j1i0Mk/2fbaC7ovpHWA5H7y13SBlszsMolSz3RPL6QzmN/cOltVhZU4xjljcSOtJVUMr/v7DQSmYd1QrIQgCfDtvYpo3oDPtEgvxdMSfr47F5KG+1k+IrkDeFHRrVPIQ1Hwyyf0f50o3BQcTbi8nIqJBwUGjxIMZUdLXr90yAkVVjbhjdJg0yqs3rh8RiHaDaGxGZlZTmhrugY/np+HTncUI83JCRV0zPtt1CnH+rpg81Me4DdYsrp+TFopv95Riz6la2YeUo2YZ7r/+Lg63jArB9OH+8HbRIiXMAz8+PA4Vdc0Y0TnOBYCUffdx0eLf62x3UjdnCjYTg/VS9+pnvzsgHU8McsOKe9Kx91QtEoLcoFMrUVBej1W7urKPsX4uWDgxEo9/YeyOa3oNXs5a/GNWAnafPCs1H5o/NhyPTIkBYOyk/vORnjMVZxpbUdfcJssIp4V7dP5NHhB7OWk7r8cVGx6ZIDt234RI5Bw/iy2FVbKtkJnx/vB01kofEgFj9ifU0xjUbyroCmzfuT0F9/x3FzoMIj7adgL3jh+CzRZ15vtKanGgtA4HOmfBj4/xhr+b9WxaAPilsOcPyeW1zQjUO8ia25kcKq/Hwo9zkfvXq6St5sXV53D7ezvR2m5AQXk9XpiZIPuetQfKpb9fFeeLADcdHDVK2fuhVSkQ4+cClbLrzsqe4lpU1jfbDBRyT9ZYPbZ6Xzmea2m3yqoVVTUiLcKzx9dsYh50m3Y7dBhEKBUC2jsMVt2f95yqxW29emY5q0Zqvcx0/9T5Xu4oqkZlXTN8XC++zp7sR0/jBw+WyX8O08I9ZL0X1uw3rokvckvwxLSh8Orcav5b1XTbGlV2rPNG1v7SWqlh2oUw3dxz1qrw+//bbjM7fj5vd97ABADHzsaj/m46qx1CljXdvi62f2aqGlogiqJUahSod5DdsHtkSgw+2XkSW4+eQX1zOxZ8mINNj06EahClhwfRSyEiIupyTWIAHpgcdUEBN2DMmt+UEmzV0AYwzgX/x6wE/GHCEDx97TDsXTIFK+65wmbdqYNGiY/uTpNmoaqVAlzMmtZkDPXF71NDoFAIyBzuLzWqCtQ7YGSIu8062YQgvdRFdk6afBufKSsNAMmh7tI5lrPNAeCe8UOgUyuRGu4hBXePTo2RsuzTE/yhUAiYOTII4RbZ56RgNzhrVXhxViIc1EqEeDhiwbgh0vFpNt43W97bUoS3sro+2I0K90CojWv1sCgVMKdRKfD6rSMRbzY2TqNSYME4YxO3AH3Xe/Jx59bz5rYOaSu+r6sWk4f6IDPeD4DxZkDW4dNWzd12FFVj2qs/40+r9mDxF/mY/K8sq9rKlvYOlNc2W9VgWjpSUY/qxlZZpvu564ZJI3nOtXbIssmf556StreuzCnGMbNRP6frW6Rxc1E+zgjzcoIgCBji3ZXtBoBhAa5QKxWI9HaWdmt8n1+G1OfWY91B68y6iUoh4IbOfgQt7QasO1ghdes3uZBMd4lZgFFytgl/++4Ahi1Zgw+2Hseh8nrZeDWgqyzCpLmtAwaDCINBxKvrj2DBhzmyQB4wNmq7mJru2qY2WcM905ZkGvh6WwKxYsEVmBDjLStlMVdY2bX2TUG3h5NG2iV0MTXdlpluACjqvDH75e7uy4H83Xq+IVRe24Sfj5y2CrjVyu5rpV1sNFUDjNNJAGPZkiXL/7/6ulrXvwPGoLuirkW6GRjm6Yh5o8MAAM5aFcZFeeP/bk+R/n9ZWtssuzk6GDDTTUREdJHO1+zFVafGigVXYOOhSsQFuOJcawf+tGoPnLQq/GPm8AvqiA0Ya+s+nX8F8oprMC7aG5X1LdL24v/elYaCino4a1UY1rmtPMLbGRv/NAH5JbXYU1wDdycN4vxdEWERlAGAj6sOn//hSmwqqMTMzm3bGpUC/56dhOuW/SKdNzxQDwAYE+WFnCczoFYqZFsFp8T54i9f5sNgozw51s9FypQsXXdEenyItxO8nLVW28tddCqp63p3XHVqfPGH0ThSWY+WdgNCPBylbNTvEvyxdN1h1De34+u8UkyI8YaLVi3VjE+I9oEgGG8umMZoLdtYeN66w3OtHbhzeTaWz0vF6EgvbDlShYdW5Fl1K7dl8Rf5eOqb/VIg7e2kxi2pIWhuF/G3zt0I+0rqsO3oGThqlLLaZ4NofN9evWUEAOPWctP2fPPeBJE+zsg3255pquVWKRWID3RF9vHeBZSJwXrcmBIsdSt/e/Mx2ZZ1wFjf+X+bj+GbPaUoqmpEoN4BwwJdkRikx/pDlThd34KXZiUgPtBNFiAbRODdLUXS894zPgKWCk93NXzbdaIaCz7cBUEARoa4S5MKXB3U+OeNiQCM0wTyzJqxmfRm9q95QAUYs/1Xx/fuBhLZt978XP4+LUTaseHuaPtG39HTDbii8xxT0O3lrIEoijh7rq3bjHLuybNoaG7HuGhvq2Pldda/a4qqGrHt6BmpHEilEGTlLoDxZ+D7/O7HNpbVNmODjVKVMZFe2NhNw7MHJ0ch+3g1ftxfIXvcdKPOz9V6d4/liFB3R41VwzoAqKpvlc0GH+LtjD9OikKAmwNGhurh1nmj46GMKMx731jO+3nuKWQMlY+LHMgYdBMREfUhnVopy5p//8DYX/V8Pq46TBlmzMw+MiUaaqWAybG+CPNyslkTLQgCEoL0UuDVkxg/F6va96RgPR6dGoOXfiyAIAATY7s+ODpprT9GeDprcf/ESLy/9TgemBSF5344KB27MSVYCizNTe8cERPr5wKFYAzI3B3VeH9eqvRhrCcalUK60WDO11WH/3ddvFSz+cjKPVLWBoDUFf3KSE+4aFWob2nH3lNdwaqXs1b2gT3ATYe4AFesO1iJtg4Rr6w/gvLaZvzpsz2w7IF2x5Vh+CG/DFUNLbhn/BDZrHjzxkzRPsbs/lCz9/2Zb/fLmsmZ+2ZPKZy0SswbHY6X1xZIj5vWBABZXTdg/G9okhiktxl0m16/uTGRXkgN95DeB1tb4jccqpR9uC+oqEdBRT2+yO3K0r2+qRCvz0lGaa3tmxklNU3YfLhrh0BCkBv2nqqFKAL5p2oRF+CKBz7Jw5nOQMd8NOCP+8rx0qwE7C+twwOf7Lb5/JY13iZtHQYoBQEKhYDCSvm2Wcv68sHmQrrO26PjVY3YUXQGV8f7WwV+5hpb2mWlFt0x3xbd3e4a042ZptYOaVeGu6MG7YbOoNtGpvtwRT1ueH0rAODfsxNx/Ygg2XHzTHe4lxOKqhpxpLIBD6/Ik36nPDo1Bl/klsgmB4wMtQ66vV20Ula/tKYJWTaC66nD/BDs4Yg9xTVobO2Q3WyK83fF3WMj8NKPh7BsY9fvK0eNMegO0NvIdDvI/x+gUAjwcdGhxGIHypnGFnnQ7eMMN0c15o+T32wbG+klvY4NhyptbuMfqLi9nIiIaICK9XPF63OSpcx0X7lvwhC8fVsy/ntXWq+C90VTYpD/9FTMHxchNYrTqBSYOTIQn92bjienD8UjV0XjmWuH4cM7U/HgZGMtvr+bA/49OwnzRofh2z+OkQWLF2tGUiBmjjS+Pwaxq7nYiBA9Mjqzw1qVEpMtMiouOhWemBYre2ze6HC8eWsyQjyMgXLuibN4fvVB6cOx+dbNG0YGYtOjE7Dt8cmyBnOWoryN2aOh/l1b5G0F3KlmGf9PdhZjyr83S7Wqk2J9ZE3SLLeXm3fFv25EIJQ2dmgEusuzWF7OGtxxZRiUCgHThvtZnW9LoN7B5vbVrUfPoL3DYLN+1cS0lVSjUsiaEGYdPo0lX++z+hBvUt/SjuNnzvWY9WtuM1h1JF+VU4yov6zGvf/dBYNBtGo0lV9Sa7NrtYnBIGJldjHe+fmY9Zi9omrc8PoveHHNIZvN7yxlH6/Gl7tPob2HEYKWas+1YdeJ6gv6HsAYhF7/+i+Y+M9N552ZbK86DCLmvLMDj32ej+e+t76JZ643WW4AsqaL7ucJus1rtz2dNVJmvL6l3WrNmN+QenjFHqu1Yuo07uWsQYzZWEJTXXlauAfmj41AiFnpjVIhICnY+ibj8MCux37aXyHdpPJ00kCrUkDvqMakoT54dkY8vr5/jNXvPFN9drTFeESHzqA7Ldy6d4OtGx62tphXNbTiaKU8022LSqnA9Z0lLW0domwCw0DHTDcRERH1SBAEWSb1Qjw7Ix6BegeMCvOA3lGDlDAPqX7dlhlJgZiRdP4xaRfixVkJCPFwxCvrD8MgAqMjPfHWbSlQK7tyD1fH++GrvK4PeAvGRmC0WSM9ALhpVDBUSgXGR3vjo+0n0G4QpcB3bJQX/j07CR/vOIkgdwfp5oSjRtXjCKpob+OHaXcnDfxcdbImTiYOaiXevSMFX+SW4MU1h9BolrnzdNLgHzMTZFlL80y3i06FMM+uHRDxgW7I/ksGlIKAa/6zReou39JukDVge/raYVLwccPIIFn3esA4B928A/mjU2OwcGIkWto7kF10FgfL6vDK+iNoaGlHzbk2/Hykqse59Kbts0P9XDAlzhdPfrUPHQYR/9txQmri56JT4Y+TIrHuYCVKa5qkMoBtR8/IurjbUtvUJvUvEEURS74xzqj/6UAFfjpQjiMW28tb2w3YX1prNUscMGbI//zZXqnm1kGjxJy0UADGgPDPn+3B8TPnkHuyBj8fqcJ7d4yCt4vtWtcTZxpxy9vb0W4Q0dDSgduuCO3xdQBAWW0Trv3PLzhd34LFmbG4d/yQ836PyXd7S7G7s1HeR9tP4PHMob3+XntRXtcs3YTZcp6mjb2t5zYPFPXdZM5NYx6rzRqzuTtq0NreFUi/u6UITa3tuG9iJHRqpdWNjTX7yjG980Zke4dB6jzu56aTzcI2mTc6HAqFgFCPrqA7yN0BPjYalsUHuklBfo7ZTo1Hp8Ygc7g/lApB6tsBQPZ7AegKoC13Ozl17g4K8XREhJeT1OgNsK7pBiBrYGlSVd8i+74hNl6rycyRQVIjt7UHK5EZGdbtuQMJM91ERETUZzycNHh82lApq9wflAoBD2ZEYfWD4/Dmrcl4/45U2YdPABgf7SNto/Rw0mDemHD4uGgxsXML+qNTY6QPpWOj5ME4YBzx5uWsxQOTo3DDSPnOA51aKTWJC/FwlOaHA8BQ364P00P95R92gz0cMH24P169ZQRcdGrMvTIM6x4ZjwXjIuDhpIGzVoWlNydZBXShno7SFtn0CE+r3gMeThq4OapxRUTXzY+iqkbc2hnwzUoOko2FSwrWY/m8URgdacx0pYS6m3WbN7ot3fi9WpUSY6K8MH9cBB6dGiMdX7Wr2Oo9syU+0A2ezlqM63yPzTv/P3Z1LBaMG4KV96Tjtc66dgD4344TVjXZlsy3mBdVNcq2HL+yvtDm93e3xfzJL/fJmlz9ZFYDu/ZAOY6b3YzIL6nFSz8e6va6thRWSTccfrZo4NdhELH92BlZtra9w4CF/8uVgsmV2b17X03Mu+tndzNP/mLVN7fhno9ycNNb23Cmlxnmi3HKbAxhaW1zjw3Mep3pNgtiVUqFzQxuSU0TGlvacaax6zk9nTTwcOo69x9rDuHVDYX4z4ZCANa9At7IKoTYuTWmqqFVynz7uTpYNa1UKgRc2fkzZ95kMtzLyeYW+GEBrlaPAcCEGB+4OaitfudZNq40lfJEeMmz0KbfiwCs6tJtZ7qtg+76lnYc6CxPcXNQ99ggM9rXWWoUt/tkDdptzFcfiJjpJiIiosuCrZp1EweNEi/flIQV2SfxhwmR0gfUd+eOQvW5Vqk5GwCkD/GEUiFIH5jVSgEZ3cynNVn2+5H4dm8pbhkVgpZ2A5Z8sw+jh3jC3yzDNtTfVdbk6E9TYqyy/v5uDnhi2lA8nhmLDoMIldI6f6JWKvD2bcnYcKgSc3rInN56RShW5hgbtc0fG44npg3FveOHwN1RbVXvOyHGBxNifNDQ0g4HtRJf7i6RvvemlCCbGS9TkA4AP+SXWx23xbRF9roRgbL3wt9NhxtTgmTnOWtVaGhpt1lrbvoeU81sdUMrPtx2HDnHz8qCCEA+NsrdUS01xHoz6xiuGxEo+29f29RmdQNhR9EZNLd1QKtS4E2zjvwmPY2RM+/Qbjmb+M2so3jpxwI4aZR45eYRyIjzxUs/FshGuh2raux1fbYoilLnfsB4Q6CptUPaPtwbO46dwTtbipB74iwWjIvAPWZZ9qXrjkhNuN7afAxPTBuKDoOIl34sgEEU8eepMTbX64Uqtmh0eLCsTjbe0VzvM93yQNHDSWOzAd+wJT/KvnZ30tgsB/nPxkI8fFU0Civl3f33ldQhv6QWCUF6lJn1OPB300kTDExGhuilnyvzfh0RXs5w1CihVSlk/3aAmwO8nDWyEWlx/q42u44D1plu586MtuUMbfOvJ8R4Sw3eAOuRYQCQGe+H/+04AU8nLaL9XKQRjKbt7kO8nXpcr4IgIDXcA1/nlaKprQMFp88hYBD0NGTQTURERATjFvOr4+Xb6BUKQRZ0AYCLTo2RIV0Nya4c4nXehm/xgW6IN6u5/HRBOgwGAyoru2o+zeu6AVhtbzcnCIJs7ral823jB4xdzV+clYA9xTWYP9bY0KinDBSArrFyw/2x4ZAxwPrr7+JsnjvE2xk+LlppC21vmN6jjKG+cFArpYZVC8ZFyEbzqZQKpIZ72OzQbBLi4SgF3b9/Z0ev/v2r4nxR1dCKDYcqUdXQggc+2Y1nZ8RLW/Z3FlVbdUdvbjMg5/hZKARIs8Zj/Vzg6qDGzqJqlNQ0dTv323w2eWltM840tMCzc72ZOtc3tnbg7g9zkBrmgZ3HrbPTlfUtNrOLlgorG2SZ37YOEbuLz+LKIfJ11m4QUdfUBr2TfN2vzi/DH/6XK339wppDSIvwRFKwHoWV9fjALBhbd6ACT0wbiq92l+DNLGNTrqH+LlaNxGxd49d5JZiRFIBIH9s3yCy3bB/oKejuYUa3iVIhwNNi3et70cARMP68dNf0cM2+cpuZ9m/3lCIhSC/VcwOAv15nNVViXFRXVjk9whNjo7xw/Ewjfp8WDEEQ4OGkkTVi83LRwM9NJwu6zRtfWvKzWDPmO2LMf27Nb1yYurebuNoYM5YW4YmdT2TAUavE//vuoNXx7uq5zY0KMwbdAJBXUo+JCef9FrvH7eVEREREF8j8A3FvG42dj+X2cstgvy/clBKM564fbjMg7ImDRonX5yTj9TnJcLGR5QaMNwZ6unFgSaNUSE2cnLQq3JxqbEAX6umIm0eFWJ1vORM+KVgvu2kwxKfnD/d+rjpMjpU3k4ryccGLsxKk937r0TPIeDkL93+ci5pzrbI57DOSAqS/rz9UIdWJA8A94yMw0qwe3DRPvb2jq6lbfXObVS25KWtfUtMk1RCbmAfcWrPso3kTuDMNLfjrV/vw2vojVk3czLPcJtlF8i30jS3tWLDiEJL+tg4rsk/Kjn1n0axOFIEnv8pHh0HEs98dlI21OlbViKOnG7DJbMt87okaq3/f0oOf7sZrGwqliQO2FFfLM90HymzvdAB6l+n2cdFal2B0MzbMkqeTtttz//lT13SB3yX4Q9X5b3y3twwGgygLmP3ddFbzwc23cquUCnx0Vxo2PzpRuhlhOdrM00lrNdZrUmz3I7cUiq4dOpY/B8FmNeTFZjc5dGolxnT+THs5a7rtHO/upIFWpbT5O+x8P5cAkGpWvpJX0nPpyEDBTDcRERHRBbr9yjDsOVULV53qvNm73hri7YyxUV7YdvSMNHt6oLthZKCs/hkAFl0VjZfXHoaXsxZjo7yk47H+LrKtrH+ZNhSTY30R6+9icwv0DSMC4e6oRlFVI9RKBTLj/XCmsRULP87F8EA3jI30wsc7ugJHy+ZvoyO98Pcb4vHAJ7ulbdGmEWmv3pyE+R/mSE3rvttbhuzj1VL2TyEYX8c3e0ohisD7vxyXnjcxyA3XJgZinaar1ntTwWnkHK/GypxTaGxpx9u3J0OnUlqNmssvqcW4aG9sOdIVrDprVdCqFNL23MeujoWHkxqPfZ4PADhSWY8xUV5o6zDgno92SU20Nh0+jZdmJWDvqVq890uRbByeyb/XHUZrRwfmXhkGHxcdnvnuAA5UGN+j574/KI3kEkUROZ1Bv1alQKC7A46dbsS+kjos+DBH2kJs7qf9Fdha2HWTYl+p9b8PGDvBG0QRtU1t0k2H/aV1ONfaLhvxZ1Jsmekuret2i313Qbf5LGlbN5x6u+Xe3UmNc63tNo8VmTUOSwl1x7nWDmw4VImy2mb87fsD+MXsvfFzdYAgCJgQ441NBacRqHeQ7YwxMX+Nns5dQbebgxoalUKqhQaM2fqkYOtGgOaW3pyE7KJqpEXId8XMHBkk9TS4JiFAduzl2Yn4eMdJjI/2Pm+5gPk1mlhuo7cl0tsZekc1as61YU9JAwwGEYoBnipm0E1ERER0gdwc1Hhnbspv+pyCIODDO1PR1NZhM9gYiMZGeeOaxADZ6J/7JgzB5KE+8HdzwPf5ZVLQbRlkqJQKjLHRtM5EoRAw2aKW3sdVhw2PTABgzCSHeTriTEMrFoyLwPxxEfjpQAUe+GQ3BAG4MSUIWpVSqrd30aqR2Dmm7spIL/z82CR8kXsKr20oRG1TGyrqugK44YFuCPV0QkKgG/aYBbOCAPztungoFQJGmI28+2SnPGv84poCXJMoD2YAYH9nYLrZrCv3B3emIiHIDduPnYGDWomUMA9Zk7cjlQ0QRRHP/3BI1rV614mzmPSvLKt/Q++ohk6llDrlL9t4FFuOVOHOMeH4bFfXDZK65na8+/MxLJoSg1Nnm6TXnxrugYcyojDrzW0QRWC92Rb/R6fG4KUfCzqft1Aa0QcYa6/bOwyyQO1sYyvmLc/GofI62ag4ADhc0WBzbGCJRU33ofJ6xD31I65NDMALM4fLAlPT9m5BAD6Yl4r/bj+B29PDcOu7XeUGvjY6y5tPHPBz1WFMlBc2FVQiQO8gu3nh6aRFUy/mgEf5ukDvqJHKIcxv0gCQguW/Xz8cX+SeQkacr83RfubMM91encGtef32+Gjv8z6Hs1aFiTay4TePCkZhZQNqmloxd3SY7JiPiw4PZUT3+Lxd12X93sZ10/DNnEIhYFSYB9YeqEBdSweOVDZgaID1TYiBZHD8RiciIiIaBARBGDQBt8mSa+Kw9kA5mtsMiPVzgUqpwLDOD9DjorygUghoN4iY8ht3uHfRqbFu0XgoBEHaPnxtYgCC3R1gEIHkUGMW0Dgb2Hq3goeTBnePjcDvEgJw5/Js2Tbm9M466Hmjw7FoZZ5U5z1/bIQ0Ls7HVYdAvYPNGeOHyutxpvG41eP5JbXoMIhSFtRFp0JikBtUSgXGmpU0mI+F21dSi0dW7sEXnTcv1EoB7o4aq1p6U/O/+WMjUFx9Dp+adT7fc8r4HCaCYNw+/u6WItx+ZRiyzba2jwrzQHKoBxZlRONfaw9Lj6eFe+C+CUPw5e4SFFY2yAJuwFj7fqyqEYF6B/x77WG0dRiwr7ROqmu3DEQPl9dbBd1tHQZZAzKTprYOrMgpxg0jA5FmVndsynS7O2owLtrbqvs2YLvbdmNLVyDtqFXinzcmQhRFFFc3YdxLG6Vj7k5qNLbKs7nTh/tj7YEK2Yi8SB9nJAXr4aJTyTryA8ayClOwHKB3wP2ToqyuxxbzUgpTcGvewTwz/uK7jykUAp66xnavhgsR7Ws2vlCrwl+mD0WQu2MP39EltTPoBoyz7Bl0ExERERF1w8tZi68WjsYnO05iVnKw7FiopxN+fHgc6pvbbWY1fy1b219tzd7uiZ+bDm/cOhLjX9okPXblEGNgd92IQIyL9kZFXTNUCkEWDAPG5njmQfft6aHSzHNTQOiiU2GItzPyimtQXN2E29/bgZrODuqjh3jZfA1uDmr4umpRUdeCvadqZdnXp64ZhunD/bEypxibCirR1GbAg5MjMTbKGy3tBjhrVag51wpPZw2yi85KteKmmuwpMR7wcnPCxzuL0djagSVf74erQ1fIkBJmfP/unxSJg+V1+CG/HGqlgKevHQZBEHBTShD+/oPtMWn7SmrxZtZRfJFbYvO4uYKKeunvre0GLN9ahINl9VaN7My9kXUUcQGuUCoEOKiVUqbbu4f+CLayseOivbHtmLEGflpn8CoIAkI8HREf6Ip9JXUI9XSEVqW0qq2eMswXwR6OUgM5wFg3LggCls9LxZp9ZXBzUONMYysOV9Rj2nB/aYb8hZAF3Z3Z+vHR3vhHZ7Z/6rD+G9NoEuHtjP/8fgROVp/DTSnBF9SnIi3CA3H+Lhjmq+t2HNpAwqCbiIiIiPpUrJ8rnpkRb/NYb7oZ97dQTye8d0cKFq3cg2EBrkgf0pVN9XDSdNv1fVKsD9YdNGbr/jQlGneOCcfKnGI0t3VlQe8eE4HG1nYp42s+Yqyn+fZRPi6yLe+OGiVempWI6QnGIPHe8UNwr9lIL8A4Tg4A9I4aPDo1FqIo4rZ3d2JLZ2Y92N0Bj00KgbPeA6v3lePsuTZ8b9ZATaUQMKKzTlgQBLxy8whMii3FEG8nqfv+XWMi8PORKvxstkXe5F8/HbaZ+bflcGfQXV7bjPv+t0s2Kg0w7lrYerQKLe0GKXu8qeA0hj/9E/SOajx33XCps7iXS/eN0Zy01gHvvNFh2HvK+O8tnBgpO/bGnGSsyinG1M5JB5bNxIb6u2LyUF98ufsUKupaMCnWR9rynhzqLu2w+LXczdac6aaCIAiYbaPpYH/6XYJ1GUVvJATp8d0fx6CyshI+Pr/Ne9afGHQTEREREZ3HpFhf5D55lVWn657clBKEqoYWOGtVmHtlGJQKATcmB+Oj7cZs9w0jA/HHSZGoa25DdWMrvskrRWuHAV7OGvxxUhRuGBHY7XOHezlJwTJgrP0edZ4xcZYEwbiN+Ka3tgEAXr05CU7aVni7aPHsjHj88ZPdsvPjA91kTcbUSgVmJcu35isVAl67ZQRmvrEVR083Yk5aCP7X2dDOPOAeG+WFME8nRPk646mv98NSQXk9RFHE/A9zkF9i3YRtVJg7/nVTIkQR+DT7pOw5as61YeHHXePNesqw2mqaplMr8catyTbPD/ZwxKIpMbLXay7cywlqpQIr70nHhkOVFx10no+3WZMyH9e+n3RAvw6DbiIiIiKiXriQgBswbm9/YLK8RvdPU2LQ2NIOXzcdFl0VDYVCgN5Rg3/emIjFmbHYX1qHlFB3OGl7/pg+IcZbCt7/MXP4BQfcJtG+Lti2eDIMoggHtUKaHf+7BH+sP1iBr/K6muBN6eWWZb2jBt8/MBZHTzcg1s8VGw9VotRsRNaUOF+8dVsyBEGAKIo2g+7K+hZsPXrGZsANAEEejlLm/sbkYPx3+wkcrmiAk0YpdZ03sWy4NzHGGxsLjB3XEztr8H+N6Qn++H5vGTLj/aRrCvV0wrzR4b/6ubszPtoHwwPdUNfchhlJ3d+cIfsgiKLlsALqTl1dHdzc3FBbWwtXV/utLTAYDJ1bMXygGOj99WnA43oke8L1SPaE65F+DVEUsenwaThrVRcdcFuyXJMdBhG7T57FsapG6NRKWVB5IR77bC9W5BRDEIDbrgjF4sxYWcPA/2w4gn/+ZGzKlhrugZ1FxjrzUWHuyD5u7Mj+yFXyxm3rFo2TZlYDQENLO8prmxHq6Yh//XQYH2w9juRQdzyUEYUUi/enpKYJf/v2ABKC3XDfBPn28YvR1mHA3lM1GBbgdlH12RfLFMbZGpc2GAyE35G9jQ+Z6SYiIiIiGmAEQcDEGOtxT78lpUJASpiHVdB6oZ783VAkh7kjIcgNsX7WgcmCcUOgUAgIdndEXXObFHSbAm5BAGanBkPvqMZfv96POH9XhHvJewE4a1VSI7vFmbF47OqYboPRQL0D3rzN9vbxi6FWKpAc+tvc+LgQgzXYHowYdBMRERERUZ9x0alxU0pwt8c1KoWUcd5ns37bAz4uOtyWHoar4vzg4aQ57wxqBqRkT+wzT09ERERERJed+EA33GjRnG1aZ6dwwDjCTaNiCEMDC1csERERERHZjedvGC6NPnPRqTCt8+9EAxW3lxMRERERkd1QKRV47eYRmDUyCCGejvBx0fX3JRH9Kgy6iYiIiIjIrigUAibG9m2jOKJLhdvLiYiIiIiIiPoIg24iIiIiIiKiPsKgm4iIiIiIiKiPMOgmIiIiIiIi6iMMuomIiIiIiIj6CINuIiIiIiIioj7CoJuIiIiIiIioj1x2QXdxcTEmTJiAuLg4JCQkYNWqVf19SURERERERDRIqfr7Ai41lUqFpUuXIikpCeXl5UhOTsa0adPg5OTU35dGREREREREg8xlF3T7+/vD398fAODn5wcvLy9UV1cz6CYiIiIiIqLfnN1tL9+8eTOuueYaBAQEQBAEfPXVV1bnLFu2DGFhYdDpdEhLS8POnTsv6t/atWsXOjo6EBwc/CuvmoiIiIiIiMia3QXdjY2NSExMxLJly2weX7FiBRYtWoQlS5YgNzcXiYmJmDp1KiorK6VzkpKSEB8fb/WntLRUOqe6uhq333473n777T5/TURERERERHR5srvt5ZmZmcjMzOz2+Msvv4z58+dj3rx5AIA333wT33//Pd577z0sXrwYAJCXl9fjv9HS0oLrrrsOixcvxpVXXtnjeS0tLdLXdXV1AACDwQCDwdDbl3TJGQwGiKJo19dIlw+uR7InXI9kT7geyd5wTZI9GQjrsbfXZndBd09aW1uxa9cuPP7449JjCoUCGRkZ2LZtW6+eQxRF3HHHHZg0aRJuu+22Hs99/vnn8cwzz1g9fvr0aTQ3N1/YxV9CBoMBtbW1EEURCoXdbWagywzXI9kTrkeyJ1yPZG+4JsmeDIT1WF9f36vzBlTQXVVVhY6ODvj6+soe9/X1xaFDh3r1HL/88gtWrFiBhIQEqV78o48+wvDhw63Offzxx7Fo0SLp67q6OgQHB8Pb2xuurq4X/0L6mMFggCAI8Pb2ttsFSpcPrkeyJ1yPZE+4HsnecE2SPRkI61Gn0/XqvAEVdP8WxowZ0+ttAFqtFlqt1upxhUJht//hTQRBGBDXSZcHrkeyJ1yPZE+4HsnecE2SPbH39djb67LPq++Gl5cXlEolKioqZI9XVFTAz8+vn66KiIiIiIiIyLYBFXRrNBokJydj/fr10mMGgwHr169Henp6P14ZERERERERkTW7217e0NCAwsJC6euioiLk5eXBw8MDISEhWLRoEebOnYuUlBSkpqZi6dKlaGxslLqZExEREREREdkLuwu6c3JyMHHiROlrUyOzuXPnYvny5Zg9ezZOnz6Np556CuXl5UhKSsKaNWusmqv9lpYtW4Zly5ahvb0dQNfoMHtlMBhQX18PnU5nt/UPdPngeiR7wvVI9oTrkewN1yTZk4GwHk1xoSiKPZ4niOc7gySnTp1CcHBwf18GERERERER2Yni4mIEBQV1e5xB9wUwGAwoLS2Fi4sLBEHo78vplmm0WXFxsV2PNqPLA9cj2ROuR7InXI9kb7gmyZ4MhPUoiiLq6+sREBDQYzbe7raX2zOFQtHjHQx74+rqarcLlC4/XI9kT7geyZ5wPZK94Zoke2Lv69HNze2859jn5ngiIiIiIiKiQYBBNxEREREREVEfYdA9CGm1WixZsgRarba/L4WI65HsCtcj2ROuR7I3XJNkTwbTemQjNSIiIiIiIqI+wkw3ERERERERUR9h0E1ERERERETURxh0ExEREREREfURBt2DzLJlyxAWFgadToe0tDTs3Lmzvy+JBqHNmzfjmmuuQUBAAARBwFdffSU7LooinnrqKfj7+8PBwQEZGRk4cuSI7Jzq6mrMmTMHrq6u0Ov1uOuuu9DQ0HAJXwUNFs8//zxGjRoFFxcX+Pj44LrrrkNBQYHsnObmZixcuBCenp5wdnbGzJkzUVFRITvn5MmTmD59OhwdHeHj44NHH30U7e3tl/Kl0CDwxhtvICEhQZorm56ejtWrV0vHuRapP73wwgsQBAEPPfSQ9BjXJF1KTz/9NARBkP2JjY2Vjg/W9cigexBZsWIFFi1ahCVLliA3NxeJiYmYOnUqKisr+/vSaJBpbGxEYmIili1bZvP4iy++iFdffRVvvvkmduzYAScnJ0ydOhXNzc3SOXPmzMH+/fuxdu1afPfdd9i8eTMWLFhwqV4CDSJZWVlYuHAhtm/fjrVr16KtrQ1TpkxBY2OjdM7DDz+Mb7/9FqtWrUJWVhZKS0txww03SMc7Ojowffp0tLa2YuvWrfjggw+wfPlyPPXUU/3xkmgACwoKwgsvvIBdu3YhJycHkyZNwowZM7B//34AXIvUf7Kzs/HWW28hISFB9jjXJF1qw4YNQ1lZmfRny5Yt0rFBux5FGjRSU1PFhQsXSl93dHSIAQEB4vPPP9+PV0WDHQDxyy+/lL42GAyin5+f+NJLL0mP1dTUiFqtVvzkk09EURTFAwcOiADE7Oxs6ZzVq1eLgiCIJSUll+zaaXCqrKwUAYhZWVmiKBrXn1qtFletWiWdc/DgQRGAuG3bNlEURfGHH34QFQqFWF5eLp3zxhtviK6urmJLS8ulfQE06Li7u4vvvPMO1yL1m/r6ejEqKkpcu3atOH78ePHBBx8URZG/H+nSW7JkiZiYmGjz2GBej8x0DxKtra3YtWsXMjIypMcUCgUyMjKwbdu2frwyutwUFRWhvLxcthbd3NyQlpYmrcVt27ZBr9cjJSVFOicjIwMKhQI7duy45NdMg0ttbS0AwMPDAwCwa9cutLW1ydZkbGwsQkJCZGty+PDh8PX1lc6ZOnUq6urqpAwl0YXq6OjAp59+isbGRqSnp3MtUr9ZuHAhpk+fLlt7AH8/Uv84cuQIAgICEBERgTlz5uDkyZMABvd6VPX3BdBvo6qqCh0dHbIFCAC+vr44dOhQP10VXY7Ky8sBwOZaNB0rLy+Hj4+P7LhKpYKHh4d0DtHFMBgMeOihhzB69GjEx8cDMK43jUYDvV4vO9dyTdpas6ZjRBciPz8f6enpaG5uhrOzM7788kvExcUhLy+Pa5EuuU8//RS5ubnIzs62Osbfj3SppaWlYfny5YiJiUFZWRmeeeYZjB07Fvv27RvU65FBNxERDRoLFy7Evn37ZPVhRJdaTEwM8vLyUFtbi88++wxz585FVlZWf18WXYaKi4vx4IMPYu3atdDpdP19OUTIzMyU/p6QkIC0tDSEhoZi5cqVcHBw6Mcr61vcXj5IeHl5QalUWnX3q6iogJ+fXz9dFV2OTOutp7Xo5+dn1eCvvb0d1dXVXK900e6//35899132LhxI4KCgqTH/fz80NraipqaGtn5lmvS1po1HSO6EBqNBpGRkUhOTsbzzz+PxMREvPLKK1yLdMnt2rULlZWVGDlyJFQqFVQqFbKysvDqq69CpVLB19eXa5L6lV6vR3R0NAoLCwf170gG3YOERqNBcnIy1q9fLz1mMBiwfv16pKen9+OV0eUmPDwcfn5+srVYV1eHHTt2SGsxPT0dNTU12LVrl3TOhg0bYDAYkJaWdsmvmQY2URRx//3348svv8SGDRsQHh4uO56cnAy1Wi1bkwUFBTh58qRsTebn58tuBq1duxaurq6Ii4u7NC+EBi2DwYCWlhauRbrkJk+ejPz8fOTl5Ul/UlJSMGfOHOnvXJPUnxoaGnD06FH4+/sP7t+R/d3JjX47n376qajVasXly5eLBw4cEBcsWCDq9XpZdz+i30J9fb24e/ducffu3SIA8eWXXxZ3794tnjhxQhRFUXzhhRdEvV4vfv311+LevXvFGTNmiOHh4WJTU5P0HFdffbU4YsQIcceOHeKWLVvEqKgo8ZZbbumvl0QD2B/+8AfRzc1N3LRpk1hWVib9OXfunHTOvffeK4aEhIgbNmwQc3JyxPT0dDE9PV063t7eLsbHx4tTpkwR8/LyxDVr1oje3t7i448/3h8viQawxYsXi1lZWWJRUZG4d+9ecfHixaIgCOJPP/0kiiLXIvU/8+7losg1SZfWI488Im7atEksKioSf/nlFzEjI0P08vISKysrRVEcvOuRQfcg89prr4khISGiRqMRU1NTxe3bt/f3JdEgtHHjRhGA1Z+5c+eKomgcG/bXv/5V9PX1FbVarTh58mSxoKBA9hxnzpwRb7nlFtHZ2Vl0dXUV582bJ9bX1/fDq6GBztZaBCC+//770jlNTU3ifffdJ7q7u4uOjo7i9ddfL5aVlcme5/jx42JmZqbo4OAgenl5iY888ojY1tZ2iV8NDXR33nmnGBoaKmo0GtHb21ucPHmyFHCLItci9T/LoJtrki6l2bNni/7+/qJGoxEDAwPF2bNni4WFhdLxwboeBVEUxf7JsRMRERERERENbqzpJiIiIiIiIuojDLqJiIiIiIiI+giDbiIiIiIiIqI+wqCbiIiIiIiIqI8w6CYiIiIiIiLqIwy6iYiIiIiIiPoIg24iIiIiIiKiPsKgm4iIiIiIiKiPMOgmIiIiIiIi6iMMuomIiC4z+fn5mDVrFkJDQ6HT6RAYGIirrroKr732mnTO3//+d3z11Vf9d5FERESDhCCKotjfF0FERESXxtatWzFx4kSEhIRg7ty58PPzQ3FxMbZv346jR4+isLAQAODs7IxZs2Zh+fLl/XvBREREA5yqvy+AiIiILp3nnnsObm5uyM7Ohl6vlx2rrKzsn4siIiIaxLi9nIiI6DJy9OhRDBs2zCrgBgAfHx8AgCAIaGxsxAcffABBECAIAu644w7pvJKSEtx5553w9fWFVqvFsGHD8N5778mea9OmTRAEAStWrMATTzwBPz8/ODk54dprr0VxcXFfvkQiIiK7wkw3ERHRZSQ0NBTbtm3Dvn37EB8fb/Ocjz76CHfffTdSU1OxYMECAMCQIUMAABUVFbjiiisgCALuv/9+eHt7Y/Xq1bjrrrtQV1eHhx56SPZczz33HARBwGOPPYbKykosXboUGRkZyMvLg4ODQ5++ViIiInvAmm4iIqLLyNq1a5GZmQkASE1NxdixYzF58mRMnDgRarVaOq+7mu67774bP/zwA/Lz8+Hp6Sk9fsstt2D16tUoKyuDg4MDNm3ahIkTJyIwMBAHDx6Ei4sLAGDVqlW46aab8Morr+CBBx7o+xdMRETUz7i9nIiI6DJy1VVXYdu2bbj22muxZ88evPjii5g6dSoCAwPxzTff9Pi9oiji888/xzXXXANRFFFVVSX9mTp1Kmpra5Gbmyv7nttvv10KuAFg1qxZ8Pf3xw8//NAnr4+IiMjeMOgmIiK6zIwaNQpffPEFzp49i507d+Lxxx9HfX09Zs2ahQMHDnT7fadPn0ZNTQ3efvtteHt7y/7MmzcPgHUztqioKNnXgiAgMjISx48f/81fFxERkT1iTTcREdFlSqPRYNSoURg1ahSio6Mxb948rFq1CkuWLLF5vsFgAADceuutmDt3rs1zEhIS+ux6iYiIBiIG3URERISUlBQAQFlZGQBjRtqSt7c3XFxc0NHRgYyMjF4975EjR2Rfi6KIwsJCBudERHTZ4PZyIiKiy8jGjRthq4eqqcY6JiYGAODk5ISamhrZOUqlEjNnzsTnn3+Offv2WT3H6dOnrR778MMPUV9fL3392WefoaysTGrmRkRENNixezkREdFlJD4+HufOncP111+P2NhYtLa2YuvWrVixYgWCg4Oxe/du6PV6TJ8+HVlZWXj22WcREBCA8PBwpKWloaKiAmlpaTh9+jTmz5+PuLg4VFdXIzc3F+vWrUN1dTUASN3Lhw8fDkEQMG/ePFRUVGDp0qUICgrCnj174Ojo2M/vBhERUd9j0E1ERHQZWbNmDVatWoWtW7fi1KlTaG1tRUhICDIzM/Hkk0/Cx8cHAFBQUIAFCxYgOzsbTU1NmDt3rjQ+rLKyEs8++yy++eYblJeXw9PTE8OGDcPs2bMxf/58AF1B9yeffIK9e/fi3XffRX19PSZNmoTXX38dISEh/fUWEBERXVIMuomIiOg3Zwq6V61ahVmzZvX35RAREfUb1nQTERERERER9REG3URERERERER9hEE3ERERERERUR9hTTcRERERERFRH2Gmm4iIiIiIiKiPMOgmIiIiIiIi6iMMuomIiIiIiIj6CINuIiIiIiIioj7CoJuIiIiIiIiojzDoJiIiIiIiIuojDLqJiIiIiIiI+giDbiIiIiIiIqI+wqCbiIiIiIiIqI/8f2PaPDZ/p6ZEAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot loss curve\n", "plt.figure(figsize=(10, 6))\n", "plt.plot(losses, linewidth=2)\n", "plt.xlabel('Step', fontsize=12)\n", "plt.ylabel('MSE Loss', fontsize=12)\n", "plt.title('Training Loss Curve - Sequence Modeling', fontsize=14)\n", "plt.grid(True, alpha=0.3)\n", "plt.yscale('log') # Log scale to better see the decrease\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 312 }, "id": "1FrQzqIoajtL", "outputId": "1aa099e9-b000-4843-df1e-47d09b041767" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA/QAAAEnCAYAAAAHEAjKAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XV4FNf6wPHvrGXj7p5AiOAWtKVIKVUq1EuN2u29FVpqt97euis1ar96C7QFCqW4O4FAIO7unrX5/TFJIGiA7G4Szud58iQ7O7vn7GZWzpz3vK8ky7KMIAiCIAiCIAiCIAg9isreHRAEQRAEQRAEQRAE4dSJAb0gCIIgCIIgCIIg9EBiQC8IgiAIgiAIgiAIPZAY0AuCIAiCIAiCIAhCDyQG9IIgCIIgCIIgCILQA4kBvSAIgiAIgiAIgiD0QGJALwiCIAiCIAiCIAg9kBjQC4IgCIIgCIIgCEIPJAb0giAIgiAIgiAIgtADiQG9IAiCIFjB6tWrkSSJ1atX27srQg/w7LPPIklSh20RERHccsstXdbGLbfcQkRERJfdnyAIgmB/YkAvCIIg2NxXX32FJEnH/dm8ebO9u9gtlZWVcf/99xMbG4ujoyN+fn6MHDmSRx99lPr6ent3r0c7/PhTqVQEBQVx/vnn97gTMoWFhTz77LPs3r3b3l0RBEEQbEBj7w4IgiAIZ6/nn3+eyMjIo7b36dPHDr3p3iorKxk+fDi1tbXcdtttxMbGUlFRwZ49e/j444+55557cHFxsXc3e7QpU6Ywc+ZMZFkmKyuLjz76iIkTJ7J48WKmTZtm8/4cPHgQlerU5l4KCwt57rnniIiIYPDgwR2u++yzz7BYLF3YQ0EQBMHexIBeEARBsJtp06YxfPhwe3ejR/jiiy/Izc1lw4YNjBkzpsN1tbW16HQ6O/Ws94iJieHGG29sv3z55ZczcOBA3nnnneMO6Jubm9HpdKc88O4MBweHLr0/rVbbpfcnCIIg2J8IuRcEQRC6rWeeeQaVSsWKFSs6bL/zzjvR6XQkJSUBYDAYePrppxk2bBju7u44Ozszfvx4Vq1a1eF22dnZSJLEG2+8wYcffkhUVBROTk6cf/755OXlIcsyL7zwAiEhITg6OnLZZZdRWVnZ4T4iIiK4+OKL+fvvvxk8eDB6vZ74+Hjmz5/fqce0ZcsWLrjgAtzd3XFycuLcc89lw4YNJ71dRkYGarWaUaNGHXWdm5sber3+tNpZv349I0aMQK/XEx0dzSeffHLUeu625+2rr7466vaSJPHss8922FZQUMBtt92Gv78/Dg4OJCQkMG/evA77tOUY+Pnnn/nf//5HSEgIer2eSZMmkZ6eflQ7W7Zs4cILL8TT0xNnZ2cGDhzIu+++22GfAwcOcNVVV+Hl5YVer2f48OH88ccfR91XZw0YMAAfHx+ysrI69PnHH3/kySefJDg4GCcnJ2pra9v7eLrP+bEcaw19dXU1Dz74IBERETg4OBASEsLMmTMpLy9n9erVjBgxAoBbb721fQlB2//tWGvoGxoaeOihhwgNDcXBwYF+/frxxhtvIMtyh/0kSeLf//43CxcupH///u3/16VLl57q0yoIgiB0ITFDLwiCINhNTU0N5eXlHbZJkoS3tzcATz75JH/++Se33347e/fuxdXVlWXLlvHZZ5/xwgsvMGjQIECZof7888+57rrruOOOO6irq+OLL75g6tSpbN269ajQ4++++w6DwcB//vMfKisree2117j66quZOHEiq1ev5tFHHyU9PZ3333+fhx9++KjBaFpaGtdccw133303N998M19++SUzZsxg6dKlTJky5biPd+XKlUybNo1hw4a1n6z48ssvmThxIuvWrWPkyJHHvW14eDhms5lvv/2Wm2+++YTPa2fb2bt3L+effz6+vr48++yzmEwmnnnmGfz9/U94/ydSUlLCqFGj2geAvr6+/PXXX9x+++3U1tbywAMPdNj/lVdeQaVS8fDDD1NTU8Nrr73GDTfcwJYtW9r3Wb58ORdffDGBgYHcf//9BAQEkJKSwqJFi7j//vsB2LdvH2PHjiU4OJjHHnsMZ2dnfv75Z6ZPn85vv/3G5ZdffsqPpaqqiqqqqqOWgLzwwgvodDoefvhhWlpa0Ol0NnnO6+vrGT9+PCkpKdx2220MHTqU8vJy/vjjD/Lz84mLi+P555/n6aef5s4772T8+PEAR0V0tJFlmUsvvZRVq1Zx++23M3jwYJYtW8acOXMoKCjg7bff7rD/+vXrmT9/Pv/6179wdXXlvffe48orryQ3N7f9NSsIgiDYmCwIgiAINvbll1/KwDF/HBwcOuy7d+9eWafTybNmzZKrqqrk4OBgefjw4bLRaGzfx2QyyS0tLR1uV1VVJfv7+8u33XZb+7asrCwZkH19feXq6ur27Y8//rgMyIMGDepwv9ddd52s0+nk5ubm9m3h4eEyIP/222/t22pqauTAwEB5yJAh7dtWrVolA/KqVatkWZZli8Ui9+3bV546dapssVja92tsbJQjIyPlKVOmnPA5Ky4uln19fWVAjo2Nle+++275+++/7/A4TrWd6dOny3q9Xs7JyWnftn//flmtVsuHf0Voe96+/PLLo/oFyM8880z75dtvv10ODAyUy8vLO+x37bXXyu7u7nJjY2OH5ycuLq7D/+7dd9+VAXnv3r2yLCv/28jISDk8PFyuqqo66rG2mTRpkjxgwIAO/yuLxSKPGTNG7tu371H9PtbjuP322+WysjK5tLRU3rJlizxp0iQZkN98880OfY6Kimp/HG3tdPVzLsvKsXbzzTe3X3766adlQJ4/f/5R/W9rd9u2bcf9X918881yeHh4++WFCxfKgPziiy922O+qq66SJUmS09PTOzw/Op2uw7akpCQZkN9///2j2hIEQRBsQ4TcC4IgCHbz4Ycfsnz58g4/f/31V4d9+vfvz3PPPcfnn3/O1KlTKS8v5+uvv0ajORRkplar29eQWywWKisrMZlMDB8+nJ07dx7V7owZM3B3d2+/nJiYCMCNN97Y4X4TExMxGAwUFBR0uH1QUFCHGV83NzdmzpzJrl27KC4uPuZj3b17N2lpaVx//fVUVFRQXl5OeXk5DQ0NTJo0ibVr154wYZm/vz9JSUncfffdVFVVMXfuXK6//nr8/Px44YUX2kOkO9uO2Wxm2bJlTJ8+nbCwsPZ24uLimDp16nH7cSKyLPPbb79xySWXIMtye9vl5eVMnTqVmpqao/4ft956a4f1/22zypmZmQDs2rWLrKwsHnjgATw8PDrctm1ZQGVlJStXruTqq6+mrq6uvc2KigqmTp1KWlraUf/DY/niiy/w9fXFz8+PxMRENmzYwOzZs4+KKrj55ptxdHRsv2yr5/y3335j0KBBx4w2OLLkXWcsWbIEtVrNfffd12H7Qw89hCzLR70WJ0+eTHR0dPvlgQMH4ubm1v6/EgRBEGxPhNwLgiAIdjNy5MhOJcWbM2cOP/74I1u3buWll14iPj7+qH2+/vpr3nzzTQ4cOIDRaGzffqws+ocPpoD2wX1oaOgxt1dVVXXY3qdPn6MGUDExMYCy3jwgIOCoNtPS0gBOGC5fU1ODp6fnca8PDAzk448/5qOPPiItLY1ly5bx6quv8vTTTxMYGMisWbM63U5LSwtNTU307dv3qOv79evHkiVLjnv74ykrK6O6uppPP/2UTz/99Jj7lJaWdrh85P+i7fG3PecZGRmAcmLneNLT05FlmaeeeoqnnnrquO0GBwefsP+XXXYZ//73v5EkCVdXVxISEnB2dj5qvyOPKVs95xkZGVx55ZUn3OdU5OTkEBQUhKura4ftcXFx7dcf7sj/FSj/ryNfH4IgCILtiAG9IAiC0O1lZma2D5r27t171PX/93//xy233ML06dOZM2cOfn5+qNVqXn755fYB4eHUavUx2znedvmIBGGno232/fXXXz9qTX+bzpadkySJmJgYYmJiuOiii+jbty/fffcds2bN6nQ7LS0tne778WZ/zWZzh8ttbd94443HHdwOHDiww+WueM7b2n344YePO9PdmVKIISEhTJ48+aT7HT47f3j7Xfmcd0fWfH0IgiAIp0cM6AVBEIRuzWKxcMstt+Dm5sYDDzzASy+9xFVXXcUVV1zRvs+vv/5KVFQU8+fP7zD4fOaZZ6zSp7YZ4cPbSk1NBTgqi3ibtlBlNze3Tg0aOysqKgpPT0+KiopOqR1fX18cHR3bT5Qc7uDBgx0ut82aV1dXd9h+5Ayur68vrq6umM3mLnuMbY8nOTn5uPcZFRUFKGXZuvK57SxrPOfHayc5OfmE+5xK6H14eDj//PMPdXV1HWbpDxw40H69IAiC0L2JNfSCIAhCt/bWW2+xceNGPv30U1544QXGjBnDPffc0yE7ftvM4eEzhVu2bGHTpk1W6VNhYSELFixov1xbW8s333zD4MGDjxluDzBs2DCio6N54403qK+vP+r6srKyE7a5ZcsWGhoajtq+detWKioq6Nev3ym1o1armTp1KgsXLiQ3N7f9+pSUFJYtW9bhNm5ubvj4+LB27doO2z/66KMOl9VqNVdeeSW//fbbMQeeJ3uMxzJ06FAiIyN55513jjqh0Pb/9vPzY8KECXzyySftJzbOtN1TYY3n/FiuvPJKkpKSOhx7bdqei7YlAkc+V8dy4YUXYjab+eCDDzpsf/vtt5EkiWnTpp30PgRBEAT7EjP0giAIgt389ddf7bOBhxszZgxRUVGkpKTw1FNPccstt3DJJZcA8NVXXzF48GD+9a9/8fPPPwNw8cUXM3/+fC6//HIuuugisrKymDt3LvHx8cccYJ2pmJgYbr/9drZt24a/vz/z5s2jpKSEL7/88ri3UalUfP7550ybNo2EhARuvfVWgoODKSgoYNWqVbi5ufHnn38e9/bffvst3333HZdffjnDhg1Dp9ORkpLCvHnz0Ov1PPHEE6fcznPPPcfSpUsZP348//rXvzCZTLz//vskJCSwZ8+eDu3PmjWLV155hVmzZjF8+HDWrl3bHpVwuFdeeYVVq1aRmJjIHXfcQXx8PJWVlezcuZN//vmHysrKU3quVSoVH3/8MZdccgmDBw/m1ltvJTAwkAMHDrBv3772gfCHH37IuHHjGDBgAHfccQdRUVGUlJSwadMm8vPzSUpKOqV2T7WP1njOjzRnzhx+/fVXZsyYwW233cawYcOorKzkjz/+YO7cuQwaNIjo6Gg8PDyYO3curq6uODs7k5iYeMxcEpdccgnnnXce//3vf8nOzmbQoEH8/fff/P777zzwwAMdEuAJgiAI3ZR9kusLgiAIZ7MTla2jteSWyWSSR4wYIYeEhBxVmq2ttNlPP/0ky7JSsuull16Sw8PDZQcHB3nIkCHyokWLjirT1VZ+7fXXX+9wf23lyH755Zdj9nPbtm3t28LDw+WLLrpIXrZsmTxw4EDZwcFBjo2NPeq2R5ata7Nr1y75iiuukL29vWUHBwc5PDxcvvrqq+UVK1ac8Dnbs2ePPGfOHHno0KGyl5eXrNFo5MDAQHnGjBnyzp07j9q/s+2sWbNGHjZsmKzT6eSoqCh57ty58jPPPHNUCbXGxkb59ttvl93d3WVXV1f56quvlktLS48qWyfLslxSUiLfe++9cmhoqKzVauWAgAB50qRJ8qeffnrS5/x4JfLWr18vT5kyRXZ1dZWdnZ3lgQMHHlUuLSMjQ545c6YcEBAga7VaOTg4WL744ovlX3/99YTPrSwrZdnuvffeE+5zvD636ern/MiydbIsyxUVFfK///1vOTg4WNbpdHJISIh88803dygT+Pvvv8vx8fGyRqPp8Fwe+XqQZVmuq6uTH3zwQTkoKEjWarVy37595ddff71D+b0TPT/H6qMgCIJgO5Isi0wmgiAIgtBZERER9O/fn0WLFtm7K1bz7LPP8txzz4lkZ4IgCILQzYk19IIgCIIgCIIgCILQA4kBvSAIgiAIgiAIgiD0QGJALwiCIAiCIAiCIAg9kFhDLwiCIAiCIAiCIAg9kJihFwRBEARBEARBEIQeSAzoBUEQBEEQBEEQBKEH0ti7A92dxWKhsLAQV1dXJEmyd3cEQRAEQRAEQRCEXk6WZerq6ggKCkKlOv48vBjQn0RhYSGhoaH27oYgCIIgCIIgCIJwlsnLyyMkJOS414sB/Um4uroCyhPp5uZm594cn8VioaysDF9f3xOewRHOHuKYEI4kjgnhWMRxIRxJHBPCkcQxIRxJHBPWV1tbS2hoaPt49HjEgP4k2sLs3dzcuv2Avrm5GTc3N/GiEgBxTAhHE8eEcCziuBCOJI4J4UjimBCOJI4J2znZsm/x7AuCIAiCIAiCIAhCDyQG9IIgCIIgCIIgCILQA4kBvSAIgiAIgiAIgiD0QGJALwiCIAiCIAiCIAg9kBjQC4IgCIIgCIIgCEIPJAb0giAIwllFlmW+3ZTN5+sysVhke3dHEARB6AUKq5s4WFxn724IZyFRtk4QBEE4q6w+WMbP2/MBiPJ1ZmKsv517JAiCIPRkZovMEwv2Utlg4I0Zg4jxP3HdcEHoSmKGXhAEQThrVDca+HRtZvvlLzdk02Qw27FHgiAIQk+3v7CWinoDsgzfb8m1d3eEs4wY0AuCIAhnjblrMqlvMRHp40ygu57qRiM/b8+zd7cEQRCEHmxzZkX73ztyqkgpqrVjb4SzjRjQC4IgCGeFjRnlbEgvRyXB/ZP7cvu4SAAW7i6gqKbJzr0TBEEQeiJZltnUOqAP8tAD8N2WnE7dtqrBwJ9JhTQbRaSYcPrEgF4QBEHo9eqajXy8OgOAK4eFEO3rwshIL4aEeWAyy8xbn2XnHgqCIAg9UWZ5A2V1Leg0Kp68KB61SiIpr4bkgpoT3s5gsvDMH/v4dG0mC3YVdE1nmqohdRmYTV1zf0KPIAb0giAIQq83b3021Y1Ggj0cuXZEGACSJDFrXBQqCTZnVrI7r9q+nRQEQRB6nE0Zyuz8sHBPQr2cmBKvJFr97iRr6b/ZlE1WeQMA27Iru6YzG9+DVS/Bnp+65v6EHkEM6AVBEIRebXdeNf+klCBJcN+kvug0hz76wryduGhgIACfrc3ELMrYCYIgCKegbf38qCgvAK4eHopGLZFcUMOe/Opj3mZHTiW/7y5sv5xeWk9Nk/HMOmIyQM4m5e+DS0AWn2dnCzGgFwRBEHqtJoOZD1amAXDRgEDig9yO2ue6kWG46jXkVjbyV3KRrbsoCIIg9FBFNU3kVDSikmBEhDKg93V1YGpCAAD/tzkH+YiBdXWjgXf+UT6XLh4YSISPM7IMu3KrOtVmeX0L936/k5+3HZHQtSgJjI3K3zX5UHbgDB6Z0JOIAb0gCILQa/3f5hxKalvwc3Vg5uiIY+7jqtdy06hwAL7bnEtt8xnOkgiCIAhnhbbZ+QEh7rjqte3bZwwLQauWSCmqY9dhy7ksFpm3l6dS3WgkwseZW8dGMjTMA4CdOZ0b0P+9r4TcikZ+3ZGPwWQ5dEXOho47pi47rcck9DxiQC8IgiD0SilFtfy5RwlpvHdiHxx16uPuOzUhgAgfZ+pbTKKGsCAIgtApbevnR0V5d9ju7eLAhQOU5Vzfbc5tn6X/I6mQnbnV6DQq5pzfD51GxbBwTwB25VVj6cSyrw3p5QA0Gc0kF7Ym3pNlyNmo/J1wufI7YwWYxQnqs4EY0AuCIAi9Tnl9C68uPYAsw6Q4P4aGeZ5wf5VK4o7xShm7v/YWkd2aqEgQBEEQjqW60cCB4joAEiO9j7r+yqEh6DQqUkvq2JFTRUZZPV9tzAbgjvGRhHk7ARAX6IajVk11o5HMk3z25FY0klvZ2H55S2ZrMr3KTKgvAY0DjLwTnLyguRbytnTBIxW6OzGgFwRBEHqVJoOZFxbtp6LeQKiXI7PGR3XqdgNDPBgT7Y1F7nwNYUEQBOHstDmzElmGvn4u+Lo6HHW9p7OOi1pn6b/dnMPrSw9itsiMjvZuX2MPoFWrGBjiDpw87H596+y8m6MGgC1ZFcrsf9vsfPBw0DlBnynKZRF2f1YQA3pBEASh17BYZF5fdpDMsgbcHbU8c0kCLg6aTt/+2pFKSbsdOVU0GczW6qYgCILQwx3Kbn/07HybK4eGoNeqyCxroKC6CW8XHf+e2AdJkjrs1xZ2v+MkA/q2cPubRoXjoFFRUW8go6zh0Pr58DHK777nK79zNykz9UKvJgb0giAIQq/x+fpMtmVXolVLPHlxHP5u+lO6fYS3EwHueoxmmZ2dzDgsCIIgnF0aDSaSWkvSnWhA7+6k5ZJBQQBIEjw0pR9uhyXPazO0dUB/oLiW+hbTMe8rr1IJt1erJMb19W2/za6DGVCaouwUNlr57dMHvKOVNfSZq0/jEQo9iRjQC4IgCL3CH0mF/JmklJ176Px+xAYcXaLuZCRJYnTrl7O2ZEeCIAiCcLjt2VWYzDJBHnpCvRxPuO8VQ0MY28eHf02IZkBraP2R/N30hHg6YpFhz2FZ8Q/XFm4/ONQDFwcNiZFKmbzqlDXKDn5x4HzYyYW2Wfo0EXbf24kBvSAIgtDjbcms4PN1mQDcMiaCsX18Tvu+RkcrX4i2ZVdiNFtOsrcgCIJwtmkLtx8d5X1U+PyRXBw0PDYtlgv6B55wv7aw++3HCbtvG9CPa/18Gx7hhUoC38odGC3yoXD7Nn0mg6SC4mSoKTjpYxJ6LjGgFwRBELq9tJI6liYXsTmzgtSSOsrrWzC1DrbTS+t4fdlBZBmmJvhzxdDgM2qrn78rHk5aGg1m9hbUdEX3BUEQhF7CYLKwPVsZdI+KPn64/aka0lqNZWduVXuZuzZ5lY3kVijh9olRysy8u6OW/v56YkwHaWgxQdgRA3pnHwgepvyd9neX9VPofjqfKUgQBEEQ7KCm0cjj8/fSYuo4Wy5J4KbXYjBZaDFZGBzqwd3nRp90tuRkVCqJUVHeLE0uZlNGxUlL3gmCIAhnj70F1TQZzXg664jxc+2y+x0Q7I6uNdFdbmUj4d7O7ddtOCzc3vWwNfjnexSglY0UW/zw8I4++k5jpkL+NkhbDsNuUT44hV6nR83Qr127lksuuYSgoCAkSWLhwoUnvc3q1asZOnQoDg4O9OnTh6+++srq/RQEQRC6zuK9RbSYLHg66+jr54K3iw6VBLIMNU1GmoxmwryceGxaLBp113ysjWqdAdmSVYnFIp9wX1mW+Xrf13y972uMZmPHKw2NsPMbyFgJppYu6ZsgCIJgP5tba78nRnqhUnXdAFmnUTEgWFljf2S2+7Zw+yOXkw2W9yt9Mvej/liVWSLGgdYRagugZF+X9VXoXnrUDH1DQwODBg3itttu44orrjjp/llZWVx00UXcfffdfPfdd6xYsYJZs2YRGBjI1KlTbdBjQRAE4Uw0G80s3lsIwKxxkZwT4wso5elqm41UNhioaTLS198V51MoT3cyA4I9cNSqqWowcLCkjrjA4yfYkyQJB7UDC9IXkFWTxSMjHsFJ6wQWC6x88VA5IZ0zRJ0HMeeD/wBQ9ahz6oIgCGc9i0U+tH6+C8Pt2wwN92RHThU7cqq4YmgIAPlVjeRUNLZGj3kd3hncS7ZRrVGRrO5PZE4V57Z+RrbTOkLkuZC6VPkJ6H90o2YTlCSDVxToTz2ZrGB/PWpAP23aNKZNm9bp/efOnUtkZCRvvvkmAHFxcaxfv5633377uAP6lpYWWloOzaLU1iq1Gy0WCxZL902OZLFYkGW5W/dRsC1xTAhH6onHxD/7i6lpMuLvqmd0lFeHvrvpNbjpD32MdeXj0qhgWLgH69LL2ZRRTj9/lw7XJ5Ul4apzJco9CoCBPgNZkLaAlIoUntn4DI+OeBSvfX8g5WwAtRYcPaG+FA4sUn5cA5D7TFGyELuHdFm/T0dPPC6EzqluNLAxo4Lz+vnhqFN3+nbimBCOJI4JRUpRLVWNBpx1GhICXbv8+RgS6o6MzL7CGhqajTjq1KxLK0NGZnCIB8469aE2yw4gNVbg4OhMuiqKzRnljO9zjJMMfSYjpS6FjFXIo/8Nap2y3dQMB/9C2vMT1JdAyHDkaa93uq/imLC+zj63PWpAf6o2bdrE5MmTO2ybOnUqDzzwwHFv8/LLL/Pcc88dtb2srIzm5uau7mKXsVgs1NTUIMsyKjHrIyCOCeFoPe2YsMgyv2zNxmQ0MSHSmYryMpu2H+ulZpXRxJqUIqZFOyJJEuXN5fyW8xtJlUmEu4Qzp/8cVJIKL7yYEz+HD1I+IKMygyeW3sN/i/MIklXUD76blrBz0ZTvxyFnNQ4Fm5Cq8mDbPNj2Jc3h59HY/wZkvX3W6ve040LovPfX5bMjr441KQU8eG4oqk6unxXHhHAkcUxAi8nCF2vzMRlNxAU7UVVR3uVtaGQZTweJsnoja5OzGRLiysrkAkxGEwk+GkpLS9v3ddy3DCeTEaPfcJqKZVZnbaFhzY/kNGRzaeiljPUf23qnwXhq3FA1VlC35y+MvgPQZy5Fn7YYleGwxK/Zm6jK2ovF2b9TfRXHhPXV1dV1ar9ePaAvLi7G37/jQenv709tbS1NTU04Oh5dN/Lxxx9n9uzZ7Zdra2sJDQ3F19cXN7fuG4ZisViQJAlfX1/xohIAcUwIR+tpx8TGjAoqmi14uui5IrHvKc0wdoXJHt58tb2MimYLLVpXMhu2Mm//PIxmIzqtjkGBg/Dy8ULXOtvhhx+v+L/CyxueprhoO8/rZR4Jv4S+I65R7tA/ABImKrMi2RuQ0pZB/jZcCtbhUrodechM6H/FodkTG+lpx4XQOeX1LewtSUej1ZBS1sKaPAPXDA/t1G3FMSEc6Ww/JhoNJt5elEJqRQvOeh3Xje6DXxcmxDvc6L71LEkuJqsO+utcKWow46DTMnVIZIeEeNL6vTRoNOwODaCh5Wua5Ap2VTripFPzc+7PVElV3BR/E1qVFhIuQkr6Ac+DP0LSJ2BsUu7EMxR54LVImaugKAmfqp0QeXOn+nm2HxO2oNfrO7Vfrx7Qnw4HBwccHByO2q5Sqbr9wSpJUo/op2A74pgQjtRTjglZllm4uxAJiQsHBOJ82JcYW3HWqxgU6sGOnCoW7t/K1rrPMctm+vv055b+txDqevTgKEDnyvNVtbxqkcjQOfCTo4qnJKlj5n2dE8RMUX5K9sHG96E0BWnrJ0o4/uh7lXrCNsxG3FOOC6Hz/kkpQ5bBw1FHTZORH7bmERfozuBQj07dXhwTwpHO1mOipsnIc3/sJ620HiedhmcvSaDfCfKqnKlh4V78lVzCztxqfFz0SEgMCfPE3emw8UldCVRksEdl5v+qdqJ1MNHc5ECQdjQT+4awIH0BmTWZqKTW/1fMVEj6AeqKldt7R8Pg6yHqPCSVWsnxUpSEdIrZ8M/WY8JWOvu89uoBfUBAACUlJR22lZSU4ObmdszZeUEQBKF7SCmq42BxHRq1xCWDguzWj9HR3mzJzWZB9rcE6CsZ29zCf9wTkYzGo3e2WGDVS7jXFPG0UzjfxU9kRv+ZJy6j558Al30E6cthyydKJuJlT0DIcBhzH3iGW+/BCb2W2SLz937li/sd50SRlFfN8v0lvL7sAO9eOwQfl6MnLgRBOFpFfQtP/76P3MpG3Bw1PH9Zf6J9XU5+wzMwMMQDjVqipLaFP/coSWHbstsXNxRT2VxJfNEBAEb4DmKAXxBeAfEs3uKBocKFq/uNIMYrhjDXMLTq1pPhXpEw6FqoyoGEyyF0JGbZQmFDoXJyOnI8rHdqzYafDAEDrPoYha7Vqwf0o0ePZsmSJR22LV++nNGjR9upR4IgCEJnzN+ZD8DEfn54ONk2BP1wiZFe1G9YjcFYRbipkruMTkj7f4f9vytfeOKnQ+Q5oNHBrm8gez2odeinvsTtfrEd7utA5QH6efY7eoDfNnsSMR52fwd7foL87bDoAbjuJ+W+BeEUbM2qpKLegLujltFR3oyK8iK9tJ6s8gZe/esAL18xoMtKPApCb1VS28x/FyRTUtuMt4uOFy7rT6iXk9XbddSpSQhyIymvhupGIyqVRHyIhi+Tv2R5znI8HDx4p8URHaCNGM+Tg6+j2WhmxY4tlNcbyCxvYIjfkA73+Vvqb/iGJDA+8W6ya7NZu/8bNhZuxGA28OmUT9FqHZEjz6E0dQn+qUvFgL6H6VHv5vX19ezevZvdu3cDSlm63bt3k5ubCyjr32fOnNm+/913301mZiaPPPIIBw4c4KOPPuLnn3/mwQcftEf3BUEQhE7Ir2pka7ZS53f6kGDrNCLLYDlGzd4jeDjpGOV1NaOaZO6s0+DgP0AZwEsqKN4LK1+A72fA6ldh+5fKjcY9CEcM5lfkrOCZjc/w5b4vscjHyVqrc4KRd8DV3yhZ8RsroWTvmT5S4Sz0V3IRAJPj/NBpVDho1Dx+YSxOOjUHiuv4amO2fTsoCN1cbkUjj/y6h5LaZvzd9Lx65UCbDObbDA1TkqRaMODhu5MnNsxmadZSzKZmQtFQX7RL2TFcmaTUa9UMaV1OsyWzssN97a/Yz8+pP/Ph7g+5b+V9PLbuMZZkLaG6pRq1pKagvgCD2cAH6noe0dVTmLEcTC0IPUePGtBv376dIUOGMGSIctZp9uzZDBkyhKeffhqAoqKi9sE9QGRkJIsXL2b58uUMGjSIN998k88//1zUoBfsptFg4vN1mewrrDn5zoLQC8myUsM3s6z+uPv8vrsQWYaRkV7W+QJlscD8O+Hnmcqg+SQu9izj31UVODQD586B81+A639W1hk6+0JTNRxsjQZLuBxiLzzqPgwWAwDLspfx7s53MZqPEbLfxi0IQkYof+fvOLXHJpz1imqa2JVbjSTBBf0D27cHujvy4JQYQHmNrU/r+gzdgtDTybLM2tQyHpu/h8oGA2FeTrx65QD83TqXnKyrDAv3xIW14PACtc0/0FyWQmRZBk+WlPB4RhJeZrNS8tTj0LKsxCilZN2WrIoO9xXrFcvlfS4HoLSpFK1Ky+jA0cwZPoePp3xMhHsEGpWGSo2WZrWGd+QKDJmrbfZYhTPXo0LuJ0yYgCzLx73+q6++OuZtdu3aZcVeCULn/bA1j993F7Ilq5JPbhyGSmW7pFeC0B1szKjglb+UtX/nxPgwc3REhy9K1Y0GVqQouU8ut9bsfEkylKcqf//zLFz0Fqg7fhzuKdtDTm0OF4dfQGLx9+QBf6vGcpU+GHcAF18YfisMnQk5G+DgX+DgqiS0O4ZpkdNw17nz4e4P2Vy0mVpDLXOGz8FJe5wTFiEjIO1vyN8GiXd21SMXzgJLk5W180PDPAlw7zgIGRXlzRVDg5m/s4D3VqQR4eNEiKftZh0FoTsrrW3mo9UZ7MipAqCvnwvPXpaAmx2Ssqoq12DQ/4QOmWCThuvMDoy26FGpHVoH8qGQcEWH5HUjIjyRJMgsa6CsrgVfVyVXhkpScW3stSR4J1DVUsVw/+FHffaoJBX/GXofjxZuJaeukG/2fMqsmBNMgFossOEdnGur4YKn6WFzxL1OjxrQC0JPVtVgYMleJQyyuKaZvQU1DOpktmFB6C2W7Stu/3ttajkb0iu4eGAgV48IxU2v5c89RRjNMjH+riQEWSmLcNbaQ38XJcG2z2DUPe2bShpKeHfnu9Qb63Eq3M2k+jwsDm4sc5hKeFYlk+MPK4eqUish+JHnnLTZMcFjcNW58sb2N9hfsZ9nNz3L4yMfx/NY9edDhiu/K9KgqUoJwReEkzCYLCzfr5wQu6B/wDH3mTk6gtSSOpILanl5yQHevHoQeq1tS0IKQnditsgs2lPIt5tyaDFZ0KglZgwL5aphIeg0dhiolqcRsO49rpP05Lj4cl/8TWg9o5RBvLOfknflGDycdMQGuJJSVMf27EqmDQjscP0A3xOvi/fSe/HvoQ/y8po5LK/LJD5rGWMijzOoT/4Vaf/v6E1GSB8L/UT0sz2J0ymCYCO/7czHYDq0dvbwgY0gnA1K65rZnVcNwOMXxjIkzAOzReb33YXM+no7P2/LY8ke5aTXFUODT5wd/nTJspK4DqDfNOV30o+QuQaAmpYa3tj+BvXGevq4hHJO+iYASvrdRJPkxKbMimPda6cN8B3As2OexV3nTk5tDu/sfOfYkWdOXuAVpfS3YOcZtSmcPTZklFPXbMLbRceICK9j7qNWSTwyNRYPJy25lY18sibTup2SZchYBfWl1m1HEE5DRlk9D/+SxOfrsmgxWUgIcuO9a4dwfWKYfQbztYWwZA5aYxNX+Y9m9pW/oR14DYSOANeA4w7m2wwM8QDgQHHdaTU/KGoK050iAZlPdrxLccMxvqtWZMDWT9svSju/ArPptNoTuoYY0AuCDVTUt7TPzt88JgKATZkV1DSdYB2tIPQyK1NKkWUYEOLOmGgfnr+sP89flkCkjzNNBjPfbs6hvsVEgLue0a1rAbtcRQbUFYHGAcY+AIOuA8Cy+mWW7fuOB1Y9QG5dLu46d2bL7mgNDeDdh5DRVwGwK7eKuuYze91GukfywtgXCHAKIMYzpv3EhdFsZFHmIkobWwc+7evot59Re8LZY+le5cv3BQkBqE+wpMvTWcecqf1QSfBPSgkrD5Qcd98zlrZcWdqyZI4SpisI3cQfSYXM/mk36aX1ODuo+ffEPrx0+QCbJr/roLGSxsUPYWqqBO8+cP6LSNpTKzHZ108pqZdeevw8NSczo//NxFo0NDdV8N7OdzsmcjUZYOWLYDZC6EgsDh7KSYiDS457f4L1iZB7QbCBX3fkYzTLxAW6cuXQYNallZFZ1sDqg6VcNthK64QFoRuxWGT+aV0bPyXuUMj6kDBPBoV4sCa1jP/bnENpXQvXDA+1Xn6JLGUmnpARoNXDyDvIK9rBR1U7ydzxHniGEekRzb/CLsJ72VPKvmPvI9LXFX83B0pqW7j+sy14OGnxd9Pj7+ZAgJsePzc9A4LdCfJw7FQ3/J39eWn8SxjMhvZtyRXJfLv/W77d/y1R7lFc5NqHcaCso5flDmslBeFI2eUN7C+qRSXBlMOXhRzHwBAPrh0ZxvdbcvloVQZ9/VytM5DZ/7vyuyobMldBn0ld34YgnCKj2cLXG7OxyDCurw93jo/C09mOJUINjbD0ceY1ppPrqOPusfcS5XDq9e5j/F0ByKtqpMlgxlF36stp1NETuW/Te7xkrOG6wPGopMPmf7d9BpWZ4OiJfO5jNO1aiG7fN7DzG6X8qubUTkAIXUPM0AuClZXXt7C0Nbz+hsRwJEliaoKytvHvfSUnTPQoCL1FcmENJbUtOOrUjI7uOPuuUkmcF+vHxzcOY+5NwzquUe9q2euU35HntjauRjv+IfLUKpzMBm7Dg5fG/o+wpF+UQXT0RAgchCRJXDMiDGcH5ctRdaORg8V1rE0t5+ft+XywMp0Hf9pNi+nkpfDaOGudO6yf16l0DPAZgAoVmTWZvJ+3jDS1BA1lUJ17gnsSBNo/Z0ZFeePt0rkv1dcMD2VQqDstJguvLD1As7Hzx2+nVGYqSSjb7PxGzNIL3UJaST0GkwV3Ry2PTO1n38G82QT/PMPe8mTWaSzkuvtj0Z9eDhlPZx0+Ljpk+Qxm6fVueIeP53WjMwNKD1uSk78D9vys/H3uI+DoSXPkFKXaS0PZoZN3gs2JAb0gWNkv2/MxmWX6B7sxMMQdgHNjfHHQqMitbDztdU6C0JP805qo69wY3+Mm4NJpVAR3cob7tFTnQWUWqNRkeRxKFhTgE8v9Ix7mHZM7UwtTUS1/WkmWp3GAxLvb95sS788Pd4zi+zsSefuaQTw2LZZbxkQwbUAADhoVjQYz+VVNp929BJ8Enhz1JHOnzCUxMBEkiS+ddViQlVl6QTiOJoOZlSnKUo0jE2GdiEol8fD5/ZT19BWNfLq2i9fTp/yp/A4ZrlSBaJulFwQ7S24tH5wQ5GadfC2dZbHAmlcx5m3hC60B3EI4P/pi+nj2Oe27bJulTy05g++XMVNRIUH6P2A2UV6dTeOq/ynXxV0C4WOUv9U65GG3KH/v+hYMDaffpnDaxIBeEKyotK65Pfld2+w8gLODhnF9fQCRHE/o/RpaTKxPV2pedyYU2Gpak+Hl+sXw320vk1qV2n7ViPhrcE+8t8N+DLoOXDv2V5IkXPVa+vi5MraPD1cOC+FfE/rQ118JjcypOPMvM+4O7tyWcBt6tZ4MtcwalREKRD164fjWpJbRZDQT5KFnYLD7Kd3Ww0nHw+f3Q5Jg+f4SVh3oouR1xmZl/TzAwGthwAzl751fi1l6we72FSgD+v6n+Hrpcnt/gbS/Wag2UuTuj6dLINf2u/aM7rJP6zr6tDNYR0/ISKW6SlMVG/d8xUN/3cp3LflKybwjy7P2napsb65VHo9gc2JALwinaX1aObN/3s3qg6XHDZv/ZXs+ZovMgBD3oz402sLu16WV09AisoMKvde6tDKMZpkwL6f2hD12kb0OGZlvtEbMspmVuSs7Xj/gKiXEHsDFvz1hXmeEezsDkFPR2CVd9dB7cFXMVaBzZrHagFy4U0lCJAhHkGWZv5KVpKsX9A84rfwTg0I9uHZEGAAfrU4nv6oLjuOsNdBSB66BEDwM+l/ZOkufA5krT357QbASs0UmpUiZvbbrgN5sgj0/UYiZhR4eoHNmZsLMo2rEn6q2Gfr00jOYoVdroM9kADz2/U5zcxX/qI3sG3YdaI+IpFOpYfhtyt9JP0Fzzem3K5wWMaAXhNNQ22zk/ZVppJXU8+bfqby4OIWK+pYO+5TUNvN3a5jx9SPDjrqP2ABXwrycMJgsrEkts0m/BcEe2l4H5yf42y+0sb4MSvaRJJnYa6pGo9Iwvc/0jvtIkrIucMQsmPqSkjSvk8Jbk4l11YAe4ILIC7g2/iaeVwchGZuhZF+X3bfQe6QU1ZFZ1oBWLTEp7vQjYK4dEcrAEHeajRZeW5raoczq6XXsD+V37EVKqS0HFxh4tbJth5ilF+wno6yeJqMZFwdN+3u3XWSvRW4s5wu9jMnBlcG+gxkdOPqM77Zthr6ktoWaxjM4ERxzAQDxDbVMMevAyZtP8v6mxdxy9L5R5ymZ+Y2NsPuH029TOC1iQC8Ip+HHrbk0Gsx4OutQqyS2ZlXyr+92snz/oSR3P2/Lw2KRGRR69Ow8KKG75ycoX75E2L3QW+VUNJBWUo9KJTEhxs+OHVmPCZlvnB1ApWFaxDQCnAOO3k/rCENvAp9TW78Y5t02oO+69YNalZbLY67Eqb18nVhHLxzth61KwsQJ/fxw02tP+34OX0+fXdnAj7vOIPS+MguKk0FSQb8LD21vm6WvzoUMMUsv2MfefGUGOT7IzXoVVTpj30IaAbNHKFq1jtv639YlJ72dHTTt+WjSzmSW3qcPeEcDcL33ELy8+lDSWMIvB48RVq9SKSfDAZJ/g4by029XOGViQC8Ip6iguonFrbV+Z0+J4d1rB9PXz4VGg5n3VqTxzB/72Jtfwz+t6xBvSAw/7n2dF+uHRi2RWdZwRjVDBaG7Wt46Oz8q0gt3p9MfbJyxrHWsUBkp0Gpw1blyed/Lu/Tuw1pnecrrDV2/hCZkBDIye3NXi6oYQgf7C2vZnVeNSiVxzYjQM74/T2cds6fEALA6vYqC003y2JYML2IsOB9W1ULnDAOvUf7e+ZWYpRfsYl9hLQAD7BluX5kJRUk4S2qemfgeL417CX/nrssxE9Oa1yW15Ay/W465D/pMxmny89wx8E4AFmcuJqM64+h9w0aBf38wG5SKFoLNiAG90KvIssyqA6Vds/7vOL7ZmI3FIjMs3JPBoR6Eezvz+oxB3DImAq1aYlduNU8s2IvFIjM0zIO4wOOXHnHTaxkdpXzZEbP0Qm9jNFtYdVA5sWXVUnQn01xLQ+EOflE3g4MrM2Jm4Kx17tImXPVavF2Uske5lV37/iMHDeMlTSMv1u1je96aLr1voWf7fmsOAFPi/PB36/wSkRMZEubJyAgvLDJ8v/U0yiWaWiDtb+Xv2EuOvr7/Fa2z9HmQseLMOisIp8hikdlX2JYQ7/RKw3WJfQuV3xHjkFz9CHM7emnmmejbFZnuAYIGw6SnwNWfof5DGRs0FgsW5ibNxWg5IpxfkmDkHcrfBxZBbeGZtS10mhjQC73KxowK3lqeyuPz91LdaOjy+08uqGFjRgUqCW4bG9m+Xa2SuHJYCO9dN4S4QNf27defYHa+TVtyvDUHy7q+BrAg2NG2rEpqm0x4OusYGuZ58htYS+4mkiQjdVodwe4RTA6bbJVmDq2j79qyPZKrH9GOynKFb5I+wWDu+vc2oedJLqghKa8GtUri6uFnPjt/uOsTlcHFuvRyMstOcYYvsy0ZXgC0LRc53OGz9Du+ErP0gk1lVTTQaDDjqFUT5WOnJK2GBkhbTqZkpqHfVKs00VZ5Jb20vksju25OuBl3nTuD/QbDse42aLDyureYYdd3XdaucGJiQC/0Km3ldqobjby9PBWLpevexCwWmXnrswA4PyGgfc3s4UI8nXjlioE8PDGUxyaH0S/A9ah9jjQg2J0Adz1NRjPr0sSaI6H3WJ6ihNtPivVDbc91illrGWPR8nLUVdw18C7UKrVVmglrzXSf3YWJ8dpMD5uCl6yitL6IRZmLuvz+hZ7nuy3K7PmUeH/8umh2vk2UjzOJ4crs5f9tPsVZ+vZkeBeDSoXZYmZj4UbKmw77fGubpa/JV+pcC4KNJBd0g/XzqcuQjQ287Qh3Jn/M/or9Xd5ElI8LKpVETZORsrpjJLE7Te4O7rw78V1uiLsBrfrQMjqLbKHOUEdBfQEHIkezSzLSmLely9oVTkwM6IVeo77FxI7cKkCZMd+ZW83vSQVddv/r08tJK63HUas+Ztb6NiqLkXP3PsHYDbfDrv87aakplUri/NZw5L9tEXZvaoFfb4cFdyt/C4IVVNS3sDNHeT3aNdze2AR5WwGI6nc5/bz6Wa2pCO+uz3TfRh82mpvMejA2sCBtQcfBkdCjyLLMoj2F7G9dx3s69uRXk1xQg0bd9bPzbS4f4ItKktiWXUlKUSf7WpkFxXtbk+FNQ5Zlvtz3Je/ufJenNzyNse3zsMNa+q+V2TxBsIFke9efl2XYv5BUyUyp3hmNSkOUe1SXN6PTqIhs/Uw6o3r0x+CoOVS27rODn3HX8ru4YfENzPp7FrNXz+aZjB95RdvEBy3ZIjmejYgBvdBrbMqowNRa6/ruc5U3x6835pB2puuHAIPJwtcbswG4clgwns664++cvU6ZdTAbYOtnMP8OJdvvCUyO80elkjhQXNf+YWM1B/+CinQoTRFJSwSrWXmgFIsM8YFu7dl27SEndRHF5ialFnZrtl5rCT8s032XJ68LHMRo9MQZZQyGer5LEaGMPdWWrEo+WZPJc3/uO62lYbIs833r7Pz58QH4ujp0dRcBCHDTMSm2danHppzOHdMHWqNHwseAsw9Ls5eyPGc5AJf2ubTDjF57xvuafPj+Glj3lnLyzSSWlAjWoayfV05O2W39fOFOqMphvRbQuzEiYAR6TddG2LTpsnX0J9BkbqLWUIsFZemMs9YZZ50raHS0AKbiPVZrWzhEDOiFXmNtay33c2N8mZoQwJhob8wWmdeXHaTJcGZn//9MKqS0rgUvZx2XDQ4+8c6pS5XfISNA767MWPx+L6x7U1lXeAyezjrGRivJ8Z76PZlFewqtk83aYoY9Px26nPSD0j9B6EKyLPNPa7i9PWfnLbKFT/Z/w0O6ejb5RSoJe6woxNMJSYK6ZhM1TWdQ+/dYtI5IAQO4xaxHMjaysXAjaVVph663WJTQ5W2fiwFRN7f6oPJZ1Wgwt4fNn4qk/Br2FdaiVUvMGB5y+h3Z8B4s++8Jj5drR4SiUUvKev38k5xsNrVA6jLl77hL2V26m2/2KSeNb4i7gQsiLui4v84Jxj2olIpsKIP9v8OSOfDNZfDPs5C2HFpE9Reh6+RVNVLXbMJBo6KPr53Wz+9biBGZTU7OIKkYHzLeak319euiTPcnMCNiBq+Of5VPJn/Cdxd+x7yp85g3dR4fhV7CUyZnNKUHrNa2cIgY0Au9QnWjgT351QCMj/FBkiT+PbEPPi46imqa+XjNMcprdFJNk5Gft+cBcNOocPTaE6y/rS89VCt63INwzbeHavDu/wN+nqnU3j3GYP3e8/owOtobk1nmkzWZvLbsII2GLi5/lblayTqqd4fQRGWAv/YNkZRI6FJZ5Q0UVjejVUuM6+Njt35syFtLRkMhGiTiY46RbbuL6bVqAlrXMltjHT3Bw4mQ1ZyrciXMNQyL3Pq6LT0Af/wHVrwAO7+FA392fdtCl2gymNmaVdF+edm+4lNKOqfMziuZ7acmBODjcpqz8+VpSq3o7PWQdfzKCb6uDlzYPxBQKryc8ERzWzI8F3/y3AN5e8fbWLAwIXQCl0Qdev0V1Rexq3SXcqHPJJj5B0x7FeIuAScvMDZCxipY+SL8crMI2RW6zN7WCMjYQFc0ajsMgerLIHs9SZKJOgcnPBw86O/d32rNxbTO0GeU1ndpTqnDBToFEuYWhofeA41K077dO6g1IWbJPqu0K3QkBvRCr7A+vRyLrGT1DHRXwntd9VoentoPlaQky2tLmHeqftyaS6PBTKSPMxNbww+PK3WpMlgPHATuwcrAecKjcMm74BEKjZXwz3PwzzNg6PiF39lBw+PTYpk1PhKVSmJ9WjkP/rSbrPIGyN+hnBDI3wF1xac3AJdl2P09zch87BfIhz4+FGt1UJJ8KExSELrAxgxlwDIs3BNHnXUS0J1Mk6mJ75I+AdnMdJUn7iGjbNLu4WH3Xa41Y/gttfW8OvZ/9NP7wupXYeHdyuu4Tcaqrm9b6BKbMyswmmWCPPSM7+uDLMNn6zI7HZG1K6+alKI6tGqJq4adwez8/oXH/vsYZgwPwUGjIq20ni1ZlcfeqbaoPfqrpu8kXtv+Os3mZuK84pg1YBZSa3RMXm0ej69/nHd2vENenXKiHI1OqV99zsNww29w+VwYciO4+CmD+dUvi5POQpdILrBz/fmU30G2sN7dEzQOjAkaY7UkrQChXk44aFQ0Gc0UVDdZrZ1j8osHoKY8heLaPNu2fRbSnHwXQej+1hw8FG5/uIQgd64dGcb3W3L5eHUGsYGu7QP+umYjKUV1HCiuJaWolvJ6A7IsI8tgkWVkwCJDTesax9vGRZ44I6rFoqxPB4i9qON1QYPhynmw+zslUV7mGqjKgfNfVAb6rSRJ4rLBwcT4u/Lq0gNUVFWzbd5DeDvsxU2vwSJDi9lMvVlFjcqHAtmbHJMXNf2HMPOcabg7nOBDKn8bjRVpvKJr4aCxAkqrGdZvCgHJi2HLJxA+Fpy9O/V8C8KJbGod0I+Jtt/s/IK0BVTVF+Enq7g4YiqobHP+Oszbmc2ZleRaY4beJwYcXHFsqYON7ynRPobWEwd9z1cyhy+4W0lKVl8GLr4nvj/B5ta0Lw3zY0q8P1uyKkkuqGVDegXj+p749XL42vkLBwTifbqz8y11kHZYZvniZKjIOG6OCQ8nHZcNDuLn7fl8uzmHkRFehz4LG8ph17eQsggsJtDoUfedis+BYpBg9vDZaFWH1s0HuQTRx6MPe8v38ub2N3lp3Es4aQ+rGKNSgV+c8tP3fCUHTf52JZpg4IzTe7yCgPL6aas/nxBkhwG92QgpizAgs0unDL+sGW4PSoLoaF8X9hfVklpSR6jX0dWZrMY9hJV6DZ9bKhi++2Nmn/OS7do+C4kZeqHHK61t5kBxHZLEMcN7rxkeSv9gN5qMZl5ecoD3VqRxz//t4PrPtvDCov38sj2f5IJaimuaKaltobSuhfJ6AxX1BqoaDFhkGB3tzeBQjxN3pHiPEs6udYLIc46+XqOD4bcqs/VO3lCVrXz5zt181K5xgW68P1HH/8xvM6RlG8W1Bv6pD+G7ejUvWhqYo6vGVJVNUPVOAlr+ZkXWa/xn2Wzy6/KP273anV/zgraBg45OODu4cWHkhSSOegh8Y8FQT+66VzBZujjEXzjr5FU2klvZiFolMSLSyy59KKovYnH679BSy80mPdq+1qnzeyxtme6tEnKvUkHIcOXvlD8xGOpZ4OHJilG3wMT/KoOggNbwzROEUQv2UdNkZFdeNQDnxPjg6+rQPss+b0MWLaYT53rZmVvFweI6dBoVVw49g9n51GVgagavSIg6V9m2//cOu+TW5lLaVNq+rGP6kGCcHdTkVjSyJq0Mmqph88fww3Wwb6EymA8eBpe+h4tnBE8kPsEzo5/BTdcx8ZhapeY/Q/6Dt96booYiPtr90aGlI61KG0uZnzafZ1O+5K2QPhiRYesnykkHQThNBdVNVDca0aql9lB0m8paA01V6Jx8eOf8z7lz4J1EukVavdm2evRdnen+pCSJaO8EzMD2kh3UtFg54fNZTszQCz1eW+32hCD3Y85YqFQSs6f0474fdpFV3qCEsLcK9nAkLtCNuEBXgj0dUaskVJKESlJmyyVAo1IR5NGJDKQHlyi/oycqSX6OJ6A/XPEZLH9aCZNd+hgMv10JMZQkZV37rm9R7fgSk7OJZCcPfrDEUqitwKxyRC054aSxsCrwYs5xdsE15Qe8zQUUlibz2Nr/8ljiHPr7dFyTVZW/hRcrNpMvWXBzD+O/o54mwj1CufKch2mcfwfPF/6D69JSbhx2H0P9hraHSArCqWibnR8c6oGLg30+Yr7Z/w2mpgoGmVUM84pTlsDYSLiXUos+r7IRi0Xu+jrHUROUkHpHDzb0HceP5dtxLd7AqPircdY6K+8/xcnKPgOu6tq2hTOyMb0ci0UmyteZEE/lxM/lQ4JZvr+EsroWFuws4NrjlEQtrmnm83VKAtNp/QNOXGnlRGQZ9i1Q/o6frkSIZa5REtAl3o2sdeS3tN/45eAvGE1G3A+6c3nfy7k0+lKuGBLCD5sOkrf8A+qdNlNnbKRakqn1Dqc66hzU3n2Y7KuUhdSqtPg4HjviwN3BnYeGP8TTG59mW8k2fk39lakRU9sjzPLr8vnp4KHkrZ4+AdxaXgIrnocrPgWNdbL6C71bWwWhfgFu6DR2mM9se93FXYKnsy+TnCfZpNkYG2S6P57w4JH0Kd1AurGRtflruSTa+rlszlZiQC/0eIdCGI8frujr6sDjF8aycFch4d5OxAa4Ehvohruj9ri3OSWGBiXhHEC/aSff39kbLnlHCZvd/wds+5zMoh04DbqWgB3fQnEym1UGPnXWgosXKksp/hYZL70fE8MncF7oeYS5KV/8DANG4fLVzXzh0sj+0mz+t/l/3DXoLs4JVqIEKpoqeHH9UxRLZjyd/Hhy3P8IcT1sdsenL/kxk5Ay51NYvp/Xtr7CpX2mc0PcDV3z3AhnlQ0Zygm2MdH2Wb5hkS3EuEdzsOk3bjHpkQZeY/Xs9ocL8tCjVkk0Gc2U1bfg79bF5YiiJsCVX4BbIOdo9Py5dg4F9QUsTF+ovGYjz4WN7ysnC+tKwNV+VQaEjtamHb00TK9Vc+vYCF5bepBfduQzKc7/qDJ0G9LLeXdFGk0GM+6O2jNbO1+wUykTp3VSQtq1juAeAjX5WNL+5hu5mr+ylaVjGklDo7ERvVo5hi8ZFETjhkeZr09moUVC46hHcvYFBzMUroLCVRTUF3BT/E2opBMPmKI9orkt4TY+3fspv6X9hsli4vq46wFI8EkgMSCRIJcgFqQvYKnWwgBHR4ZXZcOWuTD2/tN//MJZq239vF3K1ZWnQXEyskqFFGfbQW3bDH1WeQNGswWtLZMB+icw0awj3djEytyVXBx1sZgsshIxoBd6tLzKRrLKG1CpJMacJJv2wBAPBoZ4WKcjGauUkj0eoeCf0LnbqLUw/iFk7xhWbHqVL0tXc9tfWwiw6EDrRMjQ2/Cr2E6QSxDBLsH09ezLcP/hHev4Ajr/GIImP8Ldf73ID+Yatjg48nHSxxTXF3Oux7m4NFbh2ViJBRVPjnsJf9ejvwzGjHmId/OTWNCcxx8NFfyZ8SfD/IcR6xXbFc+OcJYoqW0ms6wBlQSJkfYZ0KskFZdLrkxr1qN39oXo82zavkatIsTTkZyKRnIqGrt+QA/g0wcANXBj3I28uu1VlmQtYUr4FPyc/SBgIBQlKSGeA6/u+vaFU1Ze39Je/3p83465Dcb18WFxUBH7Cmv5emM2D09VZrkNJgtfrM9iyd4iAGIDXJlzQT88nE5zdh4OJcCLOV8pGwcQPx3TpveZm/Qx69w8ALgl4RYG6gdicDLg5agsnXEs2IivLh0kKMMLo9mDaK0/UR6+uDu44+7gTh+PPicdzLeZFD6J7Nps/s75m5LGkvbtDmoHZg+fDYDRYmRR5iI+dnfl1aYGfJLnQ+goCEs8/edAOOvIskxy6/r5/vZYP7/7ewAW+Ueye8/HXBJ9CYP9Btuk6QA3PS4OGupbTGSXN7TXprcJ31jGyA58Y2mmsDaX1KpU+nn1s137ZxGxhl7o0dpmPIaGeeCm76LZ9tPRlgyv30WnNBtoMBuYa8znMy9vTCoNtZIM/v3hqnn0G3wz7096n8cTH2dmwkxGB40+ajDfxnvodPSxF3B9jYpJpdU0NRtYkL6AHRU7cNi3gEeMTjwbcB7+wcOP3RGdE07jH+IGs54JDQ3IxkbmJs3FYBb1rIXO29g6O58Q7I67k51ej7IMe35GjwQJVygnzmzMqpnujzDEbwj9vftjspj48cCPysaoCcpvke2+21iXVoYsQ0KQ21Ez8JIkccc5UUiSEnG2v7CWguom5vya1D6Yv2pYCC9fMQA/1zM4QdRaMgtQwu3bxExlqwbWtZSgNrXwnyH/YWrEVNQqNeFu4UoovKERNrzLNLWe1yNuYqDTK/jWP0Rtzo041V7PrIT7uGPgHYwNHnvCLsiyTEltMxvSy/lmUzbZGSPxr5/NlRF3HXP/a2OvJdo9mnq1ig/8ApCRlaz3jcfJti8Ix1Bc20xFvQG1SqJfgI3Xz1fntUdwrnN0ILkimfIm25VilCSJGHuto9c64ujdh1FmLZiUWXrBOsSAXuixZFlmbWu4/TkxdszmXJWthLdKKiWEsZNKG0t5euPTrM5bjUrrxA0jH2b6lHfg0vfALfDU+iBJRFz2JCrPSKbXGZlUYGS43yiGO4Ujpf2NExLeQ2898X2Ej4aoCcw0OeBZlU9R4TZ+2fyaKBckdNrG9Lbs9rafna9oquC/6//LvgPzoSIdNHqlrrUdtK2jz620QmK8I0iSxI3xNyIhsaFwA+lV6cqAXpKgdL9STkywu7ZKLMf7rIr2dWFKnLI84u1/Unnwx91kljXg5qjh2UvjuXlMxJnXzT7wJ8gWJaeE12HJuPRujI6axmVmBx52iGBc8Lijb7t9HjSUoXILpu+E//DONSO4bmQYKpXExowK/vXdTta1nmAHaGgxkVXewJbMChbtKeSL9Vk8uXAv13+2hVlfb+eVvw7wy/Z8dufVkFVm5NN1x054p1VpuW/offg5+nHxyIeQvKKgqQrWvq6cvBOETmgLt4/xd0GvtXEp1aQfQbaQEzSAnJZKNCoNowJtU0a1TZ/WWfm0EhsP6AH8+zPRogVjE9tLtmM0G23fh7NAjxvQf/jhh0RERKDX60lMTGTr1q3H3ferr75SEpsd9qPXWyH8UbCLjLJ6Cqub0WlUjLJTeC8AB5cqv8NGd7rsW1JZEo+ve5ysmixcda48kfgEl8ZdixQ+Ck63JqnWkbBr3kDSOnFhbRExe2pwTF+iZB8OGHAo+/WJjJ+Nc9QEZpmdwNBAzoGFWH66USkZZLD+4EToucrrWzhQrCTdGR1l+9fjdynfkV6dzo97vlBm8fpNA70d1kpy+Ay9bV4zke6R7eWPvj/wPTh5HUoEKLLd211BdRMZZcrSsLEnWBp20+hwHHVqimuaaTKa6R/sxnvXDmFYeBdUizAbIeVP5e+E6QDUtNTQZFJqU0sJ07nerGdowT5oPiIbdVmq8hkAMO5B0OrRqlVcnxjG21cPIsLHmbpmE68tPcid32znmk82ce2nm7nvh128uDiFT9ZksnBXAUl5NdS3mFCrJPr4uTA1wZ9Z4yPRqCWS8mrY3VoB4EgBzgG8c947DA8eBROfArUOcjZC0g9n/rwIZ4W2hHg2rz9fXwapynfE9X7hAAz1G4qLzsWm3YjxU9qzR2I8/BOIkdXcrfLlnfPeOW6kqXBmetQa+p9++onZs2czd+5cEhMTeeedd5g6dSoHDx7Ez8/vmLdxc3Pj4MGD7ZdFMobeY22qErI0MtILR52Nz7i2MZva36zpd0GnblLRVMFrW1/DJJuIdo9m9vDZx80GfKocfCNxm/o4tYufJr74T5pr9bi5aGBwJxPc6d1hyvMMr72Hp3Z8QkLONqTaAtjwHmz7AvpMAq8ocA0EF39wDThxRn/hrLE5U5mdjw1wPf362Kdpf8V+NhRuQDIbubWmFknSQP8rbdqHw0X4tGa6r2rEZLac+cxqJ1zT7xqMZiPX9LtG2RB1HhTuVsLuB11r9faF42ubnR8S6nHCRKweTjruOieKL9ZnceGAQK4bGYa6q6okZK9TwtSdvCBiPNk12by54038nfx5dMSjaH1jwScGylOVJWQDWnMvyBZY94byO3oihI7scLdRvi68dfUgft6ex8/b8iiqaW6/zsVBg5+bA36uDvi56gnzdiLa14Vwb6cOiblKa1v4I6mQrzdmM+jqQcf8nqZuO9HtHU350BvQbvsC9y2fKKVix95vl6U1Qs/RVn8+3tbr5/f8BBYTloCBrK/LBDh2BIyVta2bz69qpMlgtu13Zv8EJCTOq6kAlahQYS09akD/1ltvcccdd3DrrUro8Ny5c1m8eDHz5s3jscceO+ZtJEkiICCg0220tLTQ0tLSfrm2VgnTsVgsWLpx6LHFYkGW5W7dx65kscisSS1DRmZctLf9HnfuJqSmKtB7IIckHjM83WQxkVyeTH+f/mhUGjwdPLk65mqKGoq4NeFWtGptl/bfZ9AFVGduR9r/OzWNTeh8YnEPGXlqofMu/sSf+zSysRE5bTnSvvlQnatk5D+S3h1cA5ADh8Dg68HBDvVdhU6x5vvEhrRyZGRGR3vZ9PVotpj5MvlLkGGi5EKkXIEcNhrZLdhuy0W8nbQ4aFQ0m8wUVDUS6uVk9Ta9HLy4b8h9gPJ/JmI80oZ3oewAcnU+uAUd97Zn2+eHLcmyzJrUUmRkxvc9+WfVef18mRDj0zqolbFYuiasXNq3UOlPv4tYnb+OefvmYTQbkWWZquYq5aRy3KVI696A/b9jib8CWZaRk+cjlx0EnTPyqHuP+ZpSS3DdiFAmxPhQWN2Mr6sDvi4OJxw0HP48XDk0iL/3F5NWWse6tDLGnSCKYU/ZHt4rXUvf4GgeLcxCTvmD+oo0qsb8C5POmSj3qNN/koQT6qnvE2V1LRTXNqOSJGL9XWzX/+YapBTlO9PfQTFUFq/HSevEYJ/BNn8OPRw1eDnrqGhoIa2klv5dFKnQqWPC2R/J0ROaqpDLDoB/f8wW86GTdMIJdfZY6TEDeoPBwI4dO3j88cfbt6lUKiZPnsymTZuOe7v6+nrCw8OxWCwMHTqUl156iYSE42chf/nll3nuueeO2l5WVkZzc/MxbtE9WCwWampqkGUZlarHraQ4ZQdLGympbsBRqyLMyUhpaald+uG6ez46k5GmgNE0VlS1b7fIFjLqMthevp2dFTtpMDVwb+y9JHgqx95Il5FIrhJVh92mS/uVeDOleXtwrM3gw8qRXJVbhIfj6b3cmzwG83toNudFTCGyugB1YymqhlJUjaWojA1QX46xvgx10V7Y9ycNA27GEHaOTUuFCZ1jrfeJ2mYTSbkVWGTo6ybb9PW4s2InGZUZOKscuLIwG5PFRE3IZEx2ek9o4+ckkVlhIimzEAeT7UP/t5dvJ94tgtCKgzQk/Ulzv8uPu+/Z9vlhSzmVzeSU1aFVS0S7WOzyWaWuzcMjbzsGJD5tKmbdziUAJHgkcEvfW7DUWSitKwW3AXihRarMoTr5H+pwx2vLXFRmI/UDrqGl3gT1x++/Ggh1BExG6qrrOZXg3knRrizcW84Xa9KIcjGjOU5kgqnBRENLA9ssRm739qapNhtL1Wb4azsB3gN5YsTLp9CqcCp66vvExqwaTEYTUd566qorTum4PBOO+3/EqbmecrdQvixYj1E2cl7QeVb73ncyIa4qSqpN7EgvxE/bcvIbdEJnjwlXlwh0daVsT17Ab/u+Icw5jKsjRQWWzqir69wR22MG9OXl5ZjNZvz9O9bU9ff358CBA8e8Tb9+/Zg3bx4DBw6kpqaGN954gzFjxrBv3z5CQo5dx/Xxxx9n9uzZ7Zdra2sJDQ3F19cXNzf7rMfsDIvFgiRJ+Pr69qg32tP1W0oGGq2Gc2L9CA7sfARGl2qqQipLAo0Wl2EzcPFSln3UGep4eevLZNVkte/q7eyNxkVz3KUh1uB62+e89t1idhrDqNtRwQuXJZxW6O9nez9jU8VmSjyieH7S8x3PqrbUQX0J+3LX8NG+edzcUs3o3R8iFa9HHnu/Ep4vdBvWep/Yta8YlUZDjK8LCVFnUCP7NBzMP4hWo2Wqgz9eqkLwi8crfoLdTyjFBNWQW1tKrVln09c9wOq81Xyb9S0BjvCiRoN7+Q7cxh87izicfZ8ftrQ4LRuNVsOYaB/Cgu30WZX2A+UaNW97uJBZuxedVsdVfa9iep/pR5eYi78Iaf9CvEs24NjUhE4yQdAg3BNvUBK/WslN471Zl7ODiiYje8plLujvf8z9/PDjdtXtfJH8BY0AnuFItQW4mox4le/Dr3YP9JlMRVMFC9MXMjN+prJm19gIdcXg7CuiyE5TT32fOLCjEo1Ww/Bof9u9FxsbkXL/AY0W/9F38birFztKd3BD/A12W/o7MMJAUnETJc2qLnseOn1MRI5AKtuFY3MJeVRQYargTu870anPoATnWaKzud96zID+dIwePZrRo0e3Xx4zZgxxcXF88sknvPDCC8e8jYODAw4OR6/xUKlU3f4NTJKkHtHPM2WxyGzKqERC4tx+fvZ7vBkrQDaDbyySTzQATaYmXtv+Glm1Weg1ekYFjmJs8FgSvBNsHl7k5OLGFZPGk7Iyn/1FdXy7JY/bx0We/IZHuLrf1Wwp3kJmTSYLMhbg5+TH5qLNhLmGcX3c9eDoToSrHy0la3mvvpj1DXXcXrwbn/l3KOuYh98KOmcrPELhdFjjfWJzpvJ6HNPHx6avR6PZyO6y3YDM6OI0JICB1yCp7R/KF+nrgnSgjLyqJpu/Rw0LGIZPug/FDSW8rW3iifJUNHWF4H78ky1ny+eHLVksMuvSy1s/q+w0CDI0Iqct4z1NI5laD1x0LvxnyH+OXwM7YbpSqz53Aw5GA2j1SOc8jKS27tdFZ72Ka0eE8enaTH7cns/EOP/jZiMPdxxNgMHCqChvzo+Nwh01mtWvQO4mWPU/LEVJvFaxidzmcsr3/czDRj3aZmX5JFpHmPA4RJ1r1cfTW/W094mK+ha2ZCmfT5Pi/G3X7wOLlQkP9xCkqAkMVqkY7D/YNm0fR78ANyQk0kvru/R56NQx0ZqUeWBVEb7BQZQ1lbOpaBPnhZ3XZf3orTr7v+oZr0jAx8cHtVpNSUlJh+0lJSWdXiOv1WoZMmQI6enp1uiiYCP7i2qpaTLi4qBhUIiHfTohy8obNijZtFt9vvdz0qvTcdG68OK4F7ln8D0M9B1ot7VCAW467p/UF4CFuwrYkH7qtU899Z7MjJ8JwG9pv/Fx0sfsKt3F+oL1yK1lg5y1zlzS5zI0zn7s9ArmITc1S6UmLHt/hp9uUpJzCb1SXbOR3flKwqETZfC2lrsG3sVk53CimuqU2be2Gux2Fu6tnMTKtkEt+iO5O7jz6IhH0Wud2eegY566GVnUpLe5/UW1VNQbcNSpGd4VmepPR+pSJGMTdzpFERcwjFfGv3L8wTwo5ewCB7WXhJMHzADvaJt0dWpCAP5uDlQ1GPgzqfCY+6w8UMLj8/eQWejCTxsN1DXo0Th6wNSXYMiNAKgOLOKW8hJ0LXXsbinjNVMRBmSllKWxCZY/DTu+EiVZzwJ/7y/BIkN8oFv7e7LVmQy0JP3Ae5pGiuIuhG5y8qNva6b7ktoWsspt/LnkGwsqNaqmSs73V0r2LUhfgNlitm0/erHucZR1gk6nY9iwYaxYsaJ9m8ViYcWKFR1m4U/EbDazd+9eAgNPsca30K1syaoEYESEZ9dlAD5VpfuV+vMaByXze6urY64m3C2cJxKfINQ11D59O8KYaG+uGBoMwLv/pJF3GrWxzw05l+H+wwEIdQ1lRswMnkh8oj10TKvWclXMVbxyzivE+MTT7OrPl94+POMEeU2l8PeTJ1x7KfRc27OrsFhkwrydCPawbcUDrVrLmKDR3FFZgYQE/a/oNtmuw1sT4RXXNNNisv2XljC3MO4fej8qBzdWqA0sSf3V5n04261trcs+Jtobncb2X7eaa/KU6iRA2IDreGb0M/g6+Z78hglKvgWzkx8MnWnNLnag06i4IVEp7fXbznzqmg/VqzZbZD5fl8nby9MwmmVcHDSYLTJv/5OKyWxRBk0j74DzX4ToiSTEz+DxuFtw8Ihgj3cIr/U/j5aZC2DAVcodbv8SVjyrDPCFXslskVm2rxiAaQNst9zFcnAJHxoL2KCRebViS7cZtDo7aBgTrZSU/WBlepcl3OwUjQN4K5NL52t9cNO5UdJYwtr8tbbrQy/XYwb0ALNnz+azzz7j66+/JiUlhXvuuYeGhob2rPczZ87skDTv+eef5++//yYzM5OdO3dy4403kpOTw6xZs+z1EIQzJMtye3msUXaodd2ubXY+akKH9Xj+zv68Mv4Voj1sM6PRWTNHR9A/2I0mo5lX/jpAk+HUPmAkSWL2sNnMnTyXN859g6tiriLE9ejw3VDXUJ4b8xy39b8NvaM3qR4B/OTipISe/fOcUuZP6FXaoj7GRtt+dh6AoiSoyFBm32Ivtk8fjsHDSYurXoNFhrxK+wwahvoP5cYBtwES3zblsDNjqV36cTZqNJhYm6oM6M+J6cQguoulV6byn8U3ssNYBX5xEH9Z59fuRk1Anvw8tec8b/OypOfG+BLh40xDi5lfd+QDShTQc3/u4/fdyqz91SNC+eD6Ibg4aMgsa+Dn7fmH7iByPEx+BsY9QPyo+3n8nJfRO7izt+oAr21/k5bEO+HcR0Glgcw18Pu/lbX1Qq+zLbuSinoDbo5KDgubsJj5cddHbFEZ0Tj5cOegu7tVNvc7zonCUacmtaSOxXuLbNu4fzwA+vI0Lo2+FID56fMxWQ77Xpi9AXK32LZfvUSPGtBfc801vPHGGzz99NMMHjyY3bt3s3Tp0vZEebm5uRQVHTpAq6qquOOOO4iLi+PCCy+ktraWjRs3Eh8fb6+HIJyhvMomimua0aolhoR52qcThkbIWAmAHDON71K+Y0fJjvarj0oy1A2oVRKPTI3F01lHbmUj769Maw+X7/x9qPHUn/w5V0kqpkZM5a0JbzEmeCy3THpTWUNfkkzt5g8xmA2n+zCEbqbJYGZnrpKxt+3Mv62kVqUyP20+hSkLlQ19JoG++yQulSSJcG9llj638tjhjbIs09Bi3ZNcF8ZcxSSXSGQgOWOJVdsSDlm0p4iGFjPBHo4MtvHSsJKGEl5dPYfalhr+0VqQJzwBpzKokCSIHI/F2bbJHAFUKomZo5VZ+j+TCtmdV83DvySxK7caB42KRy+I5aZR4Xi7OHD3BOXE+U/b88goqz/m/cV5x/FE4hPo1XqSK5LZUrQFYi+ES94BR0+oSIcFdyknBoVe5a/WAevkOH+bRcis2vY+vxtLQaXmrhFziPfuXuMNHxcHbhkTAcC3m3Ioq+uabPed4t9aYaxkH1PCp+Cuc6e0sZStxVuV7fv/gGVPKD8ttqpF0Ht0v5HHSfz73/8mJyeHlpYWtmzZQmJiYvt1q1ev5quvvmq//Pbbb7fvW1xczOLFixkyZIgdei10lbbZ+cGhniescWtVmauUMD33EBY0ZPJHxh+8tf0tyhrL7NOfTvJ01vHYBbGoVBLr0sr5amM2SXnVFFY3YTB1/VpCb0dv7h96Pz7+A5QZEeDrA//H/UtuZmXuym4Thiacvu05lRjNMoHu+vbBq62syVvDTwd+YnG+cnKNqO6XXKdtzWZOxdHLXJqNZp79Yx/Xf7aZA8W1VuuDJEnc1v9WnjM6M7O0ENL/AZM4qWZNzUYzv+8uAOCaEaGobLg0rN5QzysbnqK2JpcIWc19wx5E8gyzWftdYXi4JwlBbhjNMk8tTG6vbf/aVQMZ1/fQTOs5fX0YE+2NxSLz9vLU436O9fPqxxOJT+CkcWKAzwBlY8AAlo24njddHfi7pZTiRfcj7/wWTDYc4AhWU1zTzM7caiQJpvW3zTJbo8nA16k/A3Bl0LmcEzHpJLewjwsSAogNcKXJaObj1RmnPLlz2vyVxHhUpKFHYmbCTGYPm82owFGQuRrWv6VcbzEpJ9qEU9Krs9wLvU/bgD4xyk4JhqA93P6fgD781PrmfUPcDZ1bm2hn8UFu3DY2gs/XZTF/ZwHzdxa0X+fhpMXXxQE/Nz2XDQ4iLrALZzujzqUp/lJS076jsiKVT3Z9yKLMRVwfez3DA4Z3XTuCTW1rzWcxJtrbpqV4LLKFbcXbwNhIYosZHDwhqPudrG1bR3/kgL7RYOL5P/ezr1AZyK9NLSM2wHrRBZrIc4ndPg8aymDFC8oyob7nQ+xFNkt4djb5K7mI2iYTAe56m4bbG81G3tj2GoUle/CWJR71G4fjgJ5X61mSJG4eE8Ejv+4BICHIjcemxeLhpDtqv3smRJNcWENORSM/bcvlptERx7zPfl79eHfiu7jpDr3OttWkstfVg600QUs1PrvfJj75C+JjLiWh//X4OvvbrcSYcGb+SlZm54eGeRLg3rmyX2dqd8YSmkyNeKPhqnFP26TN06FSSfxnYl/u+3EX27IrWZ9ezvi+NnifcvEHJ29orICyA4wLHqdsz9+ufC7JMqh1YDZAeXq3/EzvznrcDL1w9iqvbyGttB5JgsRIOw3oK7OgZB+pKgtf1h0A4Mq+V3Jh1IX26c9puHRQELPGRzI41IMQT0ccWkPRqhuNpJXWsyG9nP8u2Muu1lDqruI4+j+85T6MmUYtLnUlFNTl8/r219lQsKFL2xFsp21AOtjGy18OVB6gxlCDi7GZeFkNkeeClctqnY5DM/SHQu7rmo08uTCZfYW1tE3c7syptm5H9G5w+acw7BbKnb14y1xCefIv8OttsOBuOLBImRURzliLydx+onTGsBCbJW61yBY+SvqIlPwNOJpaeFzli9d5Tynh8z1QXKAb957Xh5vHRPDC9P5HDebbeDjpuHdCHwB+3ZFPWsnxQ3UPH8wDXB97PdfGXkdcxCQ0bsGUq1WsNVczN+UbHl5wOaa8rV33gASbMZgs/JOiVMS6oL/tkuHtTl8EQKJ7X1SOHjZr93SEeTsxY7iSB+nTtZkdElBajSQdFna/X/ldmgJ/P0mzxYgxcjwMvl7ZXp5q/f70Mt3vG5AgHMfW1tnAfv6ux/1wt7oDi6nBwtsuOkwSJAYkMiNmhn36cpokSeKywcFcNljJfC/LMnUtJsrqWiiva+Gv5GJ25FTxwqL9/PeiOIZ1VbkljQ7tlOe56LdZnNdQx/95urJCrmde8jziveMPrc83m6A6B5y8lDWOQrdUWtdMaV0LKkl5TdrS1qKtIMOw5mY0qLpNqbojhbUuQyivN9DQYsJkUUKIs8obcHHQ8PiFsTy1MJmC6iZKapvxd7PiTJKLLwy/lc9MRewuNNJoVvHf6nqk0hSk0hScIi+EgDnWa/8s8fe+Eqobjfi6OnBerO3WoK/LX8fGnJWoGyt5yOhE6KRHwdlOiSq7SGcHY2P6+DC+rw/r0sp5+59U3rlmSKfWTEd5RBHlEcXlfS+n2dRMWvk+9u39nv0FG3A2GNH+9QiEJsKou9llqmGgj/1K0AqdtyGjnNomE94uOkZE2Gjyx2Lh9soqxhqd8Yq71jZtnqEZw0JZn1ZOflUTX27I5r7WEsdW5Z8AWWuhJFmpFLVkDivMNfzoquOaPuOYrHZX9hMD+lMmZuiFHmOLvbPbmwxY0pbxvqaJSp0jgc6B3D3o7h4fkidJEm56LdG+LiRGefPfi+IYFeWF0Szz4uKU9ue9S7gFwoRHcULitvw0ItUu1DdX8dX6Z2HjB7DwXvjqQmXm8OeboTKz69oWulTb7Hy0r4tN81lYZAtbireAsZGRLWYlfLybhua5OGjwdlFOPiblVfPE/L1klTfg4aTl5SsGMDDEoz3UvqsjYo7nlv63otV7sNfRkZUT7muv3e2Qt7a99rhwegwmC7/tVDKuXzUsBK3adl+xxvkN4/yGRu4y6RkQfQFET7RZ293B3ROi8XDSklfZxHdbck759nqNngEBw7h2yps8f81fPNLvJiWRYN4Wls6/kVdWP8IHu97vmJFb6JaW7lWqFlyQEGC70sbFSagay4nXehDQd5pt2jxDOo2Kf09UoluW7y9hb36N9Rttm6Ev3gOLH4aWOprdg6h19WNB1p8YvSKV66tzwdhs/f70ImJAL/QIDS0mklrfbEbZOJt2u5wNmJtr8Ne64ODoyUPDHsJJa9tEYLagVSuZhMf08cZklnn5rwNszCjvugYiz4EBM9Ag8a+8g8SX53Jtxg7Y+4ty1tbUApIKmmtg0WzljV3odva3Dujjg2ybWT6zOpPK5kr0xiYGyRrleOqG4fZtIlrD7l9bdpDcyka8nHW8dPkAInyU7cPClSiUHTm2GdAHugRybawyg/RN+gLK4y4CjQOq5iqozLBJH3qrlQdKqKg34OWsY3Kcv03atMhKIjj11k+5vcHAuY7BMPZ+m7Tdnbjptfz7PGVwsmBXAemlx8563ymOnkjj7oerv4HI8XhbQNNQxsaD83lv+1sYLTYITxZOS3Z5A/uLlOVM5yfYLtxeTl2u/BF1LmjsFEF6GhKC3NsjYT5YlWaVBMkd+PRTSkY21yo5XTzDmXLhx3jqvSlvKmd1xR4lMlO2iAmdUyQG9EKPsCOnCrNFJsTTkWAP29bFbXdgMVok7oi9gTcmvEmoW6h9+mEDGrWKR6bGMr6vD2aLzKt/HWB9WhcO6hPvAv8EwmQ1z8ieBPoPhP5XwHn/hWu+hZkLwbsPNFXBogehJv+kdynYVtuAPiHI3abtFjcWo1frGdrcghap24bbtwlrTYxntsj4ujrwypUDCPU6dCJwSJgHAEl5NZjMVv4y1erCyAvp59mPZnMzc/fNQw4cqlyRu9km7fdGJrOlvW76FUODbVIm65fUX3h/1/uYC3bA/t+VjRMe71blG20pMcqbMX28kWVY0bqG+oy4h8CUFxgx5mEeMrmgaally8H5vL3pfxjNYlDfHf2VrMzOj4ryxsvZNgNrQ0s9D2b/yjx1E02R423SZle6ZUwEns46Cqub+WVHnnUb0+jAJ0b528UfLnwTnYsfl/W5DID56Qsweisn5qhIs25fehkxoBd6hM12Drevq8zAUrBNuRB7EX5Otq/Pa2tqlcRD5/fjvH6+WGR4fdkB1qR2UWk+tRYufhtmfAW3LoHpH8HY+8kLiEN2DwW9O1z0JnhGQEO5MqivLeqatoUzVttsJLdSydxu6xn6ccHj+Lz/vcxspjXcfqhN2z9VsYFKfoEAdz2vXDmAQPeOJySjfV1wc9TQZDRzoNg2tXdVkop7Bt2DVqVlb/le/nFT/odS3habtN8brTpYRkltCx5OWqZaeWZQlmV+Pvgzv6b+ysaCjexe95JyRfylEDLMqm13d5NilciITZkVWCxdsIREkqD/FQy96H3mqPzQmlrYkb6IN9c+Lgb13UyTwcyqA6UATBtgm1J1AEn7f6JINrBdq0IfPNJm7XYVZwcNt49TQt1XpJRav4zdsJshYhxc9IaS2wWYHDYZT70nlc2VrHRoHZqKdfSnRAzohW7PaLawvTUc1R7l6oxmIy+vfZyX1A3UBA4AtyCb98Fe1CqJBybHMCnOD4sMb/19kN151V1z5xoH8IpUBvfAjwd+ZM7aOazJX6Nc7+ihDPo9QqG+VBnU15d2TdvCGWmbnQ/zcsLdUWvz9rXZ6/FE1e3D7QFGRXrzwvT+vH3NYPxcj056p1JJDAlVwu5ttY4elND762KvA2C1oQwLMpTuV0IhhVNitsj8sl2Z2Zo+OBi91no5JWRZ5seDP/Jb2m8A3OQWx7CacuXk1ohZVmu3pxgc6oGjVk1FvYH0sjMIuz9S0BAGX/41jzrFoLOY2ZW9gtf+vgezxQwWC9QWQs4mSPoRVr8KS+Yo5bgEm1mTWkqT0UyQh55BIbaLHNuU+RcAo3wHI6l7ZtLExEgvNGqJsroWCmusvHY9bBRM/R94hLVv0qq1TO8zHYCFjTk0ISul64ROEwN6odvbW1BDk8GMh5OWGD/bZtMG+Dr5KzJqssiSzBj6nGfz9u1NpZK4b2Lf9pn6D1am0WQwd3k7jhpHZGS+2vcVFU2tificvOCit8EtGOqKlEF9Qxcm6RNOyz47rZ+vN9Qjm81Kllzo9uH2oLx+Bod64OJw/BMPtl5H32Za5DRu7X8rz4x/BYtbuLJuMX+bTfvQG6xNK6OophlXvYYLrTgzKMsy3x/4noXpCwG4ue8MLs5qHTSOmKVENp3ldBoVwyOU19PG9C5cJgbgFsiAK77icb/x6IH4gmTUP8+EeVPhh+tg6WOw+WM4uATytsKmD0SiSRuRZZklrcnwpvUPtFmyYkNjJTtqlbXeo2Kvskmb1qDXqm2eoPVIk0In4evoS7VsYp9kUnK6mEUSys4SA3qh29uSqZSrGxXljcpWGUtb7SzZyfKM35EsRv6j8sY35hKbtt9dqFQS90zog5+rAyW1LXyzKbvL27g46mL6ePShydTEJ3s+ORT25eKrzNS7Bipr6Rc9AI2VXd6+0Hn7CpQElbYe0D+z8RlmL5tFdnN5jwi376y2dfQZZQ1UNxps1q5KUnFBxAVo1VqMAWId/emwHDE7b82KDz8d/Ik/Mv4A4Nb+t3JhaS601IF3NMRdarV2e5rRrYlzN2VWdH34sNaR+Avf5c2Eu7ncolc+k8wGUOvAKwqiz1NCitU6qMyCCpFo0tpkWeanbXlklTegVUtMirPdksik5O9pxoKPxpm+ET27ssSQUA9AqchiD1q1ln8N/hf/O+c1hms9wGxUShgLnSIG9EK3JssyW7La1s/bNtzeIlv4LuU7aK7hIrMDg/te2qOyl3Y1R52ae1tLnCzeW8S+wq4tcaJWqfnX4H+hVWlJKktiSdaS9gzOuPorg3pnXyXr/ZpXu7RtofOaDGYyWkNZE2w4oC+oLyC/Pp+S2mx8ZRVEjO/24fad5eGkI8pXyXq/K7faLn1o8h/CFpUROW+zEkIsdMr69HLyKptwdlBz0UDrzc7XG+rbB/OzBsziApc+kPKncuWY+0Alvs61GR7uhVYtUVjdTF5lU9c3IEn4jLwLrvoSLniZpqu+YNWUx2HGlzD5WRh+G4SPVvZNW9b17QvtZFnmi/VZfLdFqYZzQ2I4rnrbLQPblP03AIn+w5B6+GuwPUFrfg3mrsg/cRriveOJ8oxWkiIDlIvEeJ3Vs48+oddLL62not6Ao1bNgGAPm7a9Ln8d+TVZOLc0cLnZAWIvsmn73dHQME8mx/kjy/DeijRaTF0beh/sEsw1/a4B4Jv933DX8rtIqUhRrnQLhAtfV0qe5G6GnI1d2rbQOQdL6rDI4OvqcMw14daytWgryNC/xYAzkjIT1osMDbP9Ovo2FtnCK8WLeVvbzFZDBZQdsHkfeqLqRgOfrVPCbS8dFIzzCZZVnCknrRP/TfwvM2JmMCVsMmx8V1kiET0RggZbrd2eyFGnZkjr66lLS64eySsSU+hIXkz5krl7P2Fx5uJD1/WdqvxO/wcsXb9ETVByV7y3Ip3fdxcCMGt8JFcOC7FZ+4baInY0KNE5o+OvtVm71hLt64KLg4Ymg5nUEtskaD0unxjyJDMbc1fatx89iBjQC93a5iwltHpIuIdNygC1MZqN/Lx7LlTnMN2kw8UvQQlrFLh9fGR7iZPvt3R9jfiLoi7iosiLcNG6UG+oJ8jlUBLC3aYaNkaNwogMG98Hk+3CkwVFW2SGLWfnAbYUbwFjI4kGS68Kt2/Tto5+Z25112TnPgUqSUW8ZwLonPlB3YwpV5wsOxlZlnl/ZTrVjUbCvZ24ysoDCZWkIsEngatiroKMlVC0R0ksOuoeq7bbU7VVxNmYYd2cKxqVhsTAREA5Cb2xoPW1EzpSKR/YWAkFO6zah7ORwWThtaUH+CelBJUED0zuy2WDg23bh4wVTDbriHPwpU/waJu2bQ0qlcTA1mSC9gq7b5Pl7MFj2nrmlmygtFEkQ+6M0xohmUwm/vnnHz755BPq6pSzOIWFhdTXd2FGUUHATuXqjM2Ur30Jh7KDeJllLvBMgIlP2q79bs7FQcO9E5STGwt3FZDWxWdyVZKKmQkz+WTKJ7w47kXcHQ4lelqYvpB3G1N53tFMTW0+7PmxS9sWTm5fe/152w3oc2pzyKrJQm2oZ7hF06vC7dv0C3DFUaumpslIZnmDzdufEjQFN0cfiiQLK1uzNgvHt2xfCVuzKtGolfKeNjvhbGxSEq8BDL4BXHp/CdXTMTLSC5UEWeUNFFs5a/clUZcwNUKZkf8w6UP2VexTqrdET1J2SPvbqu2fbZoMZl5YtJ+NGRVo1BKPTotlUpy/zfvhkrWOm8x6nh14j82S8FlbW9h9l1UzOk0RIaOJsWhoMTXxxZ7PrV9Krxc45U+gnJwcBgwYwGWXXca9995LWZlSl/rVV1/l4Ycf7vIOCmevopomcisaUUkwvHX2yurK02HBnQSmruR1owtP970O3eVzwd22Z367u8Qob86J8cEiwzsr0jCau37NrUalIdrjUFSELMvEesXirHMl1cmN/2obyNv1FdSVdHnbwrEZzRYOttZKTwiyXUbtv7P/BhlGtJhwR9UjstufKq1a1T47stPG2e4B9Go9V8RdD8CvDZk01RbavA89RWF1E1+sV0LtbxoVTqSPs1Xb21a8ja/3fU1mdSbs/g4aypQkoYN6fpivtbg7aukfrLyeNmVaMewekCSJWxJuITEgEZPFxBvb3iCvNg/6nq/skLUODI1W7cPZoq7ZyFO/J7M7rxq9VsUzlyQwJtrH9h2pzlOWJkmqXrX8a3BrCdWU4jqrVDPqLMkzgjtlN7SyzO7ibWwo3GC3vvQUpzygv//++xk+fDhVVVU4Ojq2b7/88stZsWJFl3ZOOLutTVVOFg0Icbd+khOLBfb8AgvvhqoccPJGfdFbBI6f014nXejozvHRuDlqyK1o5Jft+VZvT5Ikro29lhfHvkiAVwxlOj1PqqrYteZ5q7ctKNJL6zGYLLg5agjxdDz5DbpAo7GR9QXrwdjI1JbWcPvgYTZp29aGtp643JVnp7JBfS4jQOdGjSSzaPendulDd2e2yLy1PJVmo4X+we5Mt0GY76q8VSzJWsKOnJWQ9JOycdQ9Ssi9cFxt2e43plu/1KlKUvHvIf+mn2c/Gk2NvLT1JSpc/cA9BEzNkL3O6n04G7z5dyoHi+twcdDwwvT+DG7NzG5rmft+IkkyYQoeBo42mnCygQB3Pf5uDlgsMnsLujbx8SlRqQn0juEKswOYW/h639fUGey8rr+bO+UB/bp163jyySfR6Tpm+46IiKCgoKDLOiac3TZmlLevzx7f19e6jckyrHweNn1AtbmFxX7hGC+fCyHDrdtuD+fupOWuc5QZ9J+2KyVjbCHIJYgXx71IfPBomoHXyjawePsHIiTLBva31Z8PdLNZiKGjxpFHRj7ChTo/4mR1rwy3b9OWGG9/UR2NBtvX39WoNFwbpMw2LSpYTXVztc370N39sj2Pg8V1OOnUPDilr9VLqTYaG0kqSwJgVOF+pURa8DCIPMeq7fYGbUv1DpbUUdlg/XwrOrWOR0Y8QrBLMGaLmRpj7aFZ+lSR7f5MGUwWdrWGgj93WUJ73XSbk2UWZf3FS9oGfnG1bnSOPbQllLT3Onq8+3CJWUeI5ECtoZb/S/k/+/anmzvlAb3FYsFsPjoMIz8/H1dX1y7plHB225FTyWtLD2KRYVKcH1OsvTaqKAkyVoFKw2/RI/hG28L7B8QbR2eM7+tDYqQXFovMm38ftFkNbVedK0+c+woTfQZhAXYc+BWLSJBndYfWz9su3F6SJBJcwrm5shwJqVeG27cJcNcT5KHHYpFJyrPP7MiouBn0kdVEtjTTZBR5cQ6XVlLHD1uVE813T4i2SZWHHSU7MFlMBDsFEJK3U9k46l/QS9bsWpOPiwP9AlyRZdiSaf1ZegAXnQuPj3yc/437H1HuUYcG9IU7ob7MJn3orXIrG7FYZFwcNPT1c7FbPwzFe9lhrARJxfDYq+zWD2tpi3qw9zp6fGLQInGn2g8JidV5qzlYedC+ferGTnlAf/755/POO++0X5Ykifr6ep555hkuvPDCruybcBZKLqjhf4tTMFtkxvX14b6J1p8BYc/PABRHn8OKlmIApkZOtW6bvYQkSfzrvD64OWrIqWjk4V+SyK2wzVpBrUrLnVPe406VD7MbLKhTfgfAaDHapP2zjcUik1Jk+4R4gJIErKlaCV/tpeH2bdpmR3baoXwdgOQXz+PqAJ5p0RFYb5tBUE/QbDTzxt/KiebxfX2YEGPlyLFWW4q2AJDoHIoky+DsKyqunILRNsp2fzhfJ198nVqPD7dAKvz6KZGA6f/YrA+9UWaZcoIxytfZrknodid/TzMyPo4+9PHtb7d+WMvAEHckSTmBUl7fYr+O+MQA0K+6iIujLuLGuBvp49HHfv3p5k55QP/mm2+yYcMG4uPjaW5u5vrrr28Pt3/11Vet0UfhLJFaUsfzf+7HaJYZHuHJ7Ckx1h/MV2VDzgaQJH7SqzDLZgb7DibBO8G67fYiXs46Xr1yIAHuekpqW3j41ySbndmV9G5MSpyNCxJs/xIaKnht62u8uvVVcSa3i+VWNlLfYkKvVRHla5vZkU+SPuHrDS9QceAPZUby3Ed6bbh9m/bydTlV9llGolLhEjpaiYbI3Wz79rupeRuyKKxuxttFxz0Tom0yoGgyNbG7bDcAo+TWZY7+CWJ2/hSMal1Hv6eghrpm25/s3Va8jQcthSxTGUS2+zPUVv3DVp8/x2Q2sbJQyYeQGDK+12S3P5yrXkuf1ud4T361/TriFaUkHWyu4cawC7gk+hLUKrX9+tPNnfKAPiQkhKSkJJ544gkefPBBhgwZwiuvvMKuXbvw8xPlU4TTk1lWz9O/J9NkNDMwxJ3HpsWiVdugDNCeX5T2gweysTIZgOtir7N+u71MiKcTb8wYRHygG00GM8/8sY9l+4pt03jMBeAXB8ZGSte/TnLpLnYWb+PpDU/x7MZn2V26W6yv7wJt4fb9AlxRW/tEG1DRVMHq3BUsSf+dBkmGhCsgcJDV27W3AcHuaNQSpXUtFFQ32acToaMAqM/ZwHcp35FRnWGffnQTuRWN/LVXeT97YHKM9ZO0ttpVsgujxUigcyBhVa1VB/x734ygNQV7OBLu7YTFIrMtu9Lm7efX5dOi1fOVppndVQeVSjrCackoPTRDby/ZB/9gl6UBlUrD+QNvs1s/rG1wW/m63Gr7dUKjA88I5e/y1PbNRouRRqOoGnGk0xoxaTQabrzxRl577TU++ugjZs2a1SHjvSCcivyqRp7+fR8NLWZiA1x58qJ4HDQ2OAvXWAlpSqKaHxyVWb+xQWOJcI+wftu9kLujlhem9+fcGF8sFpkPVqbz5YYsLBYrD6ZVKhj7AEgSftmbeKuomPNKc9GUp5Jy8HdeXnI7j307jj0rn7JuP3q5fYXKmm5brZ9fmbcSS30pcSYIcw2DkXfYpF1702vV7UsadtihfB0AoSNAkvih9gB/pM7n94zf7dOPbqKt6sDQMA+bZtU2WAx46b0YFZCIVLpP2RggBvSnqq2smS2y3R9pep/pnBs2EYuDC+9om8jb94vN+9AbWCxye+LdaB/7zdD/vv87AEZ5xRPg2nvLGbe9z+3Kq7bvhIhPX+V364D+YOVBHlnzCN/s/8Z+feqmTjl28ZtvTvwkzpw587Q7I5x90kvreHFxCjVNRiJ9nHnm0gQcdTYKqUn+DcxGkrzD2NNYgEbScE2/a2zTdi+l06h46PwYgjwc+WFrLvN3FlBc08yDU2LQa634f/WLhaEzYd8CAo1N3G1WM8PswCJ1CyvVRrLNDbya8ydPpwyjX9x06/Wjl5Jl+bCEeNZfP2+ymFhxcD40V3O+xUkJtdeePSeNh4R6kpRXw67cai6zQVm0o+jdwS+BC0qT+MdQz7aibRQ3FBPgHGD7vnQDe/KVk1kDQzxs2u6E0AmcE3IOxsosaJ4Lah14izWkp2p0tDc/bM1lZ24VzUazdT+LjiBJEncMuIPS0r2kFGzi1ezf+d/Iu3HvRaXObKGguokWkwWdRmWzkqlHMjVWUl6nJMW8bOAsu/TBVmID3NBpVFQ3GsmpaCTCx05REd59gWVQcSiypbChkKKGIqaETyHaQ+QTaXPKA/r777+/w2Wj0UhjYyM6nQ4nJycxoBc6RZZl/kgq5MsN2ZgtMqFejrxwWX9cHGy0PtbQCPuVWSff+CsY1pCGv5M//s5Wzqh/FpAkiesTwwh01/PeyjQ2ZlTg5pjFvedZ+Yvo8NuUHwCzEW9jEzebmrmisYxPNjxPY2UGgckLoN+lyqy+0GkltS1UNhhQqyRi/K1fzWR73jqqKtNxlyVGxF0NQYOt3mZ3Mizck682ZrO3oAZD65dYmwtLJLQkmcE4sBsLizMXc/uA223fDzuzWGSSC9oG9Lar7tBGJalwaAs39e0HatuE+/cmEd5OBLjrKa5pZmdOFWP6+Ni0fa1ay0Pj/8eTP0+j2NLMw//czf8mvY+fk7JMdXPRZpalL8JQX0S0Z1+uG34/jpqz5wRmZ7Stn4/0cbZ+bqXj0GSv43mDIzmefYkI791lI3UaFf2D3NiZW01SfrX9BvRHzND38+rH+ODxrCtYx1f7vuK5Mc+hksT3OTiNkPuqqqoOP/X19Rw8eJBx48bxww8/WKOPQi9T22zkxcUpfL4uC7NFZlSUF69eORB3Jxt+UTm4BFrqwD2EoNjLeGTEI9wYd6Pt2j8LnBfrx2MXxAKw6kCpbetqq7WgdwMXP1z9Erjv/A94Qu2PW2U2pC61XT96ibZw+75+LjaZ3fp7xwdgMTJJ54c28S6rt9fdhHs74emsw2CysL+1soDNta6jv7SuHmSZ1XmrqTXYqS92lFleT6PBjKNOTbQNk3Hl1eZhtrSWCC5R8ruI9fOnR5Kk9mz3a1LLKKltpqyuhfL6FqoaDFQ3Gqhvse7nk6ujF4+GXYyLLFFbV0R5Y5mSlHf3D1RveJv96YtJL97BspSfeGz5v8iqybJqf3qawzPc203qMiQkImKn268PNtS2jn6XPdfRtw3o60uVSjfA9XHXo1frSa1KZX3Bevv1rZvpktMaffv25ZVXXjlq9l4QjrSvsIb7ftjF1qxKNGqJu86N4okL42yWZAgAixn2/qr8PfDq9tlakT2z642M9CLE05EWk4W1qfarwatz9kUz9GblwrbP2Fu0TSTKOwVt4fbxNgi3z09dwr6aDFTApLFPnFWh9m0kSWJw62zwbjuVr8O7Dzh5EW80ESXpMVgMLMteZp++2FFbuH3/IHebzQy2mFv474b/ctc/d1HeVA7FrQN6sX7+tI2OPlS+btbX27ntq23c+uU2Zs7byk1fbOW6TzfzyK9J7G39f1tDUPwVPG9y5spmM15/PwM/3wxb5jKoooD7TY7cK3njLUsUl+7lybWPsTR7qficapXRNqC30/r5vVn/UF+SrGRd7zPZLn2wtcGhyrKQ5NZIMbvQOSvlaqE97N5L78UVfa8A4PuU72ky2Sl5bDfTZXEKGo2GwsLCrrq74/rwww+JiIhAr9eTmJjI1q1bT7j/L7/8QmxsLHq9ngEDBrBkyRKr91E4msUi8/O2PJ6Yv5eKegNBHnremDGIiwcG2b7sR+ZqqCtigV7Fp4YCqpurbdv+WUSSJM5PUJYx/L2vxL6diZ8ObkF801LIi2sf44+MP+zbnx7EZgnxDA04bf2MS8wOnOszGJ/Ic63bXjfWVo/eViUgj6JSQf8rkZC4pDQHTMqAvsVsx7rEdnBo/bztwu13l+6mxdyCo9oRb0mnzOQC+MXbrA+9TT9/V4aFe6LXqtBpVGjUEiqVxOHnaFKK6nhiwV6eXLiXg8V1Xd8Jv3iC3cK52qghoK5MiSQLHUng2NmMueY3zrn+T15zjmOYCUw1uXy5d55I/oWyRDOzrDUhnh1m6OsN9byx7XXu1dVRFDQAnLxs3gd7CPdywsNJS4vJQmqJFV4PndWWN6Q8rX3ThZEXEuAUQFVLFfPT5tupY93LKS9Y/uOPjl+CZVmmqKiIDz74gLFjx3ZZx47lp59+Yvbs2cydO5fExETeeecdpk6dysGDB49ZMm/jxo1cd911vPzyy1x88cV8//33TJ8+nZ07d9K/vzjTbUvvrkhj5YFSAM7r58s9E/rYLvnd4WQZ9vxEFRYWOjnSnL+G/r6DGRM8xvZ9OUtM7OfP1xtzSCutJ7Os3n41ZDU6SLybwBWPQVMVP+z/llDXUIb6D7VPf3qI6kYDhdXNSBLEBVp5/fyOr/BqqORGtyi44GPrttXNDWrNMpxZ3kBNo9G2S5LaO3E9FO4iMX8bwXXlxIVPocXcgoPawfZ9sQOzRWZ/a3TKABsO6DcXbQYgMTARqTRF2egectYMJKxBpZJ49tKEY14nyzKVDQZ+2ZHP0uRikvJqSMpLYmSkFzckhnXdZ5YkwXlPKJMKAQMgeDjonDrs4jL1FeYsuJO/mqr4SVvJ+ODxXdN2D1Zeb6Cu2YRKgnBv2w/ol2b9RXNTJeGyioB+l9q8fXtRqSQGhXiwJrWMXblV9A+2fQ4RQAm7z1zdoXSdVq1lZsJMXtv2Gvl1+ciybPvJwW7mlGfop0+f3uHniiuu4Nlnn2XgwIHMmzfPGn1s99Zbb3HHHXdw6623Eh8fz9y5c3Fycjpuu++++y4XXHABc+bMIS4ujhdeeIGhQ4fywQcfWLWfQkcGk4XVB5XB/H8m9uHBKTH2GcwDFO2GsoP8ojX9P3tnHR7FucXhd9aS3bgbCUmAEDQEd6dIS2mh1AtUaLn1Uvf2tpQadeHWqRstFVrcLbgEiZGQAHG3zcrM/WOSLSkWICsJ8z5PnqzMzHdmd3ZmznfO+R2Mbp60923PgPABzrHlIsHHoLWlOy51dpQ+Zihjgnsx2qJBqirgnV3vcLTyqHNtcnFWHJR/u9EBHvYtjTHXwqHF8uOB916UqfYn4u+ho22AAUmCPUfLnGOESgUjnkTtEcRrNQIzS0vx1tpfFNFVyCisotZsxcNNTYyDHAmT1cTO/J2A7ND/Uz9/amdU4cIRBIEATzdmDWvH/27qxehOIagE2JpZwn3f7+b1pSnNl3Ic0gUG3AUxQ09y5gHwjUQY8RQTRDfeLzMRm/+PE5Nemo4oOSn12Yk01M9H+hscLhBqtBhZkvoziGauEHwQYi6uCZYT29c5jcA4+f8JDj1Az+CePN/7UR6xeiMsfhAqnXx/6WTOOUIvis45mZhMJnbs2MHjjz9ue02lUjF69Gg2b958ynU2b97M7NmzG702duxYFi1adNpx6urqqKv7J6WwokKenRdF0Wn73hREUUSSJJe0MauoCqsk4eWmZVR8EJIkOa0uTNj9HTmClVV6dxDU3Bh/o1PtsSeudExc0imY9WmFrEkpYMaAKNwc2DboJPrfyfRFezlaW8Oh2lJe2/YaLwx8AU+d83rbOopzPSYqjWZ+3p6DhMSkHmF2PZak1OV8bS0i0TOMLhG9EVzguHU2CW18yCquZueREga3D7DbOGc8Ltx9YcRTqBbPRkpbhhTaHeIvtZstrsSenFIkpPpWjRKiaP/rxK78XRgtRgL0AcR6xyLlfwCAFNwFHPibcKXrhyMJ8tRxz8h2TO4ZzrdJ2axPL2JNagEdgj2YmBDuGCOiBkDP6XjsXIC0/nUk37ZkubnxzKZniPON484edxKod6xSPzjvmEgvqERCIibQw+FjrziygsqqPEIkFX1iL0FUaR36O3Q23dt4IyGRXlBFeU3dSZP6Djkm/NshAFQcQ6qrAq0BJBFS/qLjtk/BWIYEsOVDpFHP2M8OJ9HUz9ZBPcIunKKiIqxWKyEhjduKhYSEcOjQoVOuk5eXd8rl8/LyTjvO3Llzef755096vbCwEKPReB6WOwZRFCkvL0eSJFQu1pJrd0YZFrOFUH8dhYXOE0ZTV+Tgm7mBr3R1mLV+9PDuip/Fj4KCAqfZZE9c6ZgI0Un4ugkUVRlZvCODwbG+TrTGH8/wQdyTvYanqkrIQc09K+7hhtgb6O7f3Yl22Z9zPSZ+3lNAWbWRCB83OvlKdv2t5O1awGJVHUvUNczJy8GgOUX06iKjraeExWxha0YB+V297ZZSeNbjQh2GPm4qhv1fc3jdq6wvTObSuBmtPsVxa3oeFrOFtp6CQ64TkiTxy8FfMFvMdPbsTGF+Hv7H9iJYzJRpQrE68FrlStcPZ6AFpif6EeQm8uPuAtYcOEa/MAfeMrcZh1f2bnR527EufpS0ntORLBJ7C/bywMoHuC72OnoH9nacPTjvmNh3pBCL2UKgzurQ+zWLaOGXgz8jGSuYYNJQHtgPSyu9XzwTQXoVuRUmVu7JYmBM47R7Rx0TfhpvVMZiytO2gSTiseczNPWdIKweYdRUHycpYzG92lyC6BdrNzucQWVl0/QLmnR2+neU+0y88cYbTV7WFXn88ccb7W9FRQWRkZEEBQXh7W1/hefzRRRFBEEgKCjI5S6+pSlVaLQaOkUGnlLrwGGkfs0+rcBegx6duye39ryVYA8n2mNnXO2YuKyHia+Tskk6Vsfk/k7+3Ifdh8eP23i81sgbfgbyJCNtQ9oS7FdvlyRBVX36lleo8+xsZs7lmCipNrH68GE0Wg23DetAaIj9IsQUHuL7umwElYph7S8lOjzafmO1IIb4BfDh5nwqzCIWnTcRfvYpQ2jScRF0O9VVabxQuAbzkV8ZEDeGTsGtV3/CYhXJLJWP/0GdIwl2QB/mrIosjtQeweBmYGrXqQTVliNgAb0PAe17yQrbDsLVrh/OYpybN7/sLyGj1Ize28+xHXkufRFh0X/QlB/lkpx1dBv+Ku/vnU96WTpfZX5FpimTm7vcjEHrmMlPZx0TeTXZaLQaEtuFERzsuDruNTlrqDbm4ycJjPSKQdNpqKyDcJExuquR77bl8OuBMkYlRDf6DTjqmBDCu0D2ZgL2fQwVx+QXDb5IPadjjb+UR365nEpjCW3SvqLLZe/bzQ5n4O7u3qTlmuTQ79q1q0kbs+dsfWBgIGq1mvz8xjUS+fn5hIae+oY7NDT0nJYHcHNzw83tZMEflUrl8hc1QRBc0s6s4loEBGIDPZ1nmyjCkY38pq4DfSjjYsYR7uWg9Dkn4krHxJjOoXy3NYdDeZUcLTUSFeDECKx3KCRcS9udX/F6hZXkntfQsSANDv4NJYdZV5JMjMlEpKCDy96CsNYTuW/qMfHTjmOYLCLxod70bxdo1/N7WfLPbFOZwc2bsR0mucTx6gro3VR0Dvdm79Fy9hwrJ9KOddxnPy5UeI1+jqE/XcFKSymLN86hy5RfWu0NbkZ+FXUWEW93LTGBng5pWRfrG8urQ18lqyKLEM8QOLJJfiOkC4La8QmVrnT9cBbhfgaiAzw4UlzDzuxyRsQ7cDLa3RsueREW3QnHdxO++QOeH3gfv+Su49e0X9lwfAOppancnXg3Hf07OsQkRx8TFUYzxVUmBATah3g59FgsMhahqqvgMqsOXcfxoL44WxtP7R3FhvRijpbW8umGLGZf0vhYc8gxEdgBsjfLzryggk6XQe9bEPR+uAMD21/G0uSvWFi4jW6FB1uV5khTP9cmLbV69eom/a1ateqCjD4TOp2OXr16sXLlSttroiiycuVKBgw4tajZgAEDGi0PsHz58tMur9D8SJJEVpHcbiTGCe1GbBQehJoSZqsCubzjNbYelgqOI8DTjd7RskrzsgOnL3txGAnXg8EfbeVxEte+Deteg/2/UpC7i4+kUh7RVvGZUIVl3etgtTjbWoeSV25kyX75O5o+INq+qdXGClYfWYYFiAvuToxPjP3GaoHYRImyy5xqBwDuPlw2+BkEBHZUZnJsp32FcJ3JvmNyu7puEY7rPw8Q5R3F0DZD5Sf5++X/rejmtCXSL1bOTtpyuNjxg/vHwIjHQaWGI5vQLLyFq63uPNf/GYL1wRTUFpBSmuJ4uxxEQ7u6UB93DDrHTmpdHTGCd6o1jBZ1EDfWoWO7EjqNivtHx6ESYHVKIUnO+B20HQQaN4joCVM+gSEPgt7P9vakrtPRuPtyUGXhwOaWnSl+vrSoadfZs2fz8ccfs2DBAg4ePMh//vMfqqurufnmmwGYNm1aI9G8++67jyVLljBv3jwOHTrEc889x/bt27n77rudtQsXHUVVJqrqLKhUApF+TozIZm0AwDNyADd0uQkv3cWj1OxKjO0iZ8esOlTQfKrB54vOAAPuliOMel+I6AXdr0Y16F56tr8MMaA9S7VW/q5IhX0/OddWB/Nt0hFEUSIxytfu7bqsqUtYRg1o3BjT8Wq7jtUSSYzyBWDf0XIsVueLMYXHjKBXiFy7++fez1rtZNe++v7zjmhXV2osJacy5+Q38hoU7pU2u85kQKw8Eb0zu5Q6i9XxBsQMhckfQ3BnMFXD+nl03PQhryTcw/Xx13NZ7GWOt8lBZBTICvexzggIpa8gSAJ9aHfwbv0ZnWeiY6gXVyRGAPDe6nQqjWbHGhAcDzf/DZe9CQHtTno7QB/AiA6XAwK/FO+Bozsca58LcF7TXdu3b+fHH38kOzsbk8nU6L1ffvmlWQw7Fddccw2FhYU888wz5OXl0aNHD5YsWWITvsvOzm6UmjBw4EC+/fZbnnrqKZ544gk6dOjAokWLlB70DiSzPjrfxk/v8HYjJyJlrpdVMqMHO80GBejV1o8ATx3FVSa2HC5maFyQcw1qPwpihsnRj/oodCAwm6ksP7KcT7a9wc9iHoN3fIpfuxGtqp7+dGQVVbMmVRavnDagrX0HkyR2Jn9LiSDi5RGqtJA8BbGBnni5a6g0WkjJr6RLuJN6AZ/A5X1ns/3PG1kv1XBT/j4M4YnONqlZMVtFDubW9593QO/lbw99y4ajG5jeZTrjYsbJL9aUQGWufF4K7mR3GxROT7sgT9t1a09OOX1j/B1vREA7mPQ+7P8Ftn0CuXswLLqLST2nySWF6hYVn2syh4tkh75doOO60JTXlVNjqiYsdYn8wkUcnT+RG/q1ZWtmCUdLa/l43eGTUu/tjurMJQ+TOt/IqvQ/2FdbQsqWt+k4ZUGrLQk7Fed8Bvj+++8ZOHAgBw8e5Ndff8VsNrN//35WrVqFj4/9L3x33303R44coa6ujqSkJPr162d7b82aNXzxxReNlp86dSopKSnU1dWRnJzMhAkT7G6jwj/Y0u0d1MP3lJTl8ENVGs9pa9hraP2tyVwZtUpgdCd5As4l0u4B1JpTnvRHRY2ifWhPjFo931AOG99xgnGO56stR5AkGNQ+kPbBds5kyd2NUFVAhKBlZNyVaNUOFJxqIahUAgn1afe7ndkL+ATiAuKJcA/AjMT2jMXONqfZSc2vpM4i4qPXEuVv38yytNI01h1dh4hIB78O/7zR0H/eLwbclOuWMxEEgf71afdOSTduQKWCblfBVZ9DZF+wmmTn/tc7KCk9zId7PqTEWOI8++xAQ8p9u2DH3UMuzVrK/Svv5JuKQ6DWQewIh43tyrhE6v0ZCDIEMaz95SCo+KX8wD8aJBcJ5+zQv/TSS7z55pv88ccf6HQ63n77bQ4dOsTVV19NVFSUPWxUaMFkFtc79A5QCD4dYuZ61qnMHHTTUY3zU1Yvdi7pHIIgwJ6ccnLLa51tzmlRCSpu6XoLgmcI61UWUrLX2ko3WisHjlewNbMElQA39nfA+fzA7/SWtMyLuZqpnW6w/3gtlMQGh94V6uiRHZwBwT0JllRQkulsc5qd5Pr6+a4RPnbVjxAlkS/2fwHAsDbDaOd7QiqpUj/vUvSrj8pvzSpBFCXnGuMdBuNfhZFPycJ5xel8sOI+1uSs4esDXzvXtmbEaLZyrEy+R4h1UITeKlpZlbMKTFXESGpo00eZUDsBp6fen4UrOl2HzhBIoKTCuvVjOXvlIuGcHfqMjAwuvfRSQBaqq66uRhAEHnjgAT766KNmN1ChZZNZP7sa7USH/uDhJRQLInp3P3qF9HKaHQoywd7uNgdl+YH8My/sZNr5tmNEzDhUhgBSBascpTe77iTEhSBJEl9tyQJgdKcQ2thb86KmBDLXASB0UaLzZ6JHfR19an4lVXWuUbN+RecbeMfsydDS/FZ307T36D+CeHbBVAMrX2DDxldIL0vHXe3OdfHXNV6moX4+tJt9bFA4J7pG+GDQqSmrMZOS37S+0HZFEKDDGLkLi0rD9aWlqIyVbDy+kf1F+51tXbOQWVSNJIGvQYufh84hY+4s2EmpsRRvcx19RA1E9Tv7ShcZN/RrSxs/PWU1Zj5e71oTuiEeIcwfv4CZ6iDUpZlw2H5i7a7GOTv0fn5+tib3ERERJCfLF52ysjJqamqa1zqFFo3RbLVFYGOd5dDXlLCu7BAA/aNGoFM75qKgcGYaxPGWH8h3jsjQOXBt/LW8NPo9JnpEy73pdyxwtkl2YWd2KcnHKtCqBa7rZ//ofN3B31lNDXXBnSCwvd3Ha8kEe7kT7uuOKMHeo2XONgcAbXBnBK0B6iqh1LVu6i4Ek+Wf+vnu9hLE2/UVtenL+Tb1J6gt44r2V+Dn/o9iMxYTFNUrlysRepdAq1bRp75Li1PU7k9HQDvoNYNYSc2YqkqwWvg0+VMsomtM/F0IGYX19fNBjouQrziyAkSR4UYzWgRo09dhY7cUTky9X5NayM6jLjDBdQIeXmGQcK38ZPvnrVa49d802aFvcNyHDh3K8uXLAbk+/b777mPmzJlcd911jBo1yj5WKrRIsktqEB08u/pvTFkbSFKZQePO0NjxTrFB4WT6xPjja9BSVmPm5s+38dmGTFtqnavh4+ZDTEA8DLpPfmHfj1By2LlGNTOSJPHV5iMATOgWRqCnm30HFEU2HvyR+ZpantcZ7TtWKyExSnb4XKWOHpUaQrpiRiIlY6mzrWk2UvMrMVslfA1a2vjpm3+AilzY9xO/q+soFUSCKwu5jH85LEWpYDXL3Te8I5rfBoXzol/sPw69JDk57f5EelwPQfFcXSfgXV3Msapj/JX5l7OtumBs9fMOUrgvqClgT+EeMFczyqIB3yi5vEHhJE5MvV+wNc/5XYv+TdcpZLkb+KMyAxrEDVs5TXbou3fvTr9+/ejWrRtTp04F4Mknn2T27Nnk5+czZcoUPv30U7sZqtDyaFC4bxvgvHZ129P+oBaJQI8Q4v3jnWaHQmO0ahX3jepAgKeOSqOFX3cdY9ZXO3hq0T42pBVhdoH2XCfRdgDHInuxlFpY/0arSjPef7yCjMJqdBoVU3tH2n08KXsLS035IKgZ0OFyu4/XGujhYnX0ADXBnfiPtpJn07+j1FjqbHOahNFsZfWhAspqTKd8/8R0e7vUzyfNB6sZH7926A2B3Gh1R7v2FTi2859lbPXzXS8qlWZXp1dbPzRqgeNlRo6WutAEtEoNwx/DU+3GDdV1YCxnYepCimtdKJPgPDjs4Aj9quxVSEh0E/SEooZIJd3+TNzQry0BHjrKjRb2uEjmWAPF1loeN4h8rTFyZMfFUUvfZId+7dq1dOnShblz59KpUyemT5/Oxo0beeyxx/j999+ZN28efn5+Z9+QwkVDg0Mf48B2I40wG1lfsg+AwVGjUAmts61LS6V3tD+fTu/DU5d2one0n00o75Ulh7jli20sSXYRFfx6CmsKecR6lM+1dWTm725Vs75/7DkOwMj4YHz09q9lT0/+hizBilbvx/DoMXYfrzXQvY0PKgFyy43kV7hGVoOhTW/CJBWSuZYtxzc725wm8daKNN5Ynsp/vt7Juvr2jCey71gZYKd0+9y9cHgNCCrGDX+B9yb9Qt+okXI0fumTUJQuL5ev9J93RQw6DQltfAEXS7sH8I+B3rcwVNQSV1mK0VTFL2n2ayNtbyxWkSMlchlvrAMcekmS2Ja3DYBR1fXlw4pDf0Z0GpWt+8PmDNf6PQToA+gfMxYENb/UHYe8Pc42ye402cMZMmQIn332Gbm5ubz77rtkZWUxbNgw4uLieOWVV8jLc62bbwXn09Cyzmn188e2M9gi0F3jzZC4K51jg8IZUasE+sUG8OzELnwyrTdX94nEz0NHWY2ZD9aku1QafpAhiH4Rg5EMgXymrkVK+p9c69rCKag02m5OL+3mgPTCyjyW5ss3TgPbjsJLZ+fWeK0Eg05Dx1D5s9qV7SLR8KBODEQPooVNR1Y425qzsuNICRvTiwCoqrPw2tIUXv77EOU1slKzySKSkifXg3ard9yaDVGkaOOb5GKF+AkQ0A5Pd2+EUc9AWAKYa+Dvh6Hi+AkOvVI/72r0t6Xdu2B7uO7XoArpxq1mLZeZ4MYW3Dkku6QGi1XCoFMT4m3nEjDkzh1zh8zl7nZT6F1TDRo3+XepcEYafg9JmaVYnd394V9MjpsKbp4kqczkpf7pbHPszjmHLD08PLj55ptZu3YtqampTJ06lffff5+oqCguv1xJnVSQkSTJFqF3msJ91kYGiTqejL2KNt72TyNWuDCCvd25qX9bPpvem15t/ZAk+Hn7UWeb1YgbOt2Au2coqRoVa0wFkL7c2SZdMH/vy0OUoFsbH4f8Viv2fs9mlQm0BsbGX2338VoTPSLlLLhdrlJHr9HRP6ArKiC16AAFNQXOtui01FmsfLgmA4DLuodxXd8oVCqBjelF3P3dTjZlFHEorwKzVcLPQ0e4j3uzjl9+8HderNjDs25GsuPH/vOGRgdjX5LFzWpK4Pd75f8qDQR1bFYbFC6cBmG81PxKSqpdbEJXpYLhjxGtNnBTSRH61GXOtui8aaifjw3ysGvryBPRqXUMqbPKYnhhPeTfpsIZ6Rrug4dOTWWdmf3Hy51tTiMivSNJDE5EAv6uFztszVxQDnL79u154okneOqpp/Dy8mLx4sXNZZdCC6ewso4akxW1SrCPsNDZEEXI3iQ/jh7k+PEVzhuNWsW1feUJmFUpBRS4SHoxyGlck+OmgN6P+ZpaXtj+OttytyJKLfNCUWexsnS/nF01sXu4/Qc01bA29VcsQGxgl8Z9txXOSmJ9+7q9OeXO74Vdj19EbzqLGjDXsNmF0+5/3JZDfkUdAZ46pg2I5vp+UcybmkCUv4GyGjNz/zrEWyvSAOjezPXz1TXFzNnxKrmCiNa7DQavf/3W3DzlvuJeoVBdXwYQGCdHCRVcigBPN+JC5EyZrZmulWYMgG8k9L1dfpw0H6n8GGXGMqeadD4cLnJc/bzZav5H5DAnSf6vtKtrEmqVQM828nfkamn3ABO6TgNBzRqxguqjSc42x66ct0O/bt06ZsyYQWhoKA8//DCTJ09m48aNzWmbQgumITof6W9Aq3Z87XrekbX8XpdPsZsBQpW0qZZGfKg33dv4IIoSv+w65mxzGjEhdgLDYiegElQkm0t5L+llai2uUxpwLqxLLaLSaCHYy41+Mf72H/DQYvKstaDWMSp+qv3Ha2XEhXih16mpqrPYWjo5nbAeDBS1YKll43HXvAfIKalh4U75PHL70Fj0OjUA7YM9efOaHkzt3QaVIE9Eg5yt0lwYLUZeWX43R6w1+KjdeWrU2wTqA09e0CMQJrwO7vVjK+n2LotLp90DdJkMYQnkmat56s8beXbtI1hF124P+29OjNDbm4VpC3lgzQMkZa+DPFl3SWlX13R6RcoTXJsPF7vMRHMD3YITaeMRihGJ1fu/cbY5duWcPK3jx4/z0ksvERcXx/Dhw0lPT+edd97h+PHjfPzxx/Tv399ediq0MGyCeE5SuF+X8jPfaIx87KUHtcYpNihcGNf0kaP0y/bnUepCqY1alZY7e9/Pu7HXMsnqxniLBg+tfNMhSRK/pP1CtbnayVaeHUmSbGJ4E7qFoVLZOa1RtELyQmZa9bzT/R4GtRli3/FaIWqVQEK9s7nLVdTugzvTV9Khtpo5UppBXrVr6elIksQHa9KxihJ9ov0ZUC/i1IBOo2LagGhevSqBSH89ep2a3m2bR+DXbDUzb/OLpJSm4CEJPJl4P2E+bU+/gm8kXPYmdLkCul/TLDYoND8NQmB7jpZRY3LBHtcqFQx/HB+dJ/l1ZeQd38b69S+0mH7coiid0LLOvhF6i2hhdc5qcqtzoTgNRIvcKtJXKdNsKl1CPdBr1RRXmUh3lYnmegRBYELsZfhJKvSFqa067b7JDv348eNp27Yt7777LldeeSUHDx5kw4YN3HzzzXh4OKlGWsFlsTn0DuofeiKSKLK+cBcAQ9qOcvj4Cs1Dtwgf4kO9MFslFu12rSg9QGCPm7heNHBtcT4Uy7W5qaWp/JDyA5/uc/0WnvuPV5BZJLequ6RLiP0HzFoPlbng7k1I16vRa5xQitMKSKhvX3cgt8K5hjSgM+AV1InbLXpei5lCqEeosy1qxOqUApKPVaDTqLhjWOxpU+k7hnrx3nU9+erWvgR4Nk+q+8f7Pmbv0fW4SxKP+nSnbdcmOOkB7WDwA+AZ1Cw2KDQ/bfz0hPu6Y7FK7DjiIgKV/8Y7DP2Uz5jkEw9ILMz4Hcsvt0NhqrMtOyu5FUZqzVa0aoE2fvYNCu3I30FZXRk+Oh96VdZnXEQq0flzQatW0at+EnRTveioKzG0+3TeE0IZVVv3j+BoK6TJDr1Wq+Xnn3/m6NGjvPLKK3TsqIi1KJyerGLntaxLzV5LgaUGd1T07nKdw8dXaB4EQeDq+ij93/vyqDSanWzRv/AOg5ih8uN9PwPYUu83H9/s0gJhAH/s/adVnZe7/VvV1e75nlJE6HyFUht8ATRErA4XuVAWSFgCw0UdUWXHnW1JIyqNZj7dkAnAdX2jCPE+s9CdSiXgplE3y9hGi5F9x7eAsZzZZgMdBz+q9JRvJQiCYIvSJ7lq2j2AdxhjLv8MH792FKgE1pUdhF/vgC3zwew62jT/pqH/fHSAB2o7Z44tPyIL246IHIHm6A75RcWhP2cG1JehbD5c/I8egYug1RrQRA+Wnxxe41Rb7EmTHfrff/+dSZMmoVY3z8VOofViNFvJLZcvFjEBjo/Qr0uRnat+nlG46R1QF6xgN3q39SMm0INas5U/9uQ625yT6V6v0p6+HGpK6BHcg+6B3RER+fOw67ZJKaysY0uGA1vV5SWztmg3d+qq+ErbMtI+XZXoAA8EAUqrTZTVuEgpSmh3+X+ua/X6/WJjFhW1FqL8DVzRwwGijyfgrnHnPWsAL5g9SGg3DkI6O3R8BfvS4NBvzSpxvcnmE3DX6pnU/Rbwj+EXby/MkhX2fAc/3wJF6c427yQkSWJtiiwM2S7YvgGh41XH2Ve0DwGBUb7xcgaZWgvhiXYdtzXSu60/GrXA8TIjOSUuqCkUOxwrElsO/01O+RFnW2MXHK9WptDqySquRpLA16DFx2D/yN+JmK1mthTKN5VDokY7dGyF5kcQBK7uLUfp/9hznFqTiwn7hHSBkK5gNcP+XwG4vJ3cvnN19moqTC6SFv0v/k52bKs6ae+PrFKbEd29CfSJsvt4rRm9Tk1ofaQ501Wi9KHdAMgoz+Sdra+yMHWhkw2CA8crWHYgH4A7R7RD42hxVrMRde5u4iQN9L7ZsWMr2J2OIV5E+RuoNVn5crNrOwhj2o7BVx9AoYcfa3tdDR5BUHEMlj0Fphpnm9eIVYcKSMosQa0SGN/VvuU7y7Lktn6JwYkEF9ZPboR2B61SDnau6HVqEuvbqm7KcL20eyJ686VO5E1LLr/t/czZ1tgFxaFXaHay6m8yY53Qf/7bfZ9QZa7CT1LRpbOiot0aGNgugAhfPVV1Fv7a54pR+vrj7MAiMBvpGtiVWJ9YTKKJpVlLnWraqTBZRJbulx0dh7Sqq8jl8JHVHBGsaD2CGBKhiOFdKDH151aXcejdvcE/lkJBZGPOGlbnrHZq2qVVlHh/jXyDPqZzCF3Cm0+1vinUmGsQi1JAEsEQIItsKbQqVCqB/wyX224uSc7joKtoWpwCnVrHFe2vAGCbtRKmfg5eYXJEevN7zjXuBPIrjPxv7WEAru8XRawdBfFqLbWszlkNwLiYcXB0q/xGpNKu7nwZ0E7OWtnkgu3r0OgYGj4QgM3HN1BidOFSmfNEcegVmp2G2k5HRP7+TduSo6iBm7w6oPJygNCXgt1RqQSm9m4DwKLdx6izuFiUPnqofHNkrIC0ZQiCYIvSL8lcgtHiWrWKW45UUFlndlyruuSFrBLqQOdB3zZD8dQ5XlejtdHg0Ge5ikMPEJZAoqjB3WKmsLaQtLI0p5lyMLeC7OIaPNzUzBgU7fDxvz30LXdufpZNKjMEd1Jq51spXSN8GNUpGIAP12RgdbGWXScyOmo09/e8n0f7PgpuXjD8Mfm4PLQYspzfblIUJd5akUqt2Up8qBdX9Wxj1/Hc1e480vcRxkaPpZtvHBzfLb+h1M+fN31j/FEJ8kRzfoVr3fcAtOs4iY6iGouxnOWZy5xtTrOjOPQKzU7DTWaMox362lKGZ+3kLZMXg3rd6dixFezKsLggQrzdKKsxs7w+jdZlUKmg21Xy430/gijSL6wfIYYQ3DXuLtXGS5IkVqTKM9MOaVVXV0XtoT/YoDaD3o9RSteJZqHh3OpawnjdcUOgt1W+rVh5ZKXTTEnNrwSgextfvB0g+HgiZquZzcc3U2oswUsSICjeoeMrOJabB8bg6aYhs6iaP/e6lijkiWjVWgaED0Al1N/2h/f4pzXiuteg1rlq/Yt2HyP5WAXuWhWzL4mz+7VJEAS6BHThlq63oMrdC1YTeAaDX7Rdx23N+Oi1dImQs6E2u2KUvk0fLhW8QbSwPOM3TFYX0aBpJhSHXqFZEUWJrCK5JivWQQr32/K2UV5XDru/BXMNwYHx/6iPK7QKNGoVk+tn7BfuOIrZ6mK9RDtOAJ0HlOVAThIqQcUT/Z7g7RFvE+0TDRW5sPcn+RgVnZdhcCC3guzSOtzUDmpVd2gxW6wVGNU6Qn3b0dlfEQZrDhoc+pzSWkwWF/kthCYAcEl1NYgia46uIaUkxSmmpObLKtkd7CyqdSp2FOygylyFv8VMF0kNwcox35rxMWhtWSDfbMmmqKrOuQY1AZPVRFZ5FvS+FfxjZGd+3evgpDKZzKJqvtoi6xDcNiSWMB8H17DnJMn/I/sp2TQXyEBb2r0L1tFrdPSOGk6QpKKyKp/1x9Y726JmRXHoFZqVgso6as1WNGqBcN8ztwhqDrbmbuWN7W/w7LrHqNj/i/xin9uUk3IrZHSnEPw8dBRVmVh1yMVawukM0Gmi/HjfjyBJhNbVoNn1Nfx8K3x3rVyrmPQ/2al3Eov3ydkCwzsG2b9VnWiF5J9ZpzKDwY+RUSNP2wNc4dwI8nLDw02NKEocLXURUSuPAPBpQ0dRzXCfOAA+2fcJFtHxXQ3S6iP0HUO9HD722py1IFoZahJRIUBQnMNtUHAsYzqFEB/qRa3ZysfrDjvbnDOSU5HDvavuZe7WuZgEYMRToNJA1gZIXeJwe0wWkTeWp2KxSvSN8eeSznaYaJYkebK9fjL9g90fsGD/Aopr66PINodeSbe/UBq6PxzKq6Sk2vUi4Op2Ixlr1UFdJX8d/svlWuxdCIpDr9CsHC6SIyNR/ga7qwrvK9zH27veRkSkY20VXlYzhHZVTsqtFJ1GxZSesrjUl5uzKK91sVZBXSaDoIJjO+H76+Gnm2H751iK00hSWZAC62/sd3zhlHZBZTUmWxrceEe0qju8BqoKeFATzs2JdzOszTD7j3mRIAiC6wnjAYTJUfobNMF4ab3Irsxm/VHHRkHKakwUVNYhCNDewRH68rpydhfsBouRYVYt+EbK9coKrRqVSuDOEe1RCbIg2PYs1xXcCvUMRaPSUFZXxsNrH+bVzF9Y0LYLS1R17N74Gtbyow6155ukI2QVVeOj13LPyPb2mfRNmg8/3AjfXk3BhjdYf2Qlf2X+hdFqhIrjUH4UVGqI6NX8Y19kBHq60SHEE0mCpMMumHYf2ZeRah/cRQsGi8llOxGdD4pDr9CsNKTb27t+PrU0lde3v45FtNDPvwt35B9FQIA+M5XofCtmQrcwogIMVNRa+GS9i0VCvEKg3Qj5ccVxUOsQowbySHgb3ggJY8/gOyF6MIgWWPOS3OrOgSw7kI9VkmgXqLd/B4qqQti5AADPrpMZ1+4yfN197TvmRYYrO/TehSnM6DqDm7vczLBIx07kpOTJ0flIPwMGncahY284tgERkfZqD8JRQ1Anh46v4DxiAj2Y1EOecJ6/NgOj2cXEW+vRqrRc3fFqAPJq8tiRv4O/6vL5XC/wmlCCsOYVEB1TxpN8rJxfdx0D4O6R7fE16Jp/kKPbYc/38uPqIpYf+gGxJINulaVE5OySJ55Bbj+rc7yQc2tkYLtAADa7okOvccMjaiDzTJ684NkZHzfHdkCxJ4pDr9CsZBXbXxDPIlp4Y8cbGK1Gugd25x6jgEoUoU0fWehFodWiVau4b1QHVAKsSSl0vUjIgLsh8UYY/SxM+w3V+LkkxI4FlZrfD/8BQx6UW3wVZ9gcXkcgihJLk+V0+5Ed/Ow7WNZG+PlmpNIsOTrZeZJ9x7tIianXKHEphz60u/y/8BCDg3sxLmbcPyJcDiK1QM4SiwtxQrr90bUADBXrHZNgxaG/mLiubxSBnjryK+r4aXuOs805LUPbDOX9Ue/zeN/HuaXrLUyInUCv6DF0xx1V3l7Y9xOAXaOXhZV1vLk8FUmSy+kaUrWbFWM5rJkrP+40EdOop1ll0AMC4yqrZDHApP/J7yvt6pqNhvZ1e46WU2l0sUxKgNjhBKKSJ3NaUcq9Y6evFVo9hwvt79AnFyVTaizFW+fNg+2moP11lvxGn9vsNqaC6xAX4sXEhHB+232c91en88ENvdDr1M42S8bgD31nNnrp0thLWZq1lP3F+8kwldJu8GxY8Rzs+gbaDibb3cDeor3E+8XT3q+9XczamV1KQWUdXm4a+kTaydGxmGDr/2DfzwBs8g1hcWA4l5UeZKB+oH3GvIiJCTQAskMvSZJr6BN4hcpK0VUFkH8A2sgprCarifyafCK9Iu1uQmp9hD4uxPGCeLO6z2Ld0bUM2iH/BhSH/uJCr1Nz+9B2vPTXQRbuPMbwjsFE+hucbdYpCdQHEqgPbPyif29YPw+2fcwKrcTX2Ut4vO/jdPTv2Kxjp+VX8t8/D1BWYybE242ZQ2POuLzVasVsPkfHUJJg43ywqiCoOyTOJKlwNzq/TrTXetHZvw/GrHVQmQcIENYPjK7Xas3VEUURs9mM0WhEpZInbwPcBbqF6cktM7LzcAH97DFZcyEE9QCPSDDXwfGDEBDrVHO0Wi1q9YXfwyoOvUKzUWOy2HpP2rMH/cbjcs/UAeEDcN/1jXzijhkCwUp7oIuFG/u3ZcvhYvIr6liwOYtZw9o526TTEqgPZFDEINYdXcf3h74nMTiRftEDCcjaBGteYk/iZL5O+Q5PrSfvjnwXg7b+BlCSIGMVHNkoqxH7RJy3DX/XR+dHxAej09ghYlqWAyv/C0Wp8vNuU1khFZJReohjVceafzwFovw9UAlQabRQXG0i0NPN2SbJ5U6h3SF9BRzfCW16cazqGK9sfQWLaGHe8HnoNfZTsBZFibQC2aHv4IQIfaxvLLEqd9j4hSw05u+65yUF+9A/1p8+0f5syyrh/dXpvHRlN/u3B20uOk2EzHWIR7ey5eCP1LppeSnpJZ7s/yRxfs0j7rgpvYh5y1MxWUTaBhh45rLOpy2NkSSJvLw8ysrKzn0gcy149wXvfvJE+9HjaI1argu5Dg+NB0e0BujU9Z/St1IzlGae/45dpEiShCiKVFZWNppUvrK9jpo6FdraIjIzXbBOvctdYDZCYRVUOP979/X1JTQ09IIm5hWHXqHZOFIs188HeOrs2vt3eufpdPLvRDtRBZkL5JvI3rfabTwF18Ndq+bukR14elEyf+3LZUiHQLqEu24t1OWxl7Pu6Dr2Fu1lb9FePDvfzNCCQ1B6hPjjyQBUmatYnLmYqXFToSQTNr4Nx3fJGxDUMPLJ8xq7oNJoK00Y3yUUzJXNsk82UpfBhjfBXCOXEwx/nK06DQd2zEMtqBneZnjzjqcAyCKRbfwMZJfUkFVU7RoOPUBUf9mh3/8rdJ1CoD4QSZIoNhazMHUhN3a+0W5D51YYqa6zolULRAc4KTJacFD+H9AeNHaoCVZwaQRBYNawWPYdK2P/8QqWHchjXFcHiJA2B4IAfW5DdXQbDxWX8ErHfhyozJKd+n5P0s7n/CeoJEli4c5jLNiUBUCvtn48Mq7jGXUuGpz54OBgDAZD050dqwkqjoHkDno/0PtRZ6lDXaNGEATCPcJRq1wkq6+FI0kSFosFjUbT6PsxW0SOl9ciSRDkpcPDzc5ddc6VuiqoLgCNO3iHO80MSZKoqamhoEDu3BQWdv7nihbj0JeUlHDPPffwxx9/oFKpmDJlCm+//TaenqdPqxs+fDhr165t9Nodd9zB/Pnz7W3uRUlDLWd0gH2FRTx1noyMGgl/PSK/0H603EtV4aKiR6QvozuFsOJgPu+uTOed6xLtE31uBiK9IxkfPZ6teVuJ9onG2zMUhjwES5+gQ8oKHhh0B29mLuLP9N8Zl5+N18E/QRJl1XxJhJwtslCR6tz3b+n+fEQJurfxIcJPT0FBMzr0yQth4zvy47AEGPk01W4GPlvzIACXt7ucIENQ842n0IjoQNmhP1xUTe9of2ebI9NulFx2UXgItnyI28gnuaXbLby89WUWH17MkDZDaOvd1i5DN6TbtwvytHuXlRPZU7iHTcc3MTJyJB0bHHol3f6iJdjbnRv7t+WT9Zl8tjGLPtH+BLjKhNvZCI6HyH645yTxqOjLy/6dOFhykDlJc3ii7xN4433OmzRbRT5YncGKg/kAXNo9jJlDYlGfIXPBarXanPmAgHNI2ZYkKM8HjQBaD/AJBUFAI2oIUAcgIeGhV8TvmovTOfTuQCBqSqpNVJpV+Hm5uVamik4HBg/ZoXdyuZpeL2etFRQUEBwcfN7p965593sKbrjhBvbv38/y5cv5888/WbduHbfffvtZ15s5cya5ubm2v1dffdUB1l6cHDgup9XEBjngZJm7V+4dKqig1wz7j6fgktwyOBpfg5ZjZbX8sC3b2eackRldZ/DB6A94pM8j9AjuAdGDoON4kCT67vmdtugwFh7kj5QfZSc+ejBc85UsLGesgIID5zymxSqybL+cbj++uaNEVjPs+lp+nHAdXPYWeAbx7cFvKa0rJcwjjCkdpjTvmAqNaBDGy3IlYTyVCgY/IN8kpS2D47vlMpPQfoiIfLLvE0TJPiraKU7qP78yeyVrctawJXcLFCoOvQJM7B5OhxBPak1W/ufivelPouc0ANzTl/Nopxl08u9EraWWuVvncqTqyDltqtJo5tnf97PiYD4qAW4fGsusYe3O6MwDtpp5g+EcM21qSuR0e0Ela3rUO2salYYAfcDJugEKdsPPoEOrVmG2ipTWuFhPepUKtHqnO/MNNBzn56wVcQItwqE/ePAgS5Ys4ZNPPqFfv34MHjyYd999l++//57jx4+fcV2DwUBoaKjtz9v73GcXFc5OrcnKlvoWFX1j7BMpqjJV8czGZ/j78F+IWz+WX4y/FHza2GU8BdfHy13Lf+rr53/eeYzDhVVOtugcGXA3eAShqjzO1cczQLSwRAflo5+BsXPkY7tNH3nZ7M3nvPmkzBLKasz4GrT0i23m3+XhtfLNk0egLEipUnGg+AArslcAcHv329GqXSzNrpVxojCeSxEcD/GXyY83vgVWC9O7TMdd7U5qaSqrs1fbZdjU/AZBPMc59FWmKnbk7wBgWPhgKEqT3whSNF0uZlQqgXtHdkClEticUcym9CJnm9R0QrvKPdlFK/rkhTza91Hi/eOpMdeQWp7a5M1IksQLfx5g39Fy9Fo1T1/WmYkJ55befE41xeZaqKn/nD1DQK2z2aHgeFQqwVYKVlpjxmRxTDvElkhziNq2iJT7zZs34+vrS+/evW2vjR49GpVKRVJSEldeeeVp1/3mm2/4+uuvCQ0NZeLEiTz99NNnnPGrq6ujrq7O9ryiQo46i6KI6KDenOeDKIo2cQpnsDG9EKPFSriPng5BHnaxY8vxLaSUpFBbcZRxeRlIai1Sjxsd1jO1peHsY8JR9I/1Z0CsP5sOF/POyjReu6r7WWf/XQatAYY+grDkUXqqDMT6R3JYJfG3uZCrG763yP4IGavgyCakc9SK+GtfLhISozsFoxaa95gQ6lsbSZ0ul6Mhosi23G0gwaioUcT7xbf6Y8/ZtPU3ICFxrKyWWpMZN835perZ5VzR+1aEw2ugJBMpeSF+3aYyNW4qXx34im8OfkPP4J7N2gPYZBHJLKxGQqK9na5Bp2LjsY1YrBaivKOIsliRLHWgNSB5R7Toa9PFcv2wJ1H+eqYkhvPjjqN8uDaDrhHeeLq1iNtuSLwJ4dgOOPQ3bj1u5JHej7AuZx0J+oQmHxNbDhdzILcCN42auZO7EhPY9N9lw/HX8HdWJAkqcwFJzmpz8wZJos5aR1FtEUH6IHRqRdPCHjR8P6f6njzc1BjcNNTUWSisNBLuaz9R1JZMw3F+Kl+zqb+ZFnFmycvLIzg4uNFrGo0Gf39/8vLyTrve9ddfT9u2bQkPD2fv3r08+uijpKSk8Msvv5x2nblz5/L888+f9HphYSFGF25pIYoi5eXlSJJkax3hSP7anY3FbKFXuDuFhYV2GWNFxgrMFjN9S3OxWMzUxoynpgaoKbDLeC0dZx8TjmRyZ2+2ZxZy6HgZP29JZUR7O/dab050UaiHv4ao82R8XQHHa44zyHuQTSRFcI/Fz2JFKEihNOsAoqFpKYN5FSZ2ZBYhCNArRENBQUGzHROaklR8cvchqbSUBvRDqrd1bOBYQlWhdPDuYLNfwX5IkoReLVFptLA77SgxAed3s2Svc4Vbx2vw3Pkh0paPKPXuRqIhkWW6ZfjqfMkryKPOre7sG2kih4trqa0z4emmRjCWU1DnGGXlv9P+xmwxk+iVSHnaFjwtZsx+bakobEER2VNwMV0/7MmItm6sOqAir7yG95fv5+a+LUQgTx2Gt28c2qL91G78mJoet5GgT6C8vBwAk2QirSKNbn7dTrm6KEl8ujYTi9nCuDgfPMRqCgqanklkNpsRRRGLxYLFYjnr8oKpCpWlDgQNVrcAsFiQJIn82nzMopmS2hIC3ZV0++ZGkiSsVitw+iizn7uKmjqorrNQXlOHh6u0GXYhLBYLoihSXFyMVts4s7Gysmm6R0516B977DFeeeWVMy5z8ODB897+iTX23bp1IywsjFGjRpGRkUG7dqdW63z88ceZPXu27XlFRQWRkZEEBQW5dLq+KIoIgkBQUJDDL75FVXWkFZvQaDVc3juWYG/3Zh+juLaYrNostNZahtVWodH74Dn4DjzdXVfZ3Nk485hwNMHAtEHwyYZMlqRWcEXf9ucdrXQK9ROWQZyqNVAwQkQC5CcTWJsO0Z2btMk/UjPRaDX0butH5xi55V1zHRNC8keg0UKHsQRFNbZ5TPCY896uwrkTF1bEnqNllEvuJ018NxW7nSuCrkPIXQ+FKQQfXog04knmDJ+DQXMOitVNZFteLhqthq5t/AgJCWnWbZ+OrPIscuty0ev0TOg8Ae8t/0PQaFG3TcT9PL8LV+Fiun7YmwfH6Xn812Q2HqliQqIb3SJayH3LwDsQ/noQr6Nr8Bx8B6J7IIIg4Bvgy+s7Xie5KJlpnacxPmb8SauuOlRAfo2VmeJPTMo7jqr/++DV9MkMo9FIZWUlGo0GjaYJroqxTq6HdvdGo5PTvEuMJVgkCxqVhmCPYFSC/Y/jm2++mbKyMn799Ve7j3UiX3zxBQ888AClpaVnXM5qtfLaa6+xYMECjhw5gl6vp0OHDtx2223cdttt5z3+v53QE9FowN8DSqpNlNZa8dLrULlI7bqroNFoUKlUBAQE4O7e2If69/PTbsMehjWVBx98kBkzZpxxmdjYWEJDQ0+K9lgsFkpKSggNDW3yeP369QMgPT39tA69m5sbbm4nq5GqVCqXv6gJguAUO9enybXzXcN9CPO1T6ugpPwkJEkkvraaYFSQcC2CoQVFYZ2Es44JZzChWzi/78mlsLKOv5PzmdyzZWsrWEUrddY6uS992wGQn4yQkwRdrjjrunUWK6sOFSIgcGn38Ebf/wUfE9VFkLlGftztKkQkfkr9ifEx45s1hVqhacQGebL3aDlZxTUXNkljl3OFCgbPhkWzIH0FQqeJeIX3aLSEJEnN4tynFVQhIBAX6u2w893qo6tBgL5hffF195WV/QEhuPN5daRwNS6m64c96dbGj/Fdw1iSnMcHazJ457rEljHh3KYXhHSVrz17f4D+dyIIAlq1lijvKJKLk/ny4JdUW6qZGjfV9js2W0W+35ZDrOUwI9S70Ri1soDq8EebPLRKpUIQBNvfGZEkMNVH/3UeIAgYLUbK6+RsgiBDkMPb1DX3hGVTxzvbuP/973/53//+x3vvvUfv3r2pqKhg+/btlJaWnpfNJ56/z7S+n0FHpdFSL5Bndp02qy5Cw3F+qvNtU8+/Tj1LBwUFER8ff8Y/nU7HgAEDKCsrY8eOHbZ1V61ahSiKNie9KezevRu4sD5/Co2RJIlVh+TJlhHx9otIbDy2EYzlDDSaQe8L3ababSyFlolOo+K6vlEA/LT9KNV1Z0/Tc1WSi5J5cO2DfHXgK/mFqAHy/2M7wHL2NOVN6cVU1VkI8nKjV1QzT3wd+A1EK4R2g6A4/sr8i1/Tf+XpjU9jFa3NO5bCWYkNlLuKuJTS/YmcQiAPoMJUwYd7PmRR+qJmGeYfhfvTt7JtbsI9wwnWBzMyciSYaqA0S35DUbhX+BczBkbj76HjeJmRH7blONucpiEI0Gu6/Pjg71ArR39VgoppnadxTcdrAFiYtpDPkj+zda9Yuj+P/HIjV1r+xldfH7lNXQIVZxaxPhuSJGE0W0/+M9ZirKvDaAaj4EaNyczRinyMZisawYAa91Ov18S/CxHVGz58OPfeey+PPPII/v7+hIaG8txzzzVaRhAEPvzwQ8aPH49eryc2Npaff/7Z9v6aNWsQBIGysjLba7t370YQBLKyslizZg0333wz5eXlNsfw32M08Pvvv3PnnXcydepUYmJiSEhI4NZbb+Whhx6yLSOKInPnziUmJga9Xk9CQkIjewD++usv4uLiMBgMjBkzhi+++KKRjc899xw9evSwLa9SCfzwxXxG9O5KWY0Zk0W+V/jkk0/o1KkT7u7uxMfH88EHH9jWycrKQhAEfvnlF0aMGIHBYCAhIYHNmxsLBG/cuJHhw4djMBjw8/Nj7NixtkyFpuxLa6FF1NB36tSJcePGMXPmTObPn4/ZbObuu+/m2muvJTxcVsw8duwYo0aN4ssvv6Rv375kZGTw7bffMmHCBAICAti7dy8PPPAAQ4cOpXv37k7eo9bD4aJqsktq0KoFBrW3T33S8arjHC7LQFVTTH9RL7dU0dknE0ChZTMyPphfdx0lp6SWX3Yd46b+9ul3bW+0Ki251bnkV+czqf0kQv1jwTMYqgrg+C6I6n/adSVJ4q99uQCM6xLavL1fLSb5xg6g6xTyq/P5MeVHAK5of4XDoyAKEFPv0GcWVTdbtLvZ6TsTMtdCSSbs/xW6TyW5MJk1OWvQqrQMCB9AqEfTs+3+TaXRzPEyWeOmfbDjFO7Hx4xnbPRYBATI3SO3m/QIkjs/KCicgIebhv8Mb8ecxQdZuOMoPaP86NoSUu/b9JEnqAoOwt4fIUYWoRYEgckdJuOp9eSz5M9YdmQZJtHE9E638cO2HOIsKSRqs1FpDBDQXl5/19cw7JHzNqXOIjJ1/im6vYgW+U9QgboIi2jBKskOo06tQyDjvMcE+GnWANy1539tW7BgAbNnzyYpKYnNmzczY8YMBg0axJgx/5SnPf3007z88su8/fbbfPXVV1x77bXs27ePTp3OPjk4cOBA3nrrLZ555hlSUlIA8PQ89cRmaGgoq1at4s477yQoKOiUy8ydO5evv/6a+fPn06FDB9atW8eNN95IUFAQw4YNIycnh8mTJ3PXXXcxc+ZMkpKSePTRs2df6NQqBEG+RymsNLF68UKeeeYZ3nvvPRITE9m1axczZ87Ew8OD6dOn29Z78sknef311+nQoQNPPvkk1113Henp6Wg0Gnbv3s2oUaO45ZZbePvtt9FoNKxevdpW13+2fWlNtJg8qm+++Yb4+HhGjRrFhAkTGDx4MB999JHtfbPZTEpKCjU1NQDodDpWrFjBJZdcQnx8PA8++CBTpkzhjz/+cNYutEpW10fn+8UG2E291Spa6a32oqdFwMcrAuIn2mUchZaPWiVwYz/Zif999zHKXK33aRPp6N+RHkE9EBH5OfVnOVLS4MT/q32dySKy/3g5C3cc5YU/D3Djp0kcyqtEpRIY07mZa4kPr4baMvAIoiS0Cy9vfRmTaKJLQBdGRI5o3rEUmkQbPz0atUCNyUpBZfOJzDUr7j7Q9w758fbPYPd3DPDtSLfAbphFM5/u+/SComBpBXK7ylAfd3z0jm2VqBLk1OCGdHuClXZ1Cqemf2wAQzoEIkrw3z8OkJbfNLErpyIItr70woFFCP8Sm7wk+hLuSbwHFSrW5Kzh2dUfUVZtYoq4BG+9FjpPktuzQn2UPrf5bazPDEBQARIS8nOtSiNPtjmZ7t278+yzz9KhQwemTZtG7969WblyZaNlpk6dym233UZcXBwvvPACvXv35t13323S9nU6HT4+PgiCYGvRfTqH/o033qCwsJDQ0FC6d+/OrFmz+Pvvv23v19XV8dJLL/HZZ58xduxYYmNjmTFjBjfeeCP/+9//APjwww9p164d8+bNo2PHjlx//fWNHPDTIQgCapWcQVBjsvDss88xb948Jk+eTExMDJMnT+aBBx6wjdPAQw89xKWXXkpcXBzPP/88R44cIT09HYBXX32V3r1788EHH5CQkECXLl24++67CQwMbNK+tCZaRIQewN/fn2+//fa070dHRze6IYiMjGTt2rWOMM01KMtGMFmQ5cEcg1WUWJsqK9qP6Gi/cSN1PjxckI9k0UPvW0CjtB5ROD0D2gXQIdiTtIIqftiWwx3DTq2X4epc3fFqdhfuZuOxjXjrvNGojPQWLMQd2QyD7ue35IP8sn8bZcWxiGLjmxatWuDKnm3w82jG34okQfJCAIrixvBC0hzyavLwd/dnVsIs14wMXwRo1Coi/QxkFlWTWVRNiB1ESZuFjhMgbSnk7oWk+QjbPubWiAQeNhvZW7iXzcc3MzBi4HltOs3Wf94x6fZ51XkcLj9Mn5A+aNX1EwgF9QK+QUq6vcLpuW90B0przCQfK+eZ3/Yzd3I3ouuzbFyWqAEQ2AGK0nBP+wMiH2j09qCIQVglK18kf0l6dhhdLPvpqs1FpfGEHjeAwR/a9Iaj22H31zD04fMyw02j4qdZAxq/KFqhOAOQwC8GNDokSaLWUivrzzQDbpoLi33+Oys4LCzsJF2wAQMGnPS8oUy4OencuTPJycns2LGDjRs3sm7dOiZOnMiMGTP45JNPSE9Pp6amplH2AIDJZCIxMRGQxcr/Xe78b/vPhLe7htyiMg4fzuDWW29l5syZtvcsFgs+Po0zV078/BpKpgsKCoiPj2f37t1MnXrqEtym7EtrosU49ApnRlg9h5+Ld1Ki98LDIwiDVzgGnygM3pF4uHmREJRAuGd4s465K7uUshozPnotPaN8m3Xbjdj7A9RVIvjFQPvR9htHoVUgCALTBkbz9KJk/k7O44rECNd1cs5AO9929A7pzfb87SzOXAySRJBaRVxVPnWFGXy0aTd52h9x13YkWj2ZrmEBxId50SnMm9hAT3QXeBNyEvn7oTCFIrWa58t3UVBXSrA+mKcHPE2woWUrerd0YgI9bA59/9gAZ5tzalQqGP8apC+HQ39BwQHCcnYySW3kZ62VBZvn0GP0hxh8o85506n5coQ+LsQx6fbLspaxOHMxgyMGc0/iPfKLSoReoQm4adQ8c1lnnv4tmZS8Sp7+LZmXp3QnwpX7cwsC9JwOy55Cn/Y7xA+XBfNOYGiboaQeCeL32kKuEr/Cy0MDXafIzjxArxmyQ5/yNyTeBF7nXmIjCMLJqe91NaAVQO0G+n8+Q73OcaU3Z+PfCvCCIDS5tzj8I4p2YtDSbDaftz0qlYo+ffrQp08f7r//fr7++mtuuukmnnzySaqq5HPp4sWLiYiIaLTeqQTDzzTGv7OuGmx216mpqZE1Xz7++OOTJgfU6sbf8YmfX0PgoOHz0+tP/7tprn1pKbSYlHuFMyCKYKkjVS2SbCkjqTyN1UfXsnj/V/y05WW+2Phfnl15H1WmqmYdtkEMb2hcIBq1fQ6lvdnrydv3g/yk7+2tQjlYwf70iPQlIdIHqyjxbVL2aZers1jZk1OG2dr0i6sjmdltJle2v5JLYy5lQuylRAV3BeDIzhVYLCq0ag0hITlExS3mjpEhXJnYhvhQ7+Z35gGSZSEZ93Yj0bt5E2II4dmBzyrOvAsQG/RPHb1Lo3WHThPhyg9h6ufQbSpXaIMJs4qUVR7j+9+ng/Hc+sdLkkSqLUJv/5t4s9XMuqPrABgUPkh+saYEKvNkxyewo91tUGjZ6HVqnp3YmZhAD8pqzDz16z4KKozONuvMRA+G6CEIohlh2ZNQmNro7dJqE0v2ltLDvJt4twLSdFpWBJzgtId2g4heckR919fNZ1e9un2dxp2i2iKbMF9LY8uWLSc9b6ifb6h1z839p1zh39F7nU5nqxs/Vzp3llvhVldX07lzZ9zc3MjOzqZ9+/aN/iIjIwFZ12zr1q1ntD8oKIi8vLxGTn2DzXqtmsCgYEJCw8jIyDhpnJiYmCbb3r1795PKF07cr7PtS2tCidC3BlQqpKlfcOX+vxBMx6gtO0xt2RFqKo9TYzGyUTJyRVEuhrpq0DVPSmJ1nYUth+V2dSPtpG4vSiIfbJlDqaqIZwK606Xt+aVjKlycTB8QzeycPaxOKWBKzzZEBfyTfmeyiCw7kMcP23IoqzEzsF0Aj42Pd7m0cV93X66Nv/aEV3zg+AFqMjagE2dxZeQ9pJm/53D5YZ7c8CQP93mYWJ/Y5jekqhAOyyVMnt2u5SnvYMxWMwF6F40GX2REB8gO/eFCF3foT8Q/FgbejbbfHdya/A0v7vmALZYyrtvwJvrRzzZ5M4WVdZTVmFGpBNvERrNTXQR/PwreYWyLG0aluRJ/d38SghLqjaiPzvtGgZvjVPYVWi5e7lr+O6kLj/+yj6OltTy5KJlXpnTHvznLpJoTQUAa+RTmsgI0ZSnw98Nw+XvgKztGP2zPwWw2cxXLqNKpmOOpwXjoGzRu3gyPHC5vo9d0uVOLLUp/gRovJ7SrK8FKTX2bukB9yxOl/Omnn+jduzeDBw/mm2++YevWrXz66acANgf0ueeeY86cOaSmpjJv3rxG60dHR1NVVcXKlStJSEjAYDBgMJxccnDVVVcxaNAgBg4cSGhoKJmZmTz++OPExcURHx+PRqPhoYce4oEHHkAURQYPHkx5eTkbN27E29ub6dOnM2vWLObNm8fDDz/MrbfeytatW1mwYEGjcYYPH05hYSGvvvoqV111FUuWLOHvv//G29sbjUpAo1Zxz8NPMOepR/D19WXcuHHU1dXZWujNnj27SZ/b448/Trdu3bjzzjuZNWsWOp2O1atXM3XqVAIDA8+6L60JJdzZiogP6s2QxJmMG/UKV075nhumrWHmlJ/4zHcAl1rUqHZ+2WxjbcooxmyViPTX0y7IPjcwB46sobQ6H09JIK7/vXL0Q0GhiXQI8WJguwAkCb7akgWAxSqyJDmXO77azv/WHqasRk4B25RRzIqDBWfYmosQNQBRAo/SFAxiNVO69ufFwS8S4RlBibGE5zY9x7a8bc0+7NHdC1glGCEsAQLb463zVpx5FyKm3pHNrzBSY2ph7RrVWrolzODOxLt50+KNPmMVZG1s8uoN6fYxAQb79fZOmg/F6ZC5nlXr54DFxIjIEf90dWionw/ubJ/xFVolvgYdL1zRlRBvN/LKjTy9KJny2vNPpbY7ah2VAx+TVetry+Cvh6G6iPwKI0uS8+hl3k57tzJC3H0Y2XEKAP/b8z82H68Xcg1LgPBEWZF+dzNE6a0mEM3UIlAjmhAQ8NZ5X/h2ncDzzz/P999/T/fu3fnyyy/57rvvbJFzrVbLd999x6FDh+jevTuvvPIKL774YqP1Bw4cyKxZs7jmmmsICgri1VdfPeU4Y8eO5Y8//mDixInExcUxffp04uPjWbZsGRqNHON94YUXePrpp5k7d66ty9jixYttkfOoqCgWLlzIokWL6NGjBx9//DFz5sxpNE6nTp344IMPeP/990lISGDr1q221nhy6YSKq2+Yzlvvz+fzzz+nW7duDBs2jC+++OKcIvRxcXEsW7aMPXv20LdvXwYMGMBvv/3W5H1pTQjShUjLXgRUVFTg4+NDeXk53t6ue6IQRZGCggKCg4Nt9TY28vbBb3eDoMI05RPwjUSnvrBZ4Md/2UvysQqmDWjL1N7Nn7pypDyLN/6eSV5tIaM92zFz6sJmH6O1c8Zj4iIhp6SGu7/diSjBdX2jWHUon/wKWQk8wFPH1b0jqag1801SNnqtmrev60GYjwvXMgIFn11P+bEUFgfdyn13/EdWjDXX8OaON9lbtBcBgdeGvUak18m/y3M+JiSJI1mrmbPuUcpFM/d1n8XAXrPssFcKF8qMz7dSXGXilSnd6Rx+btcqlzlXbJkPe74DQwBM/QLcz74fn23I5NddxxjfLZQ7h7dvfpts10+BfL0391qyEQQ17wyeS3D7S+Rl/noYcrbC4AegyxXNb4MTcJlj4iIgv8LIowv3UlxlokOwJ69c1R2tncoYLwTbMeGlRfXHvVB+FPxj+D70Ib7fVcRcy6t08qiGfrOQEq7lo70fsSpnFSpUTI6bzBXtr0Cbtx/+uA9UGrjuO7kd6ykwGo1kZmYSExODu/tpNHBqipGqCzmuVmFUqfHWeRNkOHUrNldGEAR+/fVXrrjiCmebcs5IkoTFYmHDhg2MHDmS0tJSfH19z7peWY2Jwso6PNw0hLuyfoSDONPx3lQ/1PXOGArNT2g3aDuQQ5h4eMV/bH2jz5eCCiPJxyoQBBhuB3X7dUfX8dTKe8irLSQINZcPfLzZx1C4OIj0NzAyXk7r+25rNvkVdfgatNw2JIaPburNhG5hXN07kq4R3tSarcxblopVdO05zt2CLLo10nDYViJg0Bp4rO9jXNL2Eq5of8UpnfkmI0lQcAi2zCf92yv47+oHKRfNRKs96dbl2rOvr+AUGtLuXb6O/kz0vhnJO4I9tfmIm95r0iq2+nl79J8XRdj4tvw4/lJWJ0wCrZ5uVoHgVXPlrg+SdEKEXlG4Vzh3QrzdefGKrni5a0grqOL7bTnONunM6P1gwuuy4F1JJrE75jDEtI4Idbn8WpcrEQSBmd1nMrzNcFv71SfXP0mmhw+E96iP0n9zYXaYqqlBwojsFPu5+zXH3ik4gAZxQ6PZekEtSxX+QXHoLxb63Ea1AHk1BSxOXcjh8sPnvak1KXKrum4RPgR5Na9S5NcHvub9nW9jqjhOgqhhbrc7CYno3axjKFxcXNcvEm+9Bi93DTMGRvPxtN5M6hFhE45TqQQeGB2HQacmJa+SH1z4ZqrWZGVJRTQAHer2ywJD9ahVam7tdivXdJgqtwarKTm3jVfkQtJH8P318OsdHNj7FS+YsqlSCcR5RvLM+I/wcvdtvp1RaFYa6seziluuQy+pdbwaGsFL2mo2ZPwBRzafcXmrKJFeYEeF+0N/QlGarD3T5zaOm8rBJ5JR4YPl3tcb34HlT0NdJah1si6AgsJ50MbPwF0j5AyTn7fnuH6Peu8wmPA6otYD/6pULq/9DXetSq6N18oRRpWgYlbCLO7reR9eOi+OVB5hR/4OWfEe4NBiWZ/lfBCtSOZaSgQJBBU+Oh80KkUWrKWg06gQBAGrKLmsKHFLQ3HoLxYC2tGr3XgGilrE6gLm75mPRTz3WktJkmzq9vboPR/rE4tQmc9VFg2P+fbEq+eMZh9D4eIi2MudT6b1YcEtfZnSq83JbW+AYG93/jNc7lf/w7ZsDuWdm9K2o9iaVUKG0Bar1hN3sUZuJdeAaIW05QgLb4Hf76H8l9v4eu/HmK1NqMmsKoBfb5cjJhXH2a1R8ZKXBqN3GF3jLueJK3/EI0ipD3ZlWqQw3r8QBIGOkYNA78/36jpM616TneXTkF1SQ51FRK9V08avmdM2jRWw7WP5ce+bQe/H7N6zeX34PHpd8ir0v1PWdclcLy8T2AHU2tNvT0HhLAxqH8jQuEBECd5ckUqd5fxUyx1GQDsO93oSMxo0agGtTxjEX9ZoEUEQGBg+kHnD5nF5u8u5ov0Vch19WAIWqwl2nae2k7mGKkRMCKhVGnzdfC94d5yFJEktMt3+RIYPH44kSU1KtwdQ1dfRAxjNikPfHCgO/cVE71uYIXrgWVfDkaKD/JX51zlvYv/xCqpLjuOpNjGoffMoidZaam2PB1aVM68apuKFasTjoLKTyJHCRYVepz5rTeLwjsEMiwtClGDeslRqTa53M7UhrRBRUGMO74MAkL0ZLCY48Dv8cCOsehFKs5CQmGM6wh/7FvBL6ln0J6wWWPG87MD4RZM7+B5eCwnF7BVKr6jhPNbvSfQapcbN1YkJlB36I8XViC5eNnImJsRMICAgjmKNhr+Nx2HzB6ddNiVPdvY7hHiiUjWzaOqOz22/CTpfYXs50isSrVoHCdfA2LmgrVeSVtLtFZqBWcPa4WvQklNSy9dbTt9y1VXYYYrkM8OtGD3bIgy+HzSn1mfycfPhhk432KLo5sQbeVpbzRup37NwzZMk5SaRU5mDWWyiKKCpGncEvDRu+Lr5/iNQqdBiaAiu1Jpd716rJaI49BcT3uH4dJrEjVZ3qCnix5QfyavOO6dNrFq7iicrX+AZ4TP0mgu/gTpYfJDZa2ZTaiyVU682vUuEpIbet4B/61OhVHBtZg1vR5CXrDj80brzL0uxBzUmCzuOlAIQ1GWE/GLqEvj+Olg/DyqOg7sP9LkNYeI7TJE8wFTFb8mfk1N5hjKCbZ9AfjLoPGDcXMK6XMWkDlcyMHwgs3vNRqtEHVsEEb56dBoVdRaRXFfvaX0GdGod13a6HrxC+VVTR3nKn5CddMpl0+zVf744A/Yvkh8PvJeU8gz2Fe47ebm2A+DKD6HH9dBd0ZdQuHC83LXcO6oDAL/tPkbysXInW3Rmko9VcEjbmYzh78I5tBberrZy2MOXJJWZHzMX88aGZ3ho7UNM/3s6s9fMZuOxM3S6qG9XpwWC9cH4uPlc+I4oOBz9CXX0CheO4tBfbCROY7jKi651Zsy1pXy89+MmC1LszSmha9aXaLDSnmzIWn9BplhEC58mf0qJsYSlmUtg3WtyT9HgztD9mgvatoLC+eDppmH2mDgEAVYczGdTepGzTbKRdLgEs1WijZ+ekC6DQVDJdfLVReARBAPvget/hJ43QXgP+vZ/kN6iFmtVAR8lvYIonSKt7cgm2PMdEhK1gx8A73AApsZN5Z7Ee5SaxBaESiXY0u7Xp55nXaqLMDhiMNEBnajV+/KLuk6+NtRVNVpmx5FSNtT/PjuENGPrVEmCTe/KNfIxQ0l2d+OlpJd4bftrHC47xSSfXzT0uwM8W566toJr0ifanzGdQ5AkeGtFmktmi4GsYdGQJdMl/Nyc6gHhA3h21DtcHzaUoaKWdpVFuJtqsUpWjlUdO2OkXqpvVwcq0Bps4rAKLYuGlHuTRcQiKmn3F4ri0F9seAQgdJvKTIsebU0xWpUGo/Xs0RxJkti5/Acirdl467VoVYKckngBP8IlWUvIqczBS+vFpVYd5CTJwkLDH1NS7RWcRtcIHyYnRgDw7qp0CipdI9rZ4LwM7hCI4O4DCddCUDwMfRiu/Ra6XWUTIwIQOl/OLZFjcAdSj21hZfofjTdYmQ+rX8KMxPzwWF7IW22rtxcEAZWgXB5aGpclhAHww/YcsotrnGzN+aMSVNzU+SbwCGS5DnKr8+TrDSCKEj9uy+H5P/ZTY7LSMdSL3m39m2/wzHVwfBeodWxvP5iXt76M0Wokzi+OcM/w5htHQeEM3DYkhiAvN/IrjHy+KdPZ5pySw4VV1JqteLipaetvOOf1Owd2YdLYt7kr/iZeMnvyRVE570dfzRP9nqBbYDfbclWmKkqNpbZJ6dKaQvIRMWvdQWmp2GJRq1Q2cWKljv7CUX4JFyMJ1xLq5sOrNRoe9e1pq4/Nqcyh2nxqQaU9mbl0z/0RQQBD3+my6m9JJmSuOS8TimuL+SnlJwCuj56A19ZP5Df63Ap+bc9rmwoKzcUN/dvSPtiTqjoLL/550OkpYVV1FnZmy+n2Q9rXRwL73QGT/wedLjt13aIgEDD8Ka7ThYNo4Zttb1BSW59xIFpg5X8pqivnGU8Na6ghszyT/cX7T96OQotheFwQvdr6YbFKvLMqrUXX0ncN7EpiSC/CgrpQIwAHfqemLJ+5fx/kqy1HkCQY1zWUl67sZrspvGDMRtj8PgCb2g/kjYMLMItm+oT04dE+j+KuOU0/bAWFZsag09hS7//el8eu+vO/K7H/uCwe2znM5/w1LARBzi7rPAlBgsBN75FQUUyAPgAAURIpNhZTYizhaOVRKk2VlJurqBIk6k5Tr6/QcnBX0u6bDcWhvxhx84Ie1xOOGmHH52A1I0oi7+x8h3tX3csfGX80UsaWJInDy+bjJVag8YvEc+BtcjQQYMeC84rSf33waznq4RvH8LQNYK6BkK7Q7erm2ksFhfNGq1bx+IR4fA1aMouqeWN5qlOdo62ZxVisElH+BqICziESojNwydi3aI+W2rpyFq5/Xn592yckF+zmMbdaDnv64qXz5ol+T9AjuIdd7FdwDIIgcNeI9ui1cgvGP/Yed7ZJF8RdPe7i1Us+pl1wAnUmI4u/eoMth0vQqAXuHtmeu0a0bz5nHmDnl1CVz0qDnncqD2KVrAyOGMz9ve5XtCQUHE6PSF8u7S5n3by9Mo3qunPvTGRP9h+X6/u7hHtf2IYEAQbdD/GXyqUuK1+Ao9vktxDwd/dHo9JgFs0U1BQgSiJuCHgofedbPEodffOhOPQXK10mgyEAKnNh26eU1pYgSiJV5iq+Pvg1966+l59Sf+K7Q9/xwso5+JT9jUoQCBh9P9sKd/Odu4Do5gmlWXB41TkNva9wH5uOb0KFilt1EaiO7ahPtX9USZ9ScBmCvdx5YkInNGqBzRnFfJN0xGm2rE+TI+vn01lCFdiB23vcyRirjhuz9+Oe/C1/Jn/JHG01lV4hxPjHMXfIXLoFdTv7xhRcniAvN24eFA3AV5uPkFfuGiUj54OXzgu1WsPekEkcLamlc/kaovRGXpnSnbFdQptvILMR1rwMu79ht2DmI4MKSRAY03YMd/W4S9GSUHAaMwZGE+bjTnGVied+309hZZ2zTQLk0peGCH2XiAt06EG+9xvyEMSNlZ36Te+BsQzBWI6XoCHSMwJfN18ESQQkAgQdgsbtwsdVcContq4Tm6jnpXBqFO/pYkXrDv1myY/3fEfAnh94beirzOo+C393f0qMJfyc+jOL0hex5cifFKktmEIT8YobhiiJLMr6my/CopGQYMcX5xSlX5UjTwBcEphA9J4f5Rf7/wd8o5p5JxUULoxOYd7cM7I9AD9uP8qalAKH21BpNLMruwyAwefZKrJtjxncFjMRgySxOONbvtEYEfV+DG93Gf8d+F+CDIqgV2tibJdQukZ4U2cReXdVWpOFT12RTelFPLHNje+8fPjT18RrcQebV9W+LAcW/QdS/gZBRbfes+gbNZKJsRO5teutipaEglNx16p58JKO6HVqDuVVcu93u0g6XOxsszhaWkul0YJOo6JdUDOJUqpUMOwxaD9KdurNRqgtgbIjqIozCKitJFJS00ZSo9c1c2eLZkAQhDP+Pffcc061bdGiRWddbu3atYwcORJ/f38MBgMdOnRg+vTpmEwmu9ilVatQqwQkScJkUeroLwTlSnUxE3cJDLpPfrz3R1Qb32JEm2G8PeJtpneezpCIIfRQd+DSciOhooqISx8BQcAqyqkxS00FfOsGUlk2pK9o8rD3JN7DzPgbuDp9K4hWiB0OXa60ww4qKFw4I+NDmNxTFsl7Z2WaTdXXUSQdLsEqSkQFnGO6/YkIAgx5EHzaMNyqwUfrya19H2FWwix0aqUOsbWhUgncM7IDOo2KvUfLWXYg39kmnTcLdx7DrC5kvb+GnR4SO9J/lrs7NAcZq+GX2zlUmkqluzdcOg91z2nc3+sBbuh0g6KereASdAz14p1rE+nQoOuy+CCfrD+M2eo8B6gh3b5jqBdadTO6EioVjHxaFkd28wStAQQNIIHViNZqwg3kNqsuRm5uru3vrbfewtvbu9FrDz300Dltz15O9Ok4cOAA48aNo3fv3qxbt459+/bx7rvvotPpsFrtkxIvCILSj76ZUBz6i52uk2HYo/IN/4HfYe0r6AQ1E2IncGe3WVxyIJlLqtV4Rk7EO0wWaBkYMZDbut0GgorfDW78qjbBzgWyc94EVBKMTtuIR3Ux+LSBYfJEgYKCqzJ9QDR9Y/wxWyVeXHyAoirHpT02qNsPOc/ovA2dAWn8a3h2uoG3Jv7AJbHjFYelFRPuq+fG/nLW06cbMh16zDYXWUXVpOZXoieMG3rMAI07nwgV5O38/MI2bDXDxneoXvEsH4vFPGuQ+LrzCIjoCYBapVZ+GwouRaiPO69c1Z1JPeROC7/tPs6jP+8lt7zWKfY0pNt3Pcd2dU1CECCkiyy+7BUKAe3Avx14hYPGAFpPuW2rudYxf03McAoNDbX9+fj4IAiC7Xl1dTU33HADISEheHp60qdPH1asaBwIi46O5oUXXmDatGl4e3tz++23A/Dxxx8TGRmJwWDgyiuv5I033sDX17fRur/99hs9e/bE3d2d2NhYnn/+eSwWi227AFdeeSWCINie/5tly5YRGhrKq6++SteuXWnXrh3jxo3j448/Rq/X25bbsGEDQ4YMQa/XExUVxQMPPEB19T+C2gUFBUycOBG9Xk9MTAzffPMN0dHRvPXWWwBkZWUhCAK7d+8G5CyUivIy/D3cWLNmjW07ycnJjB8/Hk9PT0JCQrjpppsoKvqnlfDw4cO59957eeSRR/D39yc0NPSkLIiysjLuuOMOQkJCcHd3p2vXrvz555+n3JfIyEjuvffeRvvSklAKwxQgfoKskr1qDqQuAWsdjHiK1LXf4V2TQ53ag7hL7220ypi2Y6iz1vHV/gX8UFuCW2UGl6Ytg47jTzvMhmMb6BvaF92+nyF7i1w3P/p5l5xpVVA4EZVK4KFLOvLwz3s4UlzDi38e4OUp3W0zy/ai0mhmV04ZcH718yfhFUptp6l4eQVf+LYUXJ5JCRGsTysiLb+KD9dk8NSlnVqUo7p0fx4A/WL9uaFLfw4fX8vB41t4J+0Hnk+cjtbjPH4T1cVIS58kqXgvn2trKfPwB0MgKp0HoiQqKfYKLotWreK2IbF0i/DhrRVppBVUcd/3u7l1cAzhPnoEQfaFVYJg+982wICbpnmvU5IkkVwfoe98oYJ4TUEQQK2Vu7P8eJP9x/s3tywBrf7sy52BqqoqJkyYwJw5c3Bzc+PLL79k4sSJpKSkEBX1T7np66+/zjPPPMOzzz4LwMaNG5k1axavvPIKl19+OStWrODpp59utO3169czbdo03nnnHYYMGUJGRoZtMuDZZ59l27ZtBAcH8/nnnzNu3DjU6lMfD6GhoeTm5rJu3TqGDh16ymUyMjIYN24cL774Ip999hkFBQXcfffd3HPPPXz+uTzROmPGDI4fP87q1avRarXce++9FBScvlxRr/3nnNtQHlZWVsbIkSO57bbbePPNN6mtreXRRx/l6quvZtWqf3S7FixYwOzZs0lKSmLz5s3MmDGDQYMGMWbMGERRZPz48VRWVvL111/Trl07Dhw4YNv/f+9LYWEhd999N3fffbdtX1oSikOvINN+tOxgr3geMlYjmWsR9iYBUNb5Rrr6nNzn97LYy6iz1PHj7vl8KRbivv09RrUfA+qTD6vtedt5d9e7hKkNvJaTiRZg0L0Q2N7OO6ag0DzodWqevqwzD/64h4zCal5fmsIj4+KbV2X7Xyzbn48oSkQHehB5Hn1+FS5uVCqB+0fFce/3u9iaWcK6tCKGxbUMvYQ6i5XV9ZoVY7uEolapuXvIizz6yyQyzNX8uP5Zbhj3/rltVJI4tuJJvindzg6tBF7hhPl3YGb3mXQJ6GKHvVBQaH76xQbwznWevLb0EAdzK3lvVfpplw3ycuPVq7oT6Nl8AnIFlXUUV5lQqQTiQ12vlt0VSUhIICEhwfb8hRde4Ndff+X333/n7rvvtr0+cuRIHnzwQdvzJ598kvHjx9vS9ePi4ti0aVOjKPPzzz/PY489xvTp0wGIjY3lhRde4JFHHuHZZ58lKEg+5/v6+hIaenox0alTp7J06VKGDRtGaGgo/fv3Z9SoUbaMAYC5c+dyww03cP/99wPQvn173nzzTUaNGsWHH35IdnY2f//9N1u3bqVPnz4AfPrpp3Tq1Om047pp/8mIstZ3E3rvvfdITEzkpZdesi332WefERkZSWpqKnFxcQB0797dNvnRoUMH3nvvPVauXMmYMWNYsWIFW7du5eDBg7blY2Njbdv797506NCBd955h2HDhvHhhx/i7t6y2pQqDr3CP8QMhbEvwbKnqErbgMpkpFgbRuLY6addZXKHyRhNlSzb8T7h1WWQtpTymMEcKD5AVkUWWeVZZJZnUm4qB9FK75IjaCUJOlwC8Zc5bt8UFJqBEG93Hp8Qz1OLkknKLOHFxQd4YkKnZo/Ui6LEt1uz+WFbDgCjOykRdYXzIyrAwDV9Ivk2KZs3l6dyKLeCa/pE4mtwbe2ETenFVNdZCfZyo0cbXwACDYHM6nILr+9+l9/zNtPl6CZ6tBnY5G0mJb3FmyVbkNQqNP6xTOp4NVe2v1JpSafQ4gjycmPu5O78uD2HjelFiJKEKIKEhCjJkc4Ko4XCyrpmzyg7kCun27cP8rR7llojNO5ytNzRaC7csauqquK5555j8eLF5ObmYrFYqK2tJTs7u9FyvXv3bvQ8JSWFK69srDHVt2/fRg79nj172LhxI3PmzLG9ZrVaMRqN1NTUYDA0LRigVqv5/PPPefHFF1m1ahVJSUm89NJLvPLKK2zdupWwsDD27NnD3r17+eabb2zrSZKEKIpkZmaSmpqKRqOhV69etvfj4+NPKhE4EZUg2AIjdfXCeHv27GH16tV4ep4suJiRkdHIoT+RsLAwWzbA7t27adOmjW3Zf3O2fTnTJIQrojj0Co2J6kfuwOeo+O0RVEBVz//goT/9zK4gCFzfZTpj6qwEb/8Sdn7JXr077+37qNFyKgS61dYwpdYCftEw+AGlbl6hRdIl3IdnJ3ZhzuID7Mou4+lFyTwzsTNe7s3jFNRZrLy1Io0N9a3qJveMYGL38GbZtsLFyVW92pBeUMXWzBL+3JvLioP5XJ4QzuSebfBwc83bgGUH5HT7MZ1DUKn+uVb06XELlxz8mdV1eVSk/gVncOglSaLCVIGPmw9U5tHtwBIMCMSH9uK6wc8Q6RVp9/1QULAXapXAdX2juK7vqTsE5ZUbefCn3WQUVvPm8lQeHRff6Ld0vuw/Vl8/3xzt6s4FQbjg1Hdn8dBDD7F8+XJef/112rdvj16v56qrrjpJ+M7D49xLUKuqqnj++eeZPHnySe+dT5Q5IiKCm266iZtuuokXXniBuLg45s+fz/PPP09VVRV33HEH994rl+FKkoTFYkGj0dC2bVtSU1PPun1VfXvqE7uvqJEd+Qal+6qqKiZOnMgrr7xy0vphYWG2x1pt4/suQRAQ67tunVj3fyr+vS8ncmIZREvBNa/kCk4j+Vg5L25QoTM8SpyXmQeHjzvrOoIgENz9BjjwJ1TmEbv0WdppzcRoPIjW+tDWPZC2gg63slzQ6Ovr5pX0YYWWS49IX168ohvP/b6fQ3mVPP7LPv47qSv+HhcW9SytNvHC4gOk5VehVgncNaI9YzqHNJPVChcrWrWKpy/rzO6cMr7clEVaQRU/bj/KX/vyuKpXGyZ0da1j7GhpDcnHKlAJMPrfx78gcFPfB7lk1XNEHtkJxgp2VxwmuzIbAQGVoEItqLFKVtYeXYtG0DBn0IsI617DYDbyTkhfPMd+KKtpKyi0YkJ93HliQieeWpTMpoxivkk6wk0Doi94u/vrI/Sdwxzs0LdgNm7cyIwZM2zR9qqqKrKyss66XseOHdm2bVuj1/79vGfPnqSkpNC+/elLWLVa7Xkp1fv5+REWFmYTiuvZsycHDhywjXWiQy8IAvHx8VgsFnbs2GFLuU9JSaGsrMy2zYYSgNzcXBITEwFIO7APAJPFahtn4cKFREdHo9FokCQJs1VCqxaarAPTvXt3jh492ihF/0T+vS8tHcWhV7CxKb2I15elYLZKREW05b7LOqPXNTGdSusO/f4D614lwgovmdRgsgIlUH5Ci6HBD4B/jF3sV1BwJB1DvXh5Sjee/m0/R4preOTnvbx4RVdCfc4vPe9wYRUv/HmAoioTnm4anry0E10j7KAgrHDR0iPSl4SrE9icUczXSUfIKanli01Z/L7nGLf0DiLYRSo7lu2X2+z1aut/ytpfXexIInd9DSWZkLyQzVoja3LWnHJb7mp38vd+S+jR7aDW4Tn8KcWZV7ho6BLuwz0j2/Pm8jR+3H6UNn4GRsSf/w+93GjhWFktAoJjBPFaCR06dOCXX35h4sSJCILA008/bYskn4l77rmHoUOH8sYbbzBx4kRWrVrF33//3cipfeaZZ7jsssuIioriqquuQqVSsWfPHpKTk3nxxRcBWel+5cqVDBo0CDc3N/z8/E4a63//+x+7d+/myiuvpF27dhiNRr788kv279/Pu+++C8Cjjz5K//79ufvuu7ntttswGAzs27ePVatW8f7779OxY0fGjRvHHXfcwYcffohGo+H+++9vFC3X6/X079+fl19+mZiYGAoKCnjx+ecAMIsSVlHirrvu4uOPP+a6665j9oMPIbl5cigllWV//Mo3X35+WmG/Exk2bBhDhw5lypQpvPHGG7Rv355Dhw4hCALjxo07aV88PDw4cOAAy5cv57333jvr9l0N5aqmAMCfe4/z8pJDmK0S/WP9eeGKrueeQhx3CdyyDKb/Add+C1fOh/Gvyj1FB90LY/4LHc8e8VdQaCm0DfDg1SndCfF2J7/CyCML93Kk+NxbnmzNLOGxhfsoqjIR4atn3tUJijOvYBcEQWBg+0Deu64n943qQJCXG8XVJj5LykUUm9aeyZ6YrSKrDjWI4Z0mc0Clgp7T5MfJC4nX+DKszTCGRAxhUPggBoQNoF9oP67teC3v93uW0J31NZJ9Z4KvkmavcHExMj6Eq3q1AeCdVWkcqG85dz6kFdYA0DbA0GxlZhcDb7zxBn5+fgwcOJCJEycyduxYevbsedb1Bg0axPz583njjTdISEhgyZIlPPDAA41S6ceOHcuff/7JsmXL6NOnD/379+fNN9+kbdu2tmXmzZvH8uXLiYyMtEXF/03fvn2pqqpi1qxZdOnShWHDhrFlyxYWLVrEsGHDADnqvXbtWlJTUxkyZAg9e/bk+eefJzz8n7LAzz//nPDwcIYNG8bkyZO5/fbbCf7XbPFnn32GxWKhV69e3H///cyZI088IIHRbCU8PJwNGzZgrDMzbtw4Rg3qw5xnHkPv6YXJ2vTr1MKFC+nTpw/XXXcdnTt35pFHHrFlKvx7XxITE3nmmWca7UtLQpCkJjZYvEipqKjAx8eH8vJym8qjKyKKIgUFBQQHB9vqU5qCJEl8teUIP20/CsC4rqH8Z1i7ZqmzUnAu53tMKJw7JdUmnvktmSPFNXi6aZhzZVdig04WczkV6QVVPPjjbkQJEiJ9eHRcvN1ulJRjQuHf1Jqs3PLFNkqranlkfGdGdnJu+v2GtCJeWXIIfw8dn83og/p01yJRhF9ug+IMUKmhx/WQOE1uwdqAJMHfj0DOVgjpCpe/q0Tnm4hyrmhdiKLEy0sOsTmjGB+9lnlXJxDiLTuFRrOVQ3mV7D9eTvKxCurMVu4b3YG2AR7/2obIm3/vY/XhCi7tFs5/hrezm71Go5HMzExiYmJanNq4vZk5cyaHDh1i/fr1zjblpJT70xEdHc39999vU5Q/FXnlRiqNZvw9dHi4aSioqKOuPgVfr1OjEgSq6yzodWoifPUtqgXr2TjT8d5UP7TFnKXnzJnDwIEDMRgMZ1RLPBFJknjmmWcICwtDr9czevRo0tLS7GtoC8JiFXlrRZrNmb+xfxR3DleceQWFc8XfQ8fcyd3oGOpFVZ2F91alNzna+eXmLEQJ+sb489zELkrUQ8Gh6HVqJvWQRYZ+2J7j9Ch9Q+/50Z2CT+/Mg+yYj38NogeDaIWdX8HCWyFv3z/LpPwtO/NqHQx/VHHmFS5aVCqB2WPiiA3yoLzWzH//OMBnGzKZ/eNurvloC08vSub7rTkkHysnraCKxxbu42DuyZH8lAI5Qt9FSbd3GK+//jp79uwhPT2dd999lwULFtha1LUmGvrRl9eaySmtoc5iRa0SCPF2J8JXT5CXG4IgUGuyUms6dz2A1k6LubqZTCamTp3Kf/7znyav8+qrr/LOO+8wf/58kpKS8PDwYOzYsRiNRjta2nL4YXsOqw4VoBLgnpHtuaZPVKua8VJQcCRe7lqeurQTeq2atIIq1qcXnXWdfUfL2ZVdhkolMHNILBp1izklK7QiLusehodOzbGy2iYdt/Yiv8LI7pwyAC7pcvp+yTY8AuCSF+VyLr0flGXDb3fDhjehNAs21/ep730L+LY81WIFhebEXavm6cs64+ehI7ukhl93HSMtvwpRlAjycmNExyDuGtGO+PqJ6acWJbPjyD8aSNV1Fo6WyffPikPvOLZu3cqYMWPo1q0b8+fP55133uG2225ztlnNTkMLRKsogQReei1R/ga89VoEQUCrVuGjlwMexdUmLiTBXJKkC1rfFWkxonjPP/88AF988UWTlpckibfeeounnnqKSZMmAfDll18SEhLCokWLuPbaa+1laovAZBH5a18uAPeM7HCykrCCgsI542vQcVWvNny15QhfbspiQGyArb/qv5EkiQWbswAY3zX0vMX0FBQuFINOw9h4f34/WMoP27IZ0j7QKZlay+qj8wmRPrZ04LMiCBA7DMITYcuHkPIX7F8EB36TU+6DO0P3a+xntIJCCyLQ041nJ3bm842ZhHq70yXchy7h3gSf8Hsb3jGYl/8+xI4jpfz3z4PcP7oDIzoGcyivElGCCG93Ak4hVqlgH3788Udnm3DBNEXRX6dRodeqsUryBJNBd7KL6mfQUlFrxmi2Ul1nwfM8MxqNZiu55Ua89dpTCq+2RFqMQ3+uZGZmkpeXx+jRo22v+fj40K9fPzZv3nxah76uro66ujrb84oKOeVIFMUmKVI6C1EUkSSpyTauTS2gvNZMoKcbw+ICXXrfFM6Pcz0mFJqHid1DWbwvl/xKI3/sOcaViRGnXC7pcDGH8ipw06iZ2ivCId+TckwonApRFBnVwZeVGRVkl9SwPq2QIR0CHWqDVZRYcTAfCYkxnULO/RjVecLQh6HdSIT186AyF9RapKGPyO8rx/w5oZwrWi8xAQb+e3mXRq+d+D3r1AJPjO/I2yvTWZtWyLxlKVTUmCiplnumdwrzsvtx0XD8tcZIamuj4ftpju8pwu8fNfxTbU+tEvA1aCmpNlFUZcLgpuF8pp6r6ixYRQmLVXSJ46vhOD+Vr9nU31qrdejz8uSZ/pCQxpHnkJAQ23unYu7cubZsgBMpLCx06VR9URQpLy9HkqSzCthIksTCrVlYzBYGR/lRXFToICsVHMm5HBMKzctl8d58uiWXbzYfJiFQhadb4xYroiTxydpMLGYL4+J8MFeVUVBlf7uUY0LhVIiiiKmmkuExnvyWXMyXG9Lp4G1F5cASrF1HK8kvq8HTTU2sp4WCgoLz25A2Eoa+gnvWCqxekZjNejjfbV3EKOcKhRsSfFBZjSxPKeXD1Wm4aQSsVgsRBun8f59NxGw2I4oiFosFi8Vi17EUzh9Jkmyq8Y4q2fXUCZTVyD3ry6qMeLmfmysrSRKVRjOSJOGuEVzi+LJYLIiiSHFxMVpt46yDysrKJm3DqQ79Y489xiuvvHLGZQ4ePEh8fLyDLILHH3+c2bNn255XVFQQGRlJUFCQy6vcC4JAUFDQWS++KXmVHK20YHDXMaV/B1tNikLr4lyOCYXm5crAINZm1ZBVXM2qI0ZuGxzT6P3VKQXk11jx83Rn2tB4PN0ccypWjgmFU9FwXFwX5ceazF3k11jIqFQzqL3jovTbthWj0WoY1z2ciLAm1M+fjfCbL3wbFzHKuUIB4P5xwUQEHeXrpGysEqjVEgPi2xDs73H2lS8Ao9FIZWUlGo0GjabVxh5bDf92Qu2NvwcUVdVRXifi66E+p8mEOouIVQSVSoW3XucS2mEajQaVSkVAQMBJKvdN7fLg1F/Jgw8+yIwZM864TGxs7HltOzRUviHIz88nLCzM9np+fj49evQ47Xpubm64uZ1cT6FSqVz+oiYIQpPs/Cs5DwGBIR2C8PNoHbUjCqemqceEQvOiUsGtg2N45rf9/LUvj4kJ4YT5yKlkZqvId1tzEBCY0isSb73uLFtrXpRjQuFUCIKAt17H5T3C+X5rDj9sP8qg9kEOqaVPy69kx5FSBATGdglTjk0XQTlXKABc27ctvgYd76/OIMRLR4Sfwe7HhEqlQhAE25+CayJJku37ceT35GPQUl5rxmwVKa+14OfR9Puo6jo5Iu/hpnaZc1vDcX6q821TbXSqQx8UFERQUJBdth0TE0NoaCgrV660OfAVFRUkJSWdk1J+a6OsxsT6NFnF+LLuYWdZWkFB4XxJjPKjZ5QvO7PLWLDpCI+NlzONlu7PI7+iDj8PnfIbVHA5Lk8I57fdxzlSXMOWw8UMtHOUvqiqjhcXH0SUYEC7ACL9DXYdT0FB4dwZ1zWMLmHe1FSWKg62gtNRCQL+HjryK4yU1pjw1mvP3Ob0BKrqHXpHZUY6CteYmmgC2dnZ7N69m+zsbKxWK7t372b37t1UVf1TeBofH8+vv/4KyLMd999/Py+++CK///47+/btY9q0aYSHh3PFFVc4aS+cz9L9eVhFiY6hXnQI8XK2OQoKrZqbB8WgEmBjehEHcyuoNVn5YVsOANf1ibS1aVFQcBW83LVMTAgH4Ptt9u1LbzRbeWnxQUqqTUT5G7h/dAe7jaWgoHBhRPjp8TnHemUFBXvh5a5Bp1FhFSXKakxNWsdksWKyyOVEp1LRb8m0GIf+mWeeITExkWeffZaqqioSExNJTExk+/bttmVSUlIoLy+3PX/kkUe45557uP322+nTpw9VVVUsWbKkyfUIrQ2LVeTvZFkQ8FIlMqigYHeiAz0YGS8Lc362IZPf9xyjrMZMqI87Y5RWkQouyhU9wtFr1WQWVZOUWXL2Fc4DSZJ4Z2UaaQVVeLlrePqyzq3uBktBQUHBXsyYMaNRgHL48OHcf//9Drdj7dq1qFQqysrKHDquIAgE1Kfal9WYsVjPrgbfEJ036NRNjui3FFqMQ//FF180amHR8Dd8+HDbMpIkNarJFwSB//73v+Tl5WE0GlmxYgVxcXGON95F2HK4hOIqE74GLYPaObYlkYLCxcoN/aNw06g4lFfJt1vl6PyN/duiUbeY06/CRYYcpZcnfb/flm2Xtj4/bMthfVoRKpXAExM6EepzcU60KygotB5mzJhhq4fW6XS0b9+e//73vw5RUv/ll1944YUXmrTsmjVrEATBYU74nj17uPzyywkODsbd3Z3o6GiuueaaC+6W4OGmwV2rRpQkSmvMZ12+qk5W5G9t6fbQghx6hQtn8b7jAFzSJRSdRvnqFRQcQaCnG1fU96IXRYmYQA+GOFA9XEHhfJiUGIG7VsXhwmq2NnOUflN6Ed8kZQNw5/B2dI3wadbtKygoKDiLcePGkZubS1paGg8++CDPPfccr7322imXNZmalireFPz9/fHycr1S2sLCQkaNGoW/vz9Lly7l4MGDfP7554SHh1NdXX1B2xbqa+kBKmrPHKU3W0XqzFYEAQxura/cUfHqLhIyi6pJPlaBSoDxXZuhJZCCgkKTmdKzDb4Gua3L9IFtHaIcrqBwIXi7a7m0W0OUPqfZovTpBVXMW54KwKQe4YztolyPFBQUmo7RYjztn9lqbvKyJqvprMueD25uboSGhtK2bVv+85//MHr0aH7//XfgnzT5OXPmEB4eTseOHQHIycnh6quvxtfXF39/fyZNmkRWVpZtm1arldmzZ+Pr60tAQACPPPLISefkf6fc19XV8eijjxIZGYmbmxvt27fn008/JSsrixEjRgDg5+eHIAi27GZRFJk7dy4xMTHo9XoSEhL4+eefG43z119/ERcXh16vZ+TIkRw5cuSMn8fGjRspLy/nk08+ITExkZiYGEaMGMGbb75JTMw/LX2Tk5MZP348np6ehISEcNNNN1FUVGR7v7q6mmnTpuHp6UlYWBjz5s1j+PDhPPHIg7YovVajZtGiRY3G9/X15YsvvrCl25fk53L9tdee9rNu+I5ef/11wsLCCAgI4K677sJs/ufYOt1n29R9sQetL+dA4ZQs3itH5we0CyTQU2lVp6DgSPQ6NS9P6U5BhZHEKD9nm6Og0CSuTGzDn3tzSS+oYvuRUvpE+1/Q9kqqTby4+AAmi0jPKF9uHhRz9pUUFBQUTmD6kumnfS8xOJHH+j5me3778tups9adctlO/p14buBztud3r7qbSlNlo2V+uOyHCzMW0Ov1FBcX256vXLkSb29vli9fDoDZbGbs2LEMGDCA9evXo9FoePHFFxk3bhx79+5Fp9Mxb948vvjiCz777DM6derEvHnz+PXXXxk5cuRpx502bRqbN2/mnXfeISEhgczMTIqKioiMjGThwoVMmTKFlPfK0AkAAB6CSURBVJQUvL290evltrpz587l66+/Zv78+XTo0IF169Zx4403EhQUxLBhw8jJyWHy5Mncdddd3H777Wzbto2HHnrojPsfGhqKxWLh119/5aqrrjpll4SysjJGjhzJbbfdxptvvkltbS2PPvooV199NatWrQLg4YcfZu3atfz2228EBwfzxBNPsHPnTnr06IGfh47cslpAnpQ4FVVGC2azmWlXT2LQwIGn/awBVq9eTVhYGKtXryY9PZ1rrrmGHj16MHPmzDN+tk3dF3ugOPQXAZVGM6tTCgGlVZ2CgrOI8NUT4at3thkKCk3Gx6Dl0u5h/LLzGN8lZdO7rd95t6zKKqrmlSWHKK4y0cZPzyPj4ludKJGCgoJCA5IksXLlSpYuXco999xje93Dw4NPPvnE5jx+/fXXiKLIJ598Yju/fv755/j6+rJmzRouueQS3nrrLR5//HEmT54MwPz581m6dOlpx05NTeXHH39k+fLljB49GoDY2Fjb+/7+8uRscHAwvr6+gBx1fumll1ixYgUDBgywrbNhwwb+97//MWzYMD788EPatWvHvHnzAIiLi2PPnj28/vrrp7Wlf//+PPHEE1x//fXMmjWLvn37MnLkSKZNm0ZIiCwO/N5775GYmMhLL71kW++zzz4jMjKS1NRUwsPD+fTTT/n6668ZNWoUAAsWLKBNmzbyZ6pT46aR0+ir6+vkT8QqihjNVv76bSGSJJ3xswY5c+G9995DrVYTHx/PpZdeysqVK5k5c+ZZP9uz7Yu9tNwUh/4iYOXBAkwWkbYBBrqEezvbHAUFBQWFFsKViREs3ptLWkEVO7NL6dX23KL0kiSxdH8+H63LwGyVCPDU8fRlnfFohaJECgoK9mfBuAWnfU8tNK6N/mjMR6ddViU0rjp+b+R7F2ZYPX/++Seenp6YzWZEUeT666/nueees73frVs3mzMPsmBcenr6SfXvRqORjIwMysvLyc3NpV+/frb3NBoNvXv3Pm0p1O7du1Gr1QwbNqzJdqenp1NTU8OYMWMavW4ymUhMTATg4MGDjewA2WE/G3PmzGH27NmsWrWKpKQk5s+fz0svvcS6devo1q0be/bsYfXq1Xh6ep60bkZGBrW1tZhMpkZj+/v720oW5Fp6uayx2mTBKkqNJozrzHLUPv3QfjLO8Fk30KVLF9Tqf46lsLAw9u3bB5z9sz3bvigOvcJ5UVpt4s+9uQBc1j38vKMrCgoKCgoXH74GHeO7hbFo1zG+TcqhZ1TTo/Q1JgvvrUpnfZqcitirrR8PjI7Dp15PQkFBQeFccdc0vSOGvZY9EyNGjODDDz9Ep9MRHh6ORtPY1fLw8Gj0vKqqil69evHNN9+ctK2goKDzsqEhhf5cqKqqAmDx4sVEREQ0es/N7cJLdQMCApg6dSpTp07lpZdeIjExkddff50FCxZQVVXFxIkTeeWVV05aLywsjPT09LNu38NNgyAIcl/6WhMBHrLNZrOZOosctTfV1jTps9ZqG1+jBEGwpfKf7bM9277YC8Whb4WIosSunDKW7c9jS2YJoijh4aZmeMfzOzEoKCgoKFy8TOkZweK9x0nNr2RXThk9m6ADkV5QxatLDpFbbkSlEpjWvy1XJkYogpAKCgqtGg8PD9q3b9/k5Xv27MkPP/xAcHAw3t6nzqINCwsjKSmJoUOHAmCxWNixYwc9e/Y85fLdunVDFEXWrl1rSws/kYYMAav1n/T0zp074+bmRnZ29mmjz506dbIJ/DWQlJR09p08xfjt2rWzqdz37NmThQsXEh0dfdIECEC7du3QarUkJSURFRUFQGlpKampqTZbBUEgMCiIwvw8ymrM+Op1HM6Qsw5MVjmToU/vXiz65eczftZn42yf7dn2xV4oKvetiLJaCz9uz+H2r7bz3O/72ZRRjChKxId68dSlnXHXtr42DQoKCgoK9sXXoGNCveL9d0ln7ksvSRJ/7j3Owz/vIbfcSJCXGy9P7saUXm0UZ15BQUHhX9xwww0EBgYyadIk1q9fT2ZmJmvWrOHee+/l6NGjANx33328/PLLLFq0iEOHDnHnnXeesYd8dHQ006dP55ZbbmHRokW2bf74448AtG3bFkEQ+PPPPyksLKSqqgovLy8eeughHnjgARYsWEBGRgY7d+7k3XffZcECucxh1qxZpKWl8fDDD5OSksK3337LV199dcb9+/PPP7nxxhv5888/SU1NJSUlhddff52//vqLSZMmAXDXXXdRUlLCddddx7Zt28jIyGDp0qXcfPPNWK1WPD09ufXWW3n44YdZtWoVycnJzJgxA5WqsRs7auRIvvn8I5L37GbNxi3MmvX/9u49Lqoy/wP4Z5iBARxgFIQBEWG9rKCIIklIFxVM2kpNs41IMW+/uLTgLVEjbSlRi0zNJWtLe21truY1zS0WBZIUFRE1EcjwEhfHGwIiijPP/tF6fg0iTgqMA5/36zWvF+ecx3O+z5yvw3x5znnOK7+OtgtAqZAjcsL4u77Xd3O39/ZufWkpLOjbgGs3dEjeeQIzthTj85wzOFd1HR2Ucjzj54oPXhyAd8b58Tm/RER0z8b4u8NSLsOJimrk/3Kl0TZ19Tq8+10hVmf+jJs6gUCvTlj+Qn94u3LuFiKixtja2iIrKwseHh4YM2YMvL29MXnyZNTV1UmjyDNnzsT48eMRGRmJoKAg2NnZ4dlnn21yv6mpqXjuuecQHR2N3r17Y+rUqdKIeJcuXfDmm28iISEBLi4uiI2NBQAkJSUhMTERycnJ8Pb2RlhYGHbs2CE9Xs7DwwMbN27Eli1b4Ofnh9WrVyMpKanJOHx8fGBra4uZM2eif//+ePjhh7F+/Xr8/e9/x/jx4wEAbm5uyM7Ohk6nwxNPPAFfX1/Ex8dDrVZLRfs777yDRx99FM888wxCQ0PxyCOPYODAgQbHSklJgYeHB14cFYYpE8djxoyZsLGxBQCorOVGvdfGaOq9NaYvLUEmmuvhsm1UVVUVHBwccOXKlXu+PKOlCSHw6pd5+KniCvp5dMKTfV0xuIejNOMjtU96vR5arRbOzs4t+iFC5oM5QY0xNi8+yjqJr/PL4eNqj8VjfQ3upT9XVYe3dxSg5MJVWFjIMCnYEyP9OG+LueJnBTXUmjlRV1eHkpISeHl5wdq6ee5tp+YnhMDNmzehUChM8lk/ZMgQ9O/fH++//75BTKcv1qJep0enDla4XFsPIQQ8HG0f2LqoqXw3tg7lPfRtgEwmw/895oUbV6swoGdX/vIlIqJmN9bfHf8+VoHj5VU48ssV+HVVAwDyz1Ziyb9PoLruJtS2lpgT1ptXhRERUav7dcZ7K5yrqsOl2huAAKwUFrCSt+3aqG33rh3p4+aALg73PwslERFRYxxVSjzRRwMAWHfg13vptx4uxRtbj6G67iZ6OKvw3vP9WcwTEZHJ2FkrYCm3AP53DfqtGfDbMo7QExERkVGeG+iOb3+swLHSKry+5RiO/O9++qG9nREztPsDe0kjERG1LRkZGY2ul8lk6NjBEtqq6wAAlbLtl7scoSciIiKjOKmUGO7jAgA48ssVWMiAKY96YXpoTxbzRET0QLCztoStlQIqawWUirZf7rb9P1kQERFRsxk3sCu+L7oAmQyYE9ZbupeeiKg5cd5uulcWMhm6dLQxdRhGaY48Z0FPRERERutsp8TqCQOhVFhwVJ6Imp2lpSUAoLa2FjY25lGUEd2r2tpaAP+f9/eCBT0RERH9LvbW9/7Fg4ioKXK5HGq1GlqtFsCvz2pv65OamSNTP7bO3AkhUFtbC61WC7VaDbn83v9AzoKeiIiIiIgeGBrNr0/UuFXU04NHCAG9Xg8LCwsW9PdBrVZL+X6vWNATEREREdEDQyaTwdXVFc7Ozqivrzd1ONQIvV6PixcvwtHRERYWbX/iuZZgaWl5XyPzt7CgJyIiIiKiB45cLm+Wgoean16vh6WlJaytrVnQmxjffSIiIiIiIiIzxIKeiIiIiIiIyAyxoCciIiIiIiIyQ7yH/i6EEACAqqoqE0fSNL1ej+rqat7HQhLmBDXEnKDGMC+oIeYENcScoIaYEy3vVv15qx69Exb0d1FdXQ0A6Nq1q4kjISIiIiIiovakuroaDg4Od9wuE3cr+ds5vV6PsrIy2NnZPdDPWKyqqkLXrl1x9uxZ2NvbmzocegAwJ6gh5gQ1hnlBDTEnqCHmBDXEnGh5QghUV1fDzc2tyasgOEJ/FxYWFnB3dzd1GEazt7fnfyoywJyghpgT1BjmBTXEnKCGmBPUEHOiZTU1Mn8Lb3ggIiIiIiIiMkMs6ImIiIiIiIjMEAv6NkKpVGLBggVQKpWmDoUeEMwJaog5QY1hXlBDzAlqiDlBDTEnHhycFI+IiIiIiIjIDHGEnoiIiIiIiMgMsaAnIiIiIiIiMkMs6ImIiIiIiIjMEAt6IiIiIiIiIjPEgr6NWLVqFTw9PWFtbY3AwEDs37/f1CFRK0lOTsZDDz0EOzs7ODs7Y/To0SgsLDRoU1dXh5iYGDg6OkKlUmHs2LE4d+6ciSKm1rR48WLIZDLEx8dL65gP7VNpaSleeuklODo6wsbGBr6+vjh48KC0XQiBN954A66urrCxsUFoaCiKi4tNGDG1JJ1Oh8TERHh5ecHGxgbdu3dHUlISfjtXMnOibcvKysIzzzwDNzc3yGQybNmyxWC7Mef/0qVLiIiIgL29PdRqNSZPnoyamppW7AU1p6Zyor6+HnPmzIGvry86dOgANzc3TJgwAWVlZQb7YE60Phb0bcC//vUvzJgxAwsWLMChQ4fg5+eHESNGQKvVmjo0agWZmZmIiYnBvn37kJaWhvr6ejzxxBO4evWq1Gb69On4+uuvsWHDBmRmZqKsrAxjxowxYdTUGg4cOIDVq1ejX79+BuuZD+3P5cuXERwcDEtLS+zcuRPHjx9HSkoKOnbsKLVZunQpVqxYgQ8//BA5OTno0KEDRowYgbq6OhNGTi1lyZIlSE1NxQcffICCggIsWbIES5cuxcqVK6U2zIm27erVq/Dz88OqVasa3W7M+Y+IiMCPP/6ItLQ0bN++HVlZWZg2bVprdYGaWVM5UVtbi0OHDiExMRGHDh3Cpk2bUFhYiJEjRxq0Y06YgCCzN2jQIBETEyMt63Q64ebmJpKTk00YFZmKVqsVAERmZqYQQojKykphaWkpNmzYILUpKCgQAMTevXtNFSa1sOrqatGzZ0+RlpYmHn/8cREXFyeEYD60V3PmzBGPPPLIHbfr9Xqh0WjEO++8I62rrKwUSqVSfPnll60RIrWyp556SkyaNMlg3ZgxY0RERIQQgjnR3gAQmzdvlpaNOf/Hjx8XAMSBAwekNjt37hQymUyUlpa2WuzUMhrmRGP2798vAIjTp08LIZgTpsIRejN348YN5ObmIjQ0VFpnYWGB0NBQ7N2714SRkalcuXIFANCpUycAQG5uLurr6w1ypHfv3vDw8GCOtGExMTF46qmnDM47wHxor7Zt24aAgACMGzcOzs7OGDBgAD7++GNpe0lJCSoqKgzywsHBAYGBgcyLNmrw4MFIT09HUVERACA/Px979uzBk08+CYA50d4Zc/737t0LtVqNgIAAqU1oaCgsLCyQk5PT6jFT67ty5QpkMhnUajUA5oSpKEwdAN2fCxcuQKfTwcXFxWC9i4sLTpw4YaKoyFT0ej3i4+MRHByMvn37AgAqKipgZWUlfdje4uLigoqKChNESS1t3bp1OHToEA4cOHDbNuZD+/Tzzz8jNTUVM2bMwLx583DgwAH85S9/gZWVFSIjI6Vz39jvEuZF25SQkICqqir07t0bcrkcOp0Ob7/9NiIiIgCAOdHOGXP+Kyoq4OzsbLBdoVCgU6dOzJF2oK6uDnPmzEF4eDjs7e0BMCdMhQU9URsSExODY8eOYc+ePaYOhUzk7NmziIuLQ1paGqytrU0dDj0g9Ho9AgICsGjRIgDAgAEDcOzYMXz44YeIjIw0cXRkCuvXr8cXX3yBf/7zn+jTpw8OHz6M+Ph4uLm5MSeIqEn19fV4/vnnIYRAamqqqcNp93jJvZlzcnKCXC6/bYbqc+fOQaPRmCgqMoXY2Fhs374du3fvhru7u7Reo9Hgxo0bqKysNGjPHGmbcnNzodVq4e/vD4VCAYVCgczMTKxYsQIKhQIuLi7Mh3bI1dUVPj4+Buu8vb1x5swZAJDOPX+XtB+zZ89GQkICXnjhBfj6+mL8+PGYPn06kpOTATAn2jtjzr9Go7ltAuabN2/i0qVLzJE27FYxf/r0aaSlpUmj8wBzwlRY0Js5KysrDBw4EOnp6dI6vV6P9PR0BAUFmTAyai1CCMTGxmLz5s3YtWsXvLy8DLYPHDgQlpaWBjlSWFiIM2fOMEfaoJCQEBw9ehSHDx+WXgEBAYiIiJB+Zj60P8HBwbc9zrKoqAjdunUDAHh5eUGj0RjkRVVVFXJycpgXbVRtbS0sLAy/Bsrlcuj1egDMifbOmPMfFBSEyspK5ObmSm127doFvV6PwMDAVo+ZWt6tYr64uBj/+c9/4OjoaLCdOWEipp6Vj+7funXrhFKpFGvXrhXHjx8X06ZNE2q1WlRUVJg6NGoFUVFRwsHBQWRkZIjy8nLpVVtbK7V55ZVXhIeHh9i1a5c4ePCgCAoKEkFBQSaMmlrTb2e5F4L50B7t379fKBQK8fbbb4vi4mLxxRdfCFtbW/H5559LbRYvXizUarXYunWrOHLkiBg1apTw8vIS165dM2Hk1FIiIyNFly5dxPbt20VJSYnYtGmTcHJyEq+99prUhjnRtlVXV4u8vDyRl5cnAIj33ntP5OXlSTOWG3P+w8LCxIABA0ROTo7Ys2eP6NmzpwgPDzdVl+g+NZUTN27cECNHjhTu7u7i8OHDBt85r1+/Lu2DOdH6WNC3EStXrhQeHh7CyspKDBo0SOzbt8/UIVErAdDoa82aNVKba9euiejoaNGxY0dha2srnn32WVFeXm66oKlVNSzomQ/t09dffy369u0rlEql6N27t/joo48Mtuv1epGYmChcXFyEUqkUISEhorCw0ETRUkurqqoScXFxwsPDQ1hbW4s//OEPYv78+QZfzJkTbdvu3bsb/f4QGRkphDDu/F+8eFGEh4cLlUol7O3txcsvvyyqq6tN0BtqDk3lRElJyR2/c+7evVvaB3Oi9cmEEKL1rgcgIiIiIiIioubAe+iJiIiIiIiIzBALeiIiIiIiIiIzxIKeiIiIiIiIyAyxoCciIiIiIiIyQyzoiYiIiIiIiMwQC3oiIiIiIiIiM8SCnoiIiIiIiMgMsaAnIiIiIiIiMkMs6ImIiKjNmzhxIkaPHt1km4yMDMhkMlRWVrZKTERERPeLBT0REVErOH/+PKKiouDh4QGlUgmNRoMRI0YgOzvb1KE9MGQymfRycHBAcHAwdu3a1Sz7Xr58OdauXSstDxkyBPHx8QZtBg8ejPLycjg4ODTLMYmIiFoaC3oiIqJWMHbsWOTl5eGzzz5DUVERtm3bhiFDhuDixYumDu2BsmbNGpSXlyM7OxtOTk54+umn8fPPP9/3fh0cHKBWq5tsY2VlBY1GA5lMdt/HIyIiag0s6ImIiFpYZWUlvv/+eyxZsgRDhw5Ft27dMGjQIMydOxcjR440aDdlyhR07twZ9vb2GDZsGPLz8w32tXjxYri4uMDOzg6TJ09GQkIC+vfvL21vbOR59OjRmDhxorR8/fp1zJo1C126dEGHDh0QGBiIjIwMafvatWuhVqvx7bffwtvbGyqVCmFhYSgvLzfY76effoo+ffpAqVTC1dUVsbGxv6svjVGr1dBoNOjbty9SU1Nx7do1pKWlAQAyMzMxaNAg6XgJCQm4efOm9G+/+uor+Pr6wsbGBo6OjggNDcXVq1cBGF5yP3HiRGRmZmL58uXSFQGnTp1q9JL7jRs3Sn309PRESkqKQbyenp5YtGgRJk2aBDs7O3h4eOCjjz66az+JiIiaAwt6IiKiFqZSqaBSqbBlyxZcv379ju3GjRsHrVaLnTt3Ijc3F/7+/ggJCcGlS5cAAOvXr8fChQuxaNEiHDx4EK6urvjb3/72u+OJjY3F3r17sW7dOhw5cgTjxo1DWFgYiouLpTa1tbV499138Y9//ANZWVk4c+YMZs2aJW1PTU1FTEwMpk2bhqNHj2Lbtm3o0aOH0X0xho2NDQDgxo0bKC0txZ/+9Cc89NBDyM/PR2pqKj755BO89dZbAIDy8nKEh4dj0qRJKCgoQEZGBsaMGQMhxG37Xb58OYKCgjB16lSUl5ejvLwcXbt2va1dbm4unn/+ebzwwgs4evQoFi5ciMTERINL9wEgJSUFAQEByMvLQ3R0NKKiolBYWGh0P4mIiO6ZICIiohb31VdfiY4dOwpra2sxePBgMXfuXJGfny9t//7774W9vb2oq6sz+Hfdu3cXq1evFkIIERQUJKKjow22BwYGCj8/P2n58ccfF3FxcQZtRo0aJSIjI4UQQpw+fVrI5XJRWlpq0CYkJETMnTtXCCHEmjVrBADx008/SdtXrVolXFxcpGU3Nzcxf/78RvtqTF8aA0Bs3rxZCCHE1atXRXR0tJDL5SI/P1/MmzdP/PGPfxR6vd4gJpVKJXQ6ncjNzRUAxKlTpxrdd2RkpBg1apS03Nj7tHv3bgFAXL58WQghxIsvviiGDx9u0Gb27NnCx8dHWu7WrZt46aWXpGW9Xi+cnZ1FamrqHftJRETUXDhCT0RE1ArGjh2LsrIybNu2DWFhYcjIyIC/v7802pufn4+amho4OjpKI/oqlQolJSU4efIkAKCgoACBgYEG+w0KCvpdcRw9ehQ6nQ69evUyOE5mZqZ0HACwtbVF9+7dpWVXV1dotVoAgFarRVlZGUJCQho9hjF9uZPw8HCoVCrY2dlh48aN+OSTT9CvXz8UFBQgKCjI4P724OBg1NTU4JdffoGfnx9CQkLg6+uLcePG4eOPP8bly5d/13vTUEFBAYKDgw3WBQcHo7i4GDqdTlrXr18/6WeZTAaNRiO9V0RERC1JYeoAiIiI2gtra2sMHz4cw4cPR2JiIqZMmYIFCxZg4sSJqKmpgaurq8G97LfcbTK337KwsLjtMvP6+nrp55qaGsjlcuTm5kIulxu0U6lU0s+WlpYG22QymbTfW5fC38n99GXZsmUIDQ2Fg4MDOnfu3GTb35LL5UhLS8MPP/yA7777DitXrsT8+fORk5MDLy8vo/dzLxp7r/R6fYsek4iICOA99ERERCbj4+MjTdrm7++PiooKKBQK9OjRw+Dl5OQEAPD29kZOTo7BPvbt22ew3LlzZ4PJ63Q6HY4dOyYtDxgwADqdDlqt9rbjaDQao+K2s7ODp6cn0tPTG91uTF/uRKPRoEePHrcV897e3ti7d6/BHyuys7NhZ2cHd3d3AL8W0sHBwXjzzTeRl5cHKysrbN68udHjWFlZGYyyN8bb2/u2xwpmZ2ejV69et/0xhIiIyBRY0BMREbWwixcvYtiwYfj8889x5MgRlJSUYMOGDVi6dClGjRoFAAgNDUVQUBBGjx6N7777DqdOncIPP/yA+fPn4+DBgwCAuLg4fPrpp1izZg2KioqwYMEC/PjjjwbHGjZsGHbs2IEdO3bgxIkTiIqKMpi1vVevXoiIiMCECROwadMmlJSUYP/+/UhOTsaOHTuM7tPChQuRkpKCFStWoLi4GIcOHcLKlSuN7svvFR0djbNnz+LVV1/FiRMnsHXrVixYsAAzZsyAhYUFcnJypMkCz5w5g02bNuH8+fPw9vZudH+enp7IycnBqVOncOHChUZH1GfOnIn09HQkJSWhqKgIn332GT744AODyQGJiIhMiZfcExERtTCVSoXAwEAsW7YMJ0+eRH19Pbp27YqpU6di3rx5AH4dXf7mm28wf/58vPzyyzh//jw0Gg0ee+wxuLi4AAD+/Oc/4+TJk3jttddQV1eHsWPHIioqCt9++610rEmTJiE/Px8TJkyAQqHA9OnTMXToUIN41qxZg7feegszZ85EaWkpnJyc8PDDD+Ppp582uk+RkZGoq6vDsmXLMGvWLDg5OeG5554zui+/V5cuXfDNN99g9uzZ8PPzQ6dOnTB58mS8/vrrAAB7e3tkZWXh/fffR1VVFbp164aUlBQ8+eSTje5v1qxZiIyMhI+PD65du4aSkpLb2vj7+2P9+vV44403kJSUBFdXV/z1r381eAQgERGRKclEwxvtiIiIyGwsXLgQW7ZsweHDh00dChEREbUyXnJPREREREREZIZY0BMRERERERGZIV5yT0RERERERGSGOEJPREREREREZIZY0BMRERERERGZIRb0RERERERERGaIBT0RERERERGRGWJBT0RERERERGSGWNATERERERERmSEW9ERERERERERmiAU9ERERERERkRn6L01O1Rk5Yk+wAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x_test, y_test = next(dataset(1, 1))\n", "y_pred = model(x_test)\n", "\n", "plt.figure(figsize=(12, 6))\n", "plt.subplot(2, 1, 1)\n", "plt.plot(x_test[0], label='Input Sequence', alpha=0.8)\n", "plt.plot(y_test[0], label='Target Sequence', alpha=0.8)\n", "plt.plot(y_pred[0], label='Predicted Sequence', linestyle='--', alpha=0.8)\n", "plt.xlabel('Sequence Position')\n", "plt.ylabel('Value')\n", "plt.title('Example Sequence Prediction')\n", "plt.legend()\n", "plt.grid(True, alpha=0.3)" ] }, { "cell_type": "markdown", "metadata": { "id": "2kmQDk0Fh3Q_" }, "source": [ "## 5. Fully Sharded Data Parallelism (FSDP)" ] }, { "cell_type": "markdown", "metadata": { "id": "L8qlRjWzz3Gs" }, "source": [ "### 5.1 What limits Data Parallelism?\n", "\n", "Data Parallelism involves a lot of duplicated work. Once each device AllReduces the gradients, each device updates:\n", "\n", "- full optimzer state (duplicated across all devices)\n", "- full parameter update for the model (duplicated across all devices)" ] }, { "cell_type": "markdown", "metadata": { "id": "SmsCdxzEIsbd" }, "source": [ "### 5.2 ZeRO (Zero Redundancy Optimizer)\n", "\n", "The paper ZeRO: Memory Optimizations Toward Training Trillion Parameter Models (linked below) the memory optimizations by sharding not just the data along the batch axis, but also sharding the optimizer states (e.g. momentum vector), gradients, as well as the parameters.\n", "\n", "The rows below show the memory consumption of:\n", "1. **Pure Data Parallelism**: Parameters and optimizer states fully replicated\n", "2. **ZeRO-1**: Optimizer states sharded\n", "3. **ZeRO-2**: Optimizer states and gradients sharded \n", "4. **ZeRO-3**: Parameters, gradients, and optimizer states all sharded\n", "\n", "![FSDP Memory Comparison](https://github.com/jax-ml/scaling-book/blob/main/assets/img/fsdp-figure.png?raw=true)\n", "\n", " Image Source: [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054) \n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "i9SmDdu0z_KK" }, "source": [ "### 5.3 FSDP Theory\n", "\n", "**Definition**: ZeRO-3 is also known as Fully Sharded Data Parallel since activations, weights, and optimizer states are sharded along batch dimension. Weights are gathered just-in-time before use.\n", "\n", "**Mathematical representation**:\n", "$$\\text{In}[B_X, D] \\cdot_D W_{\\text{in}}[D_X, F] \\cdot_F W_{\\text{out}}[F, D_X] \\rightarrow \\text{Out}[B_X, D]$$\n", "\n", "where both batch and weight dimensions are sharded across $X$ devices.\n", "\n", "![FSDP](https://github.com/jax-ml/scaling-book/blob/main/assets/img/fsdp.png?raw=true)" ] }, { "cell_type": "markdown", "metadata": { "id": "D1Dl8cbLIsbc" }, "source": [ " Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) " ] }, { "cell_type": "markdown", "metadata": { "id": "KFqILDPodUce" }, "source": [ "### 5.3 FSDP Algorithm\n", "\n", "As mentioned earlier, data parallelism involves a lot of duplicated work. Fully Sharded Data Parallelism (FSDP) addresses this by sharding the model weights and activations across devices. Instead of using AllReduce to communicate the full model weights and optimizer states, FSDP splits the AllReduce operation into two steps: **AllGather** and **ReduceScatter**. See the previous tutorial on how this is mathematically equivalent.\n", "\n", "Splitting the collective allows us to gather the weights from all devices just-in-time before use, and scatter the gradients back to each device after the backward pass. While this adds extra communication during the forward pass compared to data parallelism, but its in proportion to the reduction in comms during the backward pass - so the overall communication cost is same.\n", "\n", "**Forward pass:**\n", "1. Win[D, F] = **AllGather**(Win[DX, F]) (*not on critical path, can do it during previous layer*)\n", "2. Tmp[BX, F] = In[BX, D] \\*D Win[D, F] (*can throw away Win[D, F] now*)\n", "3. Wout[F, D] = **AllGather**(Wout[F, DX]) (*not on critical path, can do it during previous layer*)\n", "4. Out[BX, D] = Tmp[BX, F] \\*F Wout[F, D]\n", "5. Loss[BX] = ...\n", "\n", "**Backward pass:**\n", "1. dOut[BX, D] = ...\n", "2. dWout[F, D] {UX} = Tmp[BX, F] \\*B dOut[BX, D]\n", "3. dWout[F, DX] = **ReduceScatter**(dWout[F, D] {UX}) (*not on critical path, can be done async*)\n", "4. Wout[F, D] = **AllGather**(Wout[F, DX]) (*can be done ahead of time*)\n", "5. dTmp[BX, F] = dOut[BX, D] \\*D Wout[F, D] *(can throw away Wout[F, D] here)*\n", "6. dWin[D,F] {UX} = dTmp[BX, F] \\*B In[BX, D]\n", "7. dWin[DX, F] = **ReduceScatter**(dWin[D, F] {UX}) *(not on critical path, can be done async)*\n", "8. Win[D, F] = **AllGather**(Win[DX, F]) (*can be done ahead of time*)\n", "9. dIn[BX, D] = dTmp[BX, F] \\*F Win[D, F] (*needed for previous layers) (can throw away Win[D, F] here*)" ] }, { "cell_type": "markdown", "metadata": { "id": "8S2s_dEwIsbc" }, "source": [ "![Data Parallel](https://engineering.fb.com/wp-content/uploads/2021/07/FSDP-Graph-2.png)\n", "\n", " Image Source: [FB Engineering Blog](https://engineering.fb.com/2021/07/15/open-source/fsdp/) " ] }, { "cell_type": "markdown", "metadata": { "id": "-KATtRorh3Q_" }, "source": [ "### 5.5 FSDP Compute versus Communication Bound\n", "\n", "**Communication Analysis**:\n", "\n", "FSDP has the **same roofline as pure data parallelism** because:\n", "- AllReduce = AllGather + ReduceScatter\n", "- Total communication volume is identical\n", "- Same condition: $\\frac{B}{X} > \\frac{C}{W_{\\text{ici}}} = 2550$" ] }, { "cell_type": "markdown", "metadata": { "id": "75jVrHq7h3Q_" }, "source": [ "## 6. Fully Sharded Data Parallel (FSDP) Training with Flax NNX\n", "\n", "Now let's implement FSDP using the Flax NNX API." ] }, { "cell_type": "markdown", "metadata": { "id": "p29FesRM0dBw" }, "source": [ "### 6.1 Mesh and Sharding" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "n0JFd67Dh3Q_" }, "outputs": [], "source": [ "# Create a 1D mesh for FSDP\n", "fsdp_mesh = jax.sharding.Mesh(\n", " mesh_utils.create_device_mesh((8,)), # 1D mesh\n", " ('fsdp',) # Single axis for FSDP\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "tQt63uBteJOt" }, "outputs": [], "source": [ "# A helper function to quickly create a NamedSharding object\n", "# using the globally defined 'mesh'.\n", "def named_sharding(*names: str | None) -> NamedSharding:\n", " return NamedSharding(fsdp_mesh, P(*names))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "mlNWdixaeRlU" }, "outputs": [], "source": [ "# A dataclass to hold sharding rules for different parts of the model/data.\n", "# Makes it easy to manage and change sharding strategies.\n", "@dataclasses.dataclass(unsafe_hash=True)\n", "class FSDPRules:\n", " \"\"\"Rules for FSDP sharding - all parameters sharded along the same axis\"\"\"\n", " weight_0: str | None = 'fsdp' # Shard first dimension of weights\n", " weight_1: str | None = None # Don't shard second dimension\n", " bias: str | None = None # Don't shard biases (they're small)\n", " data: str | None = 'fsdp' # Shard data along batch dimension\n", "\n", " def __call__(self, *keys: str) -> tuple[str, ...]:\n", " return tuple(getattr(self, key) for key in keys)\n", "\n", "fsdp_rules = FSDPRules()" ] }, { "cell_type": "markdown", "metadata": { "id": "wSiKXyz305_8" }, "source": [ "### 6.2 Build the Sharded Model" ] }, { "cell_type": "markdown", "metadata": { "id": "APnOCx9SXSnG" }, "source": [ "We can pass the sharding rule corresponding to the mesh we created earlier as parameters to various layer init functions in flax to tell the JAX compiler how to shard those layers." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "MACPHZwAeg0N" }, "outputs": [], "source": [ "class MLP(nnx.Module):\n", " def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs):\n", " # Create Linear layers with custom initialization that includes sharding\n", " self.linear1 = nnx.Linear(\n", " din, dmid,\n", " kernel_init=nnx.with_metadata(\n", " nnx.initializers.lecun_normal(),\n", " sharding=fsdp_rules('weight_0', 'weight_1') # ('fsdp', None)\n", " ),\n", " bias_init=nnx.with_metadata(\n", " nnx.initializers.zeros_init(),\n", " sharding=fsdp_rules('bias') # (None,)\n", " ),\n", " rngs=rngs\n", " )\n", "\n", " self.linear2 = nnx.Linear(\n", " dmid, dout,\n", " kernel_init=nnx.with_metadata(\n", " nnx.initializers.lecun_normal(),\n", " sharding=fsdp_rules('weight_0', 'weight_1') # ('fsdp', None)\n", " ),\n", " bias_init=nnx.with_metadata(\n", " nnx.initializers.zeros_init(),\n", " sharding=fsdp_rules('bias') # (None,)\n", " ),\n", " rngs=rngs\n", " )\n", "\n", " def __call__(self, x):\n", " x = nnx.relu(self.linear1(x))\n", " return self.linear2(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "g7-wKf0C1DxF" }, "source": [ "### 6.2 Build the Sharded Optimizer" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "YZ6MhzDeh3Q_" }, "outputs": [], "source": [ "# Define a custom type for SGD momentum state, inheriting from nnx.Variable.\n", "# This allows it to be tracked as part of the NNX state tree.\n", "class SGDState(nnx.Variable):\n", " pass\n", "\n", "# Define the SGD optimizer using NNX API.\n", "class SGD(nnx.Object):\n", " # Constructor takes the model parameters (as nnx.State), learning rate, and decay.\n", " def __init__(self, params: nnx.State, lr, decay=0.9):\n", " # Helper function to initialize momentum buffer for a given parameter.\n", " def init_optimizer_state(variable: nnx.Variable):\n", " # Create momentum state with zeros, same shape and metadata (incl. sharding)\n", " # as the parameter it corresponds to.\n", " return SGDState(\n", " jnp.zeros_like(variable.value), **variable.get_metadata()\n", " )\n", "\n", " self.lr = lr\n", " # Store a reference to the parameter State tree.\n", " self.params = params\n", " # Create the momentum state tree, mirroring the structure of 'params',\n", " # using the helper function. Momentum will have the same sharding as params.\n", " self.momentum = jax.tree.map(init_optimizer_state, self.params)\n", " self.decay = decay\n", "\n", " # Method to update parameters based on gradients.\n", " def update(self, grads: nnx.State):\n", " # Define the update logic for a single parameter/momentum/gradient triple.\n", " def update_fn(\n", " params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState\n", " ):\n", " # Standard SGD with momentum update rule.\n", " # v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t)\n", " momentum.value = self.decay * momentum.value + (1 - self.decay) * grad.value\n", " # θ_{t+1} = θ_t - α * v_t\n", " params.value -= self.lr * momentum.value # NOTE: Direct mutation of param value!\n", "\n", " # Apply the update function across the parameter, momentum, and gradient trees.\n", " # This performs the update in-place on the parameter values referenced by self.params.\n", " jax.tree.map(update_fn, self.params, self.momentum, grads)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "yuHnHlJy1MY6" }, "source": [ "### 6.3 Enforce Sharding in Model and Optimizer States" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 898 }, "id": "-R9nkEnoesPZ", "outputId": "e2839dc9-0ee9-4c44-8e58-41e2bff2157e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Linear 1 kernel sharding:\n" ] }, { "data": { "text/html": [ "
                                     TPU 0                                      \n",
              "                                                                                \n",
              "                                     TPU 1                                      \n",
              "                                                                                \n",
              "                                     TPU 2                                      \n",
              "                                                                                \n",
              "                                     TPU 3                                      \n",
              "                                                                                \n",
              "                                     TPU 6                                      \n",
              "                                                                                \n",
              "                                     TPU 7                                      \n",
              "                                                                                \n",
              "                                     TPU 4                                      \n",
              "                                                                                \n",
              "                                     TPU 5                                      \n",
              "                                                                                \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Linear 2 kernel sharding:\n" ] }, { "data": { "text/html": [ "
  TPU 0  \n",
              "         \n",
              "  TPU 1  \n",
              "         \n",
              "  TPU 2  \n",
              "         \n",
              "  TPU 3  \n",
              "         \n",
              "  TPU 6  \n",
              "         \n",
              "  TPU 7  \n",
              "         \n",
              "  TPU 4  \n",
              "         \n",
              "  TPU 5  \n",
              "         \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Linear 1 kernel momentum sharding:\n" ] }, { "data": { "text/html": [ "
                                     TPU 0                                      \n",
              "                                                                                \n",
              "                                     TPU 1                                      \n",
              "                                                                                \n",
              "                                     TPU 2                                      \n",
              "                                                                                \n",
              "                                     TPU 3                                      \n",
              "                                                                                \n",
              "                                     TPU 6                                      \n",
              "                                                                                \n",
              "                                     TPU 7                                      \n",
              "                                                                                \n",
              "                                     TPU 4                                      \n",
              "                                                                                \n",
              "                                     TPU 5                                      \n",
              "                                                                                \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;214;97;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;162;82m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 3\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148mTPU 6\u001b[0m\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;231;203;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207mTPU 7\u001b[0m\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;107;110;207m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148mTPU 4\u001b[0m\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;165;81;148m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49mTPU 5\u001b[0m\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;140;109;49m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# JIT-compile the model and optimizer creation function.\n", "@nnx.jit\n", "def create_model():\n", " # Instantiate the MLP model. rngs=nnx.Rngs(0) provides PRNG keys.\n", " model = MLP(128, 2048, 128, rngs=nnx.Rngs(0))\n", " # Create the optimizer. nnx.variables(model, nnx.Param) extracts\n", " # only the nnx.Param state variables from the model object.\n", " optimizer = SGD(nnx.variables(model, nnx.Param), 0.01, decay=0.9)\n", "\n", " # === Explicit Sharding Application ===\n", " # 1. Extract ALL state (model params + optimizer momentum) into a flat State pytree.\n", " state = nnx.state(optimizer)\n", "\n", " # 2. Define the target sharding for the state pytree.\n", " # This function maps state paths to NamedSharding objects based on stored metadata.\n", " def get_named_shardings(path: tuple, value: nnx.VariableState):\n", " # Assumes params and momentum use the sharding defined in their metadata.\n", " if path[0] == 'params':\n", " # value.sharding contains the tuple like ('model',) or (None, 'model')\n", " return value.replace(NamedSharding(fsdp_mesh, P(*value.sharding)))\n", " elif path[0] == 'momentum':\n", " # Momentum states have the same sharding as their corresponding parameters\n", " return value.replace(NamedSharding(fsdp_mesh, P(*value.sharding)))\n", " else:\n", " # Handle other state if necessary (e.g., learning rate if it were a Variable)\n", " raise ValueError(f'Unknown path: {path}')\n", "\n", " # Create the pytree of NamedSharding objects.\n", " named_shardings = state.map(get_named_shardings)\n", "\n", " # 3. Apply sharding constraint. This tells JAX how the 'state' pytree\n", " # SHOULD be sharded when computations involving it are run under jit/pjit.\n", " # It doesn't immediately move data but sets up the constraint for the compiler.\n", " sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)\n", "\n", " # 4. Update the original objects (model params, optimizer momentum)\n", " # with the constrained state values. This step makes the sharding\n", " # \"stick\" to the objects themselves for subsequent use outside this function.\n", " nnx.update(optimizer, sharded_state)\n", "\n", " # Return the model and optimizer objects, now containing sharded state variables.\n", " return model, optimizer\n", "\n", "# Call the function to create the sharded model and optimizer.\n", "model, optimizer = create_model()\n", "\n", "# Visualize sharding - now using the Linear layer's kernel parameters\n", "print(\"Linear 1 kernel sharding:\")\n", "jax.debug.visualize_array_sharding(model.linear1.kernel.value)\n", "print(\"\\nLinear 2 kernel sharding:\")\n", "jax.debug.visualize_array_sharding(model.linear2.kernel.value)\n", "print(\"\\nLinear 1 kernel momentum state sharding:\")\n", "jax.debug.visualize_array_sharding(optimizer.momentum.linear1.kernel.value)" ] }, { "cell_type": "markdown", "metadata": { "id": "BU2toONE1Yuh" }, "source": [ "### 6.4 Distributed Training" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "gayPl3wvevDX" }, "outputs": [], "source": [ "# JIT-compile the training step function (same as before)\n", "@nnx.jit\n", "def train_step(model: MLP, optimizer: SGD, x, y):\n", " def loss_fn(model):\n", " y_pred = model(x)\n", " loss = jnp.mean((y - y_pred) ** 2)\n", " return loss\n", "\n", " # Calculate loss and gradients w.r.t the model's state (its nnx.Param variables).\n", " # 'grad' will be an nnx.State object mirroring model's Param structure.\n", " loss, grad = nnx.value_and_grad(loss_fn)(model)\n", "\n", " optimizer.update(grad)\n", "\n", " return loss" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "mGJOwDhQeyAa" }, "outputs": [], "source": [ "def dataset(steps, batch_size):\n", " \"\"\"Generate 128D sequence data with underlying pattern.\"\"\"\n", " for _ in range(steps):\n", " # Generate input sequences\n", " # Create a pattern where the output is a transformed version of input\n", " # First few dimensions have strong pattern, rest have weaker signal\n", "\n", " # Generate base signal\n", " t = np.linspace(0, 4*np.pi, 128)\n", " base_patterns = np.array([\n", " np.sin(t + np.random.uniform(0, 2*np.pi)),\n", " np.cos(2*t + np.random.uniform(0, 2*np.pi)),\n", " np.sin(3*t + np.random.uniform(0, 2*np.pi))\n", " ])\n", "\n", " # Create batch of sequences\n", " x = np.zeros((batch_size, 128))\n", " y = np.zeros((batch_size, 128))\n", "\n", " for i in range(batch_size):\n", " # Mix base patterns with random weights\n", " weights = np.random.randn(3)\n", " signal = np.sum(weights[:, np.newaxis] * base_patterns, axis=0)\n", "\n", " # Add noise\n", " x[i] = signal + np.random.normal(0, 0.1, 128)\n", "\n", " # Output is a non-linear transformation of input\n", " # Shift and apply non-linear transformation\n", " y[i] = np.roll(x[i], 5) * 0.8 + 0.1 * x[i]**2\n", " y[i] += np.random.normal(0, 0.05, 128)\n", "\n", " yield x.astype(np.float32), y.astype(np.float32)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1bBXHvjoYwmV", "outputId": "23151cf7-c8b6-4859-fac9-c771a232b1bb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 0: Loss = 1.761734962463379\n", "Step 100: Loss = 0.13825155794620514\n", "Step 200: Loss = 0.09147088974714279\n", "Step 300: Loss = 0.0810324028134346\n", "Step 400: Loss = 0.06444019079208374\n", "Step 500: Loss = 0.05355063080787659\n" ] } ], "source": [ "# --- Training Loop ---\n", "losses = [] # To store loss values for plotting\n", "# Iterate through the dataset generator for 10,000 steps.\n", "for step, (x_batch, y_batch) in enumerate(\n", " dataset(batch_size=8192, steps=501)\n", "):\n", " # CRITICAL: Place the NumPy data onto JAX devices AND apply sharding.\n", " # named_sharding('data') -> Shard along the 'data' mesh axis (first dim, size 2).\n", " # Each device along the 'data' axis gets a slice of the batch.\n", " x_batch, y_batch = jax.device_put((x_batch, y_batch), named_sharding('fsdp'))\n", "\n", " # Execute the JIT-compiled training step with the sharded model, optimizer, and data.\n", " loss = train_step(model, optimizer, x_batch, y_batch)\n", "\n", " # Record the loss (move scalar loss back to host CPU).\n", " losses.append(float(loss))\n", " # Log progress periodically.\n", " if step % 100 == 0:\n", " print(f'Step {step}: Loss = {loss}')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 490 }, "id": "XL_FOFuzYy0i", "outputId": "ce2bbbc6-25ee-4b80-8958-bad0e7413cf9" }, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'MSE Loss')" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkAAAAHHCAYAAABXx+fLAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAV/1JREFUeJzt3XtcVGX+B/DPzMDMcB3uNwUB7ze8i1imrSiSuVpuqZvrpbvWdqGb9Ftv1a7pbq21mbaVodZqVoqVShqFZqGmgvcLKgLK/TowwAzMnN8fyNERVECYMzCf9+s1r9/MOc85fOfQTz77PM95jkwQBAFERERENkQudQFERERElsYARERERDaHAYiIiIhsDgMQERER2RwGICIiIrI5DEBERERkcxiAiIiIyOYwABEREZHNYQAiIiIim8MARESSmzNnDoKDg1t07JIlSyCTyVq3ICLq8BiAiOimZDJZk15JSUlSlyqJOXPmwNnZWeoyiKgFZHwWGBHdzOeff272ef369di9ezc2bNhgtn3cuHHw9fVt8c+pqamByWSCSqVq9rG1tbWora2FWq1u8c9vqTlz5uDrr79GRUWFxX82Ed0ZO6kLICLrNXPmTLPP+/fvx+7duxtsv1FlZSUcHR2b/HPs7e1bVB8A2NnZwc6O/5QRUfNwCIyI7siYMWPQr18/HD58GPfccw8cHR3x+uuvAwC2bduGiRMnIiAgACqVCl27dsWbb74Jo9Fodo4b5wBdunQJMpkM//rXv/Df//4XXbt2hUqlwrBhw/D777+bHdvYHCCZTIZnn30W8fHx6NevH1QqFfr27YuEhIQG9SclJWHo0KFQq9Xo2rUrPvroo1afV/TVV19hyJAhcHBwgJeXF2bOnIkrV66YtcnNzcXcuXPRuXNnqFQq+Pv7Y/Lkybh06ZLY5tChQ4iKioKXlxccHBwQEhKCRx99tNXqJLIl/J9NRHTHioqKEB0djenTp2PmzJnicFhcXBycnZ0RExMDZ2dn/PTTT1i0aBG0Wi3++c9/3va8//vf/1BeXo6nnnoKMpkMK1aswIMPPoiLFy/ettdo37592LJlC+bPnw8XFxe8//77mDp1KjIzM+Hp6QkASElJwYQJE+Dv74+lS5fCaDTijTfegLe3951flKvi4uIwd+5cDBs2DMuWLUNeXh7ee+89/Prrr0hJSYGbmxsAYOrUqTh58iT++te/Ijg4GPn5+di9ezcyMzPFz+PHj4e3tzcWLFgANzc3XLp0CVu2bGm1WolsikBE1ETPPPOMcOM/G6NHjxYACGvWrGnQvrKyssG2p556SnB0dBSqq6vFbbNnzxa6dOkifk5PTxcACJ6enkJxcbG4fdu2bQIA4bvvvhO3LV68uEFNAASlUimcP39e3Hb06FEBgPCf//xH3DZp0iTB0dFRuHLlirgtLS1NsLOza3DOxsyePVtwcnK66X6DwSD4+PgI/fr1E6qqqsTt33//vQBAWLRokSAIglBSUiIAEP75z3/e9Fxbt24VAAi///77besiotvjEBgR3TGVSoW5c+c22O7g4CC+Ly8vR2FhIUaNGoXKykqcOXPmtuedNm0a3N3dxc+jRo0CAFy8ePG2x0ZGRqJr167i57CwMLi6uorHGo1G/Pjjj5gyZQoCAgLEdt26dUN0dPRtz98Uhw4dQn5+PubPn282SXvixIno1asXtm/fDqDuOimVSiQlJaGkpKTRc9X3FH3//feoqalplfqIbBkDEBHdsU6dOkGpVDbYfvLkSTzwwAPQaDRwdXWFt7e3OIG6rKzstucNCgoy+1wfhm4WEm51bP3x9cfm5+ejqqoK3bp1a9CusW0tkZGRAQDo2bNng329evUS96tUKixfvhw7d+6Er68v7rnnHqxYsQK5ubli+9GjR2Pq1KlYunQpvLy8MHnyZHz22WfQ6/WtUiuRrWEAIqI7dn1PT73S0lKMHj0aR48exRtvvIHvvvsOu3fvxvLlywEAJpPptudVKBSNbheasHrHnRwrhRdeeAHnzp3DsmXLoFarsXDhQvTu3RspKSkA6iZ2f/3110hOTsazzz6LK1eu4NFHH8WQIUN4Gz5RCzAAEVGbSEpKQlFREeLi4vD888/j/vvvR2RkpNmQlpR8fHygVqtx/vz5Bvsa29YSXbp0AQCcPXu2wb6zZ8+K++t17doVL730Enbt2oUTJ07AYDDgnXfeMWszYsQI/P3vf8ehQ4fwxRdf4OTJk9i0aVOr1EtkSxiAiKhN1PfAXN/jYjAY8OGHH0pVkhmFQoHIyEjEx8cjOztb3H7+/Hns3LmzVX7G0KFD4ePjgzVr1pgNVe3cuROnT5/GxIkTAdStm1RdXW12bNeuXeHi4iIeV1JS0qD3auDAgQDAYTCiFuBt8ETUJkaOHAl3d3fMnj0bzz33HGQyGTZs2GBVQ1BLlizBrl27cNddd2HevHkwGo344IMP0K9fP6SmpjbpHDU1NXjrrbcabPfw8MD8+fOxfPlyzJ07F6NHj8aMGTPE2+CDg4Px4osvAgDOnTuHsWPH4uGHH0afPn1gZ2eHrVu3Ii8vD9OnTwcArFu3Dh9++CEeeOABdO3aFeXl5fj444/h6uqK++67r9WuCZGtYAAiojbh6emJ77//Hi+99BL+9re/wd3dHTNnzsTYsWMRFRUldXkAgCFDhmDnzp14+eWXsXDhQgQGBuKNN97A6dOnm3SXGlDXq7Vw4cIG27t27Yr58+djzpw5cHR0xNtvv43XXnsNTk5OeOCBB7B8+XLxzq7AwEDMmDEDiYmJ2LBhA+zs7NCrVy9s3rwZU6dOBVA3CfrgwYPYtGkT8vLyoNFoMHz4cHzxxRcICQlptWtCZCv4LDAiohtMmTIFJ0+eRFpamtSlEFEb4RwgIrJpVVVVZp/T0tKwY8cOjBkzRpqCiMgi2ANERDbN398fc+bMQWhoKDIyMrB69Wro9XqkpKSge/fuUpdHRG2Ec4CIyKZNmDABGzduRG5uLlQqFSIiIvCPf/yD4Yeog2MPEBEREdkczgEiIiIim8MARERERDaHc4AaYTKZkJ2dDRcXF8hkMqnLISIioiYQBAHl5eUICAiAXH7rPh4GoEZkZ2cjMDBQ6jKIiIioBbKystC5c+dbtmEAaoSLiwuAugvo6uoqcTVERETUFFqtFoGBgeLf8VthAGpE/bCXq6srAxAREVE705TpK5wETURERDaHAYiIiIhsDgMQERER2RwGICIiIrI5DEBERERkcxiAiIiIyOYwABEREZHNYQAiIiIim8MARERERDaHAYiIiIhsDgMQERER2RwGICIiIrI5DEASqDIYpS6BiIjIpjEAWdiXv2ei96IE7DieI3UpRERENosByMJe++Y4AGD+F0ckroSIiMh2MQARERGRzWEAkoidXCZ1CURERDaLAUgidgoGICIiIqkwAEnEXsFLT0REJBX+FbYgQRDE9wxARERE0uFfYQsq19eK7zkHiIiISDoMQBZUqqsR3xtNwi1aEhERUVtiALKgkkqD+L7iut4gIiIisiwGIAu6PgDpa02oMZokrIaIiMh2SRqA9u7di0mTJiEgIAAymQzx8fG3bD9nzhzIZLIGr759+4ptlixZ0mB/r1692vibNE1pZY3ZZx17gYiIiCQhaQDS6XQYMGAAVq1a1aT27733HnJycsRXVlYWPDw88NBDD5m169u3r1m7ffv2tUX5zXZ9DxDAYTAiIiKp2En5w6OjoxEdHd3k9hqNBhqNRvwcHx+PkpISzJ0716ydnZ0d/Pz8Wq3O1lJyQw8QAxAREZE02vUcoE8//RSRkZHo0qWL2fa0tDQEBAQgNDQUjzzyCDIzM295Hr1eD61Wa/ZqC4/eFYydz4+C7Ood8BwCIyIikka7DUDZ2dnYuXMnHn/8cbPt4eHhiIuLQ0JCAlavXo309HSMGjUK5eXlNz3XsmXLxN4ljUaDwMDANqnZzVGJ3v6u6O3nCgCo0Bvb5OcQERHRrbXbALRu3Tq4ublhypQpZtujo6Px0EMPISwsDFFRUdixYwdKS0uxefPmm54rNjYWZWVl4isrK6tNa3dW1Y08VlSzB4iIiEgKks4BailBELB27Vr85S9/gVKpvGVbNzc39OjRA+fPn79pG5VKBZVK1dpl3pSzuu6ycwiMiIhIGu2yB2jPnj04f/48Hnvssdu2raiowIULF+Dv72+ByprG3bEutOWXV0tcCRERkW2SNABVVFQgNTUVqampAID09HSkpqaKk5ZjY2Mxa9asBsd9+umnCA8PR79+/Rrse/nll7Fnzx5cunQJv/32Gx544AEoFArMmDGjTb9Lc4R4OQIA0gsrJa6EiIjINkk6BHbo0CHce++94ueYmBgAwOzZsxEXF4ecnJwGd3CVlZXhm2++wXvvvdfoOS9fvowZM2agqKgI3t7euPvuu7F//354e3u33RdpphAvZwDAxcIKiSshIiKyTTJBEPhUzhtotVpoNBqUlZXB1dW11c9/OkeL6Pd+gZujPVIXjW/18xMREdmi5vz9bpdzgNq7YE8nAHWPxijRGW7TmoiIiFobA5AEHJQKBGjUAICLhTqJqyEiIrI9DEASCfGu6wW6WMB5QERERJbGACSREK+6AJTOHiAiIiKLYwCSSOjVO8EYgIiIiCyPAUgi9UNgDEBERESWxwAkkdDrhsBMJq5EQEREZEkMQBLp5OYAe4UM+loTssuqpC6HiIjIpjAAScROIUegR90jMTKK+EgMIiIiS2IAklAnNwcAwJVS9gARERFZEgOQhOoDUDYDEBERkUUxAElI7AEqYQAiIiKyJAYgCQXU9wBxEjQREZFFMQBJqJM7e4CIiIikwAAkIXEOUFk11wIiIiKyIAYgCflp1JDJAEOtCYU6vdTlEBER2QwGIAnZK+TwdVEDALJLqyWuhoiIyHYwAEmM84CIiIgsjwFIYgFcC4iIiMjiGIAkxtWgiYiILI8BSGKd3OrmADEAERERWQ4DkMQ4B4iIiMjyGIAkxtWgiYiILI8BSGL1c4BKK2ug09dKXA0REZFtYACSmIvaHq5qOwC8E4yIiMhSGICsQP0w2GUGICIiIotgALICnd25FhAREZElMQBZgfoeIN4JRkREZBkMQFagE1eDJiIisigGICsQwNWgiYiILIoByApwMUQiIiLLYgCyAl5OKgBAaVWNxJUQERHZBgYgK+CgVAAAqmqMMJkEiashIiLq+BiArICTqi4ACQJQXWuUuBoiIqKOjwHICqjtFOJ7nZ4BiIiIqK0xAFkBuVwGx/phMAMDEBERUVuTNADt3bsXkyZNQkBAAGQyGeLj42/ZPikpCTKZrMErNzfXrN2qVasQHBwMtVqN8PBwHDx4sA2/ReuoD0A6Ax+ISkRE1NYkDUA6nQ4DBgzAqlWrmnXc2bNnkZOTI758fHzEfV9++SViYmKwePFiHDlyBAMGDEBUVBTy8/Nbu/xW5aiseyBqJXuAiIiI2pydlD88Ojoa0dHRzT7Ox8cHbm5uje5799138cQTT2Du3LkAgDVr1mD79u1Yu3YtFixYcCfltqn6HqBK9gARERG1uXY5B2jgwIHw9/fHuHHj8Ouvv4rbDQYDDh8+jMjISHGbXC5HZGQkkpOTpSi1ycQhME6CJiIianPtKgD5+/tjzZo1+Oabb/DNN98gMDAQY8aMwZEjRwAAhYWFMBqN8PX1NTvO19e3wTyh6+n1emi1WrOXpTmp6jrjqmrYA0RERNTWJB0Ca66ePXuiZ8+e4ueRI0fiwoUL+Pe//40NGza0+LzLli3D0qVLW6PEFnOwZw8QERGRpbSrHqDGDB8+HOfPnwcAeHl5QaFQIC8vz6xNXl4e/Pz8bnqO2NhYlJWVia+srKw2rbkx9T1AyReKoNOzF4iIiKgttfsAlJqaCn9/fwCAUqnEkCFDkJiYKO43mUxITExERETETc+hUqng6upq9rK0+sdhbD+eg0fjfrf4zyciIrIlkg6BVVRUiL03AJCeno7U1FR4eHggKCgIsbGxuHLlCtavXw8AWLlyJUJCQtC3b19UV1fjk08+wU8//YRdu3aJ54iJicHs2bMxdOhQDB8+HCtXroROpxPvCrNWTsprq0EfSC+WsBIiIqKOT9IAdOjQIdx7773i55iYGADA7NmzERcXh5ycHGRmZor7DQYDXnrpJVy5cgWOjo4ICwvDjz/+aHaOadOmoaCgAIsWLUJubi4GDhyIhISEBhOjrY2Dsl1NxyIiImrXZIIg8PHjN9BqtdBoNCgrK7PYcNhHey5g2c4z4udLb0+0yM8lIiLqKJrz97vdzwHqKBxV7AEiIiKyFAYgK1F/GzwRERG1PQYgK1FVc239H6Udfy1ERERtiX9prYTRaLr23sRpWURERG2JAchKPDikM4I9HQHUBaCa6wIRERERtS4GICvhqrbHDy/eI36+fkiMiIiIWhcDkBVRKuSQyereVzMAERERtRkGICsik8nEu8GqDRwCIyIiaisMQFZGfTUAcQiMiIio7TAAWRmxB4gBiIiIqM0wAFkZlX3dr4Q9QERERG2HAcjKsAeIiIio7TEAWRk1AxAREVGbYwCyMtd6gHgXGBERUVthALIyvAuMiIio7TEAWRn11UnQHAIjIiJqOwxAVsaBPUBERERtjgHIyqg5B4iIiKjNMQBZGQdlXQB6PzENv18qlrgaIiKijokByMrYK2Ti+w9+Oi9hJURERB0XA5CVKa2sEd97OCklrISIiKjjYgCyMjNHdBHfy2S3aEhEREQtxgBkZXr7u+LNyX0BAFUG3glGRETUFhiArBAXQyQiImpbDEBWqP5OMPYAERERtQ0GICvEJ8ITERG1LQYgK8TVoImIiNoWA5AVEofAGICIiIjaBAOQFeIcICIiorbFAGSFxCEwBiAiIqI2wQBkha6fAyQIgsTVEBERdTwMQFZIfXUIzCQABiOfCk9ERNTaGICsUH0PEABUGxiAiIiIWhsDkBWyV8jFp8LzTjAiIqLWxwBkpeofh1FpqJW4EiIioo6HAchKcTFEIiKitsMAZKXq1wLi4zCIiIhan6QBaO/evZg0aRICAgIgk8kQHx9/y/ZbtmzBuHHj4O3tDVdXV0REROCHH34wa7NkyRLIZDKzV69evdrwW7SNa2sBcRI0ERFRa5M0AOl0OgwYMACrVq1qUvu9e/di3Lhx2LFjBw4fPox7770XkyZNQkpKilm7vn37IicnR3zt27evLcpvU3wcBhERUduxk/KHR0dHIzo6usntV65cafb5H//4B7Zt24bvvvsOgwYNErfb2dnBz8+vtcqUBOcAERERtZ12PQfIZDKhvLwcHh4eZtvT0tIQEBCA0NBQPPLII8jMzLzlefR6PbRardlLavUBaEXCGUz+YB8q9LwbjIiIqLW06wD0r3/9CxUVFXj44YfFbeHh4YiLi0NCQgJWr16N9PR0jBo1CuXl5Tc9z7Jly6DRaMRXYGCgJcq/pfrVoC+XVOHo5TJ8fzRb4oqIiIg6jnYbgP73v/9h6dKl2Lx5M3x8fMTt0dHReOihhxAWFoaoqCjs2LEDpaWl2Lx5803PFRsbi7KyMvGVlZVlia9wS74uarPP9XOCiIiI6M5JOgeopTZt2oTHH38cX331FSIjI2/Z1s3NDT169MD58+dv2kalUkGlUrV2mXckxMvR7LNS0W6zKhERkdVpd39VN27ciLlz52Ljxo2YOHHibdtXVFTgwoUL8Pf3t0B1rSfEy9nsc3UtJ0MTERG1Fkl7gCoqKsx6ZtLT05GamgoPDw8EBQUhNjYWV65cwfr16wHUDXvNnj0b7733HsLDw5GbmwsAcHBwgEajAQC8/PLLmDRpErp06YLs7GwsXrwYCoUCM2bMsPwXvAPBN/QAVddwPSAiIqLWImkP0KFDhzBo0CDxFvaYmBgMGjQIixYtAgDk5OSY3cH13//+F7W1tXjmmWfg7+8vvp5//nmxzeXLlzFjxgz07NkTDz/8MDw9PbF//354e3tb9svdoQCNg9lnrghNRETUeiTtARozZgwEQbjp/ri4OLPPSUlJtz3npk2b7rAq6yCXy8w+sweIiIio9bS7OUC25OXxPcT37AEiIiJqPQxAVuzZP3TH43eHAGAAIiIiak0MQFZObc+nwhMREbU2BiArp7av+xVxDhAREVHrYQCycmIPENcBIiIiajUMQFZOxSEwIiKiVscAZOXUdhwCIyIiam0MQFaOk6CJiIhaHwOQlbs2B4g9QERERK2FAcjKOVwNQHr2ABEREbUaBiArV38bfBUDEBERUathALJynANERETU+hiArBwXQiQiImp9DEBWTmXHHiAiIqLWxgBk5eqHwPS1JgiCIHE1REREHQMDkJWrHwID6kIQERER3TkGICtX3wMEcBiMiIiotTAAWTl7hRwKuQwAJ0ITERG1FgagdqB+McRKQ63ElRAREXUMDEDtgJujPQCgpNIgcSVEREQdAwNQO+DppAQAFFUwABEREbWGZgegqqoqVFZWip8zMjKwcuVK7Nq1q1ULo2s8rgagYh0DEBERUWtodgCaPHky1q9fDwAoLS1FeHg43nnnHUyePBmrV69u9QIJ8HBSAQCKGICIiIhaRbMD0JEjRzBq1CgAwNdffw1fX19kZGRg/fr1eP/991u9QAI8net6gEoYgIiIiFpFswNQZWUlXFxcAAC7du3Cgw8+CLlcjhEjRiAjI6PVCyQOgREREbW2Zgegbt26IT4+HllZWfjhhx8wfvx4AEB+fj5cXV1bvUC6Ngl6S8oVrPzxnMTVEBERtX/NDkCLFi3Cyy+/jODgYISHhyMiIgJAXW/QoEGDWr1AujYEBgArf0yD0cRnghEREd0Ju+Ye8Kc//Ql33303cnJyMGDAAHH72LFj8cADD7RqcVSnfhJ0vcIKPXxd1RJVQ0RE1P61aB0gPz8/DBo0CHK5HFqtFvHx8XBxcUGvXr1auz7CtSGwetmlVRJVQkRE1DE0OwA9/PDD+OCDDwDUrQk0dOhQPPzwwwgLC8M333zT6gUS4OVs3gOUU1YtUSVEREQdQ7MD0N69e8Xb4Ldu3QpBEFBaWor3338fb731VqsXSICDUoG1c4aKzwRjDxAREdGdaXYAKisrg4eHBwAgISEBU6dOhaOjIyZOnIi0tLRWL5Dq/KGXL2aOCALAHiAiIqI71ewAFBgYiOTkZOh0OiQkJIi3wZeUlECt5sTctuSvcQAA5JSxB4iIiOhONPsusBdeeAGPPPIInJ2d0aVLF4wZMwZA3dBY//79W7s+uk6AW13AzC5lDxAREdGdaHYAmj9/PoYPH46srCyMGzcOcnldJ1JoaCjnALUxb5e6AFRYoZe4EiIiovat2QEIAIYOHYqhQ4dCEAQIggCZTIaJEye2dm10Axd13a9Lp6+VuBIiIqL2rUXrAK1fvx79+/eHg4MDHBwcEBYWhg0bNjT7PHv37sWkSZMQEBAAmUyG+Pj42x6TlJSEwYMHQ6VSoVu3boiLi2vQZtWqVQgODoZarUZ4eDgOHjzY7NqskbOqPgAZJa6EiIiofWt2AHr33Xcxb9483Hfffdi8eTM2b96MCRMm4Omnn8a///3vZp1Lp9NhwIABWLVqVZPap6enY+LEibj33nuRmpqKF154AY8//jh++OEHsc2XX36JmJgYLF68GEeOHMGAAQMQFRWF/Pz8ZtVmjZyuBiCD0QR9LUMQERFRS8kEQWjWg6VCQkKwdOlSzJo1y2z7unXrsGTJEqSnp7esEJkMW7duxZQpU27a5rXXXsP27dtx4sQJcdv06dNRWlqKhIQEAEB4eDiGDRsmLtZoMpkQGBiIv/71r1iwYEGTatFqtdBoNCgrK7OqB7waTQK6vr4DAHD4b5HwvGGBRCIiIlvWnL/fze4BysnJwciRIxtsHzlyJHJycpp7umZJTk5GZGSk2baoqCgkJycDAAwGAw4fPmzWRi6XIzIyUmzTGL1eD61Wa/ayRgq5DI7KusUQOQxGRETUcs0OQN26dcPmzZsbbP/yyy/RvXv3VinqZnJzc+Hr62u2zdfXF1qtFlVVVSgsLITRaGy0TW5u7k3Pu2zZMmg0GvEVGBjYJvW3hvphsHJ9jcSVEBERtV/Nvgts6dKlmDZtGvbu3Yu77roLAPDrr78iMTGx0WDUHsTGxiImJkb8rNVqrTYEuajsUFCuR0U17wQjIiJqqWYHoKlTp+LAgQP497//Ld611bt3bxw8eBCDBg1q7frM+Pn5IS8vz2xbXl4eXF1d4eDgAIVCAYVC0WgbPz+/m55XpVJBpWof82nqe4B0BgYgIiKilmrRbfBDhgzB559/jsOHD+Pw4cP4/PPP0alTJ/zjH/9o7frMREREIDEx0Wzb7t27ERERAQBQKpUYMmSIWRuTyYTExESxTXtXfyt8OXuAiIiIWqxFAagxOTk5WLhwYbOOqaioQGpqKlJTUwHU3eaempqKzMxMAHVDU9ffbfb000/j4sWLePXVV3HmzBl8+OGH2Lx5M1588UWxTUxMDD7++GOsW7cOp0+fxrx586DT6TB37tw7/5JWwFnNtYCIiIjuVItWgm4thw4dwr333it+rp+HM3v2bMTFxSEnJ0cMQ0DdLfjbt2/Hiy++iPfeew+dO3fGJ598gqioKLHNtGnTUFBQgEWLFiE3NxcDBw5EQkJCg4nR7VV9D1AFJ0ETERG1WLPXAbqZo0ePYvDgwTAa23/PhLWuAwQAC+NPYMP+DDz3h26IGd9T6nKIiIisRpuuA0TSqh8Cq+AQGBERUYs1eQjs+tvEG1NQUHDHxdDtcQiMiIjozjU5AKWkpNy2zT333HNHxdDtXQtAvAuMiIiopZocgH7++ee2rIOaiLfBExER3TnOAWpnvFzqFmzMKauWuBIiIqL2iwGonenh6wwASC/UQV/LidBEREQtwQDUzvi5quGisoPRJCC9UCd1OURERO0SA1A7I5PJ0MPPBQBwNrdc4mqIiIjaJwagdqiHb10AOpfHAERERNQSTQ5AK1asQFVVlfj5119/hV6vFz+Xl5dj/vz5rVsdNap+HtC5vAqJKyEiImqfmhyAYmNjUV5+rcchOjoaV65cET9XVlbio48+at3qqFE92QNERER0R5ocgG58ZFgrPUKMWqB+DlBmcSWqDLwTjIiIqLk4B6gd8nJWwcNJCUEAzudzGIyIiKi5GIDaqfp5QGc5DEZERNRsTX4UBgB88skncHau+8NbW1uLuLg4eHl5AYDZ/CBqez18XbD/YjHS8nndiYiImqvJASgoKAgff/yx+NnPzw8bNmxo0IYsI8DNAQCQr9XfpiURERHdqMkB6NKlS21YBjWXt3PdM8EKKxiAiIiImotzgNop76sPRS0oZwAiIiJqriYHoOTkZHz//fdm29avX4+QkBD4+PjgySefNFsYkdqWlzMDEBERUUs1OQC98cYbOHnypPj5+PHjeOyxxxAZGYkFCxbgu+++w7Jly9qkSGqovgeouNKAWqNJ4mqIiIjalyYHoNTUVIwdO1b8vGnTJoSHh+Pjjz9GTEwM3n//fWzevLlNiqSGPJyUkMsAQQCKdQapyyEiImpXmhyASkpK4OvrK37es2cPoqOjxc/Dhg1DVlZW61ZHN6WQy+DhVNcLlM9hMCIiomZpcgDy9fVFeno6AMBgMODIkSMYMWKEuL+8vBz29vatXyHdVP0wGO8EIyIiap4mB6D77rsPCxYswC+//ILY2Fg4Ojpi1KhR4v5jx46ha9eubVIkNY53ghEREbVMk9cBevPNN/Hggw9i9OjRcHZ2xrp166BUKsX9a9euxfjx49ukSGqcl3Pd9S+s4BwgIiKi5mhyAPLy8sLevXtRVlYGZ2dnKBQKs/1fffWV+JgMsgz2ABEREbVMs54FBgAajabR7R4eHndcDDVP/WrQBZwDRERE1CxNDkCPPvpok9qtXbu2xcVQ84iToNkDRERE1CxNDkBxcXHo0qULBg0aBEEQ2rImaiL2ABEREbVMkwPQvHnzsHHjRqSnp2Pu3LmYOXMmh70kxjlARERELdPk2+BXrVqFnJwcvPrqq/juu+8QGBiIhx9+GD/88AN7hCRSH4DKqmqgrzVKXA0REVH70aynwatUKsyYMQO7d+/GqVOn0LdvX8yfPx/BwcGoqKhoqxrpJjQO9rBXyAAARbwVnoiIqMmaFYDMDpTLIZPJIAgCjEb2PkhBJpPxqfBEREQt0KwApNfrsXHjRowbNw49evTA8ePH8cEHHyAzM5NrAEmEj8MgIiJqviZPgp4/fz42bdqEwMBAPProo9i4cSO8vLzasjZqAk+nutWgOQRGRETUdE0OQGvWrEFQUBBCQ0OxZ88e7Nmzp9F2W7ZsabXi6PbcHOsCUEaxDpdLKtHZ3VHiioiIiKxfk4fAZs2ahXvvvRdubm7QaDQ3fbXEqlWrEBwcDLVajfDwcBw8ePCmbceMGQOZTNbgNXHiRLHNnDlzGuyfMGFCi2qzdm6O9gCAVT9fwN3Lf0a+tlriioiIiKxfsxZCbAtffvklYmJisGbNGoSHh2PlypWIiorC2bNn4ePj06D9li1bYDBcG+4pKirCgAED8NBDD5m1mzBhAj777DPxs0qlapP6pebmoDT7fPxKGca6qiWqhoiIqH1o8V1greXdd9/FE088gblz56JPnz5Ys2YNHB0db/pIDQ8PD/j5+Ymv3bt3w9HRsUEAUqlUZu3c3d0t8XUszt3J3uyz0k7yXykREZHVk/SvpcFgwOHDhxEZGSluk8vliIyMRHJycpPO8emnn2L69OlwcnIy256UlAQfHx/07NkT8+bNQ1FRUavWbi00DuYBqLrGJFElRERE7UeznwbfmgoLC2E0GuHr62u23dfXF2fOnLnt8QcPHsSJEyfw6aefmm2fMGECHnzwQYSEhODChQt4/fXXER0djeTkZCgUigbn0ev10Ouv3Uau1Wpb+I0sr34SdD2dvlaiSoiIiNoPSQPQnfr000/Rv39/DB8+3Gz79OnTxff9+/dHWFgYunbtiqSkJIwdO7bBeZYtW4alS5e2eb1twd3RvAeonAGIiIjotiQdAvPy8oJCoUBeXp7Z9ry8PPj5+d3yWJ1Oh02bNuGxxx677c8JDQ2Fl5cXzp8/3+j+2NhYlJWVia+srKymfwmJ3TgJmj1AREREtydpAFIqlRgyZAgSExPFbSaTCYmJiYiIiLjlsV999RX0ej1mzpx5259z+fJlFBUVwd/fv9H9KpUKrq6uZq/2wu2GSdAV1QxAREREtyP5LUMxMTH4+OOPsW7dOpw+fRrz5s2DTqfD3LlzAdStPxQbG9vguE8//RRTpkyBp6en2faKigq88sor2L9/Py5duoTExERMnjwZ3bp1Q1RUlEW+kyW5qMxHMSvYA0RERHRbks8BmjZtGgoKCrBo0SLk5uZi4MCBSEhIECdGZ2ZmQi43z2lnz57Fvn37sGvXrgbnUygUOHbsGNatW4fS0lIEBARg/PjxePPNNzvkWkAymczsM4fAiIiIbk8mCIIgdRHWRqvVQqPRoKysrF0MhwUv2C6+j+7nh9Uzh0hYDRERkTSa8/db8iEwunPvTR8ovucQGBER0e0xAHUAkwd2wkd/qev14RAYERHR7TEAdRDOVydDV+hrUWM04b0f03A0q1TaooiIiKwUA1AHUR+AdHojPv7lIv794zlMXvWrxFURERFZJwagDsLpagAqr67B/ovFEldDRERk3RiAOggX9dUeIIMRpZUGiashIiKybgxAHUR9D5DRJCC3rFriaoiIiKwbA1AH4aRUwO3qg1Hzy/W3aU1ERGTbGIA6CJlMhqFdPBpsN5m4ziUREdGNGIA6kOEh7g22VdcaJaiEiIjIujEAdSDDghv2AOn0DEBEREQ3YgDqQAZ0dkN0Pz+zbZUGrgxNRER0IwagDkQul2H1zCHY/tzdUMjrnhLPHiAiIqKGGIA6oL4BGnR2dwAAVNXUwmQSkKflrfFERET1GIA6KAd7BYC6HqAXN6ci/B+J+CWtQOKqiIiIrAMDUAdVvzBipaEW21KzAQAf7bkoZUlERERWgwGog3JUXusBqqey46+biIgIYADqsJyUV3uAaq4LQPb8dRMREQEMQB1WfQ9QRfW12+BVdgqpyiEiIrIqDEAdlKOqLuzkllWJ25QK/rqJiIgABqAOq34I7ErptQBUYzJJVQ4REZFVYQDqoByuDoFdLrkWgCq5KCIREREABqAOq7EeoOsnRBMREdkyBqAOys3RHgBQft0k6Eo9nwtGREQEMAB1WMFeTg22VRrYA0RERAQwAHVYXTwdG2zjk+GJiIjqMAB1UN7OKnEtoHrsASIiIqrDANRByWQyBLqb9wIxABEREdVhAOrAinR6s8+VhloIgiBRNURERNaDAagDG9fHFwDQ2d0BAGASAH0tF0MkIiJiAOrAFkT3xqsTeiL+mbvEbRwGIyIiYgDq0DQO9pg/phu8nFVQ2dX9qnVcC4iIiIgByFY4qepWhs4v19+mJRERUcfHAGQjaq7O/Zm6+jd89mu6xNUQERFJiwHIRpRfN/T11vbTKNEZJKyGiIhIWgxANuLFyB4Ivfp4DKNJQF55tcQVERERSYcByEY8H9kdP708Bt19nAEARRXsASIiIttlFQFo1apVCA4OhlqtRnh4OA4ePHjTtnFxcZDJZGYvtVpt1kYQBCxatAj+/v5wcHBAZGQk0tLS2vprtAuezkoAQGEFJ0MTEZHtkjwAffnll4iJicHixYtx5MgRDBgwAFFRUcjPz7/pMa6ursjJyRFfGRkZZvtXrFiB999/H2vWrMGBAwfg5OSEqKgoVFdz2MfTWQWAPUBERGTbJA9A7777Lp544gnMnTsXffr0wZo1a+Do6Ii1a9fe9BiZTAY/Pz/x5evrK+4TBAErV67E3/72N0yePBlhYWFYv349srOzER8fb4FvZN286wOQjj1ARERkuyQNQAaDAYcPH0ZkZKS4TS6XIzIyEsnJyTc9rqKiAl26dEFgYCAmT56MkydPivvS09ORm5trdk6NRoPw8PCbnlOv10Or1Zq9OipPp7ohMPYAERGRLZM0ABUWFsJoNJr14ACAr68vcnNzGz2mZ8+eWLt2LbZt24bPP/8cJpMJI0eOxOXLlwFAPK4551y2bBk0Go34CgwMvNOvZrXqh8AKGYCIiMiGST4E1lwRERGYNWsWBg4ciNGjR2PLli3w9vbGRx991OJzxsbGoqysTHxlZWW1YsXWpX4S9I+n8/BLWoHE1RAREUlD0gDk5eUFhUKBvLw8s+15eXnw8/Nr0jns7e0xaNAgnD9/HgDE45pzTpVKBVdXV7NXR+V1NQABwBPrD0lYCRERkXQkDUBKpRJDhgxBYmKiuM1kMiExMRERERFNOofRaMTx48fh7+8PAAgJCYGfn5/ZObVaLQ4cONDkc3Zknk4q8X11jQlGkyBhNURERNKQfAgsJiYGH3/8MdatW4fTp09j3rx50Ol0mDt3LgBg1qxZiI2NFdu/8cYb2LVrFy5evIgjR45g5syZyMjIwOOPPw6g7g6xF154AW+99Ra+/fZbHD9+HLNmzUJAQACmTJkixVe0KkEejhjf59r8qCKuB0RERDbITuoCpk2bhoKCAixatAi5ubkYOHAgEhISxEnMmZmZkMuv5bSSkhI88cQTyM3Nhbu7O4YMGYLffvsNffr0Edu8+uqr0Ol0ePLJJ1FaWoq7774bCQkJDRZMtEVyuQz/nTUU4f/4EXlaPfK0evi48roQEZFtkQmCwDGQG2i1Wmg0GpSVlXXY+UB//GAfjl0uwyezhmJwF3d4OClvfxAREZEVa87fb8mHwEgaPi51vT4f7b2AwW/uxqf70iWuiIiIyHIYgGyUr2vdZOjfL5UAAN78/pSU5RAREVkUA5CN8uW8HyIismEMQDaqvgeIiIjIFjEA2agANwepSyAiIpIMA5CN6hegabCNNwQSEZGtYACyUe6N3PZerq+VoBIiIiLLYwCyYRoHe7PPxXxCPBER2QgGIBs2oa/5w2GLdAxARERkGxiAbNjf7u+Nx+8Ogb1CBgAoZgAiIiIbwQBkw1zU9vjb/X0wqrs3AD4YlYiIbAcDEInPAeMQGBER2QoGIEKguyMAIL1QJ3ElRERElsEAROjm4wwAOJ9fIXElRERElsEARGIAulBQwcUQiYjIJthJXQBJL9jLEXIZUF5di4JyPbYfz0FuWTVeGt8TSjtmZCIi6ngYgAgqOwWCPBxxqagSRy+XYel3pwAAOWXVeH/GIImrIyIian38n/cEAOjt7woAeOXro+K2b49mQ8fHYxARUQfEAEQAgGf/0A1KhRyllTVm27NKKiWqiIiIqO0wABEAoG+ABrMiujTYnlVcJUE1REREbYsBiETThgU22JZZzB4gIiLqeBiASNTd1wX3h/mjs7sDZgwPAgBkMQAREVEHxLvAyMwHfx4MAPjiQAYABiAiIuqY2ANEjQryqHs8RgYDEBERdUAMQNSoHr4uAOpWh84qrsSqn88jp4wToomIqGPgEBg1ytdVjV5+LjiTW45RK34GABzOKMHaOcMkroyIiOjOsQeIbureXj5mn386k885QURE1CEwANFNTezv32DbqBU/Y++5AgmqISIiaj0MQHRT/TppsPmpCEwd3Nls+4b9GRJVRERE1Do4B4huaXiIB4aHeCDxTJ74mIzqGqPEVREREd0Z9gBRk7z9YH/x/dnccgkrISIiunMMQNQkE/r548TSKABAfrkexTqDxBURERG1HAMQNZmzyg6BHg4AgOc3pWBh/AkIgiBxVURERM3HAETN0i9AAwD4Ja0QG/ZnIKOIt8UTEVH7wwBEzTKki7vZ5wp9rUSVEBERtRwDEDXLsGAPs8/a6hqJKiEiImo5qwhAq1atQnBwMNRqNcLDw3Hw4MGbtv34448xatQouLu7w93dHZGRkQ3az5kzBzKZzOw1YcKEtv4aNqFPgKvZZ23VtR6g6hojfkkrgL6Wt8kTEZF1kzwAffnll4iJicHixYtx5MgRDBgwAFFRUcjPz2+0fVJSEmbMmIGff/4ZycnJCAwMxPjx43HlyhWzdhMmTEBOTo742rhxoyW+Todnr5DjtQm9xM/a6hrUGk0wmgS8vfMM/vLpQaxIOCthhURERLcnEyS+jSc8PBzDhg3DBx98AAAwmUwIDAzEX//6VyxYsOC2xxuNRri7u+ODDz7ArFmzANT1AJWWliI+Pr5FNWm1Wmg0GpSVlcHV1fX2B9ig5zam4Nuj2Xj87hBs+j0LY3v7YFtqtrj/0tsTJayOiIhsUXP+fkvaA2QwGHD48GFERkaK2+RyOSIjI5GcnNykc1RWVqKmpgYeHuZzU5KSkuDj44OePXti3rx5KCoquuk59Ho9tFqt2YtuzdWhbhHxr49cRoW+1iz8EBERWTtJA1BhYSGMRiN8fX3Ntvv6+iI3N7dJ53jttdcQEBBgFqImTJiA9evXIzExEcuXL8eePXsQHR0No7HxuSnLli2DRqMRX4GBgS3/UjbCVW0PAOLjMW5kMnF9ICIisl7t+llgb7/9NjZt2oSkpCSo1Wpx+/Tp08X3/fv3R1hYGLp27YqkpCSMHTu2wXliY2MRExMjftZqtQxBt+HqYH/L/TnaanRyc7BQNURERM0jaQ+Ql5cXFAoF8vLyzLbn5eXBz8/vlsf+61//wttvv41du3YhLCzslm1DQ0Ph5eWF8+fPN7pfpVLB1dXV7EW35qK+dXY+caUMedpqC1VDRETUPJIGIKVSiSFDhiAxMVHcZjKZkJiYiIiIiJset2LFCrz55ptISEjA0KFDb/tzLl++jKKiIvj7+7dK3XRtCAwAlHYN/zN6asNhjFrxMy4V6ixZFhERUZNIfht8TEwMPv74Y6xbtw6nT5/GvHnzoNPpMHfuXADArFmzEBsbK7Zfvnw5Fi5ciLVr1yI4OBi5ubnIzc1FRUUFAKCiogKvvPIK9u/fj0uXLiExMRGTJ09Gt27dEBUVJcl37IiuHwLrF9B4j5mh1oTNh7IsVRIREVGTST4HaNq0aSgoKMCiRYuQm5uLgQMHIiEhQZwYnZmZCbn8Wk5bvXo1DAYD/vSnP5mdZ/HixViyZAkUCgWOHTuGdevWobS0FAEBARg/fjzefPNNqFQqi363jsz1uiGwgYHuGNXdG+8lpkGpkMNgNIn7Pky6gEqDEY+EB6G7r4sUpRIRETUg+TpA1ojrAN3ehYIKjH1nDwDgPzMGYUI/P/ySVgBvZzUmfbCvQXs3R3v8tuAPcFRKnrmJiKiDajfrAFH75WLWA+QGe4Ucf+jl2+BRGfVKK2sQn8K1goiIyDowAFGLeDurENXXF1MGBqCz+7Xb3RVymVm7QA8HzB/TFQDw5e+Z4vb0Qh3WJ1/iekFERCQJjkdQi8hkMnz0l8bvwNvx3Cj8klaAx0eFQoa6NYE+TLqAE9laVNcYobZX4N5/JYnn+cuILpYrnIiICOwBojbQJ8AVT43uCoVcBrlchgCNGl7OShhNAk7l1IWgeqmZpeL7U9laFJTrJaiYiIhsDXuAqM3JZDL076TBz2cL8OCHvyGy97VHn8hlwHdHs7Hut0s4lFECF5UdPpw5GKO6e0tYMRERdXTsASKL6BugEd//ePrayt/704vw3KYUHMooAQCU62ux5NuTFq+PiIhsCwMQWcTY3j6Nbs8qrkL9QgzPje0Oe4UMFwp0uFhQYcHqiIjI1nAIjCxiUJA79seOhYeTEnvOFeBKSSWWfHdK3P/2g/0xfXgQUjJL8EtaIXafysNTo53NzlFYoYe+1oQAjRoymezGH0FERNRk7AEii/HTqKG0k2NcH1/Migg22xfdv+45beP71M0Pik/NhiAISL5QhO3HcnCpUIcx/0zCXW//hBe/TDU7dlvqFTy3MQVVBiOIiIiagj1AJAm5XIYevs44l1eBmHE9oLn6bLFJAwLw1vbTOJ2jRUjsjkaP/fZoNpb+sR80jvYoq6xB7JbjqDQY4aRSYGwvX0T28W30OCIionrsASLJrJw2CO88NADP3ttN3ObmqMTEMP9G2yvt5FDayWESgCkf/orTOVp8fiADlVd7fjYezMLj6w8ht6zaIvUTEVH7xR4gkkyfANdGH53xt4l90NndESNCPVBlMOKxdYcAAFF9/eDjosKn+9KRXqjDnz/ejwp9bYPjz+aVw0+jvunP/fZoNpQKGSb0azxoERFRx8ceILI6Hk5KxIzrgZFdvczWA5o2NBDR/fzEzyWVNagxCrC74fEbZ3O1OJldhimrfsWWI5fN9mUU6fDcxhQ8/fkRnMsrBwAYak04mF6M2uueYk9ERB0bnwbfCD4N3roculSMK6VVmDywE4C6FaNLqwx4afNR5JRV46VxPfDO7nONHiuTAaffmIDqGiNe33ocO47nivvCQzzQr5MGO47niOd5bFQISitr8N+9F+GgVODVqJ6844yIqJ1ozt9vBqBGMAC1D4ZaE7JLqxDs5YTgBdtv2u7NKf2w62QufkkrvOX5vJxV6OLpiMNXF2UEgG/mjUSJzoBe/i7o7O4Io0lAhb4WaXnlKKwwYMJ1PVK3UlppgNpeAbW9omlfjoiImo0B6A4xALU/R7NK8duFIixPOAMAeCWqJ2qMJqz8Ma1BW4Vchj8PD8KG/Rmwk8sweWAnfHPDUNmN1PZyPDCoM3afykVhhUHc/sXj4birm9dNj6uuMeJcXjmm/3c/+nfSYNOTI9ijRETURprz95uToKlDGBDohgGBbjAJAlR2cjw+KhRGk4ATV7TiozfuD/PHU/d0hc5Qi/AQD0T380OotzP8NGr4uKqwOukCAGBEqAemDOyEBVuOi+evrjFh48HMBj/3ifWHEN3PH1nFlRgR6oH593ZDQbkeB9KLcXc3Lzz8UTIyiysBAAfSi/HaN8dwNq8CwZ6OKNYZEN3PHzOGBzIUERFZGHuAGsEeoI5DX2vEez+m4bcLRXjn4QHo6u3caLtaowlLvjuJn88U4PPHwxHi5YStKZfx85kCRHT1xLIdp6GtroWLyg7Pje2O3y8V4+ez+agxmv+/j5NSAV0zF2R8bUIv9Ovkim2p2fjpTD6mDwvE2N6+6OPvCgdl3ZDZ5ZJKGGpNCL1af43RhBNXytC/kwZ2irp7GYwmAUln8+HupMTgIPfmXioionaPQ2B3iAGIbpSSWYJ//nAWr0T1xKCr4SJPW40fT+ehtLIGFfpabEu5guxG1iCaHdEF3Xyc8fcdp+Fgr0BJZU2TfmaIlxMGBbnhckkVDqYXQyaruxNuVHdvvLP7LC4W6DBnZDCW/LEvBEHAnM9+x55zBQCAe3p4w2gyYcbwIAzo7IafzuTjwcGd4KK2h7a6Bve/vw9BHo7Y8NhwmIS6YcEbCYKAlKxSdHZzgKuDPecvEZHVYwC6QwxA1BK1RhMOZ5SgrKoG4SGeOJWjhYvaDv06aQDUTdq2k8sgADiXVw4fFxVGvv0T9LV1t9/3DXDFyWxts3/uqO5ecHdU4tuj2bdsNzzEAxseG47tx3IQs/koAOD1+3phecJZvDS+B+aP6YbvjmZj5Y/nUKwzQONgj0tFdcN3chngrLJDZG9fjO3ti/TCCsy9KwROqluPou88noPy6loUVxqgkMnw8LBAcdVvIqLWxgB0hxiAyFISTuTibG45nh4TCpWdAi9+mYqtKVcAADHjeuBgejH2na+7e21cH1/kl+txNKsUnk5K1JoElFXdvDfJUakQV8mu9/jdIcjVVuP7YzkN2od6OeFioa7JtQ8OcsOamUMw/4sj8HJWwUllB5W9HAMD3WAnl8FJZYenNhw2O6Z/Jw3WzhmGtPxybEvJxh8HBuCubl44k6uFnVwGlZ0CHk7K2waretrqGriqGaiIqA4D0B1iACKpZJdW4dWvj2HmiC7iLfapWaXILK7EpDB/mARg96lc9A3QIKOoEm9tP4VAD0fsv1CEqhoj1j06HAu3nUCIpxM+mT0UhzJKcDSrFBoHe7zy9bEm1TChrx9SskpgJ5ejrKqm0dW2W9OIUA8cTC+G6eq/RMODPfDlUyNQpDNgTdIFXCmtQr9OGnT3cYa3iwqDgtwhCAI+35+BhdtO4rUJvTBvTNcW//zMokos/+EMnhnTDX0CXPH8phQcv1yGr56OgKezqpW+JRFZAgPQHWIAovamqEKPYp0B3X1dbtrmnV1n8Z+fzgMAhgW7Q2knx6/niwAAY3v5YFiIB8JDPDAw0A21JgEyAG98fwrrkzMAABf+cR9e/foYtqRcxp+HB+GLA+Z3xclldcNs+y8Wi9sGBblBX2OCp7MS4/v4YuG2k2bHyGRAc/4FUshlmNDPD7tP5sFw3crd9csRlOgM+PMnB+CitsNL43pgWLAHDEYTFHIZSitrkJpViq7eTuJkcgD4y6cH8EtaIRyVCnz77F2IfHcvgLpgtnhSX/T2v/m/AYIg4OvDl+HlosK9PX2a/kWIqE0wAN0hBiDqqPZfLEJ+uR739fNDjVHAhv2XMKq7903/yOdrq/HkhsN4cHAnzIoIRq3RhCKdAb6uany+PwPLd55BoIcjpg0LxPAQD/T2d4UgCDiZrUVKZgkeHhYIld21ydPHL5fB01mJ/HI9wjpp8OPpPDx5dZjMX6NGTiOTyOeN6Yrtx3LE5QQa42CvwOyRwdhzrgCnc8znUcllgFwmQ63p2j914SEeePTuEBy/XIYPfj5/y2vW09cFj94djGnDgvDFgQx8ui8dPi4qXCzQIb9cL7Z7cHAnLJjQC6VVNejq7SxOLE/NKoWjUoEeV8NpZlElOrs7QN7IxPOmyNdWY/4XRzB5YAD+EhGMnLIq/Ha+CJMHBoh3BJ7MLoOvqxpe7MEiG8MAdIcYgIiaxmQSYBQE2Cta/ljBn8/ko1xfiz8OCEBZZQ3eTjgjrrn05D2heP2+3jCaBBhqTfjPT2k4lFECHxcVPJ2UeGhoIJYnnGmwyre7o32T77ZrqvF9fJF4Jh9G0+3/yQz1csJd3bxwqUiHX9IKobKT44M/D8YvaQVYn5wBF5UdZkZ0wdy7ggGhrqfNUGuCvZ0c2qoaTOzvjyFd3LH213Sk5VXgdI4WGgd7PHlPKM7lV+B/V3vfTr8xAX/8YB/S8ivwYmQPPB/ZHZ/9mo6l352Ci9oODw0JxOGMYjw2KhR/HBAAo0lARpEO/hoHcYmF6+WUVeFwRgkiQj2bNPxnMgmQyeqWYNhy5Aq6+zqLd0kSSYEB6A4xABFJy1BrwoH0IkSEeoq9GjdTazThmyOXcSC9GI5KBaL6+mFUd2/sOVeARdtOYFJYAOJTr+BySRVeieqJBwd3wvKdZxCfeu2uuUfCgxDZxxeVeiPsFTLc08Mb+loTPtpzAT+dyceZ3HKxrZ+rGs+N7Q43R3u8/NVROCrt8K+HwvDcxhRoq5s3X8peIWuwllRzhHXW4NjlMvFz3NxhmBv3e6PDimp7Oapr6oYN3RztseyB/ijUGfDhz+fh6axEH39X7DpVt6yDvUKGAZ3dkFNWjUFBbvjPjEEQBOBcfjncHJTYe64AlYZabPo9Cwq5DGGd3cTQOm9MV9zV1Qu7TuVCqZDD3UmJp0d3FXvE6kPTzRb/NJkEHMksQS9/Vzir7FBpqEVhuQFBno4or66BQi6Do9IOR7NK0d3XGY5KrudL1zAA3SEGIKKOJau4Er+kFeLhoZ1hp5DDZBKwNaWux8LHRQ1vF1WjayHVe+3rY/jyUBYGB7lh/WPhcL56l1qethoKuQxezioczijG2zvPYNKAAOhrTPj3j+dQaTDiyXtC8cX+DHGBzE5uDvByVuJEtrbR3qR7e3rj57MF4udF9/fBXd28kHQ2H8t2nmnQXi4DbjxN/XILPi4quDsqcTpX26y5VjcaGOiGCn0tzudXtOj4sb188OK4HliddAE7TuSgu48znFV2KK+uRRdPJzw4uBPkMhl6+Dpj6XensOdcATq5OWDKoABsOXIFOWXVmDE8CDtP5MBVbY9H7wrGku9OIcTLCaN7eGPasED09nfF2dxyFFXoUagzwFBrwthePnBW2+H4lTK4Oyrh56rGldJKnM4ph5ujPUZ1925Qa2v0ajbX4YxiVBlMuLv7zR+rAwAXCyqgkMvQxdPJQpW1PwxAd4gBiIiuZzIJOJRRgrDOmiYvCJlVXInM4kqM7OqJYp0B9nZyqOzkUCrkkMlkqDWacKlIB19XNZyUdvjPT+fhr1Hj4WGB+PL3TLz1/Wm8OaUfpgzqJJ7z57P5WLTtBDQO9sgoqsSgIHf8bWJvONgrcN97v6D86h17m5+KQIBb3Rwgtb0CxToDvjqUhVM5Wjw3tjtW/Xwee8/VTfy2k8vMlj+I7O2DZ//QHet/u4QtV5dkuFGwpyOySqrMAlxEqCci+/jiHztON2mYsLU1NuzporKDs9qu0bllQN3SEkO6uCOjSIeD6cWoNQko0RkgANj05AgoFXW9Zv/cdRZyWd1dmmVVNXjqnq6YFdHFrHcyu7QKRzJL4GCvgFwmQ6i3Ezq5OeBSkQ6ZxZXwclah/9U1wS4UVKDSYET/Thq8n3ge//7xHADg1Qk9MX9MtwZ1VhpqsfN4LhZsOQYHewV+XfAH5JRVY/G2k3hqdCjGXDcB33j1Bga5XAaTSWjxXLP2igHoDjEAEVF7cy6vHJsOZsFPo8ITo0Kb9Xw5QRCwLTUbn+y7iJXTBqGbz7VHrqxJuoC88mpcKanCXyK6oF8nDbydVSitrEFljREBGjUKKvRwc1BCaSdHWWUNKgy16OTmAJNJwKJvT2DH8VwU6wxwVCqw6pHBSL5QhB3HcxDi5dRg/lZkbx/MG9MVv18qwYkrZejXSYPqGiM++/XSLde9up7aXg4PR6W4MrtCLoNCJhPvCGyNgObrqkIXDydkl1WhymBEkc5w22OCPByhspMj7WpPWlhnDY5fKTPrnVt0fx94OCnhrLLDxoOZ6Bvgiq8OXzYLcfeH+Zut5TVnZDD2XyyC2l6B9EId/FzVmDemK5Z+dxIju3lh8aQ+OJtbjn3nC/Hb+SJMHx6IYE+nusfpdK67voXlBvhp1CjS6XFff38AwJJvT0Iuk6G7jzM0jvYY3cMHHk5KAEBppQGXS6rg46qCj4tarCU1qxTrfruEuXcFw0llh1AvJ8hkMlTXGLE15Qr6BrgirLMbdPraJq/31RwMQHeIAYiIqPUYTQK+P5aNnn4u6OVn/m/q37efQlp+BeaN7goPJ+VNl3KoNZpgEoCfzuThQoEOEV09cSSjBJMGBEAQgONXylBjrBtGUshkUNnJsTetAOXVtRjdo26oK1dbjR4+LpDJgK0pV3AqW4uzeeUI8nDEH3r5wM3RHtU1Jry0+ShytdcCh51chifuCcWQIHccu1KGDcmXGp1k39vfFQo5YDQBZ3O1MAl1C5IGeTgio6gSVTWNPyfwgUGdYKg1YfvxhguUSqGLpyPcHJU4mlVqtl0uA8JDPCGXA/svFotBMtjTEcNDPBDq7Yx/7z4nrm4PAN19nKG0k+NSoQ46g/HqnDENLhbo8N+/DEF4qGer1s4AdIcYgIiIbFdZVQ1yy6rRxdMRRzJKEOLtBH+Ng7hfX2tE4ul8FFbo0dXbGSeulKGHn4vZWlCXCnXIKqnE8BAPqOwU0Olr8UtaIWqMJkR09cRHey7gk33pGBbsgQ/+PAgV1bX44we/AgBMgiCu4h7i5YQ8bTXWzhmGYE8nRK3ci1qjCf07azC0iwf2nCtAaZUBUX38sP14ToPhvhuHBl3Udgj1dkZ5VQ0KyvUo19fCwV6BYC8nnLnJXDEXlR18XFW4UGC+UrynkxLFlYYWzy+7r78fPnxkSMsOvgkGoDvEAERERG3NaBLMJt+X6AxQ2ctRYxSQcCIHkb194eGkRI1RgNKubr5RWWUNlHbyRpcxMNSaUFVjhJNSgcIKAzydlag1CjiUUYxgTydklVQiItRTHB41mQT8cr4QPX1d4KepG8YqrTRg/8UiZJfWBcDhIR5wsFfATiHHzuM5+N/BTIwI9cSksAAEeTqirKoGRzJKcPBSMU5cKUNPXxdMDPPH4YwSjOnpgwsFdcN9fq5qdPd1xjdHrmBN0gX8cWAAYsb1aPXJ5gxAd4gBiIiIqP1pzt9vy93nR0RERGQlGICIiIjI5jAAERERkc2xigC0atUqBAcHQ61WIzw8HAcPHrxl+6+++gq9evWCWq1G//79sWPHDrP9giBg0aJF8Pf3h4ODAyIjI5GWltaWX4GIiIjaEckD0JdffomYmBgsXrwYR44cwYABAxAVFYX8/PxG2//222+YMWMGHnvsMaSkpGDKlCmYMmUKTpw4IbZZsWIF3n//faxZswYHDhyAk5MToqKiUF3d+GqgREREZFskvwssPDwcw4YNwwcffAAAMJlMCAwMxF//+lcsWLCgQftp06ZBp9Ph+++/F7eNGDECAwcOxJo1ayAIAgICAvDSSy/h5ZdfBgCUlZXB19cXcXFxmD59+m1r4l1gRERE7U+7uQvMYDDg8OHDiIyMFLfJ5XJERkYiOTm50WOSk5PN2gNAVFSU2D49PR25ublmbTQaDcLDw296Tr1eD61Wa/YiIiKijkvSAFRYWAij0QhfX1+z7b6+vsjNzW30mNzc3Fu2r/+/zTnnsmXLoNFoxFdgYGCLvg8RERG1D5LPAbIGsbGxKCsrE19ZWVlSl0RERERtSNIA5OXlBYVCgby8PLPteXl58PPza/QYPz+/W7av/7/NOadKpYKrq6vZi4iIiDouSQOQUqnEkCFDkJiYKG4zmUxITExEREREo8dERESYtQeA3bt3i+1DQkLg5+dn1kar1eLAgQM3PScRERHZFjupC4iJicHs2bMxdOhQDB8+HCtXroROp8PcuXMBALNmzUKnTp2wbNkyAMDzzz+P0aNH45133sHEiROxadMmHDp0CP/9738BADKZDC+88ALeeustdO/eHSEhIVi4cCECAgIwZcoUqb4mERERWRHJA9C0adNQUFCARYsWITc3FwMHDkRCQoI4iTkzMxNy+bWOqpEjR+J///sf/va3v+H1119H9+7dER8fj379+oltXn31Veh0Ojz55JMoLS3F3XffjYSEBKjVaot/PyIiIrI+kq8DZI24DhAREVH705y/35L3AFmj+kzI9YCIiIjaj/q/203p22EAakR5eTkAcD0gIiKidqi8vBwajeaWbTgE1giTyYTs7Gy4uLhAJpO16rm1Wi0CAwORlZXF4bU2xOtsGbzOlsHrbBm8zpbRltdZEASUl5cjICDAbP5wY9gD1Ai5XI7OnTu36c/gekOWwetsGbzOlsHrbBm8zpbRVtf5dj0/9bgSNBEREdkcBiAiIiKyOQxAFqZSqbB48WKoVCqpS+nQeJ0tg9fZMnidLYPX2TKs5TpzEjQRERHZHPYAERERkc1hACIiIiKbwwBERERENocBiIiIiGwOA5AFrVq1CsHBwVCr1QgPD8fBgwelLqld2bt3LyZNmoSAgADIZDLEx8eb7RcEAYsWLYK/vz8cHBwQGRmJtLQ0szbFxcV45JFH4OrqCjc3Nzz22GOoqKiw4LewfsuWLcOwYcPg4uICHx8fTJkyBWfPnjVrU11djWeeeQaenp5wdnbG1KlTkZeXZ9YmMzMTEydOhKOjI3x8fPDKK6+gtrbWkl/Fqq1evRphYWHiYnARERHYuXOnuJ/XuG28/fbbkMlkeOGFF8RtvNZ3bsmSJZDJZGavXr16ifut8hoLZBGbNm0SlEqlsHbtWuHkyZPCE088Ibi5uQl5eXlSl9Zu7NixQ/i///s/YcuWLQIAYevWrWb73377bUGj0Qjx8fHC0aNHhT/+8Y9CSEiIUFVVJbaZMGGCMGDAAGH//v3CL7/8InTr1k2YMWOGhb+JdYuKihI+++wz4cSJE0Jqaqpw3333CUFBQUJFRYXY5umnnxYCAwOFxMRE4dChQ8KIESOEkSNHivtra2uFfv36CZGRkUJKSoqwY8cOwcvLS4iNjZXiK1mlb7/9Vti+fbtw7tw54ezZs8Lrr78u2NvbCydOnBAEgde4LRw8eFAIDg4WwsLChOeff17czmt95xYvXiz07dtXyMnJEV8FBQXifmu8xgxAFjJ8+HDhmWeeET8bjUYhICBAWLZsmYRVtV83BiCTyST4+fkJ//znP8VtpaWlgkqlEjZu3CgIgiCcOnVKACD8/vvvYpudO3cKMplMuHLlisVqb2/y8/MFAMKePXsEQai7rvb29sJXX30ltjl9+rQAQEhOThYEoS6syuVyITc3V2yzevVqwdXVVdDr9Zb9Au2Iu7u78Mknn/Aat4Hy8nKhe/fuwu7du4XRo0eLAYjXunUsXrxYGDBgQKP7rPUacwjMAgwGAw4fPozIyEhxm1wuR2RkJJKTkyWsrONIT09Hbm6u2TXWaDQIDw8Xr3FycjLc3NwwdOhQsU1kZCTkcjkOHDhg8Zrbi7KyMgCAh4cHAODw4cOoqakxu9a9evVCUFCQ2bXu378/fH19xTZRUVHQarU4efKkBatvH4xGIzZt2gSdToeIiAhe4zbwzDPPYOLEiWbXFOB/z60pLS0NAQEBCA0NxSOPPILMzEwA1nuN+TBUCygsLITRaDT7xQKAr68vzpw5I1FVHUtubi4ANHqN6/fl5ubCx8fHbL+dnR08PDzENmTOZDLhhRdewF133YV+/foBqLuOSqUSbm5uZm1vvNaN/S7q91Gd48ePIyIiAtXV1XB2dsbWrVvRp08fpKam8hq3ok2bNuHIkSP4/fffG+zjf8+tIzw8HHFxcejZsydycnKwdOlSjBo1CidOnLDaa8wAREQ39cwzz+DEiRPYt2+f1KV0SD179kRqairKysrw9ddfY/bs2dizZ4/UZXUoWVlZeP7557F7926o1Wqpy+mwoqOjxfdhYWEIDw9Hly5dsHnzZjg4OEhY2c1xCMwCvLy8oFAoGsx4z8vLg5+fn0RVdSz11/FW19jPzw/5+flm+2tra1FcXMzfQyOeffZZfP/99/j555/RuXNncbufnx8MBgNKS0vN2t94rRv7XdTvozpKpRLdunXDkCFDsGzZMgwYMADvvfcer3ErOnz4MPLz8zF48GDY2dnBzs4Oe/bswfvvvw87Ozv4+vryWrcBNzc39OjRA+fPn7fa/54ZgCxAqVRiyJAhSExMFLeZTCYkJiYiIiJCwso6jpCQEPj5+ZldY61WiwMHDojXOCIiAqWlpTh8+LDY5qeffoLJZEJ4eLjFa7ZWgiDg2WefxdatW/HTTz8hJCTEbP+QIUNgb29vdq3Pnj2LzMxMs2t9/Phxs8C5e/duuLq6ok+fPpb5Iu2QyWSCXq/nNW5FY8eOxfHjx5Gamiq+hg4dikceeUR8z2vd+ioqKnDhwgX4+/tb73/PbTK1mhrYtGmToFKphLi4OOHUqVPCk08+Kbi5uZnNeKdbKy8vF1JSUoSUlBQBgPDuu+8KKSkpQkZGhiAIdbfBu7m5Cdu2bROOHTsmTJ48udHb4AcNGiQcOHBA2Ldvn9C9e3feBn+DefPmCRqNRkhKSjK7pbWyslJs8/TTTwtBQUHCTz/9JBw6dEiIiIgQIiIixP31t7SOHz9eSE1NFRISEgRvb2/eNnydBQsWCHv27BHS09OFY8eOCQsWLBBkMpmwa9cuQRB4jdvS9XeBCQKvdWt46aWXhKSkJCE9PV349ddfhcjISMHLy0vIz88XBME6rzEDkAX95z//EYKCggSlUikMHz5c2L9/v9QltSs///yzAKDBa/bs2YIg1N0Kv3DhQsHX11dQqVTC2LFjhbNnz5qdo6ioSJgxY4bg7OwsuLq6CnPnzhXKy8sl+DbWq7FrDED47LPPxDZVVVXC/PnzBXd3d8HR0VF44IEHhJycHLPzXLp0SYiOjhYcHBwELy8v4aWXXhJqamos/G2s16OPPip06dJFUCqVgre3tzB27Fgx/AgCr3FbujEA8VrfuWnTpgn+/v6CUqkUOnXqJEybNk04f/68uN8ar7FMEAShbfqWiIiIiKwT5wARERGRzWEAIiIiIpvDAEREREQ2hwGIiIiIbA4DEBEREdkcBiAiIiKyOQxAREREZHMYgIiIiMjmMAARUbtVUFCAefPmISgoCCqVCn5+foiKisKvv/4KAJDJZIiPj5e2SCKySnZSF0BE1FJTp06FwWDAunXrEBoairy8PCQmJqKoqEjq0ojIyvFRGETULpWWlsLd3R1JSUkYPXp0g/3BwcHIyMgQP3fp0gWXLl0CAGzbtg1Lly7FqVOnEBAQgNmzZ+P//u//YGdX978JZTIZPvzwQ3z77bdISkqCv78/VqxYgT/96U8W+W5E1PY4BEZE7ZKzszOcnZ0RHx8PvV7fYP/vv/8OAPjss8+Qk5Mjfv7ll18wa9YsPP/88zh16hQ++ugjxMXF4e9//7vZ8QsXLsTUqVNx9OhRPPLII5g+fTpOnz7d9l+MiCyCPUBE1G598803eOKJJ1BVVYXBgwdj9OjRmD59OsLCwgDU9eRs3boVU6ZMEY+JjIzE2LFjERsbK277/PPP8eqrryI7O1s87umnn8bq1avFNiNGjMDgwYPx4YcfWubLEVGbYg8QEbVbU6dORXZ2Nr799ltMmDABSUlJGDx4MOLi4m56zNGjR/HGG2+IPUjOzs544oknkJOTg8rKSrFdRESE2XERERHsASLqQDgJmojaNbVajXHjxmHcuHFYuHAhHn/8cSxevBhz5sxptH1FRQWWLl2KBx98sNFzEZFtYA8QEXUoffr0gU6nAwDY29vDaDSa7R88eDDOnj2Lbt26NXjJ5df+Sdy/f7/Zcfv370fv3r3b/gsQkUWwB4iI2qWioiI89NBDePTRRxEWFgYXFxccOnQIK1aswOTJkwHU3QmWmJiIu+66CyqVCu7u7li0aBHuv/9+BAUF4U9/+hPkcjmOHj2KEydO4K233hLP/9VXX2Ho0KG4++678cUXX+DgwYP49NNPpfq6RNTKOAmaiNolvV6PJUuWYNeuXbhw4QJqamoQGBiIhx56CK+//jocHBzw3XffISYmBpcuXUKnTp3E2+B/+OEHvPHGG0hJSYG9vT169eqFxx9/HE888QSAuknQq1atQnx8PPbu3Qt/f38sX74cDz/8sITfmIhaEwMQEdENGrt7jIg6Fs4BIiIiIpvDAEREREQ2h5OgiYhuwJkBRB0fe4CIiIjI5jAAERERkc1hACIiIiKbwwBERERENocBiIiIiGwOAxARERHZHAYgIiIisjkMQERERGRzGICIiIjI5vw/8PvxDtBNy/wAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# --- Plotting Results ---\n", "plt.figure()\n", "plt.title(\"Training Loss\")\n", "plt.plot(losses)\n", "plt.xlabel(\"Step\")\n", "plt.ylabel(\"MSE Loss\")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 312 }, "id": "RpYtm93FY0nq", "outputId": "e1d39c83-6c37-4816-9827-55948ad5af60" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA/QAAAEnCAYAAAAHEAjKAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XV8FHf6wPHPrMXdjRhOCEEDtDiF0lLXq+vdVe5Xt+tdXa7e3tWVXuWqtKWlxd01QRLiIe6ysc3KzO+PSQIhCSSQ7Cbwfb9evMjuzs58ZzPZ3ecrzyMpiqIgCIIgCIIgCIIgCMKAonF0AwRBEARBEARBEARB6DkR0AuCIAiCIAiCIAjCACQCekEQBEEQBEEQBEEYgERALwiCIAiCIAiCIAgDkAjoBUEQBEEQBEEQBGEAEgG9IAiCIAiCIAiCIAxAIqAXBEEQBEEQBEEQhAFIBPSCIAiCIAiCIAiCMACJgF4QBEEQBEEQBEEQBiAR0AuCIAhCH1i3bh2SJLFu3TpHN0UYAJ566ikkSWp3X1RUFDfddFOvHeOmm24iKiqq1/YnCIIgOJ4I6AVBEAS7W7RoEZIkdflv27Ztjm5iv1ReXs4999zD8OHDcXFxITAwkEmTJvHII49QX1/v6OYNaEdffxqNhtDQUObNmzfgOmSKiop46qmnSEpKcnRTBEEQBDvQOboBgiAIwpnrmWeeITo6usP9gwcPdkBr+reqqiomTJiA0WjklltuYfjw4VRWVrJv3z7ee+897rjjDtzd3R3dzAHtnHPO4YYbbkBRFHJycnj33XeZPXs2S5cuZcGCBXZvT1paGhpNz8ZeioqKePrpp4mKiiIhIaHdYx999BGyLPdiCwVBEARHEwG9IAiC4DALFixgwoQJjm7GgPDJJ5+Ql5fH5s2bmTp1arvHjEYjBoPBQS07fQwdOpTrrruu7fYll1xCfHw8b775ZpcBvclkwmAw9Djw7g4nJ6de3Z9er+/V/QmCIAiOJ6bcC4IgCP3Wk08+iUajYfXq1e3u//Of/4zBYCA5ORkAs9nME088wfjx4/Hy8sLNzY1p06axdu3ads/Lzc1FkiReffVV3nnnHWJiYnB1dWXevHnk5+ejKArPPvss4eHhuLi4cNFFF1FVVdVuH1FRUSxcuJAVK1aQkJCAs7MzI0eOZPHixd06p+3bt3Puuefi5eWFq6srM2bMYPPmzSd8XlZWFlqtlsmTJ3d4zNPTE2dn55M6zqZNm5g4cSLOzs7ExsbywQcfdFjP3fq6LVq0qMPzJUniqaeeandfYWEht9xyC0FBQTg5OTFq1Cg+/fTTdtu05hj47rvveP755wkPD8fZ2Zk5c+aQmZnZ4Tjbt2/nvPPOw8fHBzc3N+Lj43nrrbfabXPo0CEuv/xyfH19cXZ2ZsKECSxZsqTDvrpr9OjR+Pv7k5OT067N33zzDf/4xz8ICwvD1dUVo9HY1saTfc0709ka+pqaGu677z6ioqJwcnIiPDycG264gYqKCtatW8fEiRMBuPnmm9uWELT+3jpbQ9/Q0MADDzxAREQETk5ODBs2jFdffRVFUdptJ0kSd999Nz///DNxcXFtv9dly5b19GUVBEEQepEYoRcEQRAcpra2loqKinb3SZKEn58fAP/4xz/49ddfufXWW9m/fz8eHh4sX76cjz76iGeffZYxY8YA6gj1xx9/zJ/+9Cduv/126urq+OSTT5g/fz47duzoMPX4q6++wmw287e//Y2qqipefvllrrzySmbPns26det45JFHyMzM5D//+Q8PPvhgh2A0IyODq666ir/+9a/ceOONfPbZZ1xxxRUsW7aMc845p8vzXbNmDQsWLGD8+PFtnRWfffYZs2fPZuPGjUyaNKnL50ZGRmKz2fjiiy+48cYbj/u6dvc4+/fvZ968eQQEBPDUU09htVp58sknCQoKOu7+j6e0tJTJkye3BYABAQH88ccf3HrrrRiNRu6999522//rX/9Co9Hw4IMPUltby8svv8y1117L9u3b27ZZuXIlCxcuJCQkhHvuuYfg4GBSU1P57bffuOeeewA4ePAgZ511FmFhYTz66KO4ubnx3XffcfHFF/Pjjz9yySWX9Phcqqurqa6u7rAE5Nlnn8VgMPDggw/S3NyMwWCwy2teX1/PtGnTSE1N5ZZbbmHcuHFUVFSwZMkSCgoKGDFiBM888wxPPPEEf/7zn5k2bRpAhxkdrRRF4cILL2Tt2rXceuutJCQksHz5ch566CEKCwt544032m2/adMmFi9ezJ133omHhwf//ve/ueyyy8jLy2v7mxUEQRDsTBEEQRAEO/vss88UoNN/Tk5O7bbdv3+/YjAYlNtuu02prq5WwsLClAkTJigWi6VtG6vVqjQ3N7d7XnV1tRIUFKTccsstbffl5OQogBIQEKDU1NS03f/YY48pgDJmzJh2+/3Tn/6kGAwGxWQytd0XGRmpAMqPP/7Ydl9tba0SEhKijB07tu2+tWvXKoCydu1aRVEURZZlZciQIcr8+fMVWZbbtmtsbFSio6OVc84557ivWUlJiRIQEKAAyvDhw5W//vWvytdff93uPHp6nIsvvlhxdnZWDh8+3HZfSkqKotVqlaO/IrS+bp999lmHdgHKk08+2Xb71ltvVUJCQpSKiop221199dWKl5eX0tjY2O71GTFiRLvf3VtvvaUAyv79+xVFUX+30dHRSmRkpFJdXd3hXFvNmTNHGT16dLvflSzLytSpU5UhQ4Z0aHdn53Hrrbcq5eXlSllZmbJ9+3Zlzpw5CqC89tpr7docExPTdh6tx+nt11xR1GvtxhtvbLv9xBNPKICyePHiDu1vPe7OnTu7/F3deOONSmRkZNvtn3/+WQGU5557rt12l19+uSJJkpKZmdnu9TEYDO3uS05OVgDlP//5T4djCYIgCPYhptwLgiAIDvPOO++wcuXKdv/++OOPdtvExcXx9NNP8/HHHzN//nwqKir4/PPP0emOTDLTarVta8hlWaaqqgqr1cqECRPYs2dPh+NeccUVeHl5td1OTEwE4Lrrrmu338TERMxmM4WFhe2eHxoa2m7E19PTkxtuuIG9e/dSUlLS6bkmJSWRkZHBNddcQ2VlJRUVFVRUVNDQ0MCcOXPYsGHDcROWBQUFkZyczF//+leqq6t5//33ueaaawgMDOTZZ59tmyLd3ePYbDaWL1/OxRdfzKBBg9qOM2LECObPn99lO45HURR+/PFHLrjgAhRFaTt2RUUF8+fPp7a2tsPv4+abb263/r91VDk7OxuAvXv3kpOTw7333ou3t3e757YuC6iqqmLNmjVceeWV1NXVtR2zsrKS+fPnk5GR0eF32JlPPvmEgIAAAgMDSUxMZPPmzdx///0dZhXceOONuLi4tN2212v+448/MmbMmE5nGxxb8q47fv/9d7RaLf/3f//X7v4HHngARVE6/C3OnTuX2NjYttvx8fF4enq2/a4EQRAE+xNT7gVBEASHmTRpUreS4j300EN888037NixgxdeeIGRI0d22Obzzz/ntdde49ChQ1gslrb7O8uif3QwBbQF9xEREZ3eX11d3e7+wYMHdwighg4dCqjrzYODgzscMyMjA+C40+Vra2vx8fHp8vGQkBDee+893n33XTIyMli+fDkvvfQSTzzxBCEhIdx2223dPk5zczNNTU0MGTKkw+PDhg3j999/7/L5XSkvL6empoYPP/yQDz/8sNNtysrK2t0+9nfRev6tr3lWVhagdux0JTMzE0VR+Oc//8k///nPLo8bFhZ23PZfdNFF3H333UiShIeHB6NGjcLNza3DdsdeU/Z6zbOysrjsssuOu01PHD58mNDQUDw8PNrdP2LEiLbHj3bs7wrU39exfx+CIAiC/YiAXhAEQej3srOz24Km/fv3d3j8yy+/5KabbuLiiy/moYceIjAwEK1Wy4svvtgWEB5Nq9V2epyu7leOSRB2MlpH31955ZUOa/pbdbfsnCRJDB06lKFDh3L++eczZMgQvvrqK2677bZuH6e5ubnbbe9q9Ndms7W73Xrs6667rsvgNj4+vt3t3njNW4/74IMPdjnS3Z1SiOHh4cydO/eE2x09On/08XvzNe+P+vLvQxAEQTg5IqAXBEEQ+jVZlrnpppvw9PTk3nvv5YUXXuDyyy/n0ksvbdvmhx9+ICYmhsWLF7cLPp988sk+aVPriPDRx0pPTwfokEW8VetUZU9Pz24Fjd0VExODj48PxcXFPTpOQEAALi4ubR0lR0tLS2t3u3XUvKampt39x47gBgQE4OHhgc1m67VzbD2fAwcOdLnPmJgYQC3L1puvbXf1xWve1XEOHDhw3G16MvU+MjKSVatWUVdX126U/tChQ22PC4IgCP2bWEMvCIIg9Guvv/46W7Zs4cMPP+TZZ59l6tSp3HHHHe2y47eOHB49Urh9+3a2bt3aJ20qKirip59+arttNBr573//S0JCQqfT7QHGjx9PbGwsr776KvX19R0eLy8vP+4xt2/fTkNDQ4f7d+zYQWVlJcOGDevRcbRaLfPnz+fnn38mLy+v7fHU1FSWL1/e7jmenp74+/uzYcOGdve/++677W5rtVouu+wyfvzxx04DzxOdY2fGjRtHdHQ0b775ZocOhdbfd2BgIDNnzuSDDz5o69g41eP2RF+85p257LLLSE5ObnfttWp9LVqXCBz7WnXmvPPOw2az8fbbb7e7/4033kCSJBYsWHDCfQiCIAiOJUboBUEQBIf5448/2kYDjzZ16lRiYmJITU3ln//8JzfddBMXXHABAIsWLSIhIYE777yT7777DoCFCxeyePFiLrnkEs4//3xycnJ4//33GTlyZKcB1qkaOnQot956Kzt37iQoKIhPP/2U0tJSPvvssy6fo9Fo+Pjjj1mwYAGjRo3i5ptvJiwsjMLCQtauXYunpye//vprl8//4osv+Oqrr7jkkksYP348BoOB1NRUPv30U5ydnfn73//e4+M8/fTTLFu2jGnTpnHnnXditVr5z3/+w6hRo9i3b1+74992223861//4rbbbmPChAls2LChbVbC0f71r3+xdu1aEhMTuf322xk5ciRVVVXs2bOHVatWUVVV1aPXWqPR8N5773HBBReQkJDAzTffTEhICIcOHeLgwYNtgfA777zD2WefzejRo7n99tuJiYmhtLSUrVu3UlBQQHJyco+O29M29sVrfqyHHnqIH374gSuuuIJbbrmF8ePHU1VVxZIlS3j//fcZM2YMsbGxeHt78/777+Ph4YGbmxuJiYmd5pK44IILmDVrFo8//ji5ubmMGTOGFStW8Msvv3Dvvfe2S4AnCIIg9FOOSa4vCIIgnMmOV7aOlpJbVqtVmThxohIeHt6hNFtrabNvv/1WURS1ZNcLL7ygREZGKk5OTsrYsWOV3377rUOZrtbya6+88kq7/bWWI/v+++87befOnTvb7ouMjFTOP/98Zfny5Up8fLzi5OSkDB8+vMNzjy1b12rv3r3KpZdeqvj5+SlOTk5KZGSkcuWVVyqrV68+7mu2b98+5aGHHlLGjRun+Pr6KjqdTgkJCVGuuOIKZc+ePR227+5x1q9fr4wfP14xGAxKTEyM8v777ytPPvlkhxJqjY2Nyq233qp4eXkpHh4eypVXXqmUlZV1KFunKIpSWlqq3HXXXUpERISi1+uV4OBgZc6cOcqHH354wte8qxJ5mzZtUs455xzFw8NDcXNzU+Lj4zuUS8vKylJuuOEGJTg4WNHr9UpYWJiycOFC5Ycffjjua6soalm2u+6667jbdNXmVr39mh9btk5RFKWyslK5++67lbCwMMVgMCjh4eHKjTfe2K5M4C+//KKMHDlS0el07V7LY/8eFEVR6urqlPvuu08JDQ1V9Hq9MmTIEOWVV15pV37veK9PZ20UBEEQ7EdSFJHJRBAEQRC6Kyoqiri4OH777TdHN6XPPPXUUzz99NMi2ZkgCIIg9HNiDb0gCIIgCIIgCIIgDEAioBcEQRAEQRAEQRCEAUgE9IIgCIIgCIIgCIIwAIk19IIgCIIgCIIgCIIwAIkRekEQBEEQBEEQBEEYgERALwiCIAiCIAiCIAgDkM7RDejvZFmmqKgIDw8PJElydHMEQRAEQRAEQRCE05yiKNTV1REaGopG0/U4vAjoT6CoqIiIiAhHN0MQBEEQBEEQBEE4w+Tn5xMeHt7l4yKgPwEPDw9AfSE9PT0d3JquybJMeXk5AQEBx+3BEc4c4poQjiWuCaEz4roQjiWuCeFY4poQjiWuib5nNBqJiIhoi0e7MqAC+g0bNvDKK6+we/duiouL+emnn7j44ou73H7dunXMmjWrw/3FxcUEBwd365it0+w9PT37fUBvMpnw9PQUf1QCIK4JoSNxTQidEdeFcCxxTQjHEteEcCxxTdjPiZZ9D6hXv6GhgTFjxvDOO+/06HlpaWkUFxe3/QsMDOyjFgqCIAiCIAiCIAiCfQyoEfoFCxawYMGCHj8vMDAQb2/v3m+QIAiCIAiCIAiCIDjIgAroT1ZCQgLNzc3ExcXx1FNPcdZZZ3W5bXNzM83NzW23jUYjoE4rkWW5z9t6smRZRlGUft1Gwb7ENSEcS1wTQmfEdSEcS1wTwrHENSEcS1wTfa+7r+1pHdCHhITw/vvvM2HCBJqbm/n444+ZOXMm27dvZ9y4cZ0+58UXX+Tpp5/ucH95eTkmk6mvm3zSZFmmtrYWRVHEOhYBENeE0JG4JoTOiOtCOJa4JoRjiWtCOJa4JvpeXV1dt7aTFEVR+rgtfUKSpBMmxevMjBkzGDRoEF988UWnj3c2Qh8REUF1dXW/T4onMk0KRxPXhHAscU0InRHXhXAscU0IxxLXhHAscU30PaPRiI+PD7W1tceNQ0/rEfrOTJo0iU2bNnX5uJOTE05OTh3u12g0/f5ilSRpQLRTsB9xTQjHEtfEwJZf1cinm3PwdjEwLNidoUEeRPq5odUcPwPuiYjrQjiWuCaEY4lrQjiWuCb6Vndf1zMuoE9KSiIkJMTRzRAEQRCEHimrM/HPXw5QWW8GYFVqKQAGnYbBAe4MDfYgLtSTSdG+JyxxIwiCIAjC6WFABfT19fVkZma23c7JySEpKQlfX18GDRrEY489RmFhIf/9738BePPNN4mOjmbUqFGYTCY+/vhj1qxZw4oVKxx1CoIgCILQY7WNFp74+SCV9WbCfVyYGutHWmkd6aX1NJltpBQbSSk28vPeQv5+3gimxPo5usmCIAiCINjBgArod+3axaxZs9pu33///QDceOONLFq0iOLiYvLy8toeN5vNPPDAAxQWFuLq6kp8fDyrVq1qtw9BEARB6M+azDae+vUghTVNBHg48ezFcfi7q0vDZFmhsKaJ9NI61qaVkZxfy+rUUhHQC4IgCMIZYkAF9DNnzuR4OfwWLVrU7vbDDz/Mww8/3MetEgRBEIS+YbbKPLc0hcyyejxddDxz0ai2YB5Ao5GI8HUlwteVIYEe3PX1HnYdrsZosuDprD+lY1c2VVJYX0h8QPypnoYgCIIg9LpGSyMlDSXEeMc4uikOJTIYCIIgCEI/JMsKr61IY19BLS56LU9dMIpwH9cutx/k50q0vxs2WWFTRsUpH/+Zrc/w/Pbn2VWy65T3JQiCIAi97f3k93ls02NsKNjg6KY4lAjoBUEQBKGfURSFd9ZmsiWrEp1W4vHzRzAkyOOEz5s1PACAtYfKTvn4JY0lAGwq7LoyjCAIgiA4gtFsZFep2uH89aGvabY1n+AZpy8R0AuCIAhCP/PltsOsSClFI8FD84YxJsK7W8+bPiQAjQSHSuooqTWd9PElSeK20bcBUG+pP+n9CIIgCEJf2F68HZtiA6DaVM3S7KUObpHjiIBeEARBEPqRUqOJ73YVAHDXrMFMHezf7ef6uTsRH+4NwNq0UxulH+w9GIDs2mxkRT6lfQmCIAhCb9pcuBmAYT7DGOE7gjEBYxzcIscZUEnxBEEQBOF0tzWrEoC4MC/mjQru8fNnDQ8gKb+GdWllXD0x4qRr0kd4RKDX6GmwNFDSUEKoe+hJ7UcQBEEQepPZZsZoNgLwf+P+Dz9nv5P+rDsdiIBeEARBEPqRLVlqQruzBp9c6bkpMf68o8uiqMZERlk9Q7ux9v5Yr+x8pd16xMyaTBHQC4IgCP2CQWvgtRmvUdxQjL9L+1lsiqKcccG9mHIvCIIgCP1EdYOZQyV1AEyOObmA3sWgZUrLc08mOZ6syByoOMD+iv2M8htFuHv4SbVDEARBEPqKJEntOprrzfV8kfIFb+x5w4GtcgwxQi8IgiAI/cT2nEoUBYYEuberN99Ts4YHsD69nI0ZFdx6djQ6bff770sbSjHZTBg0Bh6e+DBajfak2yEIgiCcWUqNJl5elsboME+unxKFVtO7o+WNlkZ0Gh0GraHd/Uazkd+zf0dG5mDlQUb5jerV4/ZnYoReEARBEPqJLS3r56fGdj8RXmcSInzwdtVT22Rhb35Nj56bU5sDQKRnpAjmBUEQhB75dmc+6aV1/LinkCd+OYDRZOnV/f+e8zu3r7id37J/a3d/qHsocyPnAvBlypdnVDJXEdALgiAIQj9QZ7KQXFALwJTYk5tu30qrkZg+5ORq0ucY1YA+2iu67T6rbD2ja/wKgiAIJ2Y0WVjXUmFFr5XYV1DL/d8mkVvR0Cv7VxSFLUVbMNlMeOg75oe5fOjlOGudya7NZkvRll455kAgAnpBEARB6Ad25VYjywqD/FwJ83Y55f3NHKYG9NuyK2k0W7v9vNYR+ujaUvjhVr5aejs3/3I5q7P/OOU2CYIgCKevFQdLsdgUYgPceOOqBII8nSk1NvPQD8lsyaw45f0fNh6msL4QvUbPxOCJHR73cvLi4sEXA/C/Q//DYuvd2QH9lQjoBUEQBKEfaM1uP/UUR+dbDQ50J8zbBYtNYUtmZbeeoyiKGtArCtGZ66EyE5fi/Zhrcsnc+C9Y+iDs+x5q8kFRetSeRksjRfVFJ3MqgiAIQj9nkxV+318MwML4UCL93HjjqjGMifDCZJF58Y9DfL09D1nu2WfH0TYXqbXnxwWOw1Xv2uk258Wch6+zLxVNFfyRe2Z0RIuAXhAEQRAczGSxsftwNUBbhvpTJUkSs4aro/Tr0rs37b7J2kS0VzQ+aAhvbgIXbwZHzgStnkzMULATtr4N314HP94GuZu6FdjLiswL21/g/nX3k1SWdApnJQiCINhLZX0zK1NKMVtPvB59e04l5XXNeLromD5U/ezxcNbz9IVxXJSgZqP/3448XvwjlYbm7s8aayUrcts0+qlhU7vczknrxNXDrgbg58yfz4hRehHQC4IgCIKD7TlcjcWmEOTpTLS/W6/td+awQAD2FdRSUX/iNfCuelf+MfkfvOeegB4JYmYSO/Of4BtDqW8kxgm3QNh40OigMhOWPw4//RXyth83sD9sPExGTQYKCh/v//iM+IIlCIIw0L2xKp1/r87g0805J9z212R1dH7+qGAMuiMhplYjcdu0GO6ZMwSdVmJbdhV/+99eDhTW9qgtGdUZVDRV4Kx1ZlzguONuOy18GrfG3crzZz2PXqvv0XEGIhHQC4IgCIKDHclu74ck9V6JnyBPZ0aGeKIosD6tvHtPslmRDm9Sf46egZvejVC3UNAZyIqIh4Wvw/U/wdjrQO8C5Yfgj4fhl7uhYHengX20VzQvTXsJgPKm8g7ZiQVBEIT+5XBlA8n5atC9dF/xcQPw3IoGDhTWopFgQVxIp9vMHRnES5fFE+TpTHldM3//aT//3ZqLxda9bPSt0+0nBk/sULLuWBpJw7yoeYS4d96W040I6AVBEATBgSw2mR25VcCpZ7fvTOu0+1WppZgstuNua7KaoGgvNNeBizeEjAFgsM9gADJrMtUNnT1h0u1w9dcQfxVoDVB6AJbeD7/eA3UlHfYd5RXFXQl3AfBT5k9UNJ16giRBEAShb/y2Tx1x12vVTuZ/r87o8jNkacva+cmxfgR4OHW5z6FBHvznT2OZOyIIRYHvdxXw0PfJ5Fc1nrA950Wfx5VDr2TOoDk9PZXTngjoBUEQBMGB9hXU0GS24e2qZ1hQxzI8p+qswf64GrQUVDfx1JKDNJk7/0KmKAr/t+b/uHvrExRjg6hp0FKHfrB3S0Bfndn+Sa6+MOVO+NP/YNQloNVDcTJsfgsAs81Mbm1u2+bTwqYx3Hc4zbZmvkj5otfPVRAEQTh1dSZLW8nTR84djp+7geJaE19vzzvuthfEh55w3y4GLffMHcJjC4bj7qQjq7yBe79N4vf9xSjHWboV7BbMZUMvY4TfiG6fx/r89by26zVKGjp2Mp9OREB/GpEVGYss1iUKgiAMJFtbpttPifVDo+m96fatPJz1PHXhKFwMWg4WGfnHzweoM3X8rKgyVVHbXEtlUxl+aCBmVttjw32HMz5oPAmBCZ0fxM0fzr4XLn5PvZ23DRqrWJa7jEc2PsKXKV8CaqK+m0fdjAYN+yv2U22q7uWzFQRBEE7VqtRSmq0yUf5uTIr25c6ZaqfuL0mFpJXUtdt2dWpZ27ajQj27fYypg/15+5qxJER4Y7bKvLcui9dWpB83qO+p9QXr2VGyg71le3ttn/2RCOhPEz9l/sRjux9ja9FWRzdFEARB6CZZVtiWrU63nxrr32fHGRHiyQuXxOHupCO9tI7HfzpAbWP7oD7XmAuWRsJtYHD2gtCEtsciPSN5eOLDLIhecPwD+Q+BwJGgyDSkLeWXzF8AiPCIaNskyiuKu8fezRsz38DH2ae3TlEQBEHoBbKssHRfa/m5ECRJYlK0L7OGBSAr6tT71qz3sqKwdH9Ju217ws/diacvHMVt06LRSLA+vZyiWlO7bSw2C//e82+2Fm3FJh9/2dixxgaOBWBP6Z4ePW+gEQH9acIqW6mz1LGjZIejmyIIgiB0U0qxkdomC+5OOuJ6MLJxMgYHevDipaPxdtWTU9HA33/aT1WDue3xnNocaK4jWtFA1PS26fY9NnQ+AL8e+pZ6Sz3h7uFMC5/WbpOzws7Cy8nrpM9FEARB6Bs7c6soNTbj7qRjRkv5OYDbpsfg7aonr6qR73blA5BcWE9pnanDtj2h0UhclBDGyJbPwP0FNW2PKYrC4ozFbC7azH9T/tvjDoPWgD6lKkXNEXOaEgH9aWJi8EQA9lXsO60vWEEQhNNJ63T7SdG+6LR9/5Ec5e/Gi5eOxs/dQF5VI48t3kd5nVrOLqc2G8x1RMtaiJnR4bmKolDRVNFuTXynYmdTo9Xyu6kQrM1cPfxqNFLX57a7dLeYei8IgtBP/LqvCIB5o4Jw1h/p2PV01vPXGbEAfL+7gOyKBlamV3e67ckYHeYNqGVWAWyyjY/3f8zizMUAXBR70XE/SzoT5h5GoEsgVtnKgYoDp9S+/kwE9KeJSI9I/Jz8sNgsJJcnO7o5giAIwgkoisKWLDXT+9Q+yG7flXAfV/51aTxBnk4U1Zh47Kf9lNebyS3fD7KNaL0nhHas8bu9ZDt3rb6Lj/Z/dPwDOHuy2DeIZhSGyFomBE3octOvUr/i5Z0v81XqV6d6WoIgCMIpyqtsJDlfLT93/uiOJd/OGuzP1Fg/ZFnhpT8OkVLSgEaSOt22p+LD1Vlb+wtrMVlNvLb7NVblrUJC4ua4mzk3+twe71OSJMYGqaP0p/M6+gEV0G/YsIELLriA0NBQJEni559/PuFz1q1bx7hx43BycmLw4MEsWrSoz9vpCJIkkeCbAMDOkp2ObYwgCIJwQpll9VTUm3HWa0gY5G3XYwd7OfPipfGEeDlTVtfMh9szqaorRAIiB80Ara7Dc6I8owB1rb3F1nUC1tKGUlbTAMCf6hqQjrPmcXLIZCQkNhZuJLUy9ZTOSRAEQTg1raPziTF+BHo6d7rNX2fE4u6ko9iozgieFOXb5bY9MTTIA71WoqrRyN83PM3u0t3oNXruG38f50b1PJhvNS5Q7aDeW7a3VxPu9ScDKqBvaGhgzJgxvPPOO93aPicnh/PPP59Zs2aRlJTEvffey2233cby5cv7uKWOMcZXrRe8u3S3yHYvCH0oraSOUqNY2iKcmq3Z6nT7cZE+OOlObariyQjwcOKZi+LQShJppUamN9lIlPW4DO68xm+QaxAeBg+sspXDxsNd7re0sRQ39yDGaNwYZWqC/O1dbhvrHcvsQbMBWHRw0Wn7ZUsQBKG/q2+2dqv8nI+bgT9Pj2m7fX58cK8c36DTMCLEk0Z9MikVabjr3fnH5H+QGJJ4Svsd6TcSZ60z/i7+NFgaeqWt/U3HLvh+bMGCBSxYcIIMu0d5//33iY6O5rXXXgNgxIgRbNq0iTfeeIP58+f3VTMdJsYjBk8nT4xmIymVKYwJGOPoJgnCaWdLZgUv/nEIHzcDH14//pTXjAlnJrNVZlWq+sXprD7Mbn8sRVHaJRUK9nJm1rAAMpMOcUl5M6GBgRA2vtPnSpLEYO/B7C3bS2ZNJoN9Bne6XXxAPG/N/jcNW9+BlF8h/Q+IOqvLNl09/Go2F24m15hLcnly16XxBEEQhD6zKkUtVTfIz5W4sOMnaZ05LICC6kaqa43Eh/VegtP4cC+SCybhp1F4aurV7SqknCyD1sAH53yAs+7UZxH0VwMqoO+prVu3Mnfu3Hb3zZ8/n3vvvbfL5zQ3N9Pc3Nx222g0AiDLMrIs90k7e4Msy0hInDPoHMyymQDngH7dXqHvybKMoijiOuhFhTVNvLEqHQWFqoZmft5bwJUTTv3Dxl7ENdF/rD1USlVDM35uTiRG+/T570RRFLYVb+OP3D94PPFxnLRObY9dOjaUFbv30WCxUek/AR9JC120J8Yzhr2le0mvTmde5Lx2jzVYGnDTuwHgpHHCacTFKCm/wuGtKI3V4Nz5lz53nTuzImbxR84fLMlaQrx/fC+dtXCyxHuFcCxxTZzeZFnht31FKCgsHB2MoignnDH1p4nhlJeX9+p1oVZ7kTBXTSPEJfTU9qsooMig0WLQGAbktdvdNp/WAX1JSQlBQUHt7gsKCsJoNNLU1ISLi0uH57z44os8/fTTHe4vLy/HZOqnU2wVGU1FGhzextljr0ej1UEDlDWUObplggPJskxtbS2KoqDRDKjVNf1Ss1XmuRW51DU24+2io6bJyrfbc5kYpMPNaWCM0otron+QFYX/bcvGarEyM9qX6sqKPj3eisIVZNVlsb96PwCf7fmMSyMvbXtcZ7MSo0/CbJb5sTKSC8u6/uzwU/ywWC2klKZQdtR2h+sP83bq21wWeRmTAye33OuOl3sEuppsGvb8hGnweV3ud5LHJJbalpJUksTO7J1Eukee2kkLp0S8VwjHEtfE6W1vQR0FlfW4GbSM9KHd+3tXevOayKnLYWv5VuK84tEoOqrrrezJyGeQjzNSsxHnnBVYAsdg9R3Srf3pS/fhvus/yE4e1M54HvRqzNdgbUCv0WPQGE6pvfZSV1fXre1O64D+ZDz22GPcf//9bbeNRiMREREEBATg6dm3NYJPms2C9MfLWBpr0Yy/AE3gSEe3SOgHZFlGkiQCAgLEh28v+PfqDIobbPh7uvD6lWN45rdUcisbWJ/fzE1ToxzdvG4R10T/sD27koomGS83Z66YMgRXQ99+FGdmZZJWl8a0iGlsK97GhvINnDPkHGK91fJDxvztvOddi8UDZGMIc7VuRPm5dbovV29X9Jl6qq3VuHq74m5wJ706nfcy3sOMmT11e1g4auGR0kKjL0La+jZepVvxnHpTl20MJJDpldPJrs3G1cuVQL/A3n4ZhB4Q7xXCscQ1cXrbvLUcnV7H+QlhRIR2b018b14Ta6vXsr1yO3pnPWMGzWRvfg1FJh0TmnOQ1r8MTVWQ/iPK6Cthws2g7SIgV2TY81+kvf9VR+itRpxyFsPZ9/H+vvfZULCBuxPuZmrw1FNqr704O3dvmcBpHdAHBwdTWlra7r7S0lI8PT07HZ0HcHJywsnJqcP9Go2m/76BaZxQwiciZaxCk78NW+AwDlQeoNnWzOSQySd+vnDakiSpf1+7A8TKlFJWHypHK0k8fO5wAj1duHFqFM/8msJv+4q5MCEMf/eO7xv9kbgmHG/x3iIkJM6PD8HduW9HCWyyjVxjLkhwxbAr0Gq0bC7azAf7P+DFaS+i1+jJy/wdSQI3nS9GyYMfdhfy8LnDO92fp7MnVw67khC3EAw6A4eqD/HSjpcw2UyM8BvBI5MeQXd0hvwhc2H7e1CRjlSTC74xne4X4NbRt+Ksc+5xnWGhb4j3ioFNlhU0GunEG/aAuCZOT/lVjSQX1KKVJM6PD+3R77e3rokDlQdAgtEBoymXvTmQV4H77nfR2DarG7j4QFM10r5vIH8rzHwMAke030ljFax5Fgr3qLcjp8LhLUipSyB2Jl5OXigoJJUncXb42afUXnvp7ut6Wv9FTpkyhdWrV7e7b+XKlUyZMsVBLeo7yiD1nKS8rewq3cW/dvyLr1O/FhmDBeEUZZfX8966TACuTYwkPtwbgAmRPowI8cBiU/h2Z74DWygMJAeLajlUUodOKx03i3BvKawvpNnWjLPWmVD3UG6KuwkPgwf5dfn8kvkLyDK5hdsAGBocB8CmzAryqxq73OdlQy9jathUMqozeHH7i5hsJkb7j+axxMdw0R3TWe7iAy2fT6Qfv8KMq95VBPOC0AuMJgt3fLWbv/1vL1nl9Y5ujtDPrUtTp9dPiPIlqBfKz/VUg6WB7JpsAOL84xjvXsUD9a8SXboCBSDuUrjmW5j/Arj6QvVh+PlO2PERWM3qTgp3ww+3qMG83gVmPQ7nvggjL1IfX/8KCT5qR3VSeRKyMvDW0x/PgPrkrK+vJykpiaSkJEAtS5eUlEReXh6gTpe/4YYb2rb/61//SnZ2Ng8//DCHDh3i3Xff5bvvvuO+++5zRPP7VkQiIEFlJgmu4eg1ekobS8mvE4GGIJyshmYrL/5xCItNYXykD5ePD297TJIkbmyZar/iYAmFNU1d7keWFb7YmssLv6dS2yRKSp7JFu8pBGDO8EB83Pp+DV9mjdoZFeMdg0bS4Gnw5JZRtwDwU8ZP5OeuJcdSA5KGuOizmRTti6LA97sLjrvfvWV7eWnnS5hlM2MDx/LIxEfaJdprZ2hL/eCMFXCcmvStzDYzK3JXYDQbu32egiAc8ePuAopqTORWNPDg98n8klTYrQEeWVawyWIg6EyiKAobMtQ8LjOGBjikDSmVKcjIhLgG45++iuiN9xMil1KLB4WTn4Kz7gGdk1ot5YpFMHiuOrV+75fw059h67uw9EFoqgbfaLjkAxjakrQ18a/gEQx1xQzLWI+LzoU6cx1ZNVkOOde+MqAC+l27djF27FjGjh0LwP3338/YsWN54oknACguLm4L7gGio6NZunQpK1euZMyYMbz22mt8/PHHp2XJOly8sfgOVX8s2sto/9EA7CjZ4chWCcKApSgKb63OoKTWRICHE/fPG9ph+uKoUC8mRPkgK/DVts7rcputMv9adojvdhWwNauSV5enIYsvTGekvMpGduRUIUlwybjwEz+hF7R+aRnsfaTE3JTQKYwPGo9VtrB3x7/JkWwoejeifYZw9US1asP6tDJKajtPBGuTbdSYavAweDApeBIPTHgAvVbfdSMGTQZnT3U6ZMGuE7b55Z0v88mBT1iRu6IHZyoIAkBVg5nf9hUDMCTQHatN4eONOTz9awo1jeZOn1NnsvDj7gJu/+8u/vTRNopru+6gPpZVtrK7dDcWm+isHoiyyuspqTVh0GmYGOXrkDYcqDgAQFx1MWx7D0m2Uuo7npfdH2anfEwSPGcvmPNPOOcZcPGGqhzY960a4A8/Hy5+H3yOSqpqcIXpDwOgS/mFeBc1P8Desr32ODW7GVAB/cyZM9vKKBz9b9GiRQAsWrSIdevWdXjO3r17aW5uJisri5tuusnu7bYXS8gE9Ye8rUwKngSIgF4QTtaS5CK2ZlWi1Ug8umA4ns6dByzXT1Y/ODZmVJBZ1n5qY32zlSeXHGBrViU6rYRBpyEpv4avtnce/Aunt5/2qqPzU2L8CPPuPI9Lb8uqVQP6GK8ja9clSeLWuFv5h1MU59RWU6KVkF38iPaKZkiQB+Mj1U6q73d1PsNLkiRSqlKYHDKZe8bdg15znGAeQKuHweeoP6f/ccI2z46YDcCy3GWYbZ0HIIIgdO67XfmYrTLDgj147cox/HVGLHqtxO7D1fztf3vZk1fdtm1eZSPvrM3kps92smhLLmV1zTSZbaxPK+/28ZbnLuflnS/zbvK7fXE6Qh/b2DI6PzHKFxeDYyr2HKg4ANZmRlcVqJ8X0x+iJPEfNGg82FdQ2/mTYma0jNbPUafhz/o7zHgY9J0sGQgfDyMuAGBc/n5QZBHQC/2XuTWgL9jFeL84NGg4bDxMWaMoXycIPSHLCt+1BDO3nh3N0CCPLreNCXBvm6b2xdbctvsr65t59Md9HCg04mLQ8syFcfzfHLWn+btdBWzLruy7ExD6nYr6Zta2rFO81E6j87Iio5N0aCVtW0b7Vn6HtzE6ZzuHNQqKRwhezn54Oal14q9qGaVffaiMsrqOo/QaScNdCXdx46gb0Wm6mVu3ddp97iZY8n/w01/hh1vh2+vg66vhi0th9bNgbiAxJJFAl0DqzHWsy193sqcvCGecUqOJZQdKALhhSiSSpCbffP3KBAb5ulLTaOHJXw7yztpM/vHzfu76eg/LDpRgtspE+7sxa5j6WbYjp6rbx/wx40cAthRtwWTtp+WdhU4pisKmloB++hB/h7TBZDVhU2xIzfWMknXqEuIRC4mP8AbgYKGx62UgLj4w5wm4/icYeoLZ15PvAPdAEhqM0FBBdm02NaaaXj0XRxIB/WnE5jkI3APBZsazIoPhfmryh50lOx3cMkEYWFKKjRibrLg76ThvdMgJt7928iA0Gok9eTXsL6ilsKaJh3/Yx+HKRrxd9fzr0tGMDvdixtAALhyjJkJ7fWX6cdfdA+TX5fPWnrfIM+a1f8BmhVVPw4+3gUmsMx4IliQVYZMV4sI8GRbcdQdRb9JIGp47+zk+P/dzAlyOWhtZng6b3gTAP/4aro3/M7NDZrc9PCLEk9HhXthkpW3N/ynzHwL+Q8FmgeJkKEuFykyoyYe6YmishMxV8PMdaOuKOT/mfAB+y/7ttEteJAh95Zsd+dhkhTERXm0JXAGi/N14/aoxLBitTjdedqCE5PxaNBJMjfXjxUtH89bVCdx8VjQAGWX1VNY3d+uYrfW8Y71iMctiRs1Akl5aT1ldMy56LeOjfBzSBmedM2/OepP3tGG4I0HkWQDE+Lvj5qSlyWLrncSOBjeY/jDeaFhY38ifw8/B0FXpuwHotC5bd8aRJJRBU5FSfobDW5gUMYmUypTTLvGDIPS17S2jExOjfNB2o+xPiJcL544K5vf9xby/PouaJjPGJishXs48e3Fcu6yxN58VRWZZPSnFRl5YmsqrV4zpcprbFylfkFyezJaiLXx+7uc465xBlmH9S5C1Rt1o/3cw8bZTP2mhz9Q3W9tGzS6z0+j80dqtbzcZYeUTYDND5FT8J/6FhUBZWfuZXFdPjGB/QS0rDpZw5YQIfE81gZ8kqRmHi/aCRqfWENYa1OmVWgM018GGV9TsxT/9lZmzH+d7vTuljaXsKNkhSrAKwgnkVzWy5pBaqvn6yVEdHnfSablz5mDGRviweE8BI0M9OX90CIFHfT75uBkYEuRORmk9O3OrOTfu+PXIq03VVDdXo0HDP6f8s2OVC6Ff25ihLq2YFO2Lk84x0+0BqCvFpzJX/ZyIVKuiaDQSo0K92JFTxb6C2uPOlOy2iIkwfCHXH/oNUtfAqOtOfZ/9hBihP90MavnSk7eNs0PP4tUZr/K3sX9zbJsEYQBRFKVtOnxijF+3n3fVxAgMOg15VY0Ym6wMCXTn5cvjO5SA0Wk1PLJgON6uevKqGvnPmowusw+35sIA+PrQ1+oPOz5Us4W3OrBYjNL3c8sOlNBksTHIz5XxkfYbBekwsi3LsPZ5dUTcI0St49tFjdvRYV4MD1bLMq451EvLttz8Ycg5EDtLzVYcMRFCEyBoJAxKVDMTB46A5jqcl/2d+YYgAJZkLhElWPvAztwq1vbW71ZwuK935CEranB2vFlAU2L9eOWKMdx8VnS7YL7V5Gj1c6870+6za9VSY6HuoSKYH2BkWWFTpjrd/mwHTbdXFAWrbIXDLbXmg0ap0+hbxIery8D2F9T03kEn3wFuAWAshJ0f995+HUwE9Keb0LFq/cWGcjzqSonwiECSTjzCKAiCKr+qiZJaEzqtxLhB3Q++fN0MbWXtEiK8ef6S0Xi7dj6q6etm4NEFw9FoJDZmVLAkuajT7eZGzuXxxMcBNfHQ/q1vQvL/1AdnPAJ+sWBugAM/dP8EBbsyW2V+SVKnrV82Lsyu78ePbnyUxzY+Rr6xJbld0peQt00dEZ/3rJp5vguSJDFnhBpQt9Yo7nNufnDBv2HIPFBk5qdvxFBfgZfBE5NNrM3tTRszynn2txReX5lORmmdo5sjnKKs8no2ZVQgSXDd5MgTP+E4Jkarmc6T8qsxWY5fZrJ1OVi0VzSyInOw8iBbi7ae0vEF+0gtMVJZb8bFoO3Rd53eVNJQwq3Lb+X1lM9QUCDy7HaPjw5TA/qUYiNWWy8tvXJyh+kPqT83Vasd3acBEdCfbrQGCBuv/nx4S9vdIlGJIHRP6+j8mHDvHmd8vXpiBO9eO46nLxx1wueOCvXi1rPV9YqfbsrhQGHnmVzjA+KZFzkPTHW8m/pfGlBg0u0w/DwYd4O60f4f1SnLQr+zLbuSmkYLfu4Gpg2xX43fJmsTecY8smuz8XTyVMvF7fpUffDs+9Q17Sdw1mA/dFqJw5WN5FY09HGLW+gMarbixL/gJWl5u9bMI2UluFhFSazesq+ghtdXptM66WFFSqljGyScsi9byqZOG+JPtL/bKe0rys+VQA8nLDaF5Pya42578eCLeXv221w+9HL2lO7hma3P8PnBz7HJx+8I6EyzrZm39rxFcX0xy3KW8fimx6loqjjJsxBOpDW7/eQYPww6x4SDByoPYLI0YqwvQUJSZ24dJcrPDQ9nHSaLTEZZL6yjbzUoES79UC1/18UstYHm9DgLob3Iqer/eVuxylbeS3qPv6z8C5VNIqu2vTQ0W6kziS+gA9G2HPXvZHJMz+uxSpJEhK9rh3r1XbkgPoQZQwOQFXhv3ZFcF4qi8EfOHxTVqyP313qNJLiulCpJZlFIFCRcq24YNR18osBcDwd+7HF7hb63r2Wq4PQhAei19vvIza7NRkHB38UfL40B1jwLigLDF6qdQd3g4axnQssSAbuN0oO6jjLhGpj3PF56d3Xd/Y4P7Hf801h2eT3PLU3FalPaAr8N6eU0W3segAn9Q2qxkV251WgkuCbx1EbnQf0cax2lP9G0e0mSCHANINgtmDGBY/AweFDdXE1SeVKPjqkoCh/t+4gtRVt4aedLbC3eSmZNJpsKN53saQjHIcsKmzMdm90eYH/FfrA0ECdrwDsCvAe1e1yjkdpG6fd3Vb7uZAUM6939OZgI6E9HES3r6MtS0ZmMlDWVYbKZWHV4lWPbdYZoNFu5++s93PnVHmobRVA/kFTWN5NRqvYCT4zqeUDfU5Ik8ZcZan3wvKpGahrVDMHZtdksOriIxzY+hqUsFedVz3CnxRkfF38SJ/5NDXhA7VluG6X/QZ1+L/QrB4vU/AajQrue3t4XWpOhxnrFqgFxU426bvCse3q0n5nDAgHYkFGB3FXpoL4SdRbMew6AyswV5FSk2vf4p5lSo4mnfk2hyWwjLsyTly+PJ9DDiUazja1ZosN/IFIUhf+2lEudMyKIMO/eWcc+qTWgz63q9t+9XqNnWtg0gB6Xm1x+eDkbCzeiQcPt8bczPXw6ABsLNor8GX3gQFEtNY0W3J10jGkpD2dvsiKTUpECzfWMlnUdptu3Gt2yjn5fYY0dWzfwiID+dOTmBwFqyTrytqnTdYHVeauxyCLA7Gs/7i6got5MTaOFxXsLHN0coQd25qqjEUOC3PFzd+rVfdeZ6yiuL+5wv4eznkG+rgCkFqvT5neV7AJgjPdg9Mv/DpZGhgWP5z+X/MKEkIntdxAzC3wiobmOutb19S0OGw+zvXg79eZjpqrZLFBXCqUHIXu9mlhvx0eQ8guIL0+9pqbRTEG1WppwpJ0D+syaTAC1/nz+dvXOyCnqlPYemBDlg4teS3ldMynFDki+GDaOLR5e3CWV8+mOl+x//AGior6ZJnPXo+y1TWr98eoGM4P8XHn8/JE467VteRJWpYpp9wONTVbYklXJgUIjOq3E1ZMiem3fcaFeuOi11DRayOyiZNjBioO8uvNVVuetbrtvdoRa/nJ36W5qm7s3oppWlcYXB78A4JoR1zDKbxSTQyaj1+gpqC/gsPHwKZ6NcKzW6fZTYv3sOnPsaIeNh6kzG3E2NxKraDtMt28VH+YNqN+PzNbTY717XxBl605XkVOg/BDkbWHi3KfwcfKhurmancU7mRo21dGtO21VNZj5OelIgrPf9hVzydiwLpOjnS42FGzAXe/OuKBxjm7KKdmWrQb0k3uQ3b67/rXjX+TW5vLqjFcJcW9f235kqCd5VY0cKjEyJdaPXaVqQD+hpgwaq8A3BuY/j97pyNpIo9mIh94DSaMhb/gClu54g00H3+OpyESGBI4G1FGS33N+R0IiyjOKURYbI3OTCLA2A10sC/CJgpAxvX7+Z6KUltH5QX6ueDjrT7B178quUbNPx3rFwI6Wjp6Inpd+c9JpmTrYj9WpZaxPLyeuZfqj3UgSw4deiLT/bdKr0ihpKCHY7filtM40GzPKeXV5GhqNRHyYF4kxfkyK9sW/pVPSZLHxzK8pFNY04e9u4OkLR+HupH79mzMikP/tyGNfQS1lRlOnWc8FxzNbZQ5XNpBV3kBWeT3Z5Q3kVja0BTgL4oIJ9Oi9351Bp2FspDdbMivZnlPVacmwlKoUdpbuxEXvwpxBcwCI8IxgsPdgMmsy2VCwgQtiLzjucWqba3ljzxtYFSuJIYksjFkIgJvejXFB49hevJ0NBRuI8orqtXM706kdQWpAP82B0+0PVBwASxMjbaBz9obAUZ1uF+HrgrernppGC+mldfb/DBogxAj96WpQS9BesAudLDMnUn2zXXF4xXGeJJyq/+3Iw2yVGR7swdAgD8xWmR92n96j9FuLtvJO0ju8tPOlAZ18sclsI7llvXNr2Z7eUtFUQWZNJlbFipdTxw+j4S0lhlKLjZQ1lpFXl4dGgXHF6eoGZ98HTke+UG0p3MK9a+/l60Nf89y253jo8M+sM4BVsbFn36K27XycfQh3D0dRZHIKt/Jb3kpeksr4Q2NWa4F7BENQHERPP5IkLV28R/SW1un2caH2/QJS21xLeVM5EhIxkpNapk6jU6ugnITWafebMiqw9Fam4R7wHXkpoxU9WJrYmP6T3Y/fnyXl1/DainRkBaw2hT15Nby3LoubP9vJ/d8m8e3OPP71xyHSS+twd9LxzEVxbYE+QJCnM/HhXigKrBYl7Pql5PwarvtkO/d/l8w7azNZdqCE9FJ1tNJZryEx2pc/TRp04h31UOvn4M4u1tEf6TSMbXf/rIhZAKzN/gNl+ePw7XWQvrxDNnGbbOPNPW9SbaomzD2MO8bc0a4KSOv0/c1FmzuW4BROWnJBDcYmK14ueuLDvR3Wjv0V+8FcT5ysU3N/dZGcTpKktiB+fxfJgwUR0J++/IeoNX8tTVCcxOyI2WjQkFqVeqSEkdCr8qsaWXGwBICbz4rm2snqB+zv+4uprG92ZNP6TL25ns8OfNZ2u3Xd7kC0J68aq00hxMuZCN/eraebXJ4MwFCfobjqXTs8PjxEnY6dWVbP1sId6n16L9wtpiNB91FKG0tpsDSwJGsJ+yv2o5E0TAmZwrMWN64qzABzIwAXxl7Ia5Me532LN3fXm5kg65Fd/FgRNRb5luVwzbdw8TtqCbMpd6k7z14H1tPzerW3g0Xqlw97r59vtjUzOWQyo/1H41q8T70zJB4MHa+97ogP88LHzUB9s5U9h6t7saXd5OrLdP8EADZkLRVraltkltXzwtJUbLLCWYP9eeeacdw4NYrhwR5IEmSU1fPltjx2H65Gr5V44oKRRPh2vAbmjlSn3a9OLbV/ngThuEwWG/9Zk0GT2day3tmLS8eF8dD8Ybx33Ti+/fMU/rFwZK/NADJZTW2Z5cdH+aCRIKeigTJjx876nNocQC1Zd7SpgRMwNNUiF+zAmLsBavJh7Qvw01+gKKltu0ZrI1bZirPWmfvH39+hjn1CYALuendqmmvU4E/oFRvTj0y313YzgW9fiPePZ6TFRryig8jOp9u3GtOyjn6TI3K5DBAioD9dSRIMmqL+fHgLfi5+TAxW196KUfq+8d+tucgKJEb7MjLUk7ER3owI8cBiU/j+FEbpTRYbTy05yPNLU/rdG9kXqV9Qa1aDFg0aKk0DN7HS9pZydYkxfr1eK7w1oI8PiAfUkmJHC/VyxstFj8WmsC5vGwATmlq+QA2e26Hn+sLYC4n3j8dF58L50efz1uy3uHfO6wz1jAJTrboWHqDsECz+Mz7laUzT+3HXrNcwuAVT1lxNXv0x12TwGHAPUjPmH97cq+d/Jmo0W8lpKfVm7/Xzga6B3Df+Ph6f/Djkqx1EhE866f1pNFJbJuT16eW90cQemxh3Lc5IlDUUk1Zx0CFt6E+Kapp4+teDNFlsxId7cf85Qxnk58rl48N55YoxfH7zJO6ePZiJUb4EeTrx6IIRjAjp/DqcEuOHi0FLqbFZjID1M9/vyqfU2Iy/u4FPb5rIcxeP5uazopk+NIBwn+5XVOmu13a9xv+t+T9yanPwdNYzPFi9Znbkth+lrzJVUd1cjQZN++nwBbtxXXI3L1fV84bZFa/Q8TDhZjC4QUU6/HoPLH8cavLxMHjwxJQneHLKk4R7hHdoi16jZ3r4dCYGTcRNd2ql+ASVxSazNbs1u739yqh2ZqH3CJ5slIjQuED4hONuO21IAC56LXlVjezOc0Cn8gAg1tCfzgZNgdRfIW8bKArnx5xPtFd023QoofekFBnZll2FRoIbp0YB6jSh6yZH8vhPB1h+sIRLx4X1eI2boij8e3UGu1tGxXIrG4gJcO/t5p+U5PJk1uWvQ0LiwQkPMjZwLFpNz+q29xdWm8zOXPU1Tozu3ez2siKra8WAUX6j+GjfR2wq3MRrM1/D30UNkiRJYniwB1tzSsmozsLNYGNCRUsioMFzO+xTq9Hy98S/tz23zdjrYN2/YN834OQOm/8NNrOaNG/e8zh7hnFLs5kxg8YQ4HbMh7lGA0POgb1fQsZKiJ3dq6/DmSa12IisqFOa/Xs5wWK3WZvVDPcAEScf0APMHBbAL0lFbM+poslsw8Vg3791p6hpTN7kyTpbLRsOfsXwmS/a9fj9SXWDmSd+OUhNo4WYADceP39EhzrSPm4G5o8KZv6oE+cbcNZrmTE0gGUHSlidWmrfrNe7F6lJOc/9FwSNtN9xB4CC6kZ+3FMIwO3TY/r8by7fmM++CnVGz7LcZdwx5g4mRfuSUmxkR04V58UduZZap9uHeYThpHWChkrY9g5kqgnyQlz8YcrdMHiOOsA04kLkXZ+xJv1H9hes4N68LUijLkUfNoGYhjJIXw31ZVBfCg3l0FwH0+7nhpE39HoH+5ksKb+GhmYb3q56u88c6+DwFvX/8AmgP/6sSDcnHQtGB7N4TyGL9xTYpQrRQCNG6E9nYeNBa1DXT1bnMMx3GJcMuQRvZ29Ht+y0oigKn21Wp56dMzKo3ZTG+HBv4sK8sNoUvt/V81H6n5MK27KRAuzNqznl9vaWZmsz7np3zo06lwnBEwZsMA/qWuf6ZiueLjpGdjGKdbIyazJpsDTgpndjmM8wihqKMNlMLM1e2m67ESGeaNCT6PIo/wiaSZAM+A0G3+hO9ytJUscvOoPPAc9QtUTZhlfVYD7yLLj4PbXGKxDnE4efSxc5AoaoFTHI364m4xNOWoqDytUpikJJQ4k6Lb04Wb0G3ALUxIqnIDbAnTBvF8zWIyM8dqXRMiPyHAD2FG07Y9fUNjRbeXLJQUqNJoK9nHn6wlG4Gk59bGZuS7b7TZkVNDRbT3l/3ZL6K+z6TJ1VtOdz+xxzgFAUhffXZ2GTFSZE+TClDxK1HuvobPUTgtQR09bydfsLa9tVUcgxqt95Yjyj4MCP8N31ajAvaSDuUrjqCxgyF4tspayxjCLZxDO6Oj7yD2Kbs4FdSrNaanXZo7Dx9ZaO5BXqe5axSA3oN76OZBIzRnrTxpYZVmcP9u/12R09sb98P8acdeqNyO4l6r5gTChajcSBQiOHShxQcaWfEwH96UzvrAb1AIe3OrYtp7Ft2VUcKqnDoNN0mpjm2kT1vhUppZTUdj9pXFJ+DYs25wJqGTWgLWlbfzApZBKvz3ydq4Zf1e7+gbi+dXuOOt1+YpRvr3/ItU63j/OLQ6vRclHsRQCsyVtDnbmubbvW6bDpJSbiilvqbbcG2N2l1amj9K3GXqfW8TZ0Pl3RbDO3v8MnUi15Kdsga03Pji20c6CwJSGenTPyVjRVcM/ae/jLyr9ga1m+QcQkdZTsFEiSxIxh6qyOdWmOmXY/PP567rK68Hq9gqZh4C7vOVlmq8xzS1PJqWjA21XPMxeN6rUKKkOD3Bnk64rFprAxww6/34LdsOmNI7fztkFNXt8fd4DYmFFBcn4teq3EX6bH9vkotcVmYWPhRgAem/RY2xLNcB8XQr2dsdoU9ubXtG3fZGlCb7MSk7ZanQlmblA/Oy55H866B5w8SC5P5i+r/sKTW57k4Q0Pk1qVirOTJzdOeZzx576h5vXwG6x2Oo+6BBL/AnP+CRf+R+2AbK5Ty6kCJQ0lbCjY0KevwenOZLG1VfKZ5sDp9k3WJl7c+gy3G3dTiXzC9fOt/N2dmNWSoPWnlpkrwhEioD/dRbaso89ao35JR61x/czWZ0ipTHFgw04PNlnhv1tzAbg4IbTT2uVxYV4kRHgjywrf7uxeQsJSo4mXlx1CVmD28EDunTMUgAOFtd2uw1lvrietNq1PA2wvJ6+2RDZLs5dy39r7WJW3qs+O1xcURWF7y4dcYi9ntwcIcw9jTMAYJgSrIx5jAsYQ6RmJyWZiee7ytu1iA9zQaiQ09SVYi/arAdjJTHsfukD9QnXeqzDp9k4zx5Y1lvHC9hd4ZMMjHa+PoS2dCBkre35sAVADr/QytbPG3uvns2rVxJS+Lr5oC1rWz5/idPtWM4aqXwKT82uobjCfYOvep/GJZHrQRFwVBdL/sPvxHe3ttZkcKKzFRa/lqQtHEeLVe8k7JUli7kj1y/LKlD7Odl+dCyufUL+TDJ575Av9gcV9e9wBoqHZykcb1SntV02MINir70sJbiveRr2lHn8X/7ZcL6BeF63Tm7e3ZrtvNnJDZRmLSiuYXVupVmCZ9oA6EyxgWNtzw93DabI0UWWqwiJbSAhI4JUZr3BezHloBiWqgfvln8C5L8DZ90LCNer1EBKvVnYBOPQbFXmbuWftPbyX9B41ppo+fy1OVxvSy2my2Aj2cm6rrOMIh6oOYWs2Eqho8AuMA9fuT5+/dFwYAFuzKymobuyrJg5IIqA/3UVNA70rVGa29XTuLdvLwcqDrMgVyfFO1cqUEgqqm/Bw1nHpuI5JXVq1Zrxfc6iUwpqmLrcDtRf1+aWp1JmsDA50585ZsUT4uuDjZsBiU0gt7t5Uozf3vslbKW/xTdo33T+hE7DYLDy79Vl2luzs8FijtZGihiIOVR3qtePZQ05FA2V1zei1EmMHeff6/qeGTuXviX9nevh0QP2CdHHsxQAsy1nWVuovtXo/9V6f4in9TJPFppYYcz+JXnRNy5THiIldbuJh8CCtKo2ihqKOHXuxs0GjhfJD6hdvocfSS+uw2hS8XfWE2uHL+NFaK03EOgeq2aUlzZGZWqco1NuFoUEeyApssMcobmeGq3WqOfQ7ss1OU8P7gTqThfVpaqD9+PkjiO2DXCqzhgWikdTrN6+qj74sN1bBH4+qyTeD42DGIzD6MvWx9GXqqOwZ7uvtedQ0Wgj1duaSsV1/r+hNmwo3AWrJOY2koayxjGU5y5AVua2je09uJfqc1Ujf3QCpv6JTwDDsPLjyvzDywg6dx34ufpwXcx4hbiHcnXA3j056lEDXwO41KCQehp0HgP+OTxniPRgZmc1FImHryVp2QK3CdO6oYIdOt08qSzpSri6qe6PzrSJ8XZkU7YuiwM97xSj90URAf7pz9YWZj6g/J/8PcjZwTss6xB0lO6g2iWyRJ8tksfHVdnWK4FUTI3Bz6nod4/BgT8ZH+iAr8O2OrqcVKorCO2szyalowMtFz9/PG4GTToskSSS0lO3ozrT7GlMNB1syQS/JWsL6/PU9OLOu/ZDxAwcqD/Dx/o87ZGof7jMcgLSqtF45lr20jjqMHeSDs94+eQASQxIJcg2izlLH2vy1AOwq3YWircSTDEwWm7oevo+46FyYFq7W+O0wo8LFByIS1Z8zVtJoaeTTA5+ysWBjn7XndHNk/byX3RM6tQX05pYR9KBR6ghaL5nZMu3eUdnuiZ7OVoOWR0wZ/LH3fce0wQH2F9QiKzDI17XPktZ5uxraRmNXp/ZwlN7cAH88At/fDPu+B1MnHc9WM6z4p5rXxzMU5j0POgOEjlOnWFua4NDvvXAmA1d2eT2/7SsC4K8zYjskO+wr942/jzvG3MGsiFlYZAsPb3iYzw5+RkZ1BiNCPPDXm7mu8i0M2/+j5jzwiYIL/w0zHz3uCOv1I6/nzVlvMi18Ws/fCxP/rL53VWZytqwuLWnteBB6Jqu8noyyerQaiTkjutmp0gcOVhxkRc4yMDcyTj5xubrOXNYyeLb6UJlDZor1VyKgPxPEzIT4K9Wf175IlKJlmM8wbIqtXRIUofuarTY+WJ9NTaOFIE8nFsSFnPA517WM0q9LLyejtK7TEnRLkotYl1aORoJHzh1OgMeRKfwJLaPHSd1IjOft7M2icxcR46Emwvpw34ekVqZ248w6pygKKZUpLMlcAsCtcbd2qBk72GcwGjSUN5VT2TRw1re2lqub3AdJh9Kq0qgydUwup9VouTD2QgBW5K5AVmR2le7CVWNlclMD9VYNRE/v9fYcbc6gOYDasVfbfEzioda1+xkrMFkaSa9O5+2kt/k169c+bdPpwlH152VFPhLQV6tBQVvnTC+ZNsQfjQQZpfUnnG3UJ/TO1AWPIleysSH7zJl237p+OaGPM9C31qRfm1aG9ajPqPpmKztyqvhkUw4vLztERX3zkSc118HSB9R18FXZsPVt+OpyWP8ylKer2ygKrHsRSg+oQdq5/wKXlnORJBh9ufrzwcVtywPPNLKs8O66LGRF/TsbO8jHbsd21jkzM2Imfi5+6DV6xgeps3q2FW9DZ2viAdvHDLZmYrTp+XFwIg8G+rPa0sfJMV18YNKfAZiatRWtrJBdm01B3cmXAT5TtY7OT43167W8Gz1V1ljGG7vfQDbXMU3WMcE9Uu0Y6qGRoZ4MD/bAalP4taXzSxBl684ck/6iTqEt3gcr/8m8STeQVp3GqrxVXDz4YnQacSl0V2ZZPW+sTG+bknjLWdHd6kUfHOjB5BhftmVXcf93yUgSeDjr8HLR4+Wix9NZz7aW4PKWs6MZHd4+mdaYcG/1+OX11JkseDjrj3s8J2MJf9fF8pmvOzuqU3lt12s8f/bzBLkFdftck8qS2FGyg+TyZCqa1A/vxJBEJoV0XJPronNhkOcgco25pFWlMTWse5lLHam8rpms8gYkCSZG9e6XJ0VReHvv25Q1lfHE5CcY5T+q3eMzwmdgNBuZM2gOObU5VJuq8VaaiTVL7DUMJwhn+rLybrRXNLFesWTVZrGhYAMXxF5w5MHIs8DgDvVl+FbnMS5wHDm1OXyZ+iVW2colQy7pw5YNbDZZIbVYnTZs74C+qF6toGDQ6AkvaZkp08sBvbergbGDfNh9uJr1aeVck9gxEWhfmzrmFj4v2UhuQxF55SkMCjj9y53tbam9nNAHy4KONiHSB29XPdWNZpYcqECX2cjBIiM5FQ0cnW7DoNNw79yh6kj80gfUGuNOHuo66IyVamB/aKn6L3AkeIWruXw0WjjnaTUB59EGz4XtH0BdCeRugpgZfXqe/dHK1FLSSupw0Wu59ezOq5v0NlmRkehYMWVyyGQ2FW5ie9FWrs/YQaTtMNmSKx+6/h++bpnkVyRhsVn6voHDF0La73iWpZJgcWO3k45NhZu4evjVfX/s00ST2cb6lkSm58aduIxlX7DIFl7d+Sp1TRXENjXwZ6sLUtRZJ52s9bLx4Ty/NJWl+4q5fHx4r1T6GOjECP2ZQquDOU+pU6OqckjM3IiXwYtqUzW7S3c7unUDgk1W+G5nPg98n0xeVSPernqevGAkUwf7d3sfN0yJIrBl1F1RwNhkJb+qiQOFRrZkVSIrMGtYABeOCe3wXD93Jwb5uqIosK+g61Iu6dXpyOYGpGWP4HHwC+46uJaYhlrqGkp5Zecr2I4z+lFvrm93e2vRVlbnraaiqQK9Rs+EoAncGndrl88f7tsy7b56YEy739AybXh4sEev91qXNJRQ1lSGTtIR6x3b4XG9Vs+lQy7Fy8mLXSW7QIGxzc24arXs0k8krbTv15K2jtKvzlvdPjmezkB15GT154yVXDnsSq4cqs7y+SbtG35I/6HP2zZQ5VTU02Sx4WrQEuXXl10yHWXXqom0ovVeaC0mdYTLb3CvH6c1Od4GB027dw9JYLxTAKCwYd9nDmmDPRXXNlFqbEarkYgL7duqCTqthpktmaSXHKhgSXIR2eVqMB/q7cy0Iern3fr0cmqryuG3+9Rg3tkLLnhLDegv/xQuelsN0jU6KEtRS5IBTHuw85wOOicY0dKpeODMe39pMtvaqtpcO3lQpwl2+8K24m08sP4B1uatbXd/QkACzhonKkv3kVmWhKu7F5/7/I39Tf7sKlY/36O97NDpoNGoCfIkibOrisHcyKbCTWds2cqTsT69jCaLjVBvZ0bbuepKK72ssNCmJ6A6nwcaZAxOnmrehZM0KcqXcB8XGs02lh8s6cWWDlwDLqB/5513iIqKwtnZmcTERHbs2NHltosWLWqr1dz6z9nZvgmK+hU3P5j7FEga9JlrmK1T1z0tyVoyIEuN2VNhTROP/LiPL7YdRpYVpsb68fY145gQ1f3snKAm9Pjkpon8fNdZfHHrJP7zp7E8d3EcD80fxp+nx/Dn6THcPXtIl2vNWqdbJh1VPqZdO+sLeWLzEzz067WY6ktQtM44KfBQTQNhNYX8qTgbbeqv6lrFo+TX5fNe0nv8ZdVfyDceycQ/JXQK50Wfx2OTHuOT+Z/w0MSH8HLq+gNhmI+a4XYgBPTVDWa+3aWe65wR3Z+10F3JFWq5umG+w3DWHf99Z1fpLrA0MsFsQePsQapuRLeTH56KqWFTcdY6U9xQTErVkeR4WTVZ3FOXxA9aE3L2WrCYuGzoZfxp+J8A+D79e7499K143+jEwZb18yNDPe2eeGiQxyAuir2I6biqd0RM6rTKwalqrU1dWNNEbZMdRumOJUlMiz4XgI3FW7Cd5snx9rYssxoR4oGLoe/zfCyMDyHQw4lgTwPzRwXz0PxhLLp5Ih9cP4GHzx3OsGAPnC1Gar67S0246+KjBvN+LR2XkgTBo9USZNd+r1bb8BusliUbfl7XBx51iTqCX7zvyFT9UyArMvXmeorri0mrSmNXyS6K64tPeb99YX16OfXNVkK9nVkY37FDv6+syVtDYX0hZY3tcyboFYXxDUawNLJNL6E9/1XOnjwFm1RHfm0FKBJRXlH2aWTAMBh5MRNkHS715VhtZpH/qQfakuHFBds9pwsAJQdg8W1Mz9jEG2Y3/KJnwBWfg/fJz+7SaKS2RNS/JBVhsYkOngE1R+Hbb7/l/vvv5/333ycxMZE333yT+fPnk5aWRmBg50kePD09SUs7Elw45GLuT0LGwOQ7YOs7LMjazsHoOM6LWejoVvVryw4U8/HGHJqtMi4GLXfMiGXmsIBTupa0GglvV0P3R4VbEgyNifBmSXJR2xe8Yy3JXIJiNRFUU4gzrlQlPoBPVBy+Kb/wStrvaI2Vau3fnR+jxF9NSkQCS7J/Jak8qW0fu0p3EeEZAUBCYAIJgQndPq9hvsMIcw9jsPdgFEXp139vn27OoclsY0igO+f0RUBfpgb0YwLGHHe77Jps8uryoNnIOFlPYcQ0bCU6uwT0LjoXLh96OW56N2K81HwLNaYaXt31Ks0aHVlOztDYCIc3weC5XDz4YrSSli9Tv2Rx5mLCPcI5K6znSW1OZwePSohnb1FeUeqX7AMtJQd7qVzdsdycdIR6O1NUYyKrvJ5xdlzr2yphzM14HPqaGksDaYfXMDJmnt3bYC+tHbhjI+zzOgd5OvPxDRMoKysjMDAQzTGdQpcNd8aW8R9s9WXIEYPQXPAGO5vL2L53KVcOu7J9JnNXXxh7nfrvRNz8IWYWZK5SR+ln/f2k2v/x/o/ZUbyDOnMdMu2/6LvqXHl37rsdcsA42oqDR4IurZ06AksbStlfsR8JiVmDZh15wGaFVU8xpaaMzXoN2wOiuT5wBOf52PhizyqqZAXF4ouT1j6zCACYeCuG7HXcaypjqN9UXF16P+fN6SijtI6s8gZ0WonZw3v/e85xmRtJ2fgiYVnr8VIkcPFBf/Z9vbacZsbQAL7YdpjKejMb0sv7ZGBmIBlQI/Svv/46t99+OzfffDMjR47k/fffx9XVlU8//bTL50iSRHBwcNu/oKAz+xcOwOgrIGYGXrLMsyXFTPEe1q8DL0faklXBO2uzaLbKxId78fY1Y5k1PNC+r5e5Eb6/Eb65hnjncjQaiVKjiZJaU7vNyhvL2ViwAepKuMRmgJiZWELGg1cEnPV/aK/7Eab+DbzCyTLXcHXyqzyz/kGSypOQkEgMTuS5s547pfXRfi5+vD7zdW4bfVu/vqaS82tYl1aOJMGds2J7fSTVIlvaysEdXdO3M34ufkR5DGKByYY7Eu5xCwBIL6nvNHFib7sg9gJmD5qNi84Fi83Ca7tfo8pURZhHGH8bejUaJEhf0W77m0bdRGJIIpNDJvd5+wYSRVEckhBPURTMtpZsv/Xl6vplSYLwrksXnqrWsmlZZfUn2LJv6F19iXNXR2jSDvdOFY/+yCYrJLcG9H28fr5bGiqYfPAZwiijCi+2jPwnNq9wvk79mo2FG9umbndnSvS6/HWUN6rLNmqba3k/+X2aRrZMu89crZa4Owkmq4lac21bMO+ic2nrZBgbOLZDhRZHyz4qA/nsYfb7jtpaYWW0/+gjnTCyDdY8C4c3M0ZyxdknmlpJoaShBINOw6hBDQCUVXpReXRixL7m5AGT7yRB0eOa9D8oTrbfsQewo5PhebkcP+9Sr2qsouD763gp7zce09VTFjtDLW/Yi7kxDDoNF7UsT128p9Au35f6swEzQm82m9m9ezePPfZY230ajYa5c+eydevWLp9XX19PZGQksiwzbtw4XnjhBUaNGtXl9s3NzTQ3H3mTMhrV0RZZlpHl/julQ5ZlFEXpfhunP4xUmQ21+Sg7P4FpD/RtAweoH3cXoKBwXlwwf54Wg0Yj2f86yFiJ1PLFxmnlo0z0vYttFU7szati/qgjCU6WZC3B1lhJnNnKYL0vtsQ7URqOum51LjDqUiqjp/Hispuhrh5DXRkzwmdx3qjrCXZT99Xl+RUnQ8EOGHcjaB2TJbU3WGwy767LbPm9hhDj79brv9O0yjRMVhOeTp5EuEccd/8eeg9eDFuAdGgHinsgQYPH47x+J40WK9kV9cT499467OO9TyiKwsf7Pya9Kh1XvSsPjH8AZ6sVJekbKNiJUl/RVp5ofuR85g2ah4T692CV1SnPZ3pyzfyqRmqbLBi0GmL8Xe32XrEmbw1Lspdw15i7GFKWgQQQMBzF4AHdaEOPPz+AGH83NmSUk1lW77DPxrGB49EZi4iqr+7Xn8+n4lCxkQazFQ8nHdF+9rumurompO0fQG0Bzt4h/Md8C57pCk3+ayiqL8Ld4M750eez5vAafs/5naemPIWr3rXT/adWpvJ+8vu46Fx4edrLvJP0DqlVqRw2HubRwGF4lqWhHPwZxt90/HYqMisOr2CI95C2XCWXDr6UhdEL8TB44G5wR6/pGMj0p+tl2YFiFBQmx/jh4ay1S9usspU1eWtAgdkRs48cc8NrSNnrQKNHd86zPObuQ4RHBC46F2RZRjKU4qzTIjUH8cXWXP5vzpA+b2ub2DlIGcuhYBfKb/ezPv4CIodeaL+p/wNMo9nK+vRyFBTmjwzqk+uqq/eJxpQlvNqci0mrJTokEe+Z/0DW6Lr1edQT80YG8u3OPA5XNbDrcBUTIu0/W6yvdff3NmC+fVVUVGCz2TqMsAcFBXHo0KFOnzNs2DA+/fRT4uPjqa2t5dVXX2Xq1KkcPHiQ8PDwTp/z4osv8vTTT3e4v7y8HJPJ1Mkz+gdZlqmtrUVRlA7T47qiG3UTXhueoDHlN3518aHQXMmNg2/s45YOHBnljRwsqEankZgT7UJFhQMSQCkKXnu/Q2e1oGidkOpKudb2Jsnm29mS5sLYAPV3bTQbWZGxFKWhnAvMBmrGXEVTva3Ta0JRFObEXIcmbTHnVBXimraT2rArKGvouu6wvngXHltfRlKs1FsNNMcuOG6zbbKNyuZKAl0cV++0K78drOBweR2ezlrmxbhQVtbDesvHoyhoTFVsKtuExWoh1juWivITl/bx2LcEg9VCU1AijZWVDPLUcbCkme2HCnAf2nsfUF29TzTbmnkn9R0y6zKRkLhu8HVoG7SUocXLIxpddToNe3/CNOSCDvtUFIUvs76kxlzDbUNv63dTWe1pS2Y1VouVwT6uVFf2cUmnFmVNZXyy7xOa5WZ25O4gLD8JJ6uFRq8RNHXz2j6Zzw8fnRmrxUpqQWXv/g31wCj/6Uw9sAy5JIey0tKTzpjcn21MKVevqRD7fgZ1ek0oMj5Zm9BYLdgm3031ZidKSir4aM9XWCULc8LmUFVRxTcHv6GiuYL3dr7H9YOv77DvBmsDryW/htliZrz3eGx1NhYELSCrKou0ijQeV/T809aMd/IPVIfO7bIDubK5ki8yvyDdmE6wSzCPxT+GXqNHgwZnnLGYLFTTv9dam6wyKw8UYbXITAwx2O1vKbkqmYqGCjz0HkRIEZSVlWHIW4/HwZ9R0FA35QEsTtF4W6Cuqo466pBlGQ/Fg1ivYAqrA1h+oIip4U4M8rFjbqox/4e7+S1+L9vADwc/IzJ/A/dN/Q8aTd/nlhho1mRUU9/UTIingQCdibKy3p9R0dVnx/qs3ynChrdLKNeOuJeqipObbdMdkwe5seJQFT/tzGaQS0SfHcdR6uq6lyB5wAT0J2PKlClMmTKl7fbUqVMZMWIEH3zwAc8++2ynz3nssce4//77224bjUYiIiIICAjA09O+JYh6QpZlJEkiICCg21/ICJiFlDacpsoMfsv9DsXFh0sNl3aakftM9OmeQ+j0OuYOD2TIIPslqWmn9CBSQwE4uaJc8iHS8scIrirkLvMnLKq4F3//ADQaiZWpK6GpjCGKhvjgcZB4LR4KXV4TVwVdDcMXIC2+HRrKCTz0X5TZ/+z8C3H+dqRdb4JWAvR4VyWjTOm646e0oZRHNz6KRtLw8Tkfo+1HH7SlRhN/pGWh0+v4y6whRIf3codDys9Im9/iT4MSiRt7F37uoQT6nuAYlVlIlftBp8c94RLcfQMZF2MirTKf4iapy/wgJ6Or94lGSyOlllL0Oj3XDL+GWbFHraeMvxBp81t4Zf6MZ/hwtaTdUYrqi9hn3IfZZubdrHd5eMLD+J2h6xsLk2vQ6XVMiA3q1d9bV2yyjf9s+w+yRibeP54/xV+Jdt//QKfHY+QcPLrZhpP5/HDxtKDbWERVs4Krly/uTg74OuHng7TVFWyNBDo1n1KSpf4qq6YUnV7H1GGhdrmmWnV6TVRmIsmN4OxBaPxM5hvz+D7tZ0oaaxgdHM5loy9Dr9Vzj9M9PLvtWXZV72KmMrOtpjm0dADu+ZIGuYEIrwjunHSnOiU+MJDnA57nxR0vUm6q5Bk3eNxkJLjuAAw+B4yFUJ2LUpVDdWUau42ZfK0zY9K74ObkxoXDLiQkKASNdPzrV1ZkiuqLCHINQq+14xTkLqxOLcWChnA/V2bERdktkeaew3vQ6/TMjZ5LaHAoGIuQDnwGOj3KuBvxSeiYW8lsNXP9iOsJCAjgtZUZbMys4JdDdTxzYYR9l9hd+DKzt7zF75lfkleXzYG9zzF33pvQD36f/YWiKGxdU4ROr+PCcZF9tty4s/cJWbayqSkXSZK4evgVxIb1bUxx2SR31mQZOVjWjNbVy24VIuylu8ncB0xA7+/vj1arpbS0tN39paWlBAd3r66iXq9n7NixZGZmdrmNk5MTTk4dLwaNRtP9QNlBJEnqeTtHX0Hwun8xzWRmg4vCL1m/8ODEB/uukQNEqdHE9uwqJCQuGRfuuN/9od/U/2NnI/nFwPmv4fTL3UTUFHB19fvklo8hNtiX3MItYK7nEtkdzYyH1DKFLW+0XV4Trj5q1YNf/w+y1yKFjlGzDB8tfwesfAJkC4SNg8I9UJyM1FyrZjbuRJB7EFqNliZrEwUNBfYpbdNNH2/KxWxTGB3mzezhQb37JURRIOVnALzytnN2fTnMf+H4WcbTV8DGV0G2QlAckr/6wTcy1AuJAg6V1PX6tdfZNeHu5M6jkx6lvKmcaWHT2r8uQ89V60qXpSCt+IeagyPxL21fnsI9w3lq6lO8tOMl8uryeGLrEzw26TEGeZ5+wdWJpBTXISERF+5tl/eMn7N+JrMmExe9C3eNvQtdRRqYG8DZEylwZI8y3Pf088PL1YlgT2dKjc3kVjYSH+59kmdxCjROyIEjyS/Zgz5vA6G+N9i/DX2o0WwlvVS9psZF+tr9c6jDNVHYUuI2NAFJ78TckV4syt6GbLZyTsTFOOnV706j/EexMGYhv2b/ykf7P2KY3zA8DeqAyKrDq9hZuhOdRsc94+7BzXBkSdEgr0E8e/azPLftOYqbqnnCVsFDW15j2KbXwWZhl2ThP/omTLSslZW0DIuZx53j72lbMnYij6x/hLy6PJ6c8iQj/Ub2zgt1ClaklCEhMW9UMDqd/Tq/F8YsxEXnwtzIuWgUGdY+p1a/CR6NNP7Gdu8dO0t28n3a98T5xzHPbx4ajYabzopmW04V+wpq2ZNfy8QeVvw5NRr8pz3A1TTyWeZi/le2jcQ/HsBr/ovqWnuB9NI6cisbMWg1zB0Z1KfvHce+TxzMWUOZYsZF0jJ12GV9/r4V6e/OqBAvUoqNrE2r4MqJp9cofXdfv/4doR7FYDAwfvx4Vq9e3XafLMusXr263Sj88dhsNvbv309ISEhfNXPgiZ0DLj5cbLIimevZWbqTPGOeo1vlcL8mFyErahKiSDvXkm5jMkLWGgC2BcWwOGMx2xsLKZn5EJKzK7HWLOSVzyA11/GP0jKetLgxbvQN4BvT/WMEx0HiHerPW9+BsqOWrxTshuWPg80MUWfDglfAfwgoMuRu7nKXGknDEG91XV1/Kl+3PbuSHTlVaDQSd86M7f0RhfI0qD6sTg919VUTk/30ZyhK6rit1QwbX4e1z4O1Wc1Gfu4LbQ8PC/ZAI0GpsdluiYdG+I1gevj0jq+LwRUu/LcayAPs/x5+uRuMR8o/xXrH8tzZzxHmHkaVqYontjzB/vL9dml3f1FmNFFe14xGgmFBff+lMqsmix/S1Xrdt8TdQoBrAORvVx8Mn9gn5eqO1ZYYr9wxifEAvnWCh/X1/Ja7zGFt6Cv7C2qRW+q/B3n2g5K7rQF92AQAkmvW4exkRScHUFbSvuP2qmFXEe4eTq25lo/3f4yiKOTX5fP5wc8BuHr41cR4d/ys8nfx5+mpTxMVMJo6jUSurRFsFtA54+kTjcnZE617EGE6d6636HnKdVi3g3mAcA91uWVqZerJvAK9Kq+ysaXTVuqTSivHMzpgNPeOv5cQ9xDY/RmUparB8Ox/qqUDjyIrMofrDrM6bzU2xQaolRAubElI9tnmHGwOSEg27+zHiQ4ZT6Mk8UX5dvjlrnafS2eyP/aryfDOHuyPh7N9Zy6sylwCwDS3QTg7udvlmPPj1L+fFSklZ2xyvAET0APcf//9fPTRR3z++eekpqZyxx130NDQwM033wzADTfc0C5p3jPPPMOKFSvIzs5mz549XHfddRw+fJjbbrvNUafQ/+gMMPIiwhQtiS3lhH/K/MmxbXKwhmYrKw6qM0EuSghzXEPSl4PNTJpPKG9k/8S3ad/y+u7XuW/fv7k33IvX/Gx8X7eB7O+vRWqsYKRHJJrxJ5EDYfTlasBus8Cqp6C5Tv3ituxRNZiPPEsdydfqILolQ2nOhuPucrjvcADSqvpHQG+y2PhwQzYAl44NI8K380RNpyRdDSiWBEbw09hLKPeLVjtllt4PKUuObFdXqs6KSPlFvT3+Rjj3JXA+UubM1aBr60hKK+ne+qk+pdXD1LvVGQdOHlB+CH68DbKPZBcPdA3kmanPMMJ3BE3WJl7c8SIbCo5/nZxOWsvVxQa693mt8GZbM2/vfRubYmNKyBSmhU1Ts1NnqVmriUjs0+O3OpLpvsEux+vM4FC1NF+a8bDD2tBX9rZkt0+wU7m647Kaj2QWD1cD+gtiL+DiwZfiaZ7FqtRyGs3Wts31Wj13jb0LraRle/F2Nhdt5qeMnzDLZuL94zk/5vwuD+Xl5MWTZz/H2cMupXH0ZfCnb+DmP4i+dBGvLvySzy/7jdenPM1C2QnNgR/UWSnd1PrZlFrl+IB+eUupusRoX3zcHJRotnA3JH2l/jz9QfDo2LEwNnAszlpnTFYTf9v2N7YXqx2HV0yIwMNZR35VU9u52JNG0nD75MeQvCPZqFM4UJMFv90HzY7rYOwPGpqtbMxQ823Mj+t+Z1dvUBQF98ZqDEjMCZ1mt+NOjfXH1aCl1NhMckGN3Y7bnwyogP6qq67i1Vdf5YknniAhIYGkpCSWLVvWtjYkLy+P4uIjvXPV1dXcfvvtjBgxgvPOOw+j0ciWLVsYOdLx06z6lZEXgVbPJUYjWExsLdpKcf2Z28u5IqWEJouNQb6ujHNUmaCjpm87DT6HWK9YIj0jifGKwaAxoHV25oCTLzsNNpqaWpKNTHsAdCexdkiSYMYj4BECdcVqIL/sMTWYHzSlJZhv6eFtLTlSuEsNVrswzHcYYN+A3iYr5FU2kl1eT3Z5PZll9WSW1ZFRWseiLbmU1TUT4OHEVX0xHctmgazVKCj8LjXyTe5SSs7+GwyeowZaG19TR+TztsHi246Mhix4CSbc0ulo6vAQdZQ3xQ716Lst6iy47GMIigNzvbocY9Ob6vkD7gZ3Hk98nKmhU/Fz9qOwvrBb5atOB62/J3vUn2+2NRPoGoivs++REpHZ69R1xk4eED29z9sAEBuodjo5coR+aNQcQKLQ1kBDda7D2tEX9uapCd0SIrwd2xCAkv3qZ4KbP/hEAWo5uPsTb2CIVxxNFhsrU9oviYzxiuHSIZciIVHcUMwdY+7gotiLuCvhrhOudXfVu/K3Kf/gkskPgmcIaDTotXoiPCLUrPUxs9ScCc11cGBxt0+jdZp9WlUaFtnSs9egF5mtMmsOqQnw5o+y3+h8vbme79K+I7M6E5pqYM3z6veNEQshZmanzzFoDYwLGtd229dZnV7v7qTjmkR1adXX2/PadejYS6x3LOfELgTvSD51BrmuCDa9rp7TGWptWhnNVplBvq6MDLFv7i8JuL3BwodmD6Kieq9E3Yk467XMGq7mGFl+sPQEW5+eBswa+lZ33303d999d6ePrVu3rt3tN954gzfeeMMOrRrgXH0hdjZR6csZL2vZjcLPWT9zx5g7HN0yu7PJCkuSigC4KCHUcbXUC/dAbQHoXYmKu5rn9M6YbWacdc7IikxZYxl3fbcSr6bNBGhSIf4SCB9/4v12xdlTDdyX3A0lB9T7IhLhnGfUWRytvAepX+aqcyFvKwyd3+nuBnsPRoOGSlMlFU0V+Lv4n3zbuunF31PZnnP8TKp/nh6Ds74PRk/ztoHJSL6LJ9UoGDQGhgfEw+xx4BsLOz9SR+RbR+X9h6qvrWfXy39GBHvyx/4SUov7wQj90TyC4YK3YOfHkPw/OPgTuAXA2GsBdWTub2P/1na9nikOFKr15+PsUH/e0+DJo5MepdJUibvBXf3ymvS1+mDcZaC3T6WB1hH6wpommsy2Pp+Z0BkvjxCCDJ6UmmvJyFlFgs/pMQOvzGiiqMaERoL48L7vJDqhwl3q/2HjabQ24aJzQZIkJEniwjGhvLM2i1+Ti7ggPrRdYreLB1/MmIAxDPFRl2FdM+Ka3mmPRqOWUF3zLOz7Vr3uDSeeeRXmHoaH3oM6Sx3ZNdltnc/2tjmrgvpmKwEeToy14wyMpPIkfsz4kR3FO3jVZIDGSvCJhCl/O+7zxgaOZUvhFgAiPSPb7j93VDC/JRdTWNPEkqQirp5k/9wpVw+/muKGYq4YOQ7N2lcgczWET4Jh59q9Lf3Brly1I3DOiED7f4etK4GGclw0OgjsukR4X5g/Kpil+4rZll1JTaMZb9eBW175ZAyoEXqhD8VdDsAl1RVM8R/DgqjjlyU7XW3OrKCi3oyXi56ZwxxYci21JfAbcg4YXNFImrbgSCNpCHYL5uzwRMqV61gW9wGcfe+pHzNwOEz9P5A0ajA/77n2wXyr1lH6o6ZbH8tZ59xWG9Yeo/QV9c3syFWDeW9XPb5uBvzcDfi7GwjwcCLQw4mF8SFMjumj7OsZywHYExQLkjoKpNfq1dkPY6+Fec9Daz3m4QvhoneOG8wDjGwJDLPK62m22vqm3SdLq4PJfz1y3aX80q6+7NHX65mgptFMQXUTACP6MKCvbVbLA4GaiKitoyx/B1RmqoF83KV9dvxjebuqf2eKAjkVjpt2P9RTXb+dXrTdYW3oba3T7YcGeeDmiAoCxypoDegn8F7yezy+6XFyanMAmDksEHcnHaXG5g6dqjqNri2Y73Wxs8ErXB2lP9i9pYIaScMIvxEAHKrqvOSxPaxomaJ+zsggu2W2B9hTugeAsbIGDm9WZ9/NfgL0x3+/nhIyhcSQRM4LPw/DUWUEdVoNV05U8xJsyrRPqc5juend+MfkfzBsyHnqjDdA2fwG1OQ7pD2O1vpePMLOo/OF9YXk5LTkOfMfdsJrqrdF+7sxJMgdm6ywOtUxpVQdSQT0gipgKISMYYhN4l5NQFswdiZRFIWf9xYCcN7oEAw6B/15NFRC7iY+1TbxjaueZlvnSdESWpYD7Clq6r1jj7wQrl+sTgXvLJiHI+voC3aCubHLXc0dNJerh11tlyz3mzMr1FmDIR58cWsin98yiUU3T+Kzmyfx6U0T+eSmifxlRh+VTjHVwuGt2FBYKavTrieFTGq/TdRZcOXncNHbMOOhrl/bowR6OOHtqscmK2SU9tM1gcPOU6d415ceSch2FItsIaksiWpT/64Ffapag6+YADc8+ygBUY2phsc2Psa7ye9isR0zVbh1Dezwhe1yMdhD6yh9piOn3Qer04HTa7quYDPQJLVcU2MH9YP18001UJkBQLqHDztKdpBTm4NOo3Y0OOu1LBitrtVdklxov3a1jtID7PvmuJ9HR2tdR59SmdJXLTuugupGDhQa0UhqQG8vsiKTXJ4MFhPjsneod06+A/wHn/C5eq2ee8fdy8KIjuXsJkb5opHgcGUjpUZTbze7ZxKuZVdAJC8plTSserJtOdiZoqbRTFWDGUmCKDsndP4x/UceTfmYxdpmCIm367FbzR+lvg+tSClp6/w+U4iAXjhitDpKT+oSsDj4TdkBUovryCirR6+VOG+0fROJtJO2lAylmeXOWn4q3kRubW6nmyW0lInKrmigtqkXP7RcfDqvR9/KN0YdFbGZ1Wn3XZgTOYdLhlxCqHto77WtC+vT1AQw04cG9PmxOshcDbKV7d5BVNga8TR4Mj2skzXM7oEQPLrbu5UkqW39267c4y8lcBidk1rWDiD11w4Pv7brNV7c8SJbirbYuWH2ldyWvMy7T/ZvsVl4dderVJoqyajOaN/JV3JATVam0UH8VX1y/OM5khjPkevoZwGQYalBbuynfys9IMsKSXk1gFppxeEKd4OioPhE82W2+nc+PWI6ER5H8pGcNzoEjUbiQKGRPXl27MAbPEf9PDIZjyxpOoGEwAQuHnwxF8Re0MeN61zrGt/xkb7427FmdkZ1BvUN5bjXFjLUJkPkVBh16jN6PJz1jApTOxK3ZVee8v5OhVmx8qGLhr1ahceNSRRtft2h7bG37JbR+RAvZ7sugaoz17G9ZDtYGkmQdRDsmIB++pAAXPRaimpMbYlqzxQioBeOiDxbTYxmMnIw6TMWHVhEUlmSo1tlN78kqSMLs4YFOm7tjSwjpyzhU50JnL2ZET6jyzV+Pm4GIv1cURTYZ8+snpJ0JHlOTtfT7u2lqKaJjLJ6NJJaosXu0pejoPBbyzUzL2qeOt2+F5w9RD2fn5KKHJp47LhGXqj+n7cV6ttPc0sISABgW/E2OzfKfhRFaRtN7YuAXlEU3kt+j4yaDNz17jwy8RF13Xyr1rXzQ+eDu/07tGIDHJ8Yb5D/KK5wCuN+iytKycAvl5hVXk99sxUXg5ahdiiBeEIt5eq2+gaTVp2Gk9aJK4de2W4Tf3cnFo5WlxF9sD4Ls9VOyTA1Whh7nfrzvm/UWuonEOYexp+G/4k4/7g+blxHajI8NaC3ZzI8gD2HfoTafOJtEtrQsTD7H8fvvO+BxGg1Ud6J8tj0NYPWwGNTn8TPbyjFkszjWd+QdOBrh7bJnrLL1YA+JsA+5eJabSzciNXaTLTVRoyiVUsiO4CLQcuMYern4LID9q+84EgioBeO0Gja1l/uSv+FP3L/YGfJTgc3yj6Ka5vY2tKz7NBSdXlbWdNUSLYWXN2CuHbEtcfdvDWA2HO4pu/bdrTWLNp52487m6PaVM2Woi3k1/XdWrbW8ixjIrzt3xFTnQvlh7BqtESHTsJd7868yHm9tvuzB/szJdYPWVZ4fUW6/b4k94T3IAhNAEWGQ7+1e2hi8EQA0qvTqWxy7MhNXymobqKy3oxeK7XlPehNP2f+zOaizWglLfeNv0+tG92qKltdBytJMObqXj92d8QGql8c86saHZbrQSNpuDxiNmMUHdrSAw5pQ2/a2zI6PybcC60d11d3SlGgYBdmFL4yqe/jF8VehJ9Lx3wk104ehI+bgaIaEz/tLbBfG4fMA89QdWnA0SVC+wmbrFDbaCG/qpHf9xdjbLLi62ZgQpSv/RqRuYq9hxYDCmP942HBy2DovSnZidHq9XCwsJY6k2OnuUd7RfPivPcZ6hVDIwov7X6dpYe+PSOmYGe3dKzG+Ntvur2iKKzJWwOWJubYDGriZDsv/TravJZlLFuyKhx+LdqTCOiF9oadB3pXRjYawdzYL2q12sOvyUUoCoyP9GGQXx/UKO+mugM/8D+tCZy8uHL4VXg5Hf9NsTXJ24aMcmob7fjG5T9Unc1hNXW6drrVV6lf8daet1iavbRPypcpisKGdDURz/QhDphun74CAH3EFG4ffw/vzX3vhL+znpAkibtmDsbbVU9eVSNfbOuntbZHtIzSH1qqlulr4efix1CfoQDsLD09Owdba96ODPXESde7Uxy3F2/nm7RvALgl7paOI4pJ/1P/j56udqw4gJ+bAW9XPbKirqF1mOAx6v/F+xzXhl7Sr+rPGwugvpTfdDYqFDN+zn4sjO24jhrA1aDj1rPVnCnf7sy333pqjRbGXq/+nPy/bi0ZbLI2sad0D+vze3+WWanRxGOL93Hb57u46oOtXPzOZq77ZDt3frWHTzapiQTnjgyyX2dNyhIa1zxLJTYkJ08S5r16ciVujyPYy5lBfq7ICuw67PicKV5OXjxx/n+Z6RSILFv57+5/82HyB6d9GdUjI/T2C+iz6rIorC/EyWrmLFnvsPXzrQYHuhPt74bFprSVhjwTiIBeaM/JHYady3BZC01VFBrzqDXVOLpVfcpslVmVov7RX5TQ9+u9u2Qs5pvSzdRLCoP8R3RrpHdUqCdDAt0xW2V+3Vdkh0a2kKQjo/Q5G7rcrLV27dr8tTyz9RnKGnv3zTW3spG8qkb0WokpsX2Uwb4rsgwZakDfWr7v6Oy/vcXLVc/ds9SkRb8kFbaVR+tXoqaBizc0VMDh9uvlJ4dMBmBb0ek57f7IaKp3r+63wdLAe8nvAbAgagFzI+e238BYDJmr1J8Tjj+Tpy9JktQv1tGbA4ezU7Lwc9W+bidH64+azDZSi9W1nwn9Yf18wS4UFPa7uYOk4doR1+Kk7ToYnD7En/hwLyw2hffXZ9lvVHTIPLWTuam603wex8qvy+elnS/x35T/9nqQ9+W2wxwoNFJqNNFoPtLB6WrQEuTpTEKENxfEH7/KSa/Z+xVsfA1XBT4ccj0vLliEp0vfdBRNbpl2v8PB0+5b6Z3c+ev897hRdkdjbkBblYWEg2e89CGTxUZRrbrkJNaOU+43lW4CYKqsxxXJYevnW0mSxLlxLcnxDpaeETMzQAT0QmdGXYqHpGVQswkqM0n9+hL47kb4/WHY+Brs/fK0KgeSVlJHk8WGt6u+17+U90Ttge/ZpLGAwY2bx96FVnPi0T5JkrhsvFoyZum+YprMdpzy2hrQH94CVnOnm0wJmcKtcbfirHUmtSqVh9Y/xMrDK3vtDXZDujrdfkKUr/1LOxXthYZy1jppyfLq27WQiTF+nDMyCEWBN1am02i29unxekxngKEtpS5T2095TQxJBNQSUbXN/bAz4hTYZIX9Beo59XbyMje9Gw9OeJCpoVO5fuT1HTfY9426zCF8AgQ4ppZ2q9Z19JkODOgtrj686mzhf9omagsH7myQg0W12GSFIE8nQr0cX/pRKtyNhMQ/h1zLQxMeYmro1ONvL0n8dUYsWo3Ertxq+62p1urUEqEAyV+DtfPqMK1ivGJw0jpRb6mnsK7rzPy1zbXsKtnV7c+sUqOp7XPpkXOH8+614/jy1kR+vussvv3LFD6+cQLPXhxnn+VhOz6CHR+qP4+9Du3Z9xHtE9Nnh0tsmTG4O7e63ywPk/xiOW/MrTxlcePmZp3967LbUU5FA4qilu611/JDm2wjuy4bFJk59XXqnQ4O6AFmDA3ASachr6qRQyV1jm6OXYiAXujIOwISrmWEVu3hS5Eb1LXC+dvV9Wk7PoIVjzu2jb0oKV+dHjYm3Nuu9WDbkW14Za3jNbM7Nw++nJF+I7v91CkxfoR6O1PfbGVFih2TgASOBDd/sDRC4a5ON5EkiXlR83h5+suM8B2ByWbi4/0f8/z25zGaTy0DqaIobevnHTPdfhlGZD5xhr9vfZLsmuw+Pdzt02II8nSirK6Zjzbk9OmxTsqIlozRBTvV0eMW/i7+DPYejIKilkw6jaSXqp2B7k46Yvx7f0Qkzj+Oe8bd07Fzr7EKDv2u/uzA0flWbSP0DkyM56Z3I9xZfR/IOLzOYe04VUcnWHR48CHb1I5LQBMxiQnBE7rVpghfVy4dp+ai+WhDNiaLnTqah54L7kHq38e+7467qU6ja1sOdLDqYKfbmG1mntn6DK/seoV1+eu61YSf9hYiK+rv7+wh/kT4uuLlqrd/LoS87ergCyAn3oEy8bZeS4DXlcEB7ni76mmy2DhQ1I86bwdNZpiiQ1u6H2QZq2zl3aR3u6wgNFC11p+35+i8VqPlyYQneSz6UgbLkvr352HfZI+dcXPStSUVXn7wzEiOd1IBvdVqZdWqVXzwwQfU1ak9H0VFRdTX99MszELPTbqdkfNeAv+hHIqaCOe/BtMfgnHXq+vVqg9DrR2T3vShvX1ccqpbCnZBYyX+zt6cO+HuHj1Vo5G4ZKw6Sv/T3kIsNntlF9Z0a9o9QJBbEE9MeYIbRt6AXqPHaDbionM5pcMfKqmj1NiMi17LxGg7rzU1N0LOBlZqzVicPIj2iibaK7pPD+li0HLv3KFIEqxKLXV4eaAOvMLU0WJF6ZAc74aRN/DK9FeYFjbNQY3rG63B15gIO3cGHvhRLRsZOAJCx9rvuF1oTYyXW9lov/efTgxtqQiSVjZwO47SS9XvVKNCHZdUqpW2Kp1VthpMTm5q3pQeuHJCBIEeagfkd7vsNKNPq4eJt6k/7/2iXcdiZ1o7zg9VHur08V+zfqWgXv2e80vWLyccpa9pNLOiJXi4vGXmXF8rrO9kdoFsg23vqj+PvoL9YSO4e83dfJ/+fZ+2RaOR2rLd95dp94B67epdobkOqrL5KfMn1hes58ktT55WlZzaEuLZcf08qEF9fFODupzBwevnj9Zak35jRoV9c0w5SI8D+sOHDzN69Gguuugi7rrrLsrL1RGyl156iQcffLDXGyg4zgjfESBJNEhgCRkDIxaqH5attbSPSYamKAolDSUDar1KncnStu7TUesV6831pB5oSW4VO0f9UtJDs4cH4uNmoLLe3FaT3S6iZ6j/524C2/GngWskDefHnM/L01/mb2P/hl5zaqXdWqc1To7x7fVkZCeUswGLtYnlBg3onVkYs9Auo2lxYV5cMlYd+Xp7TSY1jZ0vdXCYo5Pj2Y58gA7zHcYgz0GOH3HsZa21whMiejf4+vzg5/x7z787n/XRXA8Hf1J/Tri2z0fduiPQwwl3Jx02WSGvynHr14eGTQEgo6Ggy2VA/ZlNVshqSWo1ONC+Zac6k5T3Bx/rmnjUxYqNnn2uO+u13D5dnd69eE8h+fa6Loaco1bdsDbD5rfUDsYuDPcdDkBKZUqn31suiL2grROyuKGYAxXHr6Dw675iLDaFIYHuxIf3fYdMfl0+D657kOe3PY/lqPdbDv2mzqp09oRxN7C3dC8VTRVUm/o+WV3rtPvt2ZX957ug5qgyasVJnBd9HnF+cZhsJl7a8RKrD692bPt6SVtCvD6YLdYZi83S9juWSlvKhfaD6fathgd7tOWYWpLc9bKa00WPA/p77rmHCRMmUF1djYvLkRG2Sy65hNWrT48/CkHl5eTFO3Pe4e3Zb7evqx2hJrkir31A/3Pmz9yz9h7WFzi+Nnl37S+oRVYg3McFf/fezfraXd+lfMFTpev5VmuCYQtOah8GnYaLxqgJ/RbvLUCW7fRBGhwPLj5qz3fL1MwTCXUPJcIjou12k/XEdYOPZZMVNmW2ZLcf6pjp9hs1Fmqd3fB19m1L/GYP1yZGMsjPldomC2+vyew/X5oAIs8CV181MVXuJke3pk81mW0cahlNHTuod2eIbCvexuaizZhsnWTrzloN5gY1q33kWb163JMlSRKxgS316B24jn5Y+Fmg0ZGJBWtZisPacbLyqhoxW2VcDFrCvE9tBtOpMtvM/FipJricFpzYrZwux0qM9mVilC822Y4J8iQJzr4fNDrI2wq5G7vcdIj3EPQaPbXmWkoaOk7LNWgN3D32buZHqUlPf8/5vct9NZlt/L5PnRFw+fhwu3RefnPoG2RknLROR76jNdfDrk/Vn8ffjOLkwZ6yPQCMDez72Tzx4V446TRU1JvJbpkC3i+EJKj/Fyfhpnfj0cRHmR4+HRmZD/d/SHp1ukObd6psskJupfp6R9tphP637N/429q/sal4PZS2vN+2Dvj9P3vnHd5Wdf7xz9Wy5L33jEfsOHb2DkkgO2FD2COUWeAHJUCBUlYppFBCGYVSKBtKKZsyMyGQvZzEiWM7w3vvqXnv748ryXY8YieSLSf6PI8fydLVPUfS1b3nnPd9v18XQBAELpskjzX/t7f8tLewG/CE/pdffuGPf/wjGk1XwYX4+HhKS0//FZAzjWBdcPcLU6wsckXZHrvwjCiJdnuldw+8O5hdPCWGOt2+oLGANYe/AiRGe8cMOK2xM4szwvHUKCmua2d7wSCluykUED9Tvn9sYAs5xc3F3P/z/Tyw8YEBD/T2lzbS0GbCR6sa/O+uOhepbDffKo3g4ceShCWoFIMnyKdRKbh3fgpKhcC2Y3XsdyXVe6UKUpfK948TxyttKeWl3S/xt11/G4KOOZ79pY2IokSYr5YwX8eJl9W011Cnr0OBghF+PQhYlVj1KpIXyL8/F6Gjjn7oBvER3pF4a3wwIVFYOHwWlm3kWxeIkkK9h07Pxco3+Z9TZ2ohSFJwXuZNJ7UPQRC4dfYI1EqBfSWN/Jw3SNljAXEw5gr5/qaXenU9UCvVJPnLDiKH6uS0e1ES2VS6qYvy/eJ4eaF9T9Ueylt6TuP/8UAFLQYzkf5au52sM8mty2Vn5U4UKLgy7UpESeTLw1+yf/MqaG+QtZDSzqe8tZzKtkpUClV320sn4KFS2gVCtx11obT7CJut5V4QRdQKNbePuZ0ZkfKi6Of5nw9h506d0vp2TBYJnVpJhAOvR32xo2IHte21qFor5RIwra/sQe9CTI4PJDbIk3aThW/39V2CM9wZ8GhAFEUslu4CJyUlJfj4+DikU25cjy4TroAE8AqRf8BlWQAcqOkQlREl0bWihn2QNYQTelESeTP7TUR9I9NFNempl5xS+qynRsWSDNkK59NdJYP3HYyYI98eXiv7svez3WBdMGUtZVS2VVLeOrATra2sYEZSMCrlIE5q9I2w5lH2CmZKtJ5oNd7MjZ07eO1bGRHizfxRsvCMy12kUs+Vj+PS3V3cMBQo2FS2iR0VO2gxDn+9lb3Wc4ej1e1tkaJ4v3i0quMGZqLYkQkTNcGh7Z4qtgn9UCrdC4JAil8iAPkVu4asHydLvvWzSx6kdHtREmkzdUx29WZZtPSZ7c/wZe5/AImrPSLRBMSddBthvlout0bJ/vnzUaqb+1afdxjjrpVt7FqrYdc7vW527ahreX7O88yJmQPAD8d+4KU9L/HM9mfs19AI7wjGhY4jwiuCBkNDt30YzSJf7JEDWpeMj3b6YowkSXyY8yEAZ8eeTZR3FN8f+56PDrzHKwVf04QIU28HpcoenU8LTDtl3Zr+MiXBmnZ/zIV0XkJSQaUFfRM0FADWCO7Iy1CgYE/VHo40HBnaPp4CNkHS+GDPQVkMrG2v5UjjEQRBYLzRuvgVnukSJWCdUSgELp8on3++yiobXCeoQWbAI+EFCxbwwgsv2P8XBIGWlhYee+wxlixZ4si+uXEBTBYTz+14jlvX3Npx4ReEjih9sewtneCXwHWjrgNAb9EPeII2FFQ26alo1KMQIGMQ6t2O55eSX8irzkZr0nONxVOOuJ0i54+JRK0UyK1o5kDZqanI95vIcbIQiqkdNjwFPz4MrSe+kOtUOtKC0gA58tFfjGaRzUfkdPvZg5luL4qw/s/QXIHBK5jA4DTOiT0HT7Xn4PWhE0utizdbj9ZS0zJIg+T+4BPeUZbTyQ86wjuCWJ9YLJLFPsgczjhrMTC/Ph+A5IDk7k/W5MnlLRqvIbeqOx6bMN6xmhYsg1Xy0wNXjrqWF4zeLGyoln+zwwhbhD4lzPnBEUmSeGrrU7yx/w37YxqlhnWF69hdtRuToZnRopKp0bNPua2Lx0eTFOpNi8HMqtW5g3N8qLUw83fy/f2fQG3Pk7VE/0SivKMQBIHCpkI+PCRPlMeFjeuSnXjn2DtZNWeV/ZrVmZ/zqqlrNRLopWHOyFCHv5Xj2Vm5k9z6XDQKDZemXArAvLh5ROnbqMfCP/19kazn4D2V8rV1fOh4p/fLxqT4QBSCXNNd1dxD2dBQoFRBWLp83xqIAgj3CmdG1AwivCLQm12kryeBrbxhxCAp3O+olK1Bk/2TCWqwuu64UP18Z2YmBdudoL7Pdv25ycky4An9qlWr2LRpE6NGjUKv13PVVVfZ0+2feeYZZ/TRzRCiVqopai6i0dhIbn1uxxMxtgn9dgC8Nd4sHbFUFtJDTgdzdfZYBa1GhvvgqRlcD/NWU6u8wm5o4hKzB0HRU8Dr1NP0Arw0zE2TI7ef7x6kEhiFEpb+TRZMVKigcBN8cn2/ovW2mr6BTOh3F9XTZrQQ5K1hVITvKXV9QOx+Rz7eVR5MWfQ3Xp73KstSlg1e+8cRH+xFeqQvouSCtiw2C7u877v4Qds86beWbx2KXjmM2hYDRXVtCAIOF7+ynTtHBvQwYS+zLoREjpN/dy5EhK8WnVqJySJRUj90wnixsWcRofZBMLZBnXOtJB2J0SxyrFb+3AYjQl/eWk52bTbbyrfZ08sVgpy+fUvGLTwkBXC/wQMhetIpt6VWKrh/4Uh0aiUHypr4aHvRKe+zX8ROlZ1YJBF+eb7PBR6jxcjLe17GLJoZHzqehXELuzzvrfFGIXQfMouixGe7ZCX8C8ZGolE5N2PMIlr46JAsort0xFICtbKqvEd1Hv/X2IIKgZ0eatYVr6fN1EZOXQ4wOPXzNvw81aSGy9dml1K7jxwr35ZndXn4htE38Pyc50kPTh/0LjkKu8J98ODUz++skEu/JoZNQFUjH2OupHDfGYVCYNkEOUr/xZ5SDObTM0o/4DNPdHQ0e/fu5Q9/+AP33HMP48aN4y9/+Qt79uwhNNT5K5NuBh/bivTB2k4iQ1ET5AFlY0kX+7rJ4ZOZHT2bMK+h96E8ER0RtkG2PAM+yf2ERkMjUYZ2FosaSFl44hf1k4vGRaEQYGdRPUX1g7TirFTJloaXvCFHDg3NcrT+h4egpfe6SdsgI6c2p9/ieLY6zJlJwaecWiZKYv+cGQo3wy6rNsRZ90FQIiqFasii8zZsJRY/ZFcMqV1YN2KnyZF6fVMXCzvbhH5v9d4uqb7Djb0lDYDsu+yjPTW3hs4YLUYKmgqAXiL0tvr5qMGLtvUXhaKTMN4Q+tGjUHQIM1UMH/u6ozUtiKKEn05NiI/zBVptWTKjgkZ1maien3g+c4MyGdNUixJlxyToFIn013HHOXK9+n93FrPP+htyOtPvki3LKrMht2dRux0VO7j2+2spbi7GT+PHbWNu64jOl++DzX+XM9CQf6MbSzbaF0G2HqultKEdLw8li0aHO/3t1OprMYtmvNXenJ9odRURRdjyCgmSkitDp4LKg3cPvEtBUwHz4+aTEZxBhHeE0/vWmclW+zrXrKPf1yXY4KX26nGxZrggSVKHwv0gROhbjC32+cAkbTgKUwuoPCCoh2uWizBnZAihPh40tJlYc7ByqLvjFE7qCFapVFxzzTU8++yzvPrqq9x0001dFO/dnF6MCpS9WnNqczoe1HjZ02s+3PUiawrX0GZqY8mIJdw+9na7v6urIoqSvQZ2KOrnUwJSCBAU3KAXUGt8IN5x/tyR/jqmJwUD8F3OINewBY6AC/8Bk2+W7feKtsAny7s5ItiI8IogzDMMs2Q+oSUQyErCthX/U023L2wq5LHNj3H3hrt558A7vW/YWAobngZgb+IMNmjVXQSThpJpiUH4e6ppaDOx5YgL1SsqFDD2Kvl+1r/tFmLR3tFEeUdhFs3DOu3eblfn4Pr5On0d0T7RBGoDCdEdd3ybjVBhtQaKdL0JPXQSxqsaWnXrTV7ePK9qY0/B+iHtx0DIr5QXQZJCvQdFIX13pfz7mxDWgxZD3g8AmAOTwcNxWVCzU0KYPyoMSYLnVucNju2mdwhMvEG+v+01WTDuODoHK24fezt+Htasm7pj8P0Dcsp+zjdIksQDGx/glaxX2FW5C0mS+NQanV+aETEomX6hnqGsmrOKR6Y+0rGgfGQ9VOWA2pMlc54kMzgTo2jknQPvcE3aNfxx6h+d3q/jmTJCntDvL22k1dC3pe2gEZIGSo3swtJQ2O1pk8XEDwU/UNxc3MOLXZfqFgMtBjMKhUBsoPODDLurdmORLMT4xBDeaE1hD0uXAzsuikqp4NIJ0YCsMeVSARAHMeAJ/Xvvvdfnn5vTD1uE/mjj0a41RjFTqEfkm/It/Gv/v2g0uJDa9gk4WtNCi8GMTq0kJWzw/X6nR03nJc9RZEgqSDwbVJoTv2gAXDJePnFtL2wafOsYhRLGXQMXvyEL0RhbYO3j8uDoOARBGFDa/ZajNRjNIhF+2pP2aW43t/PugXd5cOODdgGyA7UHMFh6qEM36WHNo2BoxhyazjuKVl7b91qf9kWDiVqpsEeFXK42LGWxLJ7ZWmOP0guCwJRwa9p92fBMu5ckiawS+Vw3JtrfofsO9wrn2VnP8tI5L3Wf1FXul4VIPYNcTknYRofS/dCKHuaqlWxTmNhXs79XhXNXYzDr59tMbXZV926p2Ca9PIEF9ImLHN72LbNGEBOoo77VyAtr8wfHYnX0JRCUKGeObX21WynY2TFno8KTCxIvZmzoWPlBQzOs/iPYMokKf0UQBCaFyyUI3x37juzSJvIrW1ArBc6z2sYOBmqFmni/ePkfkx62/1O+P/YqFF7B3D72dnw0PpS2lJLfkD9o/epMdIAnUf46LKLErsL6IelDN1QaCLMGm8q7Z++8mf0mb2e/zRf5Xwxyx06NY9bofGygp9NLPgDifONYHL+Yc2LOQajYB4AU5jp2db0xNy2MAC8NtS1G1h+qGuruOJyT8qHv/Hf77bezfPlybrnlFn73u985oYtuhpoQXQhB2iAsksUu2ARA7BQ2KUyIxlaS/RLtKV2iJHKs8RiVra6b1mKrn8+I9htUlXR7arexDU3BJvl+iuMHTUmh3kyIDUCU4MHP9rPZ6tk+qAQmwAWvyOnBpjZZLE/fXahvcoRcpmEbKPWG2SLyn+3yyvm8tLABR7EkSWJz2Wbu+ekevjv2HSIiUyKmcOPoG1k5cyUeSo/jXwC/Pg+1h0EXwPq0cyhrq8BH48M5MecMqG1nsjA9HIUA2aVNFLiS769KA+Oulu93itJPj5xOlHeUy2fx9EZxXTv1rUY0KgVpTtJwUCt6SOMvtWY0RE1wOSVhG7YJ/dHq1sGZqPVCSswsUGo4ILYiHhgeg3O7wv0gLDDvrd6LRbIQ6RVJuNdxaeK538puHj7hGKNnOrxtrVrJA4tSUSsFdhXW82VW71ovVc16Gtsd4B2tUMre9AB5P8qR+k6T+v0FapqO3UjhMWsNsCjCuiflckJPOdJM+T7QN7EgfgEKFBysPcjb22UNofmjwvH3dOyi/PE0G5tZXbAas9gp2i1JsOttaKkC7zDIvAyAAG0Ad4+7m6dmPDWk51lblN6l6uhtfvSdhPFsLIqXx2KbyzZT1lI2eH06RWxBm4RBqp+P841j+ejlLElYDNYJvasK4nVGo1JwyfgoAD7ZWTKk4q3OYMAzmfr6+i5/LS0t5ObmMnPmTD766CNn9NHNECMIgv2i0KWOPiCBjR4KQOIsXZT94bez3+bBXx5kbdHaQe5p/7HVwDol3V6S4PsH4fNbuyjrtpnaePCXB9lYshHp6E9yTZ5fdIfyqoO5Z34yaWGe6M0WVn5/iPe3Fg7+IFupgrmPyfZBTaWw7k/dhInSg9K5feztJxTtWX2wkvJGPf6e6pOKhpQ0l/Di7hep19cT5hnGQ5MfYsWEFSyIX4Ba2cMEKudrefAnKGg/+0E+LVoNwKUplw557Xxngr097L7H3+53sSj9yKXWKH21vX41xjeGVbNXsWTE8HRF2VMsR5vSI30dGg2RJAmTpY/Ji31C75rp9gBRATo0KgXtJgtljf3TxHAG6SGjUXuFUihY+HjfG/b6Z1elzWimtEHu42AI4tnKXcaHHXcsWcyw92MApMzLnSa8GBfkxS2zRgDw7pZCcivk7ISqZj3rcir525o8bnxnBze+s5M7PtxNm9EBKdvho2H6/8n39/4HNv4VRJHKJj3vbJazxzYfqZUXRXe+CcXb5PTsRc/IpWSSCEVbCdYFMzVyKgaTyI7q9SgEWbfG2XyR/wVvZr/JS7tfkh9or4cf/yC/F4Apt8h1zFYyQjI6ovhDhM2+bmdhHWZXSXHu7Ed/XKZGvF88E8ImICHxxeHhsRAIHYJ4iSGDM6G3U5MHrTVIik6ZDy7OwvRwfHUqKpv0bMzrXd9pOOKQ0UhycjJ/+ctfuPvuux2xOzcuSGZIJqmBqYR6dggfFjeXUKjWoAKmt3WkNSb5y8I3tpQ+V8NgtnDQaunmlAl9Q6FcO159CL64TbbukiQ+zfuUgqYCPsv7DHPe9/K2KYucFm3z1aq57+xYzrdOfv+7o5g/f5sz+PVsOn9Y8Gd5sFGyA7a/PuBdtBstdmXkyyfFoNMMfKAZ4xvDovhFXJpyKatmr+pIrbRiES18kf8F+6v3Q+VB2GQdOE25la9bC2k0NhLhFcG82HkDbtvZLM2Us2N+yq1ynXpFkKP09lr6D+1R+sGoEXYWzrKrq2yrZPmPy3l88+PdRRoNzfL5BFzOf74zSoVgjxIdqR66bJEAbQC3TL4PlBq+FBvYuPX5IetLfzhc1YIkQYiPh9MjvQB+Gj/8NH7drcyOrIOWStAFyCUzTmRhejgzkoIRRYk/f3uQm97dyY3v7OSFtfmsP1RFldWvvrHdxDZHRXgzLoXZD4CggEPfIq19nFfXHkRvEu2X4W3rv4A9H8j/zP49hKRA3HT5/0I5q25R/CJqW420qw4wLVlHuJ/21PtmaofVj8APf4DiHfbJpiRJZFVl8WPhjwCcE3sOlOyCT2+UxVqVGphxNyTOPfU+OJjUcB98dSpaDRYOWRdthpywdFnfp622i6CzjUuSLwHg15JfqWh1MfeYXjhms6wLdv5i4KbSTRyoPYBFtMBhOWhnjJwMKgf8BgYBrVrJBWOtUfpdxUOaSeZoHBZeUKlUlJU5P0XllVdeIT4+Hq1Wy5QpU9huTXnqjU8++YTU1FS0Wi0ZGRl8951r1L4ON2ZFz+KJ6U9wduzZ9sc2lm4EjRdjRTU+ndKXRgbKdktHG4/2HXEaIg6UNWGySAR5a4gOcIKYoy2SplDJNa8bn6P4x9/z/dFvAbgh4TzU5fvkibwDvOf7QqkQuGlmAvfMT0atFNhRUMe9/907+LZSwUnyQApg70f2C4ENSZI42niUDUUbenz513tLaWgzEearZWH6ySsJ3zD6BpalLOsxIv/1ka/5T+5/+Oeel2lf/UcQzZAwi9rkeXxzVK4BvzL1SlQK1xN+yYjyIzbQE71JZEOui9WGpZ4LXsFyWqhVbAvAYDGwsWQjRU2DZGHlAEwWkexSuX7e0RP6vPo8zKIZi2TpvuBRvleOEPpFg7dru8nYtC1sNeFDxayYOVwYK59f/3n0C3Kr9g9pf/rCJog3GNF5gGtGXcNr81/r6qkuivKiG0DGsi7RXmcgCAL/d04SYb6y8nRlkx6FIGsIXDI+isfPT+dCa+R7U74DS8ZSl8C8x0GppilnLZNyV+GlMHL/wpGEW8oZnfcqRosofwbJ8+XX2ERri7fLi5LGCEztIUiChZDwvH43bbQYeffAu9z/8/28se8Nsqqy5DGSJMFPK+HYRnnR4Lv7aP3vtXy/6WlWbPgdK7evxCyaGR04ijFHt8J398oT0oA4uOg1GH2xS5bhKBSCXWdkf6mLaCypPCDUetz3UEef6J/I2JCxiIh8efjLwe3bSdCsN1HZJC9+xQc7N3PQIlp4O/tt/rTlTxyqPSiLMQKGGMeJOg8G52ZG4OWhpLiuna1HXUhM+BQZ8IT+66+/7vL31Vdf8dprr3HNNdcwY8YMZ/TRzscff8yKFSt47LHH2L17N2PGjGHhwoVUVfU8gN28eTNXXnklN954I3v27OHCCy/kwgsvJDv7xGrabvpGlER+Lf0VNJ7MkrRd7OvCPMPw0/hhFs0caTxygj0NPjaF6nExAc6JFNq8oicshym3ISkUvFW6AbH+GJN8kxhbZ60ZjBwPPoNj73dOahjPXJJJkLeG0oZ2Vvx3LzsKHFfXtqeonkMV3evju5A0tyNa+/OzUNOhx1Crr+WhXx7i9X2v02LsKqjV2Gbis13yZ3bttDjUA9A8MIkmXtr9Ersrd59QmX5RwiJCdSFUV2XzoaEE/GNgzkP8N/8TjKKRkQEjmRw+ud9tDyaCILA4Q17o+G5/+Ylt+AYTlQbGXCnf3/MBWBf53s5+m1eyXuGHgh/6eLFrkVvRjN4k4qdTEx/k2PRGm0Bjsn8P1j+lu+RbF47O2xhl1RXYUVA35Mfh5TMfZZLKH7No5mjOp0Pal77Iq5IXP5IHQRDPhkJQdLXqKtoM9YWyg82oCwalD14eKp64YDRXT4nlsfNG8dEtU1l12RiWz0hgQlwA56TKi1e7i+odk3ZvY8RsGmY/SVkLjDQf4i/a9zgrzMS90rtoJAN5ymSY+tuO7YNTZDFKUxuUZ/HR9mK8TJPw1arJDEvqd7Pri9bz3bHvKGouYm3RWlZuX8lNq2/i+e9uZGPBGtoUShi5hGa1ltvbDvBO3n8pK/4VbXsjC4PGcnd5EcLej+QFgLTz4KLXZbE/FyY9UnYMOFDmIhN66Jp23wMXJ18MwMaSjVS1udgC+XEU1MjBmTBfD4daqPZEbn0uzaZmvNXepBoMsuCthw+msLFObdfReGpU9rLNj3cWD/l1ylEMONR04YUXdvlfEARCQkI455xzWLVqlaP61SPPP/88N998MzfcIFuQvPbaa3z77be89dZbPPjgg922f/HFF1m0aBH3338/AE8++SRr1qzh73//O6+99lqPbRgMBgyGDrXrpiZ5kiKKIqLoIjVAPSCKIpIkOb2PraZWmoxNeKm8iPWJxWQxMS40HKl8H1LRVkiXT4QjA0ayvWI7h2oPkeKf4tQ+DZQ9RfVISGRG+zr+85JEhFJZrV2KHAeho9istHBw94uoLSauzd2MpLaqgiYv6FZP7kiOPyYSQ7x4flkmf/khl4PlTaz8LodXrhpHmO+ppUrllDfxyFfZ6NRK3v/N5L7riifeiFB7WI50/PgHpAv/CTp/Aj0CifaOpqS5hKyqLKZHTre/5OOdRbSZzIwI9mLGiMABfWfbyrbZU8RePvvlPqPrHgoPblaE8JSpjTUqBZMnXstolZazo8+mtLmUq1OvRpIklz35z0kJ5t3NBRTVtZFVXN+jAvtgnSe6MXIpwp4PoaUSKfd7SD2XmZEz2VC0gU2lm7gm9Rq0wyBlz3buyIjyBSSHpuvl1eWBJE/oj/9+hBJ5Qi9FjnfKOcORx8WEWH/USoHShnaOVDUPii9yryhU/Hbs/zF3yyrGFu5DNLY7PfJ8MuRXNiMhkRji6fTfZnFzMVHeUV0n85KEYE0zl9LOB7XnoJ0rInw9uGxitP3/zu3FBmiJ8NNS1tjOtqO1p2xVakOSJF465Eul52/5nelN4szHkD6+lli1gWxFAC8aL+eFFiNB3p2OldhpCIe+oWr/OnYWzsRTSOWS9AAmhE6w97nd3I5O1XvW39yYueTU5pAenE5hYyG7qnZR31TEtqYytqlg5ZhbiB97PV5Tf0vy2v+jvi6fBQaJs/Rt6OrkhU9J4400635ImGX7wBzymfSHkzkmRkV4IyGRU9aEwWQe0IK80wgfg8D7ULYHyWLplt2Q7J/M6KDRaBQaTBaTS4/9D1fJ5474IC+n93N7+XaQZGcM4fA6JECMn4WkULn0Z9QTSzPCKW9o58JxUS49rgP6/dkOeEI/VF+a0Whk165dPPTQQ/bHFAoF8+bNY8uWLT2+ZsuWLaxYsaLLYwsXLuTLL7/stZ2VK1fyxBNPdHu8uroavV7fwytcA1EUaWxsRJIkFArnnDB31uzk7fy3SfFL4e5Rd3ND3A0Yo420Hf4Or+JdGHN/ojlEVsUNV4ZjMpvIKs1iqs9Up/TnZGjUm8mvkFeKI7WmXrM7ThZl3WH82+qRVJ7UiQHoy4t4q2g9Zp8ozjeqCKwtw2w0IKl01HmlgoPb70xvx8Rd00J5boOeQ1VtvLYuh9/OOHlBH1GSeHlNAWaTmWaTmV15xSQG913GIIy+Fb+qoygbSjF98wBNZz0GChVJuiSO1R/j12O/kqSSox7VLUa+3l2MWZQ4P9WPmpo+REwkSU6T75RO/9WhrzCZTUz2n0xdTd8ZCZrSraQd/JG5aiWr/YJ4OfdTHtaOxF/pzx2JdyCYBIcfL45mUpQn6/Pr+XTbUSI00d2eH4zzRG9oE5bgte9tLNveosFvHEGKIPxV/lTrq1l9aDVTQ13nPNEb2/IrMJvMJPg69ljQW/QcrTuKiEiAJaDLvoX2OgJrDgMCdepoJCccg44+LtJCPNhV3Mz3ewpYNnaISwQCJ5OmCsHcXEnr9o9oSpiHSqHqOqEdQpr0Zsrq5BpYP9qpqnKeN3u9oZ6Hdz+Mv8afx8c+jkYp1+urqvbjV7YPSaGhPnwOUlXVkJ4rOpMZpqGoppnV+4pI83fMoHtrQSOb86tQKmIwnPMkxqy/oDDUo1Jq+CH8NmobNXz4ax5XjO/IoFP7jsLX/AV1B9ZhVkzhrER/FoWNtP9W6w31PL3vaWaEzmBJ9BI0Sg1t5jZWl65macxSu3PFldFytlKmNpPzNenUb3yQXajID4xFF7HIvr8rxz6Kp+CBR9lWhPz/Ya7PxxQ0ipbJdyN6hjh17NAbJ3NMeEgSWoVEi97I9kNFJIe4gKCsEEKgRURoLKP+WDaid/dMyeVxy+XvrBWqWl33up9dWIXZZCbYw+LU8YkkSWwq2oTJbCJZFY8590UEs4mGwHE0NDQM+XniZLhmjD+IrVRVuZA7UA80N/evfM31ikF7oaamBovFQlhY1x9eWFgYhw71LL5WUVHR4/YVFb0LXTz00ENdFgGampqIiYkhJCQEX1/nWBQ5AlEU7dkSzvpRZeoyUR1TUaIvISA4oMNaSbsA4dBHqBoOoQv0A5UHE9UT+br0a4oNxQSHBLvM4OlQXjUqtYqEIC+SY53gG1u+HkGlhthJhIZHsLlsM61iK5F+MVx61l9RHvoaYfd7SKMvJTQy1vHtd6KvY+LOBd7c8/FedpW1USfqSA0/uVTPtTmVlDSbUanlU0m1Sc200BMN4EPhvL8ifHU7qoY8tPU7Ie18zlKexU9VP3G47bD9mPlwbx4olUyI8+ecMQm979JsQFjzCFRmI138BvhGUdBUQFF7ER5qDy5Iv4BAbWDvr28sRtj7GqjUXD1qGXv1h6lpr2FN7RpuHH3jwD+YIeKyad5sLNjD/ko9Ck8/gr27RiMH4zzRK4HXIBz7FlV7PaGNWZC6lAWJC/hv7n/Z07yH80efP7j9GSCtBjNFTYdRqVXMHh1HiI/jIr37a/ajVCkJ04WREn1cRlN+lnxOCU4hJHqEw9rsjKOPiwWZCvZW5LK30sDtISFDL4I45QaEX/+G+dhX/N18kMzQsVyVetXQ9slKUWE9KrWKKH8d8dERTm1rf9F+1Co1Eb4RREd0LPgJu34AlRop7TxCYuXjb0jPFZ1YNM6LH/Iaya0x4u0fiKfm1Iatje0m/ru/AJVaxVWTYxk1PgaSXoM9H0LiOSy2JLL9m4P8WtjK8jkB+NrSmIPm0r71eVT6ehJ8Krlh9lRCO4nhbTuyDSNGNlRtYH/zfpaOWMp3x76juq0arZeW60Zd17Uj7Q0I614gFIGRUWchLXq2Z1eBiIth/EXQVoPSMwjtEI6lTvaYGBtXz9ZjtVQYVMw44fhgcBAiM6DyAMGmYgh1fQ/13qhsK0OlVjF2RAShoX2McU6RgqYCmixNeHl4MUunQ40R/CIISJuNuaZ2yM8TpzNabf+yF/t1Zjw+yt0Xzz/v2mqyJ8LDwwMPj+4DNYVC4fIHqyAITu1njG8Mvh6+NBmb2Fm5kxlRVs2EwARZqKmlCqFiH8ROYUTACJaNXMbIgJFynwbpIiRJEluO1OLvqSEtwqfbQHJvSRMCAmNjA5zzOdnq56MnICgUzIyeSbBnMBbJgodGK/vEZiwbtAFub8dEUqgP89LCWJtTydubCnj20swB96nNaOb9rUUICIT5elDZZCCvsqV/n2vQCJh0E2x+Wa4JTD2X1KBUdGodzaZmCpoKUFoi+Dm/BgGB5dPje9+v2QhrH5MV9AHhyDqYsFy2TRRgSsQUgj2De++LsQ3WPCqrDEdk4jXt//ht/SGe3PokRc1F6C16l7Kp64v4YG8yovzILm1i9cEqrpka120bZ58nekWjkzUUtryCkPUhjFzMnJg5fJr3KYfqD1HZVkmEt3MnNKfCgfJmJAmi/HWE+TlWTPNI4xEQZEHRbt9LmVzCQ/REBCd+Z448LqaMCEKrUlLZZOBYbRtJoYNXG94jI5fAng/Iay+jsFpPYUsxKQEpTI4Yek2Mw1WtCAiMDPNx+m9yT/UeEGBC+ISOtqpzoWQnCAqEsVdBpz4M2bmiE4kh3kT56yhr0LOzsIE5I09tQvjGL8do1ptJCPZm2cQY+b35x8DZctnmREliRLA3x2pa+W5/JVdNsS68K7TskZIJYweXBhcRGdD1mnBB8gVE+kTyzoF3qGmv4d2D7wIQ6hXKzKiZXT9DiwnWPS47CvhFw7wn5EW7vhgkvZ0TcTLHREa0H9uO1XGgrJllE11kHB0xFioPIFTsh7Rze92spr2GtYVruST5kp7tbYcQo1mkuKEdAYHEUOeeP3ZV7QIBxoSMQXdso/xg0jwUSpVLnCdOZ/r7ufZrqz179vTrLysr61T63CfBwcEolUoqKyu7PF5ZWUl4eM+q1+Hh4QPa3k3fCIJAWqCsDvrSnpdYV7TO9gTETJHvF28FQK1Qc2nKpWSEZAxqdH53UQMrvz/EA5/t43cfZ7EupxKjWS4TkSSJLKuHtFPs6sxGqJDr4zt7RacGppIe1MlrfqijVVaumRqLh0rBoYpmNh0euNLnxzuKaWgzEemv5bdzZGGevIEoW6edJ9sjNVdA/mpUChWZIZmA7JP87uYCJAlmJgf3LhYlWmD9n2TPYBsFm2g1tcqijcCC+D6cBCRJ9iOuL5BFj+Y9AUoVo4NHc1XqVXirvU8opudqLM2QM0/W5lS6niVL2vnW77wc8lcTpAtiTKgsULShuGeHA1ehw64uwOH7jvKOYnL4ZMaEjOn6hCQNK0E8G1q1kkkJcrToF0cqlJ8sVvvE6aKaJXozSJLLHG/5VkG8RCcr3JssJrJrZEHgLnZ1NmX7pLng63oLaoIgMDNJXpDddPjUjqUtR2r5Jb8GhQB3z03qsZ5bEAQumxgDwP/2ltFutACyVsz69hQEAaYrc3rc/6TwSayavYrzE89Ho9AwI3IGz5z1DEkBxwnnbX5JFmRTe8LCp0DrutmfjsAmjHewvMl1rkknEMYDecz4py1/4ovDX8iuTi5GUV0roijh7aEi2Nu5dpf59bKI8aTgDNkuESDJ9Sx8z2T6NdPasGFDv/7Wr1/vtI5qNBomTJjAunXr7I+Josi6deuYNm1aj6+ZNm1al+0B1qxZ0+v2bk5Mgl9H2vOowFEdT9gn9H3bCDqbvdZBN8DR6lZeWJvPje/u4N/bijhQ1kRNixGVUiA90gkX0KqDYDaALoBDksHl1VGDvD24eLycdvnO5mP2hY/+UNbQztd7ZZvKG2eOIC3CF0GAyiYDDW39rAFVeXRSP38fRAvjQsYBsKk4i12F9SgUQo9RZkAWA/ppJRz7RfbinfuI7C9ck8dP+V9hsBiI8YmxL0L1SP4a2XpFoZStjDw7UtYuSLqAByY/gLdmCEW9ToLJCYHoNEpqW4yu4/1rQ62FMVfI9w98AcDZMbIVZk27C0z8+sDmjjEmxs/h+54SMYV7J97L7JjZXZ9oLIHWalkXImy0w9t1JmdZJ2G/5te4huBQ6rngGcSsdgPomzhYexCTOLS2qpIkcbhKdvVIcbLC/YHaAxgsBgK1gcT5Ws+pDcWyVRp0nItdkBnWY2lXYb19gj1QmvUmXv3pMAAXj4/uM2tkemIQkf5aWgxmvs8uB+Df24o4qErDR6tB13QMmit7fK1WpeXqtKt5d/G73DX+ru7ZXQe/lv8EAeY+CgHxJ/V+hhMjgr3QaZS0Gy0crXGReuXwDHm80Fze63cpCII9IPDt0W9d4zzWiaPV8meZGOrl9KzPhyY/xMqZK5nUrpftmAPiIKj/Dg9unM+wyo9YsWIFb7zxBu+++y45OTn89re/pbW11a56f91113URzbv77rv54YcfWLVqFYcOHeLxxx9n586d3HnnnUP1FoY9M6Jm4KH0YHzo+K7psVETZN/1xhJ5kIDsubqtfBtf5H8xaP07UCa7Etx0VgLXTYsjyFtDQ5uJj7YX8dDnsg9xWoQvWnUPtWqnijXdXh+RyYt7XmLFTys4WHvQ8e04kIvHRxHopaGyycA3+8r6/bq3fj2G2SIxLtafSfEBeGpUxATKA5fcfk4iC2tb+XfLWMoMHlSXHePbL95nT34gyYrltJXKtkkL08OI8u8hvVmS4NdV8oRcoYT5T8irxWFyJkRwQxlxvnEsjF/Y94Uu5yv5dty1EJHZ7/fvymhUCqaOCALg18N9iAgOFcnWjInafGhvYHzYeF48+0XuGn/X0ParD6qbDZQ2tKMQILMH9wCnUbpTvg0bLS+GDCPGxwWgVSuoajaQX9Vy4hc4G6t9YpykwLe9Ab253R51GiqqWww0tJlQCJAQ7FgbxOPZXSVfn8aHju84J9os0OKmu7T9WUKwF5H+WkwWie0nabf61q8FNLSZiPLXceXkvvVrFAqBSyfIUfovs8rIKm4gq7iBdpUvvvFj5Y0KN/W9j54yE9sbYJvVYWnyLRB3ZgSXFArBbmfpMvZ1Gk/ZjhD6jNKfE3MOWqWW0pZS9lb3vt1QYFscSQh2ftBBEARG+I/A89gv8gNJ81wm29SNzElN6Hfu3Mnvf/97rrjiCi6++OIuf87k8ssv57nnnuPRRx9l7NixZGVl8cMPP9iF74qKiigvL7dvP336dP7973/z+uuvM2bMGD799FO+/PJLRo8eXpEOVyLUM5TX5r3GignH6SpoPDsmREflVEaTaOJvu/7Gf3L/Q6PB+SdxvcnC4Wp54DhtRBDLJsbwr+smcv/CkV1E3ybGOT5lFrCnxn6tkajT1+Hv4U+Sv2uvYGrVSnsE/OMdxTS2nzhilVXcwLZjdSgEuGnmCPvgcKQ1wpTbz7T759fk8dHuaj42TKehzURA7n/46UAje4/4UNFkQqtWcOWkHgZekgRb/g4538gr7Of8UR6QAsSfBcCUujKeOesZ5sbO7b0DTWVQkS1flNLO61efhwsdKaq1rpPiaMMzUI5KSRKU7UGtUBPu5dplULZ0+6RQH7w9HKslW9FaQWVrZc/Rn1KrJscwSre3oVUrmRTvQmn3AGnnodAFkmGygKGZ/dX7h7Q7hyvl61VckJdzFpmtSJLEnkpZi2FcqJwFRXMl5P0o3x/rGgKBvXGqafe7CutZm1OJIMBdc5P7tla1MmdkCMHeGupbjaz8Tk6xn58Wii7ZmkVzggl9j+z5AIytcmQz84qBv34YMzpKzmzKLnWRCT1A5Fj5tjyr10081Z72ccQ3R79xfp8GwFHreHdEiHMXA+201XWUgCX2MbZyMyQMeEL/n//8h+nTp5OTk8MXX3yByWTiwIEDrF+/Hj8/x6ciHs+dd95JYWEhBoOBbdu2MWXKFPtzP/30E++8806X7ZctW0Zubi4Gg4Hs7GyWLFni9D6e7niqPXsWB0leKN/u+QAaS/BSexHtI6d059blOr1feZXNiKJEkLfGrkCtUiqYlRLCX5eN4fnLxnDH2YmcN8YJ6vbGNqjKoQqRr5vk93rtqGvttkCuzNzUUBKCvWgzWvjP9qI+t7WIEm9sPArA0swIYoM60glHWhdN+pPmXdNi4Gh1K4IAAZMvw9svgJHaRlYklnLdtDhumBHPHfN9MVLfvYZ9x79g/6fy/dkPQOI5Hc/FW4Uay7MQjC196zccXivfRo4Hrz5E84YhY2P88dQoqWs1klPRNNTd6Y5tgmobHFhp0DfQanKRlMxO2LU3Yv0dvu8vD3/JXRvu4tP8T7s+IYodgnidNDmGEzOTbWn31a6xsKTWwuhLyBBVYGhiX82+Ie2OTXMk2cn18wC/Hftbzk88n9HB1oBG1geyzWfkODn92MWxpd3vLKgbUNp9u9HCKxvkVPtzMyMY1c9yO7VSYS9JazNaUCistfVx1mtMWRYYBpB50lRmLzNi6m+7iA+eCYyOskXoh1cdPcDihMUoULC/Zj+FTYWD0LETI4oSBTVtACQ6OUL/2ObHeH7X89TlfgOSCKGjwO/k7Y7dOIcBn1Gefvpp/va3v/G///0PjUbDiy++yKFDh7jsssuIjXWuDZcbFydloTxQNxvg52dBFBkZMBKA3HrnT+ht6fbpkb49plknh/mwaHREj0I4p0zFfhAtfOCpxCQIjAoaxeTwoVdQ7g8KhcCNM2VthO+yKyhtaO912++zyymqa8NHq+qWtmjLgsi3Lqz0xR5rPXJSqDfLZ48iYuZ1BHlpmNP6A+dmBtKkWcsHh1/grg13sfyH5Tz0y0O8sufvfL16BXuy3saIBDN/ByMXddlvjUbLj75+tItmKNrWvWEbkgT5q+X7yX2I5g1TuqTdu0p0tDO2Cb1twgr859B/+O3a33aIbboIoiixt1iOKo1zgpimLe07wfc4W8aaPDA0g8YLQlId3u5gMDEuEJ1aSU2Lsd+ZO05nxBwyRRXz9BYuiFs4pF2xlSIkhzl3QC5Yr0lXp12NVqWFlio49J385ITlTm3bUSQEexHhJ6fd7xhA2v27WwqobjYQ5uvBtVPjB9Tm/FFh+Onk4MWCUWGE+mplRXz/GHkxpLiPa8zx7PiX/JroifLfGUZSiDceKgXNejPF9W1D3R2Z8Aw5Q6+xBGoO97pZiGcIUyLk4OG3R78drN71SXmTnnaTBbVSICrAsa4rnWk0NHKo7hDby7ejOyYLDZPkjs67IgOe2Rw5coSlS5cCslBda2srgiBwzz338Prrrzu8g26GEYIAs+4HtU5e8Tz4JSMDrRP6QYjQH7RO6EdFOD9TpBuluzggmNmmklCgYHn68qH3Xh4AY2L8mRQfiChKvP3rsR63adKb+HCrHMG/ZmocPtquWRoxAZ7o1Er0JpHCur4v2LsK5YjnBFv5Q/rFoPGG+gK0xdtRK9UEagNRKVQYLAaONhxl46FP+bB0PX9Rt7E5bT6kX9Rtv6sLVvOWopWXVG1Q8EvvHajOlbUeVB6QMKvPvg5XbNHRTUdcMO0+cqxcLtFYYhckCtGFICKyoWiDS4kPFdS20thuwkOlsGehOIoWYwslLSUAJAckd33SZoEZMbZnf+phgEalYMoIOe3+VBXKHYZ/DEH+cdxs9mDySQqsOQJRlDpN6AfZ1i/r3/LkMmJMR9qxiyMIQsc5rZ/HUnZpI9/uk8sw7zwnGZ1mYL8jrVrJ/52TxMzkYK6e0mkBO26mfGtT+z4RVYfg8Dp5jDTltgH14XRBpVSQGiEf59mlLpI15uEDsVYdgx8fgtbej6tzR5yLVqnFz2MIxpc9YBOATgj2Rqlw3ljTlpEQrvFDV50rX7c7Z0W6cRkGPKEPCAiguVleaY+KiiI7W7ZBaWhooK3NRVbd3AwdvhGy2AvA9tcZqZYnbMcaj2G09FP9/CSwiJJdjM0pCvYnar90F++q9KD2ZG7c3A4V4WHEDTPiUQiw7VgdL6zNY9XqXJ743wHu/2Qvt3+4i1vf20WLwUxskCcL07vXPCsUAinhcqSpL2E8i9hhH2if0Ht4Q8alAAh73ufqkVfy7OxneW/Re/ztrL+yQhHC5a16posaEkMymDHjQfv+DtYepKylDJPFxPri9aDx5mxRIzsumHs55mzR+fiZsv7DaYgt7b6+1cjBchcZQNnQeEGo1X3AmnY/LXIaHkoPylrLyKvPG8LOdWVvSQMg14A6OrvncIMcFQr3DO8+UCyxCuIN03R7G7ba518P17jOwtJAJ2ROoKyxnXajHGGLC3TeOajF2MJb2W+RVZUlL5S1VMMhay3wMInO27Cl3e/oR9q93mThpXVy9suCUWEnbVU7ZUQQDyxKxd+zU/mcrbSreBtYzH3vQJI6hPCS5kNwct/bn8aMttrXuYwwHsCch+SMi5Yq+OFBuXyyB5ICknht/mtcnXb1IHewZ9YelBfCZ6U4t1ywoKkAgDiz9TiPmtDFDciN69Dv0Ylt4j5r1izWrFkDyPXpd999NzfffDNXXnklc+e60zDcAKMulAXyTO2Ebn8Lf40/ZsnMkYYjTmvyWE0L7SYLXh5KYp04OOoRfSNi3WEmiioCfCK5fOTlg9u+g4gJ9GTRaNm5YF1OFT/lVrOzoJ5DFc0U17XTYjCjEOC2WYm9rgjbhfH6mNDnVjTTarDg7aEipbN10OhLZF/e2iN2wSGlxUTk5leYUrKfiyVP7p69kqfPfd+u4SBKIv/c+09W/LSCP276I83GZoJ8opjgEQamti4p3XYsZjhiTes+DdPtbaiVCqYl2tTuXSQ62plIqziXNRLtqfZkasRUADaVnYTglJOw2dWNc0L9vG3hIiUwpesTpna5jAeGpSBeZ8bFBthtFF1GzyFuOiISuUU/883hr4ekC/mVNkErb1TOKAOzsq96Hz8W/Mj7B9+Xs8b2/hssJjnd2PYbHCaMCPYivJ9p9//eVkR5o54gbw2/mZnQ57YDJjQddP5ySUzFCZTPi7fL1yGlGibd6Nh+DDNswnj7SxtdJwtL6wuLn5W/z5p8WPcnEHteLNKpnJfaPhCO1bSSX9WCQiEwJyXUqW0VNBYAEF9vFRx3e8+7LP2+imRmZjJlyhQyMjJYtmwZAA8//DArVqygsrKSSy65hDfffNNpHXUzjFAoYNbvQalBKNvNSElOczvS6LwJva1+Pi3CF4Uj04/aG+CbFbJvbG+U7UEtwWW+qbw87zV8NIOcPulAlk+P54rJMVw2MZrfzIznrrnJPLQ4lacuGs0LV4zlvd9MISO695Qzm5dyXh/1sruK5Oj8uFj/rt+V1hdGW50ydr8H+ib47j4o2iqnxi9c2a12q9XUSpRPFBKSfSV5ftx8lAnWCFxPafelu+TvVecPUad3LWNnZWiXiY7a6CyMZx3c2Sb02yu2dxdDHAKMZpFs67lljBPs6uwT+oDjJvT5a2SvX9+oYe9T7ZJ6DmGjMWh9+JNUxfv7/kVFa8WgdyG/yvmCeAaLgS+PfAnA+LDxckpxji06f8Ows53qr9p9fmUzX2WVAnD7nCS8HOxMgUIBsVZnlb6yPESxIzqffjH4uLabh7NJCfNBpRRoaDNR1qgf6u504BsJC58GpQaKtsDml+zXpOORJImc2hyyqrIGt4+dWJcjR+enJgTi59mDQLUDKWwqBLOB+NYG+fNJOMup7bk5efo9of/5559JT09n5cqVpKWlcf3117Np0yYefPBBvv76a1atWkVAgJPswNwMP/xjYPLNAFxWms+Lkx9lacJSpzXXUT/v4HT7I+vlCcevz0Phlh43kUqsSt1R43tW/x9G6DRKrp4Sx7XT4rloXDTzR4UxPSmYzGh/EkO8T3jxsNUYF9W10WLoORVxz/H1853JWCZrMNTkwae/kaOUGm9Yugpip3Tb3Efjw+8n/Z7n5zzP/Lj5TA6fzIL4BV1TasXjJoa2dPvEc0Dp4IGeizEmxh8vDyUNbSbXS7sPGy0PENrqoL4AgIzgDLRKLfX6ens6+lByqKIJo1nE31NNXJBjM39ESbQL4iX7d0rDlSTI/ky+n37RsJt09cRZrqbnoFCgi51JsqgCYwv7awbfvs4WoU9xUv28JEm8se8NCpsK8dP4sThhsew7bzHKv71hWsphV7svrEdv6h5JNZpFXlibjyjB7JQQJic4KT3YZpV6eB2U7Op5Apj/I9QdlWu1x13jnH4MIzQqhT2Lz6Xs6wDC0mULXEGAA1/C/k963OzX0l95fMvjvHPgnSFZdDZZRNYfqgJg3qgwp7ZltBgpaykDfRNxkhLipsnlcm5ckn5P6M866yzeeustysvLefnllykoKGD27NmkpKTwzDPPUFEx+Cvcblyc0ZdCWDrRJiPhO9/FWcNSSZLsk5X0SAcLllTl2BqBDU/JIl6dKG4u5pHCr8gVzMM+NdYR+HtqCPPVAnKU5Hga2ox2IahxsT1M6HX+cskGQGs16ALgvBdPaKsU5R3FTRk3ce/Ee/FSe8mppGpPaKuF6kMdGxo7ieWdxun2NtRKBdNGdNQwuxQqjVyaA/Y6erVSzYQw+Xe0o2LHUPXMjk14aGyMv8NFLkVJ5Paxt3N+4vnE+MR0PFG2R17gUGm7uTgMV1xSzyFuOpmSPKHfV3WCtGkHY7aIHLF6SCc5KUK/unA1v5T+ggIFd0+4m0CRjkyzCcuH7UJRYogXYb5ajGaRT3eV8EN2Oe9vLeRva/L445f7uf3DXRTVteGnU3PzrBHO60j0RPAOhfZ6+HYF/O+uriVeZgPssGatjr1azkBzQ7o17f6Aq03oAUbMhqm3y/e3vgpHf+62ycTwiXiqPClvLWd35e5B7iDsOFZHs95MoJeG8T2NoRxIg6GBWN9YAs1GAhAgYbZT23Nzagy4cMvLy4sbbriBn3/+mby8PJYtW8Yrr7xCbGws559/vjP66Ga4olDAbDn1nuJtmPO+5+finx1eO1XWqKehzYRaKTh+cFRtndDb6uVWPyLXtyIvJLyb9Rr55ib+pzR1eJqe4djs63pKu7fZ1SUEexHopen2PACZl4F3GPhFwwV/h+CkgXdCpemI6Bd2qscu+EUeaPlFD1srsIEyM1lOd3bttPuOgdH8uPncnHGzUzN6+ovteD1ZQa2+UClUTImYwtVpV6PsrGJ/4HP5NmWBHNk7Deis5/CLq6TdR08iEy1YjGRX7sLSS92sMyiobcNkkdBplET5O74uN7cul/cOvAfAVWlXkR6U3ik6nz6sbdPktHv5WPp4RzGvbDjCf3cUs/5QFXuLG6lsMqBSCtx+dqLdcs4pqHVw4Wuy9otSA+X74H+/k//K98lZNq3V8rVs9CXO68cwY3Rkhx+9S5KxDNIvlIM46/8MlQe6PK1T6ZgTMweAXZW7Br17q61ieHPTQp2qbg8Q6hnKM5Mf5pU2JQLCsNPcONM4JSWWpKQk/vCHP/DHP/4RHx8fvv3WNfwZ3bgQAfEwYTkSEq9teZpX97zMv/b/y6GpSrZ0+5QwHzQqB4oL6ZtkazOApc/L0eK6o7DxryBJ7Kzcyf7K3aiAa/1GnTaD71MlxTqhP9SDMN7uoj7S7W14BsLl78PlH8gT75Mlroc6+s7e88M0QjVQMqP98fZQuWbafaQ17bc8yy5ElBaUxry4efhr/YesWwDNehOHrVHUMU6Y0PfcaAUUWBeg0i8enDYHCVva/eYjLrKwpPFkRMQkvCWBtrYap2q8HM/2Y7KgW0aUn2M1X6xsLd+KWTIzNWIq5444Vy5rOfiV/OT464f9uW9JRgQjQrxICPZickIgizPCuXZaHCvmp/D0RRm8ef0kpic6V/0bAK8gmHEXXPFvuTxGqZaj9F//H2x/Q95m0o3yArMbAFLDfVEIUNVsoKrJherobQgCTL9LtrOzGGH9U91E8kYHjwbgUN2hnvbgNGpaDOyxjqHmpjk33d5OxX4UEnIZrVvd3qU56QLSjRs38tZbb/HZZ5+hUCi47LLLuPHGM1vB000vjLkC4ehPpNdls6mlirVFa2kzt3H72NtRK059Bd1mgTLK0XZ11bnyrW8kBCXC/Cfgm3vkmrmQNL5u3gemNpZaPAiLme7Ytocxtgh9bkUzkiTZU5VFUerfhB5kEbxTJXaK7N9dXygvzKh1HZHg5Pmnvv9hglopi5Ktzank18M1XJzmQgtPwSnyQpihWf69hY0a6h7Z2V/SiCRBTKCOYG8HHI/H8WPBj0R6RZIamNqhvXHwK5BEub450MHK3EPMmOMWlmyK10OJIn4G6RU/sc3Ywr7qfd3FCZ3EpiNylsJ0a9aCo7lu1HXE+cYxJWKKfP7d97GcmRSaBjGTndLmYBLqq+XFK1woWugdAjN/B2OuhD3vQ+73IJohKEm2qnNjR6dRkhTqQ15lMwfKmgi1lui5FAolzH0UProcmkrh2M9dvNdHBowEoKy1jEZD46B506/PqUKUZGtmZ2T2HI8oiSjKs+R/IsY6vT03p8aAwpllZWU8/fTTpKSkMGfOHA4fPsxLL71EWVkZb7zxBlOnTnVWP92cgJL6NrYXuVj0zYZCCbPu42xJy92tZlQmPZvLNvPcjucwWAynvHtbhN7h/vNVB+Vbm192xBiYdgcAlVtfJq9yLwpjG4ssGnf9fCcSgr1QKwWa9WbKOynZHqluoandjE6ttE/6nYqHT8dFqOBXWeBQEmVBKN9I57fvQsy0R0drEV3FLgjkspzIsfL90o70xTZTGz8U/MA/9v5jaPoF7LHWzztD3b7V1Mpb2W/x521/ps1s9T02GyDnf/L90yw6D6BSKhgf5w/A3pKGIe2LnbgZZIoqMLWTW71vUJosqW+jqLYNhUJwuGCbrZxNEATmxMyRbbba6mSRLxjWtfPDAp8wmHWfnF025TZY8Gf5HOemC6Oj5LGaywnjdUbj2XEe3vufLqKH3hpvor3l7EGbU4mzkSSJNVZ1+3mDEJ0XJZGbV9/MA0c/pR7RPaEfBvT7TLN48WLi4uJ4+eWXueiii8jJyeHXX3/lhhtuwMvLrXo4lDS2mXj064O8+mspX+8tG+ru9EzISEi/iKmimvvbQCOoyKrOYuW2lbSZ2k56t/WtRsob9QgCjAx39ITeWj8f2ilqmH4xJC/gV4UBmkpIN0sEKj3kukQ3gBwRTgyRtQw6+9Hvsqrbj431d6rvchfibWr3v0Lej/L9Myg6b2NMtJ8cHW03kVt18r83p2BLu+80oRclkfcOvMdPxT/JKrtDQFYnQTxHc6RBTu8O9QztiO4cXidnKniHdShon2ZkWhdH9hW7yEDeO4RJAan8yeTFg0HdXTScwebDtQCMjfbDR+u4Gu/NZZtZtXNV9+tp9mdg1suaITGD8x7PeHwjYOyV8q2bbtiyc7LLXOQ80BvpF8nZgtW5XXReAFIDZQ2ewZrQHyhroqJRj06ttDs9OJPK1kpajE2UGhvwRXBrRA0D+j2qVqvVfPrpp5SUlPDMM88wcuRIZ/bLzQDw0aqYal3p/9evx3h3c4HDheccwqQbwSuEsS0NPOw7Gp1KR05dDs/uePak+2urCY4P8sLbkV6zktQ9Qg8gCEgzV/CL1gMkkVmiWp7Mq10wbWwIsdnX5VZ2n9A7W5m1C7YJfUU21B4GhQoSzx689l0EVSdRsh1F3bUNhhRbdkvlATlKjRwBSQ+SF8m2V2wf9C5VNumpaNSjECAj2vHplLZBoN2urotV3YVyVtNpSKb1s8ytbO7Rcmwo8IufxUhJhbKoZ1tSR7PZmm4/zYE13rXttfxz7z/ZUbmDNYVrOp4wtXfUzo+7xh2dd+MSpEX4IghQ1qCnrtU41N3pHZ0/jFwi39/7UZenlo5YyjNnPcOVqVcOSlfWWMXwzkoORqdx/vWhoKkATO3ESgqUvlFyWYkbl6bfE/qvv/6aCy64AKXy9BxoDGcUCoGbz0rg0jHyD+7TXSW8uC4fs2XwPTL7ROMF0/8PgNRDa3ks/SYCtYFcnHzxSVtCOa1+vrkc9I3yBDAouctTktqDCyf9jrGCJ5NEtTvdvgds3sp51gh9s95kV723pd0OCt6hcp22jdipoB362t2hwJZ2v7O4CYsriJLZ8I8FrxBZgKgi2/7wlAg5mri9fPAn9DZ1+5HhPnhqHLhQaMXmP2+v2a60LjgpNZA69Or+ziLcV0uIjwcWUXIdgUbbol/xDjA7d3JR0ajnSHUrCgGmjXBc/fyOyh3oLXpG+I2QRfBs5P0gZ334RkHcDIe158bNqeDtoSIhWM7sdem0e4DMy0FQQMkOqDlsfzjSO5J4v3gUgvOzDduMZjZZbWed7T1vo7CpEExtxItKd7r9MMFd3HOaIAgC56YHc9c5SSgEWJdTxVPf5bhMFMROwiw5nVQ0k7Dnv7w050UyQzLtTzcaBnZyt9XPj4pwUrp9UFI3hVqFoGDOyIt5aMmb6NIv6vBNd2PHViN/pKYVg9lCVnEDogSxgZ6E+gxyNoNtwA5nZLq9jcwoP3w8VDTpLa6V6igIsggcdEm7nxQ+CQGBI41HqGkfXKszW4332BjHZ5NIksThBnlgmBxgXSzMtlrVJc07rRecBEGwR+n3WUsahpygJOo8A3hDquepn+5zalO26Hx6lB9+no5Lt99fvR+QF8HsFoiiCPs/le9nXOKu5XbjUoyOtPrRu6p9nQ3fiI6svuOi9IPFxrwaDGaR6ADd4OgP0RGhj5OUEJF5wu3dDD3uM/xpxry0MP6wJA21UmBnQT1//DKbJr1pqLvVgSDAjLtBpYXyvagPr7U/Vd5Szj0/3cN/c/+LKImYLCJHq1t6TcdvM5o5VtMKOCFCb6+fT+t9m9A0mHkPaB3c9mlAiI8H/p5qRFHiSFUruwsbABh/InV7Z5AwSz7uPHwg9vSsTe4PKqWCKdbSnJ0F9UPcm+Ow+9F3TOj9PPzsdYrbyrcNWldEUWKvE+vnK1oraDG1oFaoifWNhdYaWUUZzgi/avuEvsRFFpUEAY+46axXGtlXvc+pi0ebrPXzMxyYbm8RLRyslcvDMoIzOp4o3gaNJaDxhpTFDmvPjRtHYBMxziqud80S0c6MsabVH1kPTeX2h7Nrsnkl6xVWF6x2avNrO4nhnWw260ApaDgKJj3xksIdoR8muCf0pyFTRgTx1EUZeHuoyK1o5oFP91HV7EJ+nz7hMPEG+f7WV6FdnlzsrNxJq6mVz/I/4687/sp72/K4+z9ZvLguv0ff4kMVzYgShPl6ON5WqidBPCC3Lpdvj35Lvd7FJkQuhiAI9pXkQxVN7OqvXZ0zCEyAJavgvBfPeD/gjE41zC6FTRivJg/0HRGbyeGyxdaOih2D1pWjNa0062U3hpQwb4fvP79BTrcf4TdCtu08+JXscxyeAcFJDm/P1ciI8gdk14tWg3loO2PFK2EOiaISjC3sq9rrlDZqWgzkVTYjCNj1LBzB0cajtJnb8FJ7keDXyepw/yfybdq5smK3GzcuRGaMP1q1grIGvV1fx2UJToboibJLzv7/2h8uaSlhY8lGp16fimrbyK1oRiHAOamhTmunM42GRupbyhGQiPUMl8fsblwe94T+NCUtwpdnLskkyFtDSX07K787NNRd6sroS2Vvd0MzbH0NgPMSz+OOsXegVqjZXbWbj44+i1moZV1OFa//crTbKq4tVWtUpINTVC0meWIB3SL0Pxb8yHsH3+PLw186ts3TEFsd/ZqDldS3GvFQKRxfGtFfoifIx9sZTprVCeJIVStGswtpbHiHgH+MPGCy+d4CkyMmo1Ko0Kl0WMTBKR/abR1cZkT7OcWNYUbkDJ456xmuTrtartm2WdWNPv2s6noixMeDSH8touRC9bMRY8lUeIJoZn/Rz92ebjO1yRPnU3Bk2XxEjs6nhvsQ6OW4hcXsGll3YnTQ6I563tojcraLoDgtLRDdDH+8PVQsHi27AHy8o3gYROmvkm8PfQvtDQCkBcrjw7z6PKddn2zR+YnxgQQ48LzRFwaLgWmaYDJEFbrI8W4xzWGCe0J/GhMb5Mmzl2SiUAgcrmqhtKF9qLvUgVIFZ90rnyjyfoCyPQDMip7FE9OfIFAbSKOpimrdu5gUlXy7r5x3jlPvP2gTxHP0JLH2iCzQ5eEDftH2h9vN7faV2FnRsxzb5mlIqnXyWFIvH3eZ0f5oVO5TzlAS5uuBr1aJSRQ5Ut0y1N3pij3tvsMeKFgXzL8W/IsHJj/QURvsZHYU1AEw0UnZJEqFkni/eEYGjoSjP8kZSl4hEH/mnFNs9nX7XWVCr9KQGToWgH0VO/gw50MqWyvtT68rWsdDvzzEyu0rT7qJzVZRK0dbTvlofIjzjSMjpFO6va12fsRs2RvdjRsX5MJxUaiVAocqmskudfFa+qjxcqTebICDXwIQ4xODp8oTvUVPYXOhw5uUJIkNuVUALBgkMTyQ7VR/J/rwsNnLXT8/jHCPrk9zQn21ZETJE6ttR2uHuDfHEZYOaefL97f90/5won8i9417AqU5AkkwkJ58DIDPd5fy8Y5iAEwW0e5xnu7w+vlOdnWdViZ3VOzAKBqJ8IpghN8Ix7Z5GpIU6o2i08LuoKrbu+kRQRBICpbTb3NcRWXcRg919AA6lW7QutDUyY1hQryTy0M6W9WNOl9e5DxDyIhysTp6IClpMVoEWvR1fH3ka1kUykqQVk6Rz6vPo6K1YsD7rm812lX9HZluDzAvbh7PznqWebHz5Afa6uCw1bpu9KUObcuNG0cS6KVh/ig5nfvjnUVD3JsTIAgdtfTZn4NJj0JQ2J1KcutyHd5kSX07DW0m1EphcPWHzMaOslN3/fywwT2hPwOYarXH2epqE3qACcvlE2VVDrRU2x+ub1bhY5qJRqkgwFfPTWfJtYEfbiviyz2lHK5qwWSR8NWpiA5w8IC/F0G8X0p+AeCsqLMGTZhkOKPTKIkL8rL/PyT18266kRQs/15sC2IuQ8RYOUW4oajLucBGTXsN7WbnZhntLqxHlCAuyDluDEcajvCPrH+wuXQzVB6A6kNWq7pzT/zi0wibMN6xmlYa211DtFUVO50bLZ5MMFpYGDSGYF1HJH161HS74NzJ1MtuOVqLJEFymLfTXD7s16SDX8llY6GjIHy0U9py48ZRXDI+CoVCYG9xo30x1WUZMQd8ImRL49zvAOzCrYfqHF/WaisrHRnug9oJ5V+9UVO8BdFiAM/ALlmqblybYTOhr6ur4+qrr8bX1xd/f39uvPFGWlr6ThmdM2cOgiB0+bvtttsGqceuw5QEeUJ/qKKZhjbn+uwOGM9ACE2X7xdusj98rKYVD0scC4Mf4o9T/8gFY6O4ZmosAG/+eoy3fpWj9qMifB0/ua7uLohXr6+31yrOjJrZ06vc9MBIqzBepL+WCL/Bi7S66Z3kEPl7OFje5Fp1i1pfCLb6sh8Xpf/7nr9zx7o7nC6OZ1P/nxQf6JT976/Zz08lP7GtYhtkW9Oik+bJ58EzCH9PDbFBcqaIy9TR6/yZFTeX35u9+E3uZhKPK4mdGD4RgF2Vu3p4cd/Y7OocqW4PsmOCwWLoeMBslCf0ABnu6Lwb1yfUV8vZI0MA7BmYLotCCZmXyff3/RdES5cJvaOvp7asHofrRPWB0WLk/3b8md9ommkOG+Wunx9GDJsJ/dVXX82BAwdYs2YN33zzDRs3buSWW2454etuvvlmysvL7X/PPvvsIPTWtQjx8SAxxAtJgu3H6oa6O92x+YR3mtAX1rYhoGJUaIz9scsmxnDJ+ChAXpwAJ9jV6ZugwXpRCUm1P7y5bDMiIikBKYR5uWsS+8tMa73oglFulVRXIT5Qi0oh0NBmoqrZcOIXDCaxU+TbI+u7PBzqKav7OtO+ThQlu9qys7JJ8utlhftkXRgcPXOs6noi0wXT7pn9IISNlsVav71XzhaxMjFMntDn1uXSaOh/nxvbTey3vsfpSY5Nt39x94v85sffkFWVJT9weK2syeAdKkcT3bgZBlw6IRpBkMenBVYrYpdl5BJ58bm5HI7+RJJ/EmqFGl+NL60mx/bdaTpRfVDcXIxobEUtgXfkhEFr182pMyyK9nJycvjhhx/YsWMHEyfKF9WXX36ZJUuW8NxzzxEZGdnraz09PQkPd/5kwmKxYDINXeqgKIqYTCb0ej0KRfd1mrNG+NHY0kZ2UQ2zEv0Hv4N9ETEZtJ9DTTE01YHGk/rmFoJ1AnH+KvR6PaIkohAUXD4+HCwmfs6TU3JTg3Xo9Q605CvPBW0oeIWBoAXrvtv17YSoQ5gdPtux7TmREx0TA0GtVqNUDlyUbEyMP5/+dhqaQUwXc9M3aqWCxBBv8ipbOFjeRJivc1KAT4qk+bDrXSjZIdcCWyPXUyKm8Fn+Z+yt3ku7ud0pdfWHKpppMZjx9lCR5oQBlCRJHRP66mOyon/k2DPCqq4nMqP9+WZfOftKGoa6Kx1oPGHxX+Cbe6AmH75ZAee/DL4RBOuCifeNp6CpgN2Vuzk79ux+7XLb0VpECRKCvRyapdRsbOZY4zEkJOJ842RNBptVXfrFcjTRjZthQHSAJ9MTg9l0uIZPdhVz/8LUE79oqFBrYdSFsPs9OPQt6qS5vLHgDYdfk2paDFQ2GVAIkBbh49B990VBw1Ew64mTlAiR4watXTenzrCY0G/ZsgV/f3/7ZB5g3rx5KBQKtm3bxkUXXdTraz/88EM++OADwsPDOe+883jkkUfw9Ozdk9VgMGAwdEStmprklBdRFBHF7jZPkiRRWVlJQ0PDSbwzxyKKIs3NPdcgpXhKLM/wRBDMHDl6FFdLohHS7wLRjFRQCCoPFsQqIMYTT0MdWYcKMYtmArQBCAhMDoZ0X18kSUJqruSoI8uujCCk3Q5qLdLRo/aHRzGKtJg0MMDRTo+7On0dEwPF39+fsLCwAZc4qBWC/F25Unr3GYooikiSREqYN7mVzeSUNTI72bFpwKeEbxRCSCpUH0LKXwMZywCI9oomzDOMytZKdlfsZlrkNIc3vf1YLRISY2P8EJAQRccer1VtVTQaGlEKAgnHNiMBUvql0MN1ZbCxHRc9XeOcxagIbwCK69uobmonyNtj0NruE7UXLHoW4ZvfQUMhfLsC6byXwDOIiWETKWwqpLy1vN+f1abDNUhITEsMdOjnm12djSRJRPtE46fxQyzZiVB3FFRapJFLTvm4Gopjwo1r48xj4tIJUfx6uJpf8mu4clIrkf4uXKKXsghh93tQthupsQwPn3CHfyYHShvkxbogL7QqxaD9DgvKtoMkEqf0QfSLOeF5xH2ecD79/WyHxYS+oqKC0NDQLo+pVCoCAwOpqOhdcfaqq64iLi6OyMhI9u3bxwMPPEBubi6ff/55r69ZuXIlTzzxRLfHq6ure4zMNjc3YzAYCA0NRavVDplYmu0HpVAoeuyDJEloW0yYLRJBXmp0ateKmCoMPgiGJiS1FwZNIEKzEYUgEOGrptpQjVk0E+QRhIfSOYM+i2TBIlnQtDegtHgiagOQNIO3KuoMTnRMDGQ/er2eqqoqWltb8fEZ3p/LmYwoijQ2NhKhVWA2mckqqKEqzbW+T23oFLzK92PO/obGsNn2x0d5jaKksYQNRzeQqEp0eLub88oxm8wk+yuoqqpy+P531OzAZDYRI2kQWusxeIbSoEsEJ7Q1UGzHhSRJp5zNMxAifZQU1un55UAR0xMGr060PygmP4DvTw+jrCvE/PmdNM1+knGe4xg7Ziw+ap9+HSOtRgs7j9VgESVG+uHQ42pr4VZMZhPxHvFUVVXhs/09NGYT+th5tDa2A6cmIDlUx4Qb18WZx4Q3kB6iZW9ZC+/9ksdvpkQ4dP+ORYlvwEjU1dm07fqU9lFyXb1FtDjMXnV7fgVmk5k4X+dcj3ojvzwLSZII94qlqrrmhNu7zxPOp79BuSGd0D/44IM888wzfW6Tk5Nz0vvvXGOfkZFBREQEc+fO5ciRIyQm9jwgfOihh1ixYoX9/6amJmJiYggJCcHXt2sapsVioa6ujvDwcIKCHFsbdzKYTCbUanWvzxsEAw1tJkSlGm9XiYbY8FBCYysIJkQPLSqDgFatxMdHh0FloMnYhKSR8NZ5O6X5RkMjzYZmLCoDKpWARm1BozKiVqhRCko81b1ndbgyJzom+ouPjw8KhXxhCQoKOqn0ezdDjyiKCILAlAhfXt9eRUWrBR//IHQaF/o+fS9COPgBqpYiPNTtEBAHwFzNXDZUbSC3JRe/ID+HLu7VtBgoa7GgVqs5OzMeP92p/2a6tVFdg1qlJr2pAZVKjXL8lYSGuYa2hO24CAkJGdRB2eTENkqbSyluFbot2g89oXDxKwj/uwtVawXanX8lZOkq0PT/GrQhtwpBqWREsCdjk2NO/IIBcOzgMdQqNVPjpxLqYUSo2QdqDV5Tr8PL79Q/y6E6Jty4Ls4+Jq47S8sDn+9ne0krN87xI8THxcapnRlzMcJPufhWbEJ31i08vf0vHGk8wstnv4yvx6mXbBU1laFSq5iUHElo6OBk0YmSSLmhEkEQGJkws1/nZPd5wvlotf0rixzSCf29997L8uXL+9xmxIgRhIeHd1uhMpvN9sl0f5kyRRZcOnz4cK8Teg8PDzw8up9EFApFt4PVaDQiCAJeXl5DbmMmSZK9D731xdtDRUObiTajGfAY8j53Qa0DQQmSBYuxDVDioZIjy55qT5qMTXbLKmf028/DD0k0U2dqxYyAWTTSZuhwBNCpdER4RbjWZ3YC+nNMDATbcW6xWByySOBmaBAEgRBfHaE+WqqbDRyubmVMjP9Qd6sDzwCInQqFmxCOrIXJNwOQFJBEiGcI1e3V7KvZx5SIKQ5rcndRAwICI8N9CPByziCyydQEpjZGtjUhqP0Q0s4FFxoACYLQ43XOmYyN9efLrDL2lza65mDQLwqWPg//uwtq8hBW/xGWPAcqDUaLEY1S0+tLJUnil/xaBARmJAc79P1Vt1VT2VaJQlCQHpyOIudb+YnI8QjWBTBHMBTHhBvXxpnHRHqUP5nR/uwvaeTLrDJune34TCyHMWIObHoRmivQVOfQbGrGKBrJb8xnUvikU9p1i8FMYV0bAgIZUf6D9vurai5Hb2xBjUBU/Ox+t+s+TziX/n6uQzqhDwkJISQk5ITbTZs2jYaGBnbt2sWECbLq4vr16xFF0T5J7w9ZWVkAREQ4NpVnuEzytGolSoWARZRoN1nw1LhQxYUggIc36BsRjK2ALxqVfBDrVDoEBEyiCZNo6nMQdfLNCwQoNPhJSoxKDUZdMEbRiNFixCSa8FR5Dpvv2Vmc6e//dCM13IfqZgOHKppca0IPkDxPdr3IXwMTbwRr2chlIy9DKSjtnuCOYofdrs456vYAd4+/m5vKi1CKO2DUYtB4Oa2t4cKoCD8UAlQ2Gahs0ruWQKONgDhYskqe1JfvpW7Dk7zkqaC0tYzX5r3Wa4rtJztL7K4JZyWdeJwzEGwWqkkBSXL2mM3mMWayQ9tx42awuWxiDPtLGvnxQAWXT4rB39Px4z2HoNZC4jlw6BvI/Z7UoFRKWko4VHfolCf0uRVNSBJE+GkJ8Bq8969uLOVCsxqDUoMyKGXQ2nXjGIbFckpaWhqLFi3i5ptvZvv27WzatIk777yTK664wq5wX1paSmpqKtu3bwfgyJEjPPnkk+zatYuCggK+/vprrrvuOmbNmkVmZuZQvp0hQxAEvDzkSXyrwXKCrYcAayqj2ixbf3hYJ/QKQYFWJQ/0bFF6R9FuaqfV1CoLtpn1KACt2gtfD1+CdcFEekcS5xuHv9bfoe26cTPU2JTcc8odqSrpIOJmyBPelkqo2Gd/eFb0LGZEzXBoCYzRLLK3uAGAiU7ynwegoRivkp1oBQWMvth57QwjdBolKWGyhoNL2dcdT3ASzH8SFEr8jv5McfkOmoxN5Nbn9rj5V1mlvL+1EIDfzIwnNsixJVvpwelcN+o6FsQtAIsJyrPkJ6Im9vk6N25cnTHRfiSHeWOySPxvb9lQd6dvRi6Wb4/+zEjfeED2oz9VDpTZ/OcHz64OIKiugCstWpZHzHSp7DE3/WPYfGMffvghqampzJ07lyVLljBz5kxef/11+/Mmk4nc3Fza2toA0Gg0rF27lgULFpCamsq9997LJZdcwv/+97+hegsugZfGNqE3u57quNoLCQGFZEaNyR6hB/BUyQOiNlObw5qTJIlafS0VrRU0GZvAZF0sULtglMiNGwdjs8LJrWh2uKL7KaPy6PDRzl/t1KayyxoxmEUCvDSMCHZi1Dz7M/k2djr4RTuvnWFGZrQshrfflezreiJ6AsxcgRKB8Y3VoG9iR8WObpv9eKCCf/1yDIArJ8dy0TjHf9ehnqEsHbGUs6LPgqqD8rVL5w+BIxzelhs3g4kgCFxs/c2szanC4mrXps6EpYN/DJj1pDbJAnLHGo9hsBhO8MK+OWid0KdHDrJQaFmWfBsxZnDbdeMQhs2EPjAwkH//+980NzfT2NjIW2+9hbd3hzhNfHw8kiQxZ84cAGJiYvj555+pra1Fr9eTn5/Ps88+203Y7kzDU6Pkgbtv46ZrL8doHlybiXfeeQd/f//eN1AoMCs9sVgsvPnyc6SPGoVOpyMwMJAFsxbwxQdfoFM7zspEb9ZjsBgQBAFvtReYrSdhJ3hcu3HjasQHeaFRKWgxmCltcGzmi0NIni/fHv0JzB16FnX6Or7I/4JP8z51SDM7C+oAmBgX4LSykld3/o3H8z/igGCGjEuc0sZwJTPaH4C9JY2ut8h8PGnnwtirmCSqobmCnQXruvT557xqXtlwGICLxkWRnlDPDwU/YBGdmBFnS7ePHO+Oqrk5LZicEIiPVkVdq5HdRfVD3Z3eEQRIkaP0IUd/IVAbiEWycLj+8Env0mgWyauUs+YGNUIvimRX7KAR0T2hH6a4z/5nGAqFgMp60W8xmoe4N90xKnU89fyr/POfr/Pkk09y8OBBNmzYwK233gp68Pfwd1hbDYYGAHzUPigtJkCUhfmUbsE3N6c/KqWClDB5UTSnvGmIe9MD4WPAOwyMrXI9vZWqtir+k/sfvj36LSbRdEpNSJJkr5+f6MT6+f3FG8nBgOAXJU+83NhJjfBBpRSoazVS1tjdGtblmHQzmXFnowaqqg9QXCqX+W09Wsvzq3ORJFg0OpxrpkbxzoF3eDv7bV7b9xqi5JgF9KyqLDYUbaBOLy9EUWKd0Ee70+3dnB5oVArOSZUV1lcf6N2a2iVIWQiCAqHyAKmecgnwqaTdH65qwWSR8PdUE+k3eNmiTVX7eVKq5haPVtr9HevI4WZwcE/oHYgkSehNliH5G0hkQ62Uo1C2Ovo5c+Zw11138fvf/57AwEDCw8N5/PHHu7xGEAT+8Y9/sHjxYnQ6HSNGjODTTzsiZD/99BOCINDQ0GB/LCsrC0EQKCgo4KeffuKGG26gsbERQRAQBKFbGwBt6Ph29QZuW34Fyy6+kISEBMaMGcONN97IfffdZ99OFEVWrlxJQkICOp2OMWPGdOkPwHfffUdKSgo6nY6zzz6bd955x95Ho8XIyj+v5LyzzpMXCczyQPKFNz4kPiGhy37+9a9/kZaWhlarJTU1lVdffdX+XEFBAYIg8Pnnn3P22Wfj6enJmDFj2LJlS5d9bNq0iTlz5uDp6UlAQAALFy6kvr6+3+/FjRtnYKujP1ThgnX0CkVHlD5/jf3hlIAUAjwCaDO3caDmwCk1Udaop6JRj1IhMC7GORP62tZq6ppLUAAjMq6Sozpu7HiolKSGy8fhPquWgUujUKA951EytaEgWdjx82Psyy/gmR8OIUpw9sgQfjs7EY1Kw/Xp1wOwsWQjb+5/0yEZCD8W/Mhr+15jU+kmMLTIKfcAURNOed9u3LgKC0bJDlbbC+ppaDOeYOshxCsYomURvLHt7UwIm0CUT9RJ7+5AmawlMirCd1CFiI/u+xCAcG0gugFYc7pxHVxI5nz4YzCLLHtty4k3dAL/vXUqqn7+9pUKAQEwmCyYLHLU4N1332XFihVs27aNLVu2sHz5cmbMmMH8+fPtr3vkkUf4y1/+wosvvsj777/PFVdcwf79+0lLSzthm9OnT+eFF17g0UcfJTdXFhLqXDJhw2ARCAkN5adft1FdWkhIbFKX582iGYPZwAt/fYEPPviA1157jeTkZDZu3Mg111xDSEgIs2fPpri4mIsvvpg77riDW265hZ07d3Lvvffa92OLzisEBWqlGkzWyJCy60/iww8/5NFHH+Xvf/8748aNY8+ePdx88814eXlx/fXX27d7+OGHee6550hOTubhhx/myiuv5PDhw6hUKrKyspg7dy6/+c1vePHFF1GpVGzYsAGLRV5QWblyZZ/vxY0bZ2GbSB2qcMEIPUDSPNjzARRvhfYG0PmjEBRMCp/E6sLVbCnfwtjQsSe9e1u6/egoX3SantXKT5W8Ax+BxUSMQovWmp7ppitjov3ILm1kb0kjizMc60LjFFQeTJp0J7s2/4lt7VX4ffZ78Pwt05PC+b+5SSgU8sV4bOhY7hp3Fy/veZm1RWvRKDVcN+q6kx6om0UzB2vlCfzo4NFQvhckUdZk8Om/ha8bN65ObJAnKWE+5FU2sy6niksmuLDuSOoSKN7G7PI8Zp/931MqfTlYPgSCeE1lrCv5CQQYHTNr8Np141DcEfozEIUg2AccrQY57T4zM5PHHnuM5ORkrrvuOiZOnMi6deu6vG7ZsmXcdNNNpKSk8OSTTzJx4kRefvnlfrWp0Wjw8/NDEATCw8MJDw/vNqGXJAmDReTRx5+gpq6e8ISRZGZmctttt/H9998jSiJFzUUU1hfy9NNP89Zbb7Fw4UJGjBjB8uXLueaaa/jnP/8JwD/+8Q8SExNZtWoVI0eO5Oqrr2b58uWAPChqMbYAoBSUIEkdgniKrun2jz32GKtWreLiiy8mISGBiy++mHvuucfejo377ruPpUuXkpKSwhNPPEFhYSGHD8t1VM8++ywTJ07k1VdfZcyYMaSnp3PnnXcSHByMwWA44Xtx48ZZjAyXhfGK69pp1p9a+rpTCEyA4BQQLXBkvf3hqRFTAdhZsROzePKlQzusE/pJzlK3bypj34GPARgVMUUW+3PTjQyrMN7e4oZB13Y5WcbHzGJCwkImtWuJNR3mAeUH3Dk7jCe2PMa6oo5r54yoGdyaeSsA3x37jo9zPz7pNo80HEFv0eOj9iHONw5Kd8pPuKPzbk5DFqSHAbDmYKVr62vETgcPH2ithpLuQpn9RRQle/nbqIjBm9BXbP8HOwQjaLxYlHnDoLXrxrG4I/QOxEOl4JPbpg1J2xqlYI/49gdVDxP6zkRERFBVVdXlsWnTpnX7Pysr6yR62zNmUUIUJWJTx7B/49fs3nuQTQeK2PjLr5x33nksX76cP73wJ3KPym4GnbMHAIxGI+PGjQMgJyeHKVOm9Nh/i2hBo9agUqjkSElzBYhGQNGlfr61tZUjR45w4403cvPNN3f002zGz6+r+mjnzy8iQo4wVVVVkZqaSlZWFsuWLevxPR8+fPiE78WNG2fhp1MT6a+lrEFPbkWzc23bTpbkBVCTJ6fdW+3eUgNT8dX40mRs4mDtQTJDBm5F2m60kF0qD56c8r5FC9K6P5MltYNGx9gxyx3fxmlCargvwd4aalqMbDpcw9nW+llXxs/Dj/tnPsnj2SNJF15htHSI1/53Nfme3lS0VTAlfAre1tTVs2PPxigaeSv7Lb44/AWB2kAWxC/otk+zaOZww2EKGgvwUnsxI2oGCqEj7rKvRrZwTA9Olx+3CeK5J/RuTkNmJYfwr1+OUtrQTk5586DbuPUblUYuD8v+HPK+pyYkkc1lmzlvxHkDysYpqmuj1WBBq1YwImSQ0t4bivmucDWSAsZGTSfGx10/P1xxT+gdiCAIaNXOSds8EQNdvVRaJ/TtJgsSoFZ3jUwLgoAo9j9SorCmGHXuh8k0sIifLTKjUHmgRMukcaOZNHshv7tnBR988AHXXnstd953J22tsnXdt99+S1RU11olD48TR8A8VB6Eeofi6+GLZDGBoREQwDcCk7ljUaSlRY7iv/HGG90WB5TKrt9z58/PdgK3fX46Xe+q+bY2Tva9uHFzqqRF+FLWoCenvMk1J/RJc2Hrq3KtcEMx+MegVCiZHD6ZtUVr2Va+7aQm9HuK6rGIEhF+WqL8neBssftdiqr3UacR0PjFkh482vFtnCYoFQKLR0fw/tZCvtlXPiwm9CBrMOw2J9DkcwfjvF5nR3s5apOO+yf+3j6Zt7EwfiFGi5FfSn9hSuhEKNiEFBBPiUIkuyabfdX7OFh7EL1FLv/SKrWyLZ2VF3e/yK5KeQKfEZwBLdVQXyhrMkS6F3/dnH7oNEpmJoWwNqeS1QcrXHdCDzByCWR/jr7gF+43HqVNNJDgm0BGSEa/d2Hznx8Z7mMfozublh2v85PCCBpvzk2/blDadOMc3Cn3ZygKQUCjUiBJ9NuDeuvWrd3+t9XPh4SEAFBeXm5//vjovUaj6TOLwGCd0GtUSvCwDoYM8oR31KhRAIh6kaSRSWg8NBQUFpCUlNTlLyZGXl1MS0tj+/btvfZfAMJ8dVRUVCJJgE8EePh06XNYWBiRkZEcPXq0WzsJxwnn9UVmZma38gUbo0aNwsPDg6Kiol7fixs3zqSjjt4FhfEAPAPtokMc7hDHmxIxBa1Si1JxcouoOwtlUUqnpNuX74Pd7yMCE6OmMzFquqzV4aZXFqSHoVQI5FU2c7jKRY/F49hf0oBZaCA7eAPfBgSBQsXteoGRG1+A5spu258Xv5ingmfi9+Xt8OMf+NuXl3Hf+rt558A77K7aLafTa3yYEDaBSeGTury2rKUMg8WASqGSdSNs0fngkaB14YmOGzenwPxRctr9r/k1tLmgM5Od4GQISkJrMTNbI4+HP8v/bEC7sAniDZr/fN1R8gvWY0EiLjRD1uVwM2xxR+jPYHy0KmpbjJgsIpZ+TOo/+eQTJk6cyMyZM/nwww/Zvn07b775JoB9Avr444/z1FNPkZeXx6pVq7q8Pj4+npaWFtatW8eYMWPw9PTE09PT/rzRGh2/5fqrOHvGFKZnxBMeFsaxhkM89Ic/kJKSQkZ6BmXtZdx0503cu+JekGDmzJk0NjayadMmfH19uf7667nttttYtWoV999/PzfddBO7du3inXfeAUAULdBazZzJo6murePZ1z/m0iuv4Ycf3uf777/H17djcPTEE09w11134efnx6JFizAYDOzcuZP6+npWrFjRr8/5oYceIiMjg9tvv53bbrsNjUbDhg0bWLZsGcHBwdx3333cc889iKLY43tx48aZpEXIdfR5lc1YRGnQIgMDInkBFG+T0+7HLweFgvSgdN5Y8AYapWbAu7OIkr1+3uF2dYZm2PAUSCIJKUu4/+w/OHb/pyn+nhpmJgXzc1413+6r4O55PkPdpROyr6SRWt37+KmMoPLg8rG3MT17NTSWwNd3wtJV4B8LFjPkr4bd76Fu7lj0jjOZ2d1URuqIhWSGTyAjOIM437guafY27plwDzXtNfh5+BGsC+6Y0Lvt6tycxqRF+BAdoKOkvp2NeTUsGu3C4o8jF8PmlzmvoY41OhU5dTnk1OaQFnRi4WhJkjoE8Qarfn7n24wTVbwasZC6ybcNqqq+G8fjjtCfwfh7atCqlUgStBrNJ5zUP/HEE/znP/8hMzOT9957j48++sgeOVer1Xz00UccOnSIzMxMnnnmGf785z93ef306dO57bbbuPzyywkJCeHZZ5/t8rwtQr9gwQL+9/2PnHft7aRMXcD1y68nNTWV1atXo1ar0al03PPwPax4cAUrV64kLS2NRYsW8e2339oj57GxsXz22Wd8+eWXjBkzhtdee41H//QoAKVNhUjtdaSlJPLqC8/xyutvMmbMGLZv397FGg/gpptu4l//+hdvv/02GRkZzJ49m3feeWdAEfqUlBRWr17N3r17mTx5MtOmTeOrr75CpZLX05588kkeeeSRXt+LGzfOJCbAE51Gid4kUlDbOtTd6Zn4maDxgqYyOPglAEqF8qQm8wAbDlXR0GbCT6d2bDREkuDXv8m6HL6RMON3jtv3GcDSTFl/5Oe8KppcUaSxE5Iksb+0EU/TBHRqJbOjZ3NR5o1wwSvyJL6lCr66E7I+go+vgZ+fgeZy0AXAtDvgig9Z7BHJW20a/tjYxvkJS0nwS+hxMg8Q7hXO6ODRco2rJHUSxBs/iO/ajZvBRRAEe5R+zcHuWS8uRdI8UKgIqj3KnOAxAHye/3m/XlrVbKC2xYhCIdjFap1KTT4c2wiCgN/k20jwc483hzuC5NLSkUNPU1MTfn5+NDY2doncAuj1eo4dO0ZCQgJarXaIeigjSRJmsxmVSjWgVTazRaS4vh2zRcTLQ0WEn7bH1wuCwBdffMGFF17owF53IEoSR6tbkSSJ+GAv1EqFPPjRN4LSQx4gWVNrW4wtVLZVolFq+i3gIUkSn37/KZctvYxj+VuJ9/MD7zB5cHWacrLHRG+40vHu5uQQRZGqqipCQ0PtuhcAj32Vze6iBm6dPYJzMyOHsId9kP05bHoRVFpY9rY8YbZS3FxMlHdUr5OhzpgtIrd9sIvKJgM3zIjn4vEOtEPKWy1H5wUFFQueQAgZSZhXmOP27yR6Oy4GG0mSuPs/WRyraXX8d+NgimrbuOPfu1Gr4JnLI0kKSOg4z7bXw3e/l8UcbegCYOxVkHY+qK3nz+pc+Pr/wGyA0ZfAjLv613jdMfhkueyacP03siiXg3GVY8KN6zBUx0RDm5Hr396BKEr8/apxxAV5DVrbA+a730PxNqrGXs7d1b8gIvLUjKdICkjq82UbDlXx/Jo8ksO8ef6ysU7vpvT9g5QU/0pM4kKY+8hJ78d9nnA+fc1DO+P+9M9wVEqFfRLfajBT22ockn6YzCKSJKFQCHYFfjyDZRs5iwGaSmW/XUCn0hHqGUqEV//9ivVmPUazLDbkgwBeoaf1ZN6Nm4GQak3xy3XVOnqAURdC5Fgw6+VopyifMx7Z9Aj3/XwfefV5J9oDAOsPVVHZZMDfU80SR3qeN5XJ0XmAiTfwReMB7tpwV78jNG7khWNblP67/RX91ncZCvaVNgAwKsKP5MARXRdNdQFw3gsQMwW8QmDq7XDlR5B5WcdkHiBkJJz9sHw/+zM4+FX/Grel24dnOmUy78aNK+HvqWFqgqx14vJR+gRZyDK0dK9d1PLzwye+BtjS7Qelfr7yIAeLf+E+dSvPaAyubQnopt+4J/Ru0KqVhPnKiur1rUaa2gc/1dGWbu+hVHQMjJRq8IsCQQmmNjmNVZJQKpT4aHxQKfovAdHQXgOSXKOv9AyWhbbcuHEDyEr3gN0D1yVRKGD2A3KEviwLDn6JIAiEecoR8C1lW064C5NF5L87iwG4ZHy041xJJAnWPyWfpyIykcZczd6qvQAk+yc7po0zhNkpIXh5KKls0rO7qH6ou9Mr+0tkAavMKP+eN9B4wZJn4ZpPYczloO7FSWHEbJh0k3z/1xegZNeJGy9x+8+7ObOwpd2vP1Rld0RySeJmyM4T1Ye4MOIsdCodUd5RiFLffT5YNoj18zvf4hulAbS+BPsnuGvnTxPcE3o3APho1QR6ySv9Vc0G2k1d1eglSXJauj2A0WJVuFcfd0iqtNbUWgEMTdBa3a1fJzpRGk3ttBmbmTpjEsaGYvwj3LVCbtx0JiXMG0GAyiYDVc36oe5O7/hGwpRb5fvbX4emMmZEzQBgfdF6qtuq+3gxrMvpiM47VFyp/hhUZoNSA2c/TGFLMfWGerRKLamBqY5r5wxAq1YyL00evH+zr7zPbduNFsoa2gejW10QRYl91gl9RrQDImrjrpF9rCUR1j4GDUW9b2sxQ3mWfN8tiOfmDGF8bABB3hqa9Wa2H6sb6u70jmcghMlq8ZHVebw27zWuTru6z3KwJr2JojrZjtnp1nzleykr2cpuhQXBM5glCUuc256bQcM9oXdjJ9BLg5eHCkmSKG9ox2QZvFVQY6cIfTc0XuBjHXy318n1ichp9KUtpX0P4iWJhuZSQMJLUKH2iZBXT924cWPHU6Oyp/p9d4JJ1JAz6kKIGAOmdvj5GcYGZZIWmIZRNPLewfd6fVnn6PylExwYnQco3iHfRo4Dn3D2VO0BID043W1XdxIstpZC7C6qp6Kx5wWmwtpWbv9wF7/9YBeHq1oGs3sU1LbSYjCjVStIDvU+8QtOhCDArN/LEwFDM/zwELQ39Lxt1UH52Nf6QmDiqbftxs0wQKEQmJsaCsCagxVD3JsTkDBLvj32C1rVifWGcqzR+ZhAHX46J18vdr7Ft0ojaP2YEDmVCG8Hlp25GVLcE3o3dgRBINxXi0alwCJKVDTqEQeptqaLB31PaP3kmnqQ1YMNLQgIGCwGWkwtGC091/5LbXVIFiMg4O8dYRfWc+PGTVcuHCuLzH2fXUG70XKCrYeQ41LvhUNf85vRv0GBgu0V28mqyurxZetyKqluNhDgpXG89VHxNvk2ZjKAfUI/PtStQH4yRPnrGB/rjyTBd/u7LzDtLW7g/k/3UdNiRJTgf3vLBrV/+0s7/KJVPS1CnwwqDSx4UhZrbSyBr+6Qxe+Ox1Y/HzVB/i24cXOGMH+UfN7eU9xAdbNhiHvTB/Fy7Tzle2VhZ+Bg7UH+d+R/PW5uy/Zxerp9cwVNZbv5WWkCzyCWjljq3PbcDCruq4GbLigUApH+OpQKAb3JQuMg1NNbRBGzLeVe1cch6RkEWn9AgqYyPCQRL7Wsdlqv76HW0tiG0FZDGApiPcPQegySt6cbN8OQSfGBRPnraDNaWO3qERC/KJhyi3x/2z+JRc3ihMUAvJX9FiZL1/OW0Szy8Y6O6LxHbwuHJ4OpHSr2yfdjJtNibCG/Ph+AsaFjHdfOGYZNsHDNwUoM5o4Fpg2Hqnjs6wO0Gy3EBnoC8Et+9aBcq2zY0+2jHCxg5Rko1937RMiT+i9vh4JNXbex29W50+3dnFmE+2kZFeGLJOHaafe+ERCUJJfQFG6hpLmEJ7Y8wb9z/k1Fa9dra5vRzNocWehvUryTtZ0KN7NWacKk1jIiMIW0wDTntudmUHFP6N10Q61UEOwti+TVtRoxi85Nvbel26uVCpSKPtLhBUGOXmi8AREaSwgQJUDqHqW3mKG5DJDAww+1zi2C58ZNXygUAheOk6P0X2WVYXFhhXEARl3UJfX+0qSLCfAIQKfS0WBo6LLp2pxKalqMBHppWJTu4Oh8+V6wmOSyIL8Y9lXvQ0Qk2juaYF2wY9s6g5gUH0iojwctBjO/5NUgSRL/2V7E82vysIgSZyUH87fLx5IY4oXJIrF2kNSvRVEi2xqhz3RE/fzxBMTDRa/Jjg6mNlj9MOx+XxZeNLZC5UF5O7cgnpszkElWtfudhS48oQeInynfFvxCtE80Y0PGIiLy1eGuThbf76+gzWghOkDn9Am9VLiZjQojaLxZOmKpWwzvNMM9oXfTIz5aFR4qJaIoUe9kK7uOdPt+HI6CIEcvVDqQLHjoG/Eym8Biot5WSy9J0FxOu2jGpFSDd6i7bt6Nm35wdmoofjo11c0GNh2uGeru9E2X1Ps9eOZ+z2PTHmPlWSsJ8Qyxb2Y0d9TOL5sY3b/zzECwpdtHTwJBYFzYOO6beB+XjbzMse2cYSgUgr004pt9Zby07jAfbpPF4i4eH8V9C0aiUSnskfzvs8sHxebuaE0LbUYLnholI0IcUD/fEzp/WLIK0i+Sr2c7/gVrH4eirXLUzzdKjgK6cXOGMTFOthveV9LYJXPH5bCl3RdvB1M75444F4A91XvsmxjNIl9Zy4UuHh+Noq+A1qlibEMoz+KvJm8eGP87pkZMdV5bboYE94TeTY8IgkCwj6x639huxujEE6dxIBN6kOvg/WPlQY3akwBJAMlCi6EBY30BtFQgmVqpEkSKBIk2iwvXWrlx40J4qJR2H/DPd5e4vj9t59T7LX8n4sjGbmrCqx4Bfn8AAEYmSURBVA9WUNtiJMhbw4JRDo7OgzxgA9lzHNCpdEwKn8SUiCmOb+sMY8GocFRKgSPVrazNqUQhwG/nJHLDjAT74HdWSgjeHioqmwzsGgSbO1u6/egov74zyk4VpQpm/g5m3QcKFRz9CTY8JT8X5dZmcHNmEhfkSbC3BqNZtFtHuiRBiXLwyWKE4u0kBSQhIFCvr6fRIPd7Q24V9a3ytWnOyJAT7PAUKd0JFhNq32jGJy4ZkO2zm+GBe0Lvplc8NSq76n1Ni/Oi9HYP+oFEzgQBPHzAPxYP/3i8lFpAoNncDvpGWpEwK1QoFap+qYy6ceNGZklGBGrrJCq71IV96W2kXwyZ1mj41ldh62sYzQY+yfuE7WW7+GRnCQDLJsQ4PjrfVCbXOiuU7kmWE/DzVHNWkly24KFS8PDSUfaIvA2tWsncNFn9+tt+ODSYLSK7CutoaDu5a5ptQu+UdPueSDsPzv2bHLUXrQvrbrs6N2cogiAwMd6Wdu/8BbyTRhAgwRqlL/gVnUpHuJe8oFzYVIgoSny+W742XTg2CrWjxDV7QSrYjIQEcdPdGaunKe4J/RmGIAh9/j3++ONdtg/21iAI0Gow02Y0n3S7da1GSuvbqG7W09hupN1owSJKSJJkj9D76jR8+eWXJ9zXzz//zDnnnENgYCCenp4kj8rg/rsfJ9AjlEBdMJJCRYNKDQolfh5+ffp/unHjpit+OjXzRsk+4J/vKRni3vQDQYBpd8CU2+T/937Et9/fwae5n/DXbf+gtrWVYG8N863vyaHYovNho0HjxYaiDfw397+UNA+Dz22YcN30eM7NjOCZSzOZnNBzjemSTjZ35Y29+9JLksSL6/J5/OuDXP/Wdh75Mpu1BytpNfTv2ma2iBy0Wkw5XBCvLyIy4aLX5ePMN1Iu73Dj5gzFlna/s6DOtbPIbGn3RVvAYibWNxaAgqYCth6tpaxBj5eHkoWO1nU5HlHkSNFG7lG38LnKuSW0boYO90znDKO8vNz+98ILL+Dr69vlsfvuu6/L9hqVEl+rL2ZNs7HbydNoPPHJwWgWqW010Ga00NBmoqrJQEl9G0erWyiobUOUpH6Lcxw8eJBFixYxceJENm7cyP79+3n55ZfRemjRqrwQvEPQ+0ZhsC5Q+GrcyvZu3AyUC8ZGIQiws6Ceotq2oe5O/xh7Jcx5EAQFi8ry8GuqprC+jBb1Vi6f5IToPHRKt5ft6tYVreOz/M/Irct1fFtnKMHeHtw6O5HEPurVIzvZ3H2/v3eHhnU5VfyUW40ggChBVnEDL67L59o3t7Hyuxw2H66xLzD3xOHqFtpNFrw9VMQHeZ3S+xowPmFwwd/h8g9BM8htu3HjQoyJ8UetFKhsMlBc1/sC3pATNhp0AWBohvK9JPgmAFDeUs6nu+RF36WZkeg0TrZTrj7ENlMt5Qoods/6TluGzVf71FNPMX36dDw9PfH39+/XayRJ4tFHHyUiIgKdTse8efPIz893XiclSVZcHoq/fq5ShoeH2//8/Pxk73nr/62trVx99dWEhYXh7e3NpEmTWLt2LYGeGhQKAYPZQnx8Ak8++STXXXcdvr6+3HKLXL/6xhtvEBMTg6enJxdddBHPP/+8/XtqaDeCBBvXfM8lC2eRERfC3CmZvPzcSvQGeUHg7InpAFx00UUIgkB8fHyP/V+9ejXh4eE8++yzjB49msTERBYtWsQbb7yBTqcDoM5Qx84tO7lqyVV4e3kTExPDXXfdRWtrq30/VVVVnHfeeeh0OhISEvjwww+Jj4/nhRdeAKCgoABBEMjKyrK/pqGhAUEQ+Omnn+yPZWdns3jxYry9vQkLC+Paa6+lpqZDTGzOnDncdddd/P73vycwMJDw8PBuWRANDQ3ceuuthIWFodVqGT16NN988439+V9//ZWzzjoLnU7X43tx48bRRPnrmDoiCIAvs0qHuDcDYORiWPgUWqWW86qaCLJUIXluJSPOCVEciwnKrAJH0ZNpMjZxuOEw4LarGwp6s7mzUVzXxms/HwHgmilxvH7dBK6dGkdMoA6TRWLzkVpWfn+IG9/dQX5lc49t2O3qov2cK2DVG4Lg9p53c8ajVSvJjPYHXFztXqGA+Bny/YKNnBN7Dv+Y9w+mBl5GflULaqXAeZnOF7eUCn5lu8IEGi8mRbrF8E5Xho0qgtFoZNmyZUybNo0333yzX6959tlneemll3j33XdJSEjgkUceYeHChRw8eBCt1gl11WY9vLXI8fvtDzd8D4L6lHbR0tLCkiVLeOqpp/Dw8OC9997jvPPOIzc3l8CgcGpaDFgkieeee45HH32Uxx57DIBNmzZx22238cwzz3D++eezdu1aHnnkEUD2mG9qN7Nj62ZW3HEzL730EmeddRZHjhzhlltuwVen5vcP/ZFt27cTFRHO22+/zaJFi1Aqe16xDA8Pp7y8nI0bNzJr1qxuzzcbm8nNz+U3y37Dn/70J95/532qq6u58847ufPOO3n77bcBWL58OWVlZWzYsAG1Ws1dd91FVVXVgD6vhoYGzjnnHG666Sb+9re/0d7ezgMPPMBll13G+vXr7du9++67rFixgm3btrFlyxaWL1/OjBkzmD9/PqIosnjxYpqbm/nggw9ITEzk4MGD9vd/5MgRFi1axJ///GfeeuutHt+LGzfO4KJxUWw5UsuG3CqunRpHgJdmqLvUP+Kms3HkQ8T+/EdGq1so0JTz8s7neHLWM6iVp3aO7EJltmwrpguAoCT2lcs1irE+sQTpghzXjpt+YbO5q2o28Etejb1sBOQssWd/zMVgFhkT48elE2RF6csmxbBsYjTHalr5Oa+an/OqqW0x8vCX2TxxfjppEV0zvPY7y3/ejRs3A2JifAC7CuvZUVDPxeOjh7o7vRN/FuR8AwW/4jf9blAoeGF3NgDzR4Xj7+n862pJ4U9UCCJqDz/GhY5zentuhoZhs9T7xBNPcM8995CRkdGv7SVJ4oUXXuCPf/wjF1xwAZmZmbz33nuUlZX1q077TGTMmDHceuutjB49muTkZJ588kkSExP5+uuv8fNUy6IdEsycNYd7772XxMREEhMTefnll1m8eDH33XcfKSkp3H777SxevBiAxnYTkiTx6vN/4YEHHuD6669nxIgRzJ8/nyeffJK3/vUGvjo1keHy4Mvf35/w8HBCQnpW/Fy2bBlXXnkls2fPJiIigosuuoi///3vNDXJdY0eSg/++bd/cunll3LvintJTk5m+vTpvPTSS7z33nvo9Xry8vL4/vvveeONN5g6dSoTJkzgzTffpL19YKlbf//73xk3bhxPP/00qampjBs3jrfeeosNGzaQl5dn3y4zM5PHHnuM5ORkrrvuOiZOnMi6desAWLt2Ldu3b+fzzz9n/vz5jBgxgnPPPdf++a1cuZKrr76a3/3udz2+FzdunEVahC+p4T6YLRLf7D+x2JirUN7YzsvZHvzd+3fcqIzEz2LiWPlO3jv4nmMbsqXbR08ChYI9VXK03h2dHxoUCoHF1ij9d8cdr29tOkZBTSt+OjUr5o/sEl0XBIERId7cMCOBf1w9gdFRvrQbLTz6Vbbdbx7AZBHJKZevM4MmiOfGjZsemRgn62kcLG/qtwbGkBA5HtSe0FoD1Yc4Ut3CnqIGFIJsv+l0mivY1nQEEMiMnIpOpXN+m26GhGEToR8ox44do6Kignnz5tkf8/PzY8qUKWzZsoUrrriix9cZDAYMhg6bM9tEURRFRLFrbZ0oikiSZP9D6SFHyocClRbJLJ/U+isSYtvOdtvS0sLjjz/Od999R3l5OWazmfb2dgoLCxGAIG95JXHk6LGYLCIq66AoNzeXCy+8sEu7kyZN4ptvvqGhzQTAoYPZ7Nq+laefftq+jcViQa/X09raiqenp70vffVfoVDw1ltv8eSTT7J+/Xq2bdvG008/zTPPPMO2bduIiIjgaM5R9u3bx2cff9blvYqiyNGjR8nLy0OlUjF+/Hh7WyNHjsTf37/r93lcf45/bO/evWzYsAFv7+61nYcPHyY5ORmAjIyMLu8pIiKCyspKJEliz549REdHk5yc3OP73rt3L/v27ePDDz/s8b2kpaX1+lkd3+dTxfa+e/otuBke2M5Z/f3+LhgbSc4Ph/huXxmXjItEq3Zyrd8pIooSL6zJQ2+2EBwzkqQZq7jjfzfzF0MTvx77kQsTLyRAG+CQtgSr/7wUPQmz2cjeqr0gwdjgscPu9zHQ48JVmZcawr+3FZJX1cyh8kZSwnzYcqSWb/bJXs93z03CX6fq9X16qAQeXZrGU98dIqukgce+yubhpWmMjfEnt7wJvdmCn05NtL922H9WJ+J0OSbcOA5XOiZCfTRE+WspaWhnV2EdM61uGC6HQoUQMwWObkA6tpEXKnyo9djOlLAZhHhrnP9ZFm5mu8IMai0TI6c7vD1XOiZOV/r72Z62E/qKClkYJyysq7JxWFiY/bmeWLlyJU888US3x6urq7tFRE0mE6IoYjabMVsn06ea9n6ySGYzFotcN9hfgTnbQWLr+7333su6dev4y1/+QmJiIjqdjiuuuAKDwYDZbMZDISEIoNXpqG5qJ8Q6wbf9mO2fgXXfErIqsEoh0NrSwqOPPsqFF17YrR8qlcr+WovF0mU/vREWFsaVV17JlVdeyWOPPUZ6ejqvvvoqjz32GM3Nzdx8883ccccd3V4XGxtLTk6O/X0rjqtHtL0P22djMpns/bFF8G19bG5uZunSpV0WKWxERERgNpuRJKnL+7N9XrZ9eHh4dPkOjudE76Wvz8rWDvT/mOgL2+dSW1uLWj00x7mbU0MURRobG5Ekqdux3xMjvCUCPASqW/R8vi2feSk9q4y7Cj8eqmNvUR1atYKrxvhTg4akyHO4ofgHRjUpMDUaqGoaWGlNTwjtdQRWHgIE6jzi2HTwG+ra6vBR++Bv8h9w+c5QM9DjwpUZG6Fj87FGPtl6hIszQ3j+x2OYTRYWpwUSozP167u5dXIQLxv07Ctr4bEv93HHzGgK6/WYTWZGROiorq4ehHcytJxOx4Qbx+Bqx0RqkJqC/2/vvsOjqtIHjn/vZDKTXkgPpFETaqghgPS6CihFF1GDjR8CLkgRxAVRpKmIBRd0VWBXdxWlyIKIERJAhEivIQUCoQQSShrpM/f3R8zIhAAJJJkE3s/z5IE798y95955k7nvPeeek5bF9uPnaOxUcxNKnUtzHIt+JudYJAfzm5GnT8DF0b9avidyj2/iNEVg5UKAJqDS91nTYuJ+lJVV9pgupVk0oZ8+fToLFy68bZnY2FiCg4OrqUbw2muvMWnSJNNyZmYmfn5+eHh44ORk/jxdXl4eWVlZaLVatNqacW+kIolWyS9fSd137dpFREQEw4YNA4pb7M+cOYOiKKYyGo1SPI1dgRGHInC00RIcHMz+/fvNzsH+/ftR1eJE0sVeT5s2bUhISLjtZ1lS94qeSw8PD3x8fMjNzUWr1dK2bVtOnDhxy301a9aMoqIiDh06RPv2xdP/xMXFkZ6ejkajQavV4uNT3HUzLS3NVJ+jR4ufe7KysjK18K9Zs4aGDRvess4l0wHeuF6j0Zj2Exoayrlz5zh16hSNGze+6f13OpbyqKzkW6vVotFocHNzq5oxKESVMxqNKIqCh4dHub98H+9g5NMdp9iWdJ2/dmpimcHAyuH8tVx+OH4SrbWW0d0a0Kz+H1MBdZ9Iv2/3wvUU1Gv7IPjhe99Z/H4UrTV4NMHDrxEO6hmcbZ0Z2mgoPt5VP8hRZbubuKiphofZ8vu5wxxIyeVawRXyVYWQui6M6d2sQnM9vzXEk/d+jmd30hWW7b6Iq50OrbWWjo198fT0rMIjqBnup5gQlaOmxUS35jp+OZnJicsFuLt71NjvJlz6oRxcRv6183jatCVVr8Wgz6j6vyOFOVzKSKSXoifXvytBdYMqfRc1LSbuR+W93rZoFjp58mRGjRp12zL169e/q217exdfzF26dMmUnJUsh4aG3vJ9er3e1Gp6o5IkrPRrN87hbknqDVO/lbcupcs3atSItWvXMmjQIBRFYebMmaZfVlNZwE5XHDZpWfnYWFvx8ssv07VrVxYvXszAgQPZunUrP27ahKIoaBQFZ1trZs2axSOPPEJAQADDhg1Do9Fw6NAhjh49yttvvw1AYGAgW7dupUuXLuj1elxdb+4a++mnn3Lw4EEee+wxGjRoQF5eHv/61784duwYH3/8MYqiMG3aNDp27MjLL7/MCy+8gL29PcePHycyMpIlS5YQHBxM//79GTNmDEuXLkWr1TJx4kRsbW1Nx2pnZ0fHjh1ZuHAh9evXJzU11TTQX0mZ8ePH8/nnn/Pkk0+aRrFPTEzkm2++4fPPPzcNbHer+FAUhe7du9O1a1eGDRvG+++/T8OGDTlx4gSKotC/f/87Hktlx8TtlBxHWb8Lovao6GfYp5k3X8UkczEzn0PnM2gbUPNa6Y1GlQ+3JlJoUGnj78qAFj5/xry9G7R7DnYtQdn7OQed3TiWmcTIkJF3v8Pze4v/9euAotHQN7AvXep2QavR1trfjfvldzvYx4lGno4kpmaTkJqNnU7LtP4h6K0rdrljo9MwfUAwiyLj+TXhMpezC1BQaOXnUuvPUXndLzEhKk9NionmdV2ws9aSkVtE0pUcGnk5WrpKZbNxJN87lMxLUYRqLrPHTseZzDOggEapwvN4YT8+RiP/59QAOs0qniWjCtSkmLgflfe8WvTse3h4EBwcfNsfne7uRoAMCgrC29vbNPgYFLe2x8TEEB4eXlmHcF95//33cXV1pVOnTgwcOJB+/frRpk2bm8rZ6ayw1VlhVFUuZuQR3qkTy5Yt4/3336dVq1b89NNPvDBmPHq9Hidba6w0Cv369WPDhg38/PPPtG/fno4dO7J48WICAgJM2120aBGRkZH4+fnRunXZI3F26NCB7OxsxowZQ7NmzejWrRu7d+9m3bp1dOvWDSgehG7btm3Ex8fz0EMP0bp1a2bNmoWvr69pO8uXL8fX15du3boxZMgQRo8efdPd0i+//JKioiLatm3LxIkTTTceSvj6+rJz504MBgN9+/alRYsWTJw4EReXil3wrV69mvbt2zNixAiaNm3Kq6++auoqX55jEaIq2Vhb0eePEcM3HK6Zg+OtPXCeuItZ2OqseLlno5tvYDV7DFz8Sc29ysIdf2f9yfVsP7f97nZmNMK5PcX/r9fB9LKdtR06q1oyE8B9TFEU0xR2AON7NMTb+e56FGmtNEzp24QeTYoHaXV30FHXRQaVEqImsLbS0NrfBYA9p69ZtjJ3cEDTHFWFvxbG4GTIIq8ol7ScKn5058yu4n8DOlVZMi9qDkWtjNGyqkFycjJXr15l/fr1vPvuu+zYsQOAhg0bmgYlCw4OZv78+Tz22GMALFy4kAULFphNW3f48OEKTVuXmZmJs7MzGRkZZXa5T0pKIigoyOJdkFVVpaioCK1WWy29BYoMRpKv5mAwqjjbWePp+Ofx5xcaePrZ50lKjOe3nb9WqJujJQUGBjJx4kQmTpxo6apUisqOiZoU7+LuGI1GUlNT8fT0rNBNp/PpuYz59z4UBT59ui0+zjUnqTmfnsv4/+ynyKDyt16NTDcfbnJ2D/w4he+0BXzv5o1e78i8LvOo51jBKY9SY2HtGNA5sK//bDRWWkI9Qi3eS+te3G1c1FQFRUaWRCVSz9WWx9v53fP2jEaVLSdSCXSzq7mtgJXsfosJce9qYkxEHr/ER1sSaOTpwPtPhFq6Orc06evd9E1+n3DbcyxwyueU1opJLV8iLPTZW78p6xIkbYf0M2AoKJ4au6gADPlQlA/GouJp8UJHFs95fyOjkbh/P4ySn0nDAR+g8WtfJcdVE2PifnO7PPRGNePB73KYNWsWK1euNC2XtOBGRUXRvXt3oPg56IyMP6eZefXVV7l+/TqjR48mPT2dLl268NNPP0kyUgm0Vhq8nGy4kJ5LRk4h/1zyIY/8pT/29vZ8u3Y961b9h3mLPqg1ybwQ4tbqutjSNqB43t8fj1zk+S6V/yze3fr52EWKDCqt/JzpHXKbZxL92kNAZ4ae+ZW4/HyOaHW8v+993gh/A2d9BaYh+2O6uiLfUFYc/xepuamMbTWWbn7d7vFIRGXRaTVM6nPzmCR3S6NRbn2jSAhhMe0Cih/NTEjNJj2noFrmda+o05evk3C1iFOOf6Nrj0z89y3klCGDM3uXEZYSB+HjwOWPG48Z54uT+FPRkHbizhtPi4OLh6Hn38Hmhu+xtBN8V5TGEZ2RZwsu0b9KjkzUJLUmoV+xYgUrVqy4bZnSnQ0UReGtt97irbfeqsKaPbjs9Vpc7XVcu17Azt0xfLh4EVlZWdQLCOTvb7/D+DH/Z+kqCiEqycMtfdh35hqRxy8yMsy/RkxhZzSqRMcVd1v8y43Pzd9K+Dg0537n5fQsptk6cT77PG/teou/d/x78XR2lxPg2Fo4/Su4BkKLYRDQxbz144+EfoejM6lXE3HWORPmE1ZFRyiEEOJWXO11NPR0IDE1m31nrtErpObdeIuKKx5Zvl1gHeyadSFIV8iuA59QkJUFybuKH+Fq1Kf4++dK4p9vVBTwbgm+ocVz2VvpQGsD2j/+zU6FmGXF30lrRkOft8CjCQBZp7ZyTFMEOkdae1dN67yoWWpNQi9qJjd7HXkFBj78bCV6aytsra1IzynA1tqqRlzwV8Tp06ctXQUhaqy2/q54OdlwKTOPbfFp9Gvmbekqcfh8BlevF+Cg19KuPIP1OdeFlk/gfOArZmUXMsfDg3NZZ5kdOY65RY44pN7QIpJyqPjH0QeaD4Xgv4DRAKnHKUJlTVbxhdfABgOx0UqvLyGEsIS2Aa4kpmaz53TNS+iNRpVt8cU3nXsGF/cg61X/L/RtOBBNxjn47ePihDxuU/EbFA34tob63Yq709vd4XvNpxX8PBMyz8MP46HLKxD8F/ad2YoRCHAKwsu+Zp0TUTUkoRf3RFEUvJxtOHs1h/xCA/mFxYO5udjJPOVC3E80GoWHW3rz5a+n2Xg4hb5NvSz+3PjWE8UtH10auaPTlvPxntCREP8TvlmXedPBhznXztC6UMXeYAMaLQR1g8b9i7sxHv8BslJg1xLY+yV4twDVyA4nV1ILM3DWOdMnoE8VHqEQQojbaR9Yh2/3nGV/8jWKDEa0NehRz8PnM7iSXYC93op2gcXJubXVH9fHLv4w4B0481txS71nCAR0BluX8u/ArQEM+RSi5sOZnbBtIZzbQ8z1ZNAohAXJ99ODouZEvai1rP94nv7GZXu93CsS4n7TO8QLnVZD0uXrHE/JtGhd8goN7Dp5Gfiz5aNcdHbQofhxIM+UI8zN0xGhr4fS7nkY+T30fgP8w6DDi8XLD00uvvAqzIGzMcWt89bFNy6ldV4IISyrkacDzrbW5BYYiE3JsnR1zET/0d3+oUYeZd90VhQI7Axdp0DwwxVL5kvoHaHv29D+BVAUck9u4bCmCKxt6ODf/Z7qL2oPSehFpbDXa6ljXzwYSR17ncVb7oQQlc/RxprujYun8Npo4Snsdp+6Ql6hES8nG4K9KzjyeKM+EDIQ/MJw6jMH5clV0DaCQr0jnxz8hNMZp4vLWdtA00EwfGVxS0q99uywsydVay2t80IIUQNoNApt/hgcb++ZqxauzZ/yCg38lngFgO5/TH1ZYuOpjUyKnsRPST9Vzs40GmjzNAx4lwM2OooAH4e61HOo4CwuotaSZlRRadwc9LjY6bDSSDIvxP3q4ZY+/Hz8EjtPXuHq9QLTjbzqVjIYXo9gj4rfQFSU4haRUr5P+J7t57az/9J+hjQagre9N552nnjYemDjHwb+YbimHqTe8X/T3a+7tM4LIUQN0C7AlagTqew+dYVRnQJrRKNSTNJVcgsNeDnpCfE2n24spyiH89nnOZVxqnJ36teei22ehrj/0r7RoBpxHkT1kIReVCpJ5oW4v9X3cCDEx5HYlCx+OnqRJ8P8q70O164XcCD5GgDdm1Sgu/0dDGowiONXjhN/LZ5/Hf+X2ToXvQuf9PqEUM9QWnq0xKgaK22/Qggh7l67QFf0Wg0X0vM4npJJM98KTEVaRaL+GOOlWxNPNKWujQOdAgE4nXm60vf7WPNncHTypa1X20rftqi5pMu9EEKICnmkpS8Am46mUGio/sR2e0IaRhWaeDtS18W20rZrb23PjLAZPN74ccK8wwh0CsROaweARtGg1Whv+r8QQgjLstNpeahRcbf2zccuWbg2kJ7z503nssZ4KUnoz2edp9BYWCn7LJm6W1EU+gT0oY5NOWZ+EfcNuSIRVWrUqFGkp6ezbt06ALp3705oaCgffPBBtdYjOjqaHj16cO3aNVxcXKp130Lcb8IbuOFiZ016TiG7Tl6ha2OPO7+pEpWMbt+jElvnS9hqbRnaeKjZa9kF2WQV1KzBloQQQvypX3Mvfom9xM7Ey4zuWh8HCw7OvC2++KZzIy+HMm86u9u6Y29tz/XC65zPOk+gc+C97e/sNn678BsT2kzAztrunrYlaidpoX8AjRo1CkVRUBQFnU5Hw4YNeeuttygqKqryfa9Zs4Y5c+aUq2x0dDSKopCenl61lfrDoUOHGDRoEJ6entjY2BAYGMgTTzxBampqtexfiNrC2krDgOY+QPUPjpd8JYdTadex0ig81Ni9WvbpoHPAx8GnWvYlhBCi4pp4OeLvZkdBkdE0uryllIzxcqsZWBRFwd+x+HG1M5ln7mlfcVfj+OzIZxxMO0jU2ah72paovSShf0D179+flJQUEhISmDx5MrNnz+bdd98ts2xBQUGl7bdOnTo4OlZwROpqkJaWRq9evahTpw6bN28mNjaW5cuX4+vry/Xr1y1dPSFqnH7NvNBoFI6nZJKYml1t+43640KtXYArTjbW1bZfIYQQNZeiKPRr5g3Az8cumbqgV7ezV3NITM1Go1F4qOGte68FOQcB9/YcfVpOGu/tfY8iYxEdvDswIGjAXW9L1G6S0FeBvKK8W/4UGgrLXbbAUFCusndDr9fj7e1NQEAAL730Er1792b9+vVAcQv+o48+yty5c/H19aVJkyYAnD17lscffxwXFxfq1KnD4MGDOX36tGmbBoOBSZMm4eLigpubG6+++upNf1C7d+/OxIkTTcv5+flMmzYNPz8/9Ho9DRs25IsvvuD06dP06NEDAFdXVxRFYdSoUQAYjUbmz59PUFAQtra2tGrViu+//95sPz/++CONGzfG1taWHj16mNWzLDt37iQjI4PPP/+c1q1bExQURI8ePVi8eDFBQUGmckePHmXAgAE4ODjg5eXF008/zeXLl03rr1+/zjPPPIODgwM+Pj4sWrTopmNWFMX0CEIJFxcXVqxYYVq+07ku+Yzee+89fHx8cHNzY9y4cRQW/hlf+fn5vPbaa/j7+5ud2/IeixC34+agp3MDNwAWbDrBlez8Kt+n0aiaWl4qNPe8EEKI+16PJh5YWykkXb5erTeab1TyHdXW3xVnu1vfdA5yDsLf0R9n/a0H8DuUdoh1ievYc3EPGfkZZutyi3J5Z887ZBZkEuAUwLjQcWgUSeseVPIMfRWI+Cnilutae7ZmeofppuXRkaPJN5R9IRxSJ4TZnWablsdvHV/mc5zfPvLt3Vf2D7a2tly5csW0vGXLFpycnIiMjASgsLCQfv36ER4ezo4dO9Bqtbz99tv079+fw4cPo9PpWLRoEStWrODLL78kJCSERYsWsXbtWnr27HnL/T7zzDPs2rWLjz76iFatWpGUlMTly5fx8/Nj9erVDB06lLi4OJycnLC1LX4Oaf78+Xz11VcsW7aMRo0asX37dp566ik8PDzo1q0bZ8+eZciQIYwbN47Ro0ezd+9eJk+efNvj9/b2pqioiLVr1zJs2LAyp/pIT0+nZ8+evPDCCyxevJjc3FymTZvG448/ztatWwGYOnUq27Zt44cffsDT05MZM2awf/9+QkNDy/1ZlOdcA0RFReHj40NUVBSJiYk88cQThIaG8uKLLwIQERHBrl27+PDDDwkNDTWd2/IeixB38sJD9Ym/lM2lzDxm/XCMeUNa4Gxbda3mRy9kcDm7ADudFe0CZcAfIYQQf3K0saZzQ3ei49LYfOwijbyqt0eo0agSdcOUqrfTtV5Xutbresv1xy4fY0HMAowUDzw7qe0kwnzCADiZfpJ/Hf8XyVnJOOucebX9qzKN6gNOEvoHnKqqbNmyhc2bN/Pyyy+bXre3t+fzzz83JY9fffUVRqORzz//3JTsLl++HBcXF6Kjo+nbty8ffPABr732GkOGDAFg2bJlbN68+Zb7jo+PZ9WqVURGRtK7d28A6tevb1pfp07xBbunp6dpILv8/HzmzZvHL7/8Qnh4uOk9v/76K59++indunVj6dKlNGjQgEWLFgHQpEkTjhw5wsKFC29Zl44dOzJjxgyefPJJxowZQ4cOHejZsyfPPPMMXl5eACxZsoTWrVszb9480/u+/PJL/Pz8iI+Px9fXly+++IKvvvqKXr16AbBy5Urq1at3p4/BzLfffnvHcw3FPReWLFmClZUVwcHBPPzww2zZsoUXX3zRdG43bdpEv379UBTF7Nze6VgaN25coTqLB1Mdex1zH2vOtNWHSb6aw+z1x5j7WHPsdFXz1RJ1ovhC6aFG7ui00hIhhBDCXL9m3kTHpbE9/jLPd6mPrc6qUrdvNKqcu5aLo40WZ1trsynpjqdkkpaVj63Oig5Bd3/TOS0njcX7F2PESH3n+qiopmfuAY5ePsqJqyew1lgzpf0U3G2rZzwZUXNJQl8FVvZfect1Vor5H5bP+nx2y7Klu84s6bnk3ip2gw0bNuDg4EBhYSFGo5Enn3yS2bNnm9a3aNHClMxD8YBxiYmJNz3/npeXx8mTJ8nIyCAlJYWwsDDTOq1WS7t27W75HNPBgwexsrKiW7du5a53YmIiOTk59OnTx+z1goICWrduDUBsbKxZPQBT8n87c+fOZdKkSWzdupWYmBiWLVvGvHnz2L59Oy1atODQoUNERUXh4OBw03tPnjxJbm4uBQUFZvuuU6eO6ZGF8rrTuS7RrFkzrKz+jCcfHx+OHDkC/Hluu3Yt++7vnY5FEnpRXl5ONswZ3Jzpaw6TmJrNnA3HeWNgM2ysK/ciKq/QwM7E4h4mlTn3vBBCiPtHM18nfF1suJCex/aENNNz9ZVBVVWWbjvJT0cvAqDRKLjb63Bz0OHuoOdiRvFjsJ0buKPXlu870GA0UGgsNLWwFxgKeG/ve2QVZFHfuT5vdnoTnZXO7D0BTgH09u9Ne+/2NHaV6zUhCX2VqEi3l6oqeyc9evRg6dKl6HQ6fH190WrNQ8He3t5sOTs7m7Zt2/L111/ftC0Pj7ubsqqkC31FZGcXPxO1ceNG6tata7ZOr9ffVT1u5ObmxvDhwxk+fDjz5s2jdevWvPfee6xcuZLs7GwGDhxYZku/j48PiYmJ5dqHoig33eS48dn38p5ra2vzrs2KomA0FnfNutO5vdOxCFERfnXseHNQc2asPcLR85ks2HSC1x8Owdqq8lrRf0+6Sm6hAS8nPU19nCptu0IIIe4fiqLQt6k3K347zeZjFys1od987OKfybxS3FqfmpVPalY+8OcjseUd4+W7+O/4IfEHHm34KMMaD0NVVT47/BmnM0/jpHNicrvJNyXzAKGeoYR6hlbGIYn7hCT0Dyh7e3saNmxY7vJt2rTh22+/xdPTEyensi+mfXx8iImJMbUKFxUVsW/fPtq0aVNm+RYtWmA0Gtm2bZupy/2NSnoIGAwG02tNmzZFr9eTnJx8y5b9kJAQ0wB/JXbv3n3ngyxj/w0aNDCNct+mTRtWr15NYGDgTTdAABo0aIC1tTUxMTH4+xd3jbp27Rrx8fFmdfXw8CAl5c+pvhISEsjJyTEtl+dc30nJud2+fTv9+vW7af2djkWIimro6cDsgc2Y9cNR9p25xns/x/Fqv2CsNDePR1FRRQYjPx8vvojq1sTTrIujEEIIcaNeIZ78e/cZEi5lk3T5OkHu9nd+0x3EpmSybNspACI6BfJY67pcvV7A5ex8rmQXcOV6PmlZ+Xg46mlet3zXbnZaOwqNhaap64qMReQU5aBBw8S2E6UrvSg3eQhRlMvIkSNxd3dn8ODB7Nixg6SkJKKjo/nb3/7GuXPnAJgwYQILFixg3bp1nDhxgrFjx952DvnAwEAiIiJ47rnnWLdunWmbq1atAiAgIABFUdiwYQNpaWlkZ2fj6OjIlClTeOWVV1i5ciUnT55k//79fPzxx6xcWfyow5gxY0hISGDq1KnExcXxn//8x2wE+bJs2LCBp556ig0bNhAfH09cXBzvvfceP/74I4MHDwZg3LhxXL16lREjRrBnzx5OnjzJ5s2befbZZzEYDDg4OPD8888zdepUtm7dytGjRxk1ahQajfmvWc+ePVmyZAkHDhxg7969jBkzxqy1vTzn+k5Kzu3o0aPLPLd3OhYh7kZTXydmPByC1krht8QrfLw1AaPx3qYOOng2nQnfHOTQ2QwUpXgUYyGEEOJWXOx0hNUvfoZ987GL97y9K9n5zPsxFoNRpVNDN4a2qYuVRsHDUU+IjxNdGrkzOLQuLzxUn8GhdcscWLksgU6BwJ9T11lbWTOl3RTe7Pwmzdya3XO9xYNDEnpRLnZ2dmzfvh1/f3+GDBlCSEgIzz//PHl5eaZW5MmTJ/P0008TERFBeHg4jo6OPPbYY7fd7tKlSxk2bBhjx44lODiYF1980dQiXrduXd58802mT5+Ol5cX48ePB2DOnDnMnDmT+fPnExISQv/+/dm4caNpejl/f39Wr17NunXraNWqlelZ+Ntp2rQpdnZ2TJ48mdDQUDp27MiqVav4/PPPefrppwHw9fVl586dGAwG+vbtS4sWLZg4cSIuLi6mpP3dd9/loYceYuDAgfTu3ZsuXbrQtm1bs30tWrQIPz8/HnroIZ588kmmTJmCnZ1dhc51efzjH/8wjfZf+tyW51iEuBtt/F2Z2q8JGgW2xKYy98dYcgsqfpPoUmYe83+MZea6oyRfzcHRRssrvRtTz9Xuzm8WQgjxQCvpah8dl0p+0c3fQXEXs5i57ijjvt7P70lXb7mdQoORBZtOkJ5TiH8dOyb2alzuhP1OApwCAEjNSSWnsLinpkbRyHPxosIU9VYjlgkAMjMzcXZ2JiMj46ZkKi8vj6SkJIKCgrCxsex0EaqqUlRUhFarrbQ/NKJydO/endDQUD744INq3W9lx0RNindxd4xGI6mpqXh6elb5jZvfEi/z3s9xFBpU6nvYM+uRprg53Hmci7xCA6v3n2P1vnMUGlQ0CvylhQ9PhvnjaFN1U+I9yKozLkTtIDEhSqttMWE0qoz+914uZeYzqU9jevzxXPvZqzl8tfsMv528Yla+Y/06vNi1Pp6O5tc3n0Ql8tPRi9jrrXj/8VB8XSo+/tPtvPDzC2QVZOFp68mi7ovKfGa+pqptMVEb3S4PvZE8PCuEEKLSdWrozjwHHW9viOVU2nUmf3eIWY80pb7HzTMrQHEiH3UilVV7z3I5uwCAFvWc+b+u9Qlwu/fnH4UQQjw4NBqFPk29+Gp3MpuPXaRFPWf+G5PML7GXMKrFg9r1CPbE0caa9YcusPvUVQ6eTWdEB38GtfJFa6Xhp6PFg+ApCkzp26TSk3kAH3sfsgqySM1NJbMgU56bF3dFEnohhBBVItjbiUWPt2L2+mOcu5bL9NVHmDagCW0D/pyf90J6Lj8eSSHy+CVy/uia7+Go54UuQYQ3cJMeR0IIIe5KrxAv/hOTzLELmYz+114KDcWdkjsE1SEiPBB/t+JHuHqHePKPqJMcT8lk+c7TbDmRyiMtfPhsR/EgeE+FBdAu8O7nlb+doY2G8l38dzzd9GlJ5sVdk4ReiCoWHR1t6SoIYTFeTja8M6wl8zed4Mi5DN7633H+r1sDvJz0/O9QCvuTr1Hy4JePsw0Pt/ShXzPvSp/HXgghxIPF3UFPu8A6/J50lUKDSlMfJyI6BdLU17zrcoCbPfOHtGDLiVSW70wi+UoO/4g+CUB4AzeGta1XZXWUKehEZag1Cf3cuXPZuHEjBw8eRKfT3Xb09BKjRo0yjXxeol+/fvz0009VVEshhBClOdpY8+agZnwSlciW2FSW/nGhVKJtgCuPtPShjb+rTEknhBCi0rz4UH1c7awJq+9GuwDXW/b6KumiH1a/Dit2niby+CX83ex4pXdj+V4SNV6tSegLCgoYPnw44eHhfPHFF+V+X//+/Vm+fLlpWa+/86BMFSXjCooHgcS5uBfWVhom9GqEr7Mt/959BludFX1CvPhLSx/qVsFziUIIIYS3sw3jezYqd3knG2v+1qsRIzr442SrRa+V3mKi5qs1Cf2bb74JcMf5xEvT6/V4e3tXQY0wzR2ek5ODra1ckIr7W05O8ZQqJXEvREUpisLj7f3o3sQDRxtrbHVyoSSEEKLm8XCs/AZAIapKrUno71Z0dDSenp64urrSs2dP3n77bdzc3G5ZPj8/n/z8fNNyZmYmUDw1g9FoNCurKArOzs6kpqaiqip2dnYWHcCpsLBQki1hpjJiQlVVcnJySEtLw9nZGUVRbvpdELWD0WhEVVWLf37uDjpTfYTl1ZS4EDWHxIQoTWJClCYxUfXKe27v64S+f//+DBkyhKCgIE6ePMmMGTMYMGAAu3btwsqq7Jah+fPnm3oD3CgtLY28vLybXlcUBa1WS0pKisVHYzYajTIPpDBTWTGhqio2NjYoikJqamol1ExYgtFoJCMjA1VV5W+FMJG4EKVJTIjSJCZEaRITVS8rK6tc5RTVgg/GTp8+nYULF962TGxsLMHBwablFStWMHHixHINilfaqVOnaNCgAb/88gu9evUqs0xZLfR+fn5cu3YNJyenMt8DYDAYKCwsrHCdKovRaOTq1avUqVNHfqkEULkxYW1tfcubYKL2MBqNpKWl4eHhIX8nhInEhShNYkKUJjEhSpOYqHqZmZm4urqSkZFx2zzUoi30kydPZtSoUbctU79+/UrbX/369XF3dycxMfGWCb1ery9z4DyNRnPbYNVoNBbt7m40GsnOzsbOzk5+qQQgMSHKpijKHf+eiQePxIUoTWJClCYxIUqTmKha5T2vFk3oPTw88PDwqLb9nTt3jitXruDj41Nt+xRCCCGEEEIIIapCrbmdkpyczMGDB0lOTsZgMHDw4EEOHjxIdna2qUxwcDBr164FIDs7m6lTp7J7925Onz7Nli1bGDx4MA0bNqRfv36WOgwhhBBCCCGEEKJS1JpB8WbNmsXKlStNy61btwYgKiqK7t27AxAXF0dGRgYAVlZWHD58mJUrV5Keno6vry99+/Zlzpw5VTIXvRBCCCGEEEIIUZ0sOihebZCRkYGLiwtnz5697WAEliYDU4jSJCZEaRIToiwSF6I0iQlRmsSEKE1iouqVDM6enp6Os7PzLcvVmhZ6SymZLsDPz8/CNRFCCCGEEEII8SDJysq6bUIvLfR3YDQauXDhAo6OjhafZ/52Su7g1PSeBKL6SEyI0iQmRFkkLkRpEhOiNIkJUZrERNVTVZWsrCx8fX1v2wtCWujvQKPRUK9ePUtXo9ycnJzkl0qYkZgQpUlMiLJIXIjSJCZEaRITojSJiap1u5b5EvLAgxBCCCGEEEIIUQtJQi+EEEIIIYQQQtRCktDfJ/R6PW+88YZMySdMJCZEaRIToiwSF6I0iQlRmsSEKE1iouaQQfGEEEIIIYQQQohaSFrohRBCCCGEEEKIWkgSeiGEEEIIIYQQohaShF4IIYQQQgghhKiFJKEXQgghhBBCCCFqIUno7xOffPIJgYGB2NjYEBYWxu+//27pKolqMn/+fNq3b4+joyOenp48+uijxMXFmZXJy8tj3LhxuLm54eDgwNChQ7l06ZKFaiyq04IFC1AUhYkTJ5pek3h4MJ0/f56nnnoKNzc3bG1tadGiBXv37jWtV1WVWbNm4ePjg62tLb179yYhIcGCNRZVyWAwMHPmTIKCgrC1taVBgwbMmTOHG8dKlpi4v23fvp2BAwfi6+uLoiisW7fObH15Pv+rV68ycuRInJyccHFx4fnnnyc7O7saj0JUptvFRGFhIdOmTaNFixbY29vj6+vLM888w4ULF8y2ITFR/SShvw98++23TJo0iTfeeIP9+/fTqlUr+vXrR2pqqqWrJqrBtm3bGDduHLt37yYyMpLCwkL69u3L9evXTWVeeeUV/ve///Hdd9+xbds2Lly4wJAhQyxYa1Ed9uzZw6effkrLli3NXpd4ePBcu3aNzp07Y21tzaZNmzh+/DiLFi3C1dXVVOadd97ho48+YtmyZcTExGBvb0+/fv3Iy8uzYM1FVVm4cCFLly5lyZIlxMbGsnDhQt555x0+/vhjUxmJifvb9evXadWqFZ988kmZ68vz+Y8cOZJjx44RGRnJhg0b2L59O6NHj66uQxCV7HYxkZOTw/79+5k5cyb79+9nzZo1xMXFMWjQILNyEhMWoIpar0OHDuq4ceNMywaDQfX19VXnz59vwVoJS0lNTVUBddu2baqqqmp6erpqbW2tfvfdd6YysbGxKqDu2rXLUtUUVSwrK0tt1KiRGhkZqXbr1k2dMGGCqqoSDw+qadOmqV26dLnleqPRqHp7e6vvvvuu6bX09HRVr9er//3vf6ujiqKaPfzww+pzzz1n9tqQIUPUkSNHqqoqMfGgAdS1a9ealsvz+R8/flwF1D179pjKbNq0SVUURT1//ny11V1UjdIxUZbff/9dBdQzZ86oqioxYSnSQl/LFRQUsG/fPnr37m16TaPR0Lt3b3bt2mXBmglLycjIAKBOnToA7Nu3j8LCQrMYCQ4Oxt/fX2LkPjZu3Dgefvhhs88dJB4eVOvXr6ddu3YMHz4cT09PWrduzT//+U/T+qSkJC5evGgWF87OzoSFhUlc3Kc6derEli1biI+PB+DQoUP8+uuvDBgwAJCYeNCV5/PftWsXLi4utGvXzlSmd+/eaDQaYmJiqr3OovplZGSgKAouLi6AxISlaC1dAXFvLl++jMFgwMvLy+x1Ly8vTpw4YaFaCUsxGo1MnDiRzp0707x5cwAuXryITqcz/bEt4eXlxcWLFy1QS1HVvvnmG/bv38+ePXtuWifx8GA6deoUS5cuZdKkScyYMYM9e/bwt7/9DZ1OR0REhOmzL+u7ROLi/jR9+nQyMzMJDg7GysoKg8HA3LlzGTlyJIDExAOuPJ//xYsX8fT0NFuv1WqpU6eOxMgDIC8vj2nTpjFixAicnJwAiQlLkYReiPvIuHHjOHr0KL/++qulqyIs5OzZs0yYMIHIyEhsbGwsXR1RQxiNRtq1a8e8efMAaN26NUePHmXZsmVERERYuHbCElatWsXXX3/Nf/7zH5o1a8bBgweZOHEivr6+EhNCiNsqLCzk8ccfR1VVli5daunqPPCky30t5+7ujpWV1U0jVF+6dAlvb28L1UpYwvjx49mwYQNRUVHUq1fP9Lq3tzcFBQWkp6eblZcYuT/t27eP1NRU2rRpg1arRavVsm3bNj766CO0Wi1eXl4SDw8gHx8fmjZtavZaSEgIycnJAKbPXr5LHhxTp05l+vTp/PWvf6VFixY8/fTTvPLKK8yfPx+QmHjQlefz9/b2vmkA5qKiIq5evSoxch8rSebPnDlDZGSkqXUeJCYsRRL6Wk6n09G2bVu2bNlies1oNLJlyxbCw8MtWDNRXVRVZfz48axdu5atW7cSFBRktr5t27ZYW1ubxUhcXBzJyckSI/ehXr16ceTIEQ4ePGj6adeuHSNHjjT9X+LhwdO5c+ebprOMj48nICAAgKCgILy9vc3iIjMzk5iYGImL+1ROTg4ajflloJWVFUajEZCYeNCV5/MPDw8nPT2dffv2mcps3boVo9FIWFhYtddZVL2SZD4hIYFffvkFNzc3s/USExZi6VH5xL375ptvVL1er65YsUI9fvy4Onr0aNXFxUW9ePGipasmqsFLL72kOjs7q9HR0WpKSorpJycnx1RmzJgxqr+/v7p161Z17969anh4uBoeHm7BWovqdOMo96oq8fAg+v3331WtVqvOnTtXTUhIUL/++mvVzs5O/eqrr0xlFixYoLq4uKg//PCDevjwYXXw4MFqUFCQmpuba8Gai6oSERGh1q1bV92wYYOalJSkrlmzRnV3d1dfffVVUxmJiftbVlaWeuDAAfXAgQMqoL7//vvqgQMHTCOWl+fz79+/v9q6dWs1JiZG/fXXX9VGjRqpI0aMsNQhiXt0u5goKChQBw0apNarV089ePCg2TVnfn6+aRsSE9VPEvr7xMcff6z6+/urOp1O7dChg7p7925LV0lUE6DMn+XLl5vK5ObmqmPHjlVdXV1VOzs79bHHHlNTUlIsV2lRrUon9BIPD6b//e9/avPmzVW9Xq8GBwern332mdl6o9Gozpw5U/Xy8lL1er3aq1cvNS4uzkK1FVUtMzNTnTBhgurv76/a2Nio9evXV19//XWzC3OJiftbVFRUmdcPERERqqqW7/O/cuWKOmLECNXBwUF1cnJSn332WTUrK8sCRyMqw+1iIikp6ZbXnFFRUaZtSExUP0VVVbX6+gMIIYQQQgghhBCiMsgz9EIIIYQQQgghRC0kCb0QQgghhBBCCFELSUIvhBBCCCGEEELUQpLQCyGEEEIIIYQQtZAk9EIIIYQQQgghRC0kCb0QQgghhBBCCFELSUIvhBBCCCGEEELUQpLQCyGEEEIIIYQQtZAk9EIIIYS4740aNYpHH330tmWio6NRFIX09PRqqZMQQghxryShF0IIIapBWloaL730Ev7+/uj1ery9venXrx87d+60dNVqDEVRTD/Ozs507tyZrVu3Vsq2P/zwQ1asWGFa7t69OxMnTjQr06lTJ1JSUnB2dq6UfQohhBBVTRJ6IYQQohoMHTqUAwcOsHLlSuLj41m/fj3du3fnypUrlq5ajbJ8+XJSUlLYuXMn7u7uPPLII5w6deqet+vs7IyLi8tty+h0Ory9vVEU5Z73J4QQQlQHSeiFEEKIKpaens6OHTtYuHAhPXr0ICAggA4dOvDaa68xaNAgs3IvvPACHh4eODk50bNnTw4dOmS2rQULFuDl5YWjoyPPP/8806dPJzQ01LS+rJbnRx99lFGjRpmW8/PzmTJlCnXr1sXe3p6wsDCio6NN61esWIGLiwubN28mJCQEBwcH+vfvT0pKitl2v/zyS5o1a4Zer8fHx4fx48dX6FjK4uLigre3N82bN2fp0qXk5uYSGRkJwLZt2+jQoYNpf9OnT6eoqMj03u+//54WLVpga2uLm5sbvXv35vr164B5l/tRo0axbds2PvzwQ1OPgNOnT5fZ5X716tWmYwwMDGTRokVm9Q0MDGTevHk899xzODo64u/vz2effXbH4xRCCCEqgyT0QgghRBVzcHDAwcGBdevWkZ+ff8tyw4cPJzU1lU2bNrFv3z7atGlDr169uHr1KgCrVq1i9uzZzJs3j7179+Lj48M//vGPCtdn/Pjx7Nq1i2+++YbDhw8zfPhw+vfvT0JCgqlMTk4O7733Hv/+97/Zvn07ycnJTJkyxbR+6dKljBs3jtGjR3PkyBHWr19Pw4YNy30s5WFrawtAQUEB58+f5y9/+Qvt27fn0KFDLF26lC+++IK3334bgJSUFEaMGMFzzz1HbGws0dHRDBkyBFVVb9ruhx9+SHh4OC+++CIpKSmkpKTg5+d3U7l9+/bx+OOP89e//pUjR44we/ZsZs6cadZ1H2DRokW0a9eOAwcOMHbsWF566SXi4uLKfZxCCCHEXVOFEEIIUeW+//571dXVVbWxsVE7deqkvvbaa+qhQ4dM63fs2KE6OTmpeXl5Zu9r0KCB+umnn6qqqqrh4eHq2LFjzdaHhYWprVq1Mi1369ZNnTBhglmZwYMHqxEREaqqquqZM2dUKysr9fz582ZlevXqpb722muqqqrq8uXLVUBNTEw0rf/kk09ULy8v07Kvr6/6+uuvl3ms5TmWsgDq2rVrVVVV1evXr6tjx45Vrays1EOHDqkzZsxQmzRpohqNRrM6OTg4qAaDQd23b58KqKdPny5z2xEREergwYNNy2Wdp6ioKBVQr127pqqqqj755JNqnz59zMpMnTpVbdq0qWk5ICBAfeqpp0zLRqNR9fT0VJcuXXrL4xRCCCEqi7TQCyGEENVg6NChXLhwgfXr19O/f3+io6Np06aNqbX30KFDZGdn4+bmZmrRd3BwICkpiZMnTwIQGxtLWFiY2XbDw8MrVI8jR45gMBho3Lix2X62bdtm2g+AnZ0dDRo0MC37+PiQmpoKQGpqKhcuXKBXr15l7qM8x3IrI0aMwMHBAUdHR1avXs0XX3xBy5YtiY2NJTw83Oz59s6dO5Odnc25c+do1aoVvXr1okWLFgwfPpx//vOfXLt2rULnprTY2Fg6d+5s9lrnzp1JSEjAYDCYXmvZsqXp/4qi4O3tbTpXQgghRFXSWroCQgghxIPCxsaGPn360KdPH2bOnMkLL7zAG2+8wahRo8jOzsbHx8fsWfYSdxrM7UYajeambuaFhYWm/2dnZ2NlZcW+ffuwsrIyK+fg4GD6v7W1tdk6RVFM2y3pCn8r93Isixcvpnfv3jg7O+Ph4XHbsjeysrIiMjKS3377jZ9//pmPP/6Y119/nZiYGIKCgsq9nbtR1rkyGo1Vuk8hhBAC5Bl6IYQQwmKaNm1qGrStTZs2XLx4Ea1WS8OGDc1+3N3dAQgJCSEmJsZsG7t37zZb9vDwMBu8zmAwcPToUdNy69atMRgMpKam3rQfb2/vctXb0dGRwMBAtmzZUub68hzLrXh7e9OwYcObkvmQkBB27dpldrNi586dODo6Uq9ePaA4ke7cuTNvvvkmBw4cQKfTsXbt2jL3o9PpzFrZyxISEnLTtII7d+6kcePGN90MEUIIISxBEnohhBCiil25coWePXvy1VdfcfjwYZKSkvjuu+945513GDx4MAC9e/cmPDycRx99lJ9//pnTp0/z22+/8frrr7N3714AJkyYwJdffsny5cuJj4/njTfe4NixY2b76tmzJxs3bmTjxo2cOHGCl156yWzU9saNGzNy5EieeeYZ1qxZQ1JSEr///jvz589n48aN5T6m2bNns2jRIj766CMSEhLYv38/H3/8cbmPpaLGjh3L2bNnefnllzlx4gQ//PADb7zxBpMmTUKj0RATE2MaLDA5OZk1a9aQlpZGSEhImdsLDAwkJiaG06dPc/ny5TJb1CdPnsyWLVuYM2cO8fHxrFy5kiVLlpgNDiiEEEJYknS5F0IIIaqYg4MDYWFhLF68mJMnT1JYWIifnx8vvvgiM2bMAIpbl3/88Udef/11nn32WdLS0vD29qZr1654eXkB8MQTT3Dy5EleffVV8vLyGDp0KC+99BKbN2827eu5557j0KFDPPPMM2i1Wl555RV69OhhVp/ly5fz9ttvM3nyZM6fP4+7uzsdO3bkkUceKfcxRUREkJeXx+LFi5kyZQru7u4MGzas3MdSUXXr1uXHH39k6tSptGrVijp16vD888/z97//HQAnJye2b9/OBx98QGZmJgEBASxatIgBAwaUub0pU6YQERFB06ZNyc3NJSkp6aYybdq0YdWqVcyaNYs5c+bg4+PDW2+9ZTYFoBBCCGFJilr6QTshhBBC1BqzZ89m3bp1HDx40NJVEUIIIUQ1ky73QgghhBBCCCFELSQJvRBCCCGEEEIIUQtJl3shhBBCCCGEEKIWkhZ6IYQQQgghhBCiFpKEXgghhBBCCCGEqIUkoRdCCCGEEEIIIWohSeiFEEIIIYQQQohaSBJ6IYQQQgghhBCiFpKEXgghhBBCCCGEqIUkoRdCCCGEEEIIIWohSeiFEEIIIYQQQoha6P8BA85JzG20fGIAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "x_test, y_test = next(dataset(1, 1))\n", "y_pred = model(x_test)\n", "\n", "plt.figure(figsize=(12, 6))\n", "plt.subplot(2, 1, 1)\n", "plt.plot(x_test[0], label='Input Sequence', alpha=0.8)\n", "plt.plot(y_test[0], label='Target Sequence', alpha=0.8)\n", "plt.plot(y_pred[0], label='Predicted Sequence', linestyle='--', alpha=0.8)\n", "plt.xlabel('Sequence Position')\n", "plt.ylabel('Value')\n", "plt.title('Example Sequence Prediction')\n", "plt.legend()\n", "plt.grid(True, alpha=0.3)" ] } ], "metadata": { "accelerator": "TPU", "colab": { "gpuType": "V28", "provenance": [] }, "kernelspec": { "display_name": "jax-env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 0 }