{ "cells": [ { "cell_type": "markdown", "id": "17be077e", "metadata": { "id": "17be077e" }, "source": [ "# Tutorial 3: Tensor Parallel and Transformers Scaling\n", "\n", "[](https://github.com/sshkhr/MinText/blob/main/docs/tutorials/3_Tensor_Parallel_and_Transformers.ipynb)\n", "[](https://colab.research.google.com/github/sshkhr/MinText/blob/main/docs/tutorials/3_Tensor_Parallel_and_Transformers.ipynb)" ] }, { "cell_type": "markdown", "id": "47d7d9d3", "metadata": { "id": "47d7d9d3" }, "source": [ "In the previous tutorial, we learned about data parallelism and how to use it to shard data batches across devices. We also learned about Fully Sharded Data Parallel (FSDP) and how it can be used to shard model parameters, gradients and optimizer states across devices. In this part, we will cover tensor parallelism and how it can be used to shard model layers across devices. We will also learn how to use the different parallelism techniques together to scale up training of an actual transformer model." ] }, { "cell_type": "markdown", "id": "992ec8d3", "metadata": { "id": "992ec8d3" }, "source": [ "## 0. Setup\n", "\n", "Let's start by importing the necessary libraries and initializing our environment." ] }, { "cell_type": "code", "execution_count": 1, "id": "bd37fc51", "metadata": { "id": "bd37fc51" }, "outputs": [], "source": [ "import os\n", "# Force JAX to see 8 devices for this tutorial (only use if not using TPU runtime)\n", "#os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax.sharding import Mesh, PartitionSpec as P, NamedSharding\n", "from jax.experimental import mesh_utils\n", "from flax import nnx\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import time\n", "from functools import partial\n", "import dataclasses\n", "\n", "import optax" ] }, { "cell_type": "code", "execution_count": 2, "id": "ea947ba6", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ea947ba6", "outputId": "3b056705-36ea-4bb3-ed1a-a7770c36ec3d" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "JAX version: 0.5.2\n", "Available devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1)]...\n", "Number of devices: 8\n" ] } ], "source": [ "# Check available devices\n", "print(f\"JAX version: {jax.__version__}\")\n", "print(f\"Available devices: {jax.devices()[:4]}...\")\n", "print(f\"Number of devices: {jax.device_count()}\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "nm_GcLiSWxbx", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nm_GcLiSWxbx", "outputId": "7967fad5-fecf-4c5d-d817-7a289a663cd8" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/1.2 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.7/1.2 MB\u001b[0m \u001b[31m22.2 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m21.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/485.5 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m485.5/485.5 kB\u001b[0m \u001b[31m35.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.6/8.6 MB\u001b[0m \u001b[31m127.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m65.3/65.3 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "# Requirements for Language Modelling\n", "!pip install -Uq tiktoken grain matplotlib" ] }, { "cell_type": "markdown", "id": "425d3061", "metadata": { "id": "425d3061" }, "source": [ "## 1. Tensor Parallelism\n", "\n", "While fully-sharded data parallelism distributes model weights across different devices during the AllReduce operation, tensor parallelism takes a different approach. Also known as \"1D model parallelism\" or Megatron sharding, this technique shards the feedforward dimensions of individual model layers and distributes activations between devices during computation. This method enables smaller effective batch sizes per device, making it particularly useful for training very large models. The diagram below illustrates how a single matrix is partitioned across devices using this approach:" ] }, { "cell_type": "markdown", "id": "7c76edab", "metadata": { "id": "7c76edab" }, "source": [ "### 1.1 Tensor Parallelism Theory\n", "\n", "**Sharding**: Model layer activations are sharded along tensor axes across devices, model parameters are replicated on each device.\n", "\n", "**Equation** (for our MLP example):\n", "$$\\text{In}[B, D_Y] \\cdot_D W_\\text{in}[D, F_Y] \\cdot_F W_\\text{out}[F_Y, D] \\rightarrow \\text{Out}[B, D_Y]$$\n", "\n", "where $F_Y$ indicates the activations are sharded across $Y$ devices." ] }, { "cell_type": "markdown", "id": "01006675", "metadata": { "id": "01006675" }, "source": [ "" ] }, { "cell_type": "markdown", "id": "2f540e70", "metadata": { "id": "2f540e70" }, "source": [ " Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) " ] }, { "cell_type": "markdown", "id": "6a24d684", "metadata": { "id": "6a24d684" }, "source": [ "### 1.2 Tensor Parallelism Algorithm\n", "\n", "The computation pattern In[B, D_Y] * W_in[D, F_Y] * W_out[F_Y, D] → Out[B, D_Y] requires gathering activations prior to the initial matrix multiplication. This approach becomes more efficient than ZeRO sharding when activation sizes are smaller than weight sizes.\n", "\n", "**Forward pass:** need to compute Loss[B]\n", "\n", "1. In[B, D] = **AllGather**(In[B, DY]) *(on critical path)*\n", "2. Tmp[B, FY] = In[B, D] \\*D Win[D, FY] *(not sharded along contracting, so no comms)*\n", "3. Out[B, D] {UY} = Tmp[B, FY] \\*F Wout[FY, D]\n", "4. Out[B, DY] = **ReduceScatter**(Out[B, D] {UY}) *(on critical path)*\n", "5. Loss[B] = ...\n", "\n", "**Backward pass:** need to compute dWout[FY, D], dWin[D, FY]\n", "\n", "1. dOut[B, DY] = ...\n", "2. dOut[B, D] = **AllGather**(dOut[B, DY]) *(on critical path)*\n", "3. dWout[FY, D] = Tmp[B, FY] \\*B dOut[B, D]\n", "4. dTmp[B, FY] = dOut[B, D] \\*D Wout[FY, D] *(can throw away dOut[B, D] here)*\n", "5. In[B, D] = **AllGather**(In[B, DY]) *(this can be skipped by sharing with (1) from the forward pass)*\n", "6. dWin[D, FY] = dTmp[B, FY] \\*B In[B, D]\n", "7. dIn[B, D] {U.Y} = dTmp[B, FY] \\*F Win[D, FY] *(needed for previous layers)*\n", "8. dIn[B, DY] = **ReduceScatter**(dIn[B, D] {U.Y}) *(on critical path)*\n", "\n", "A key advantage of the two matrix operations in our MLP forward pass is that tensor parallelism integrates nicely with this setup. Without this optimization, we would need to perform an AllReduce operation after each matrix multiplication. However, the sequential computation In[B, D_Y] * W_in[D, F_Y] → Tmp[B, F_Y] followed by Tmp[B, F_Y] * W_out[F_Y, D] → Out[B, D_Y] allows us to perform a single AllGather on the input at the start and a single ReduceScatter on the output at the end, eliminating the need for intermediate AllReduce operations." ] }, { "cell_type": "markdown", "id": "5334a065", "metadata": { "id": "5334a065" }, "source": [ "## 2. Combining Parallelism Techniques\n", "\n", "In this section, we will combine FSDP, and tensor parallelism to implemnt distributed training of a simple MLP model. The efficiency from gathering activations prior to matrix multiplyin tensor parallelism typically emerges only when combined with some degree of ZeRO sharding, which reduces the gather operation's overhead. This synergy explains why ZeRO sharding and model parallelism are commonly used together in practice.\n" ] }, { "cell_type": "markdown", "id": "60a857dc", "metadata": { "id": "60a857dc" }, "source": [ "### 2.1 Mesh Definition" ] }, { "cell_type": "code", "execution_count": 4, "id": "45ed2560", "metadata": { "id": "45ed2560" }, "outputs": [], "source": [ "# Assign logical names 'data' and 'model' to the axes of this grid.\n", "# The first dimension (size 2) is named 'data'.\n", "# The second dimension (size 4) is named 'model'.\n", "mesh = jax.sharding.Mesh(\n", " mesh_utils.create_device_mesh((2, 4)),\n", " ('data', 'model'),\n", ")" ] }, { "cell_type": "markdown", "id": "9ce31aa1", "metadata": { "id": "9ce31aa1" }, "source": [ "### 2.2 Sharding Helper Functions" ] }, { "cell_type": "code", "execution_count": 5, "id": "5d266bc6", "metadata": { "id": "5d266bc6" }, "outputs": [], "source": [ "# A helper function to quickly create a NamedSharding object\n", "# using the globally defined 'mesh'.\n", "def named_sharding(*names: str | None) -> NamedSharding:\n", " # P(*names) creates a PartitionSpec, e.g., P('data', None)\n", " # NamedSharding binds this PartitionSpec to the 'mesh'.\n", " return NamedSharding(mesh, P(*names))\n", "\n", "\n", "@dataclasses.dataclass(unsafe_hash=True)\n", "class MeshRules:\n", " \"\"\"Rules for combined FSDP (data parallel) + tensor parallel sharding\"\"\"\n", " weight_0: str | None = None # First dimension of weights\n", " weight_1: str | None = 'model' # Second dimension of weights (tensor parallel)\n", " bias: str | None = 'model' # Bias sharded along model axis\n", " data: str | None = 'data' # Data sharded along data axis\n", "\n", " def __call__(self, *keys: str) -> tuple[str, ...]:\n", " return tuple(getattr(self, key) for key in keys)\n", "\n", "mesh_rules = MeshRules()" ] }, { "cell_type": "markdown", "id": "8f6e98b3", "metadata": { "id": "8f6e98b3" }, "source": [ "### 2.3 Define The Sharded Model" ] }, { "cell_type": "code", "execution_count": 6, "id": "31970b77", "metadata": { "id": "31970b77" }, "outputs": [], "source": [ "# Modified MLP using nnx.Linear with tensor parallelism\n", "class MLP(nnx.Module):\n", " def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs):\n", " # For linear1: (128, 2048) -> shard second dimension\n", " self.linear1 = nnx.Linear(\n", " din, dmid,\n", " kernel_init=nnx.with_metadata(\n", " nnx.initializers.lecun_normal(),\n", " sharding=mesh_rules('weight_0', 'weight_1') # (None, 'model')\n", " ),\n", " bias_init=nnx.with_metadata(\n", " nnx.initializers.zeros_init(),\n", " sharding=mesh_rules('bias') # ('model',)\n", " ),\n", " rngs=rngs\n", " )\n", "\n", " # For linear2: (2048, 128) -> shard first dimension\n", " self.linear2 = nnx.Linear(\n", " dmid, dout,\n", " kernel_init=nnx.with_metadata(\n", " nnx.initializers.lecun_normal(),\n", " sharding=('model', None) # Custom sharding for this layer\n", " ),\n", " bias_init=nnx.with_metadata(\n", " nnx.initializers.zeros_init(),\n", " sharding=(None,) # Don't shard output bias\n", " ),\n", " rngs=rngs\n", " )\n", "\n", " def __call__(self, x):\n", " x = nnx.relu(self.linear1(x))\n", " return self.linear2(x)" ] }, { "cell_type": "markdown", "id": "06d184b8", "metadata": { "id": "06d184b8" }, "source": [ "### 2.4 Handling Sharded Optimizer State" ] }, { "cell_type": "code", "execution_count": 7, "id": "1d9f4258", "metadata": { "id": "1d9f4258" }, "outputs": [], "source": [ "# Define a custom type for SGD momentum state, inheriting from nnx.Variable.\n", "# This allows it to be tracked as part of the NNX state tree.\n", "class SGDState(nnx.Variable):\n", " pass\n", "\n", "# Define the SGD optimizer using NNX API.\n", "class SGD(nnx.Object):\n", " # Constructor takes the model parameters (as nnx.State), learning rate, and decay.\n", " def __init__(self, params: nnx.State, lr, decay=0.9):\n", " # Helper function to initialize momentum buffer for a given parameter.\n", " def init_optimizer_state(variable: nnx.Variable):\n", " # Create momentum state with zeros, same shape and metadata (incl. sharding)\n", " # as the parameter it corresponds to.\n", " return SGDState(\n", " jnp.zeros_like(variable.value), **variable.get_metadata()\n", " )\n", "\n", " self.lr = lr\n", " # Store a reference to the parameter State tree.\n", " self.params = params\n", " # Create the momentum state tree, mirroring the structure of 'params',\n", " # using the helper function. Momentum will have the same sharding as params.\n", " self.momentum = jax.tree.map(init_optimizer_state, self.params)\n", " self.decay = decay\n", "\n", " # Method to update parameters based on gradients.\n", " def update(self, grads: nnx.State):\n", " # Define the update logic for a single parameter/momentum/gradient triple.\n", " def update_fn(\n", " params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState\n", " ):\n", " # Standard SGD with momentum update rule.\n", " # v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t)\n", " momentum.value = self.decay * momentum.value + (1 - self.decay) * grad.value\n", " # θ_{t+1} = θ_t - α * v_t\n", " params.value -= self.lr * momentum.value # NOTE: Direct mutation of param value!\n", "\n", " # Apply the update function across the parameter, momentum, and gradient trees.\n", " # This performs the update in-place on the parameter values referenced by self.params.\n", " jax.tree.map(update_fn, self.params, self.momentum, grads)\n" ] }, { "cell_type": "markdown", "id": "850d7981", "metadata": { "id": "850d7981" }, "source": [ "### 2.5 Applying Sharding to the Model and Optimizer" ] }, { "cell_type": "code", "execution_count": 8, "id": "3f54c41e", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 434 }, "id": "3f54c41e", "outputId": "3100ce64-0cd4-4c5c-ff23-c0bfe3d7beb0" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Linear1 kernel sharding (128, 2048):\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ], "text/html": [ "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "\n", "Linear2 kernel sharding (2048, 128):\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0,6\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 1,7\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mTPU 2,4\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mTPU 3,5\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ], "text/html": [ "
\n", " TPU 0,6 \n", " \n", " \n", " TPU 1,7 \n", " \n", " \n", " TPU 2,4 \n", " \n", " \n", " TPU 3,5 \n", " \n", "\n" ] }, "metadata": {} } ], "source": [ "@nnx.jit\n", "def create_model():\n", " model = MLP(128, 2048, 128, rngs=nnx.Rngs(0))\n", " optimizer = SGD(nnx.variables(model, nnx.Param), 0.01, decay=0.9)\n", "\n", " # Extract state\n", " state = nnx.state(optimizer)\n", "\n", " # Define sharding for the state pytree\n", " def get_named_shardings(path: tuple, value: nnx.VariableState):\n", " if path[0] == 'params':\n", " return value.replace(NamedSharding(mesh, P(*value.sharding)))\n", " elif path[0] == 'momentum':\n", " return value.replace(NamedSharding(mesh, P(*value.sharding)))\n", " else:\n", " raise ValueError(f'Unknown path: {path}')\n", "\n", " named_shardings = state.map(get_named_shardings)\n", " sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)\n", " nnx.update(optimizer, sharded_state)\n", "\n", " return model, optimizer\n", "\n", "model, optimizer = create_model()\n", "\n", "# Visualize sharding\n", "print(\"Linear1 kernel sharding (128, 2048):\")\n", "jax.debug.visualize_array_sharding(model.linear1.kernel.value)\n", "print(\"\\nLinear2 kernel sharding (2048, 128):\")\n", "jax.debug.visualize_array_sharding(model.linear2.kernel.value)" ] }, { "cell_type": "markdown", "id": "0bcf251c", "metadata": { "id": "0bcf251c" }, "source": [ "### 2.6 Distributed Training" ] }, { "cell_type": "code", "execution_count": 9, "id": "ca7002f6", "metadata": { "id": "ca7002f6" }, "outputs": [], "source": [ "# JIT-compile the training step function.\n", "@nnx.jit\n", "def train_step(model: MLP, optimizer: SGD, x, y):\n", " # Define the loss function (Mean Squared Error).\n", " # Takes the model object as input, consistent with nnx.value_and_grad.\n", " def loss_fn(model):\n", " y_pred = model(x) # Forward pass\n", " loss = jnp.mean((y - y_pred) ** 2)\n", " return loss\n", "\n", " # Calculate loss and gradients w.r.t the model's state (its nnx.Param variables).\n", " # 'grad' will be an nnx.State object mirroring model's Param structure.\n", " loss, grad = nnx.value_and_grad(loss_fn)(model)\n", "\n", " # Call the optimizer's update method to apply gradients.\n", " # This updates the model parameters in-place.\n", " optimizer.update(grad)\n", "\n", " # Return the calculated loss.\n", " return loss\n" ] }, { "cell_type": "markdown", "id": "7296dd13", "metadata": { "id": "7296dd13" }, "source": [ "### 2.7 Training Loop and Results" ] }, { "cell_type": "code", "execution_count": 10, "id": "ba4d14ce", "metadata": { "id": "ba4d14ce" }, "outputs": [], "source": [ "# Dataset function (as before)\n", "def dataset(steps, batch_size):\n", " \"\"\"Generate 128D sequence data with underlying pattern.\"\"\"\n", " for _ in range(steps):\n", " # Generate base signal\n", " t = np.linspace(0, 4*np.pi, 128)\n", " base_patterns = np.array([\n", " np.sin(t + np.random.uniform(0, 2*np.pi)),\n", " np.cos(2*t + np.random.uniform(0, 2*np.pi)),\n", " np.sin(3*t + np.random.uniform(0, 2*np.pi))\n", " ])\n", "\n", " # Create batch of sequences\n", " x = np.zeros((batch_size, 128))\n", " y = np.zeros((batch_size, 128))\n", "\n", " for i in range(batch_size):\n", " # Mix base patterns with random weights\n", " weights = np.random.randn(3)\n", " signal = np.sum(weights[:, np.newaxis] * base_patterns, axis=0)\n", "\n", " # Add noise\n", " x[i] = signal + np.random.normal(0, 0.1, 128)\n", "\n", " # Output is a non-linear transformation of input\n", " y[i] = np.roll(x[i], 5) * 0.8 + 0.1 * x[i]**2\n", " y[i] += np.random.normal(0, 0.05, 128)\n", "\n", " yield x.astype(np.float32), y.astype(np.float32)\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "beed9e47", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "beed9e47", "outputId": "0a46bed1-5f28-4daf-f2cc-88441a0702c3" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Step 0: Loss = 1.7442927360534668\n", "Step 100: Loss = 0.13562121987342834\n", "Step 200: Loss = 0.10096409171819687\n", "Step 300: Loss = 0.07630820572376251\n", "Step 400: Loss = 0.06262542307376862\n", "Step 500: Loss = 0.05251384899020195\n" ] } ], "source": [ "# Training Loop\n", "losses = []\n", "for step, (x_batch, y_batch) in enumerate(\n", " dataset(batch_size=8192, steps=501)\n", "):\n", " # Shard data along 'data' axis\n", " x_batch, y_batch = jax.device_put((x_batch, y_batch), named_sharding('data'))\n", "\n", " loss = train_step(model, optimizer, x_batch, y_batch)\n", " losses.append(float(loss))\n", "\n", " if step % 100 == 0:\n", " print(f'Step {step}: Loss = {loss}')" ] }, { "cell_type": "code", "execution_count": 12, "id": "7dcef594", "metadata": { "id": "7dcef594", "outputId": "a406443e-480f-4462-877d-9afeef367b04", "colab": { "base_uri": "https://localhost:8080/", "height": 490 } }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "Text(0, 0.5, 'MSE Loss')" ] }, "metadata": {}, "execution_count": 12 }, { "output_type": "display_data", "data": { "text/plain": [ "