{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "o5p_3ph2u3sv" }, "source": [ "# Tutorial 1: Parallelization Basics\n", "\n", "[![Open in GitHub](https://img.shields.io/badge/Open%20in-GitHub-181717?style=flat-square&logo=github)](https://github.com/sshkhr/MinText/blob/main/docs/tutorials/1_Parallelization_Basics.ipynb)\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sshkhr/MinText/blob/main/docs/tutorials/1_Parallelization_Basics.ipynb)\n", "\n", "This tutorial covers the fundamental concepts of parallelization in JAX, including device meshes, sharded arrays, and collective operations. We'll build understanding step by step, starting with basic concepts and working towards practical implementations.\n", "\n", "**Learning objectives:**\n", "- Understand JAX's device mesh and sharding concepts\n", "- Learn about how matrices can be sharded across devices\n", "- Learn about collective operations (AllGather, ReduceScatter, etc.)\n", "- Implement basic parallel computations\n", "\n", "**Prerequisites:**\n", "- Basic familiarity with NumPy\n", "- Understanding of matrix operations\n", "- No prior knowledge of distributed computing required" ] }, { "cell_type": "markdown", "metadata": { "id": "qloP9r0Tu3sw" }, "source": [ "## 0. Why JAX?\n", "\n", "[JAX](https://docs.jax.dev/) is a high-performance numerical computing library that combines NumPy's familiar API with the power of automatic differentiation and hardware acceleration on GPUs and TPUs. Developed by Google Research, JAX enables writing high-performance code that can run efficiently on a single device or scale across multiple devices.\n", "\n", "\n", "We use JAX for this notebook (and in the MinText library) because of several reasons:\n", "\n", "1. Beginner-friendly automatic parallelization using `jax.jit`\n", "2. Ability to simulate multiple devices using `\"XLA_FLAGS\"`\n", "3. Google Colab provides an 8 device runtime, v2-8 TPU, for free. This consists of 8x8 GB TPU cores which adds up to a total of 64 GB VRAM compute. So you can actually run distributed operations over 8 devices.\n", "4. There are already several great pedagogical style libraries in Pytorch (such as [HuggingFace Nanotron](https://github.com/huggingface/nanotron)) which serve a similar purpose. The key concepts from these tutorials can often be directly translated to PyTorch." ] }, { "cell_type": "markdown", "metadata": { "id": "RgGURMiBu3sw" }, "source": [ "## 1. Multi-Device Computation\n", "\n", "Choose the v2-8 TPU runtime in Google Colab to run this notebook. Once you restart the and exploring the available devices." ] }, { "cell_type": "markdown", "metadata": { "id": "N2jINWX81Nag" }, "source": [ "![Colab v2-8 run time](https://github.com/sshkhr/MinText/blob/main/docs/_static/colab-runtime.png?raw=1)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "2i4n1zkju3sw" }, "outputs": [], "source": [ "# Install JAX if needed\n", "# !pip install --upgrade jax jaxlib" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "bJytD437u3sw" }, "outputs": [], "source": [ "# Uncomment this to simulate running the code on 8 CPU devices (use for local runs)\n", "# import os\n", "# os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "JvYD0GEWu3sx" }, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "from jax.sharding import Mesh, NamedSharding, PartitionSpec as P\n", "from jax.experimental import mesh_utils\n", "from jax.experimental.shard_map import shard_map\n", "\n", "from functools import partial\n", "import numpy as np\n", "import time\n", "import matplotlib.pyplot as plt\n", "from typing import Optional" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Dk-ci7Ozu3sx", "outputId": "6563680b-2fc4-4a94-e575-d0ed0c8d0559" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "JAX version: 0.5.2\n", "Available devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]\n", "Device count: 8\n" ] } ], "source": [ "# Check available devices\n", "print(f\"JAX version: {jax.__version__}\")\n", "print(f\"Available devices: {jax.devices()}\")\n", "print(f\"Device count: {jax.device_count()}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "_HjN6kuixEnI" }, "source": [ "If you cannot access the v2-8 TPU from Google Colab (you timed out or are running this locally) restart this notebook and un-comment the `os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'` flag to simulate multiple devices on your CPU.\n", "\n", "The output of the cell above should then change to look something like this:\n", "\n", "```bash\n", "Available devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]\n", "Device count: 8\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "_NVkU2c4u3sx" }, "source": [ "### 1.1 Creating a Device Mesh\n", "\n", "JAX uses the concept of a device mesh to organize available devices. A device can be a CPU (or a CPU core), GPU, or TPU for JAX's purpose. A mesh is a multi-dimensional array of devices that can be addressed along different axes. This allows us to partition our data and computation along different dimensions of the mesh." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fnLKxdOyu3sx", "outputId": "07285e46-4417-431d-cdd2-b7fce38bd709" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2D Mesh: Mesh('x': 2, 'y': 2)\n", "[0, 1, 2, 3]\n" ] } ], "source": [ "devices = jax.devices()\n", "\n", "# If you have multiple devices, you can create a 2D mesh\n", "if len(devices) >= 4:\n", " mesh_2d = jax.make_mesh((2, 2), ('x', 'y'))\n", " print(f\"2D Mesh: {mesh_2d}\")\n", " print([d.id for d in mesh_2d.devices.flat])\n", "else:\n", " print(\"Not enough devices for a 2D mesh demonstration\")" ] }, { "cell_type": "markdown", "metadata": { "id": "fL_6yWgyu3sx" }, "source": [ "## 2. Sharded Matrices\n", "\n", "Sharded matrices are arrays that are split across multiple devices. JAX provides abstractions to create and operate on these distributed arrays efficiently." ] }, { "cell_type": "markdown", "metadata": { "id": "o025TqGbSPUx" }, "source": [ "### 2.1 Device Mesh and Partition Specifications\n", "\n", "- **Mesh**: A logical arrangement of devices\n", "- **PartitionSpec (P)**: Specifies how to partition an array across mesh dimensions\n", "- **jax.device_put()**: Places an array on a specific device or according to a sharding" ] }, { "cell_type": "markdown", "metadata": { "id": "qfGkerAz1Enu" }, "source": [ "![Sharding](https://github.com/jax-ml/scaling-book/blob/main/assets/img/sharding-example.png?raw=true)" ] }, { "cell_type": "markdown", "metadata": { "id": "hfdTmAIu1R2b" }, "source": [ " Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) " ] }, { "cell_type": "markdown", "metadata": { "id": "8OyInnwOSK3c" }, "source": [ "Let's see how to create sharded arrays using JAX's sharding API:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "biXk36aWu3sx", "outputId": "613a11c2-3c66-4ef5-e3a3-bc18edf62738" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Matrix shape: (8192, 8192), Size in memory: 256.00 MB\n" ] } ], "source": [ "# Let's create a large matrix and shard it\n", "matrix_size = 8192 # Adjust based on your device memory\n", "matrix = jnp.ones((matrix_size, matrix_size))\n", "print(f\"Matrix shape: {matrix.shape}, Size in memory: {matrix.size * 4 / (1024**2):.2f} MB\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 251 }, "id": "9CJJEff5u3sx", "outputId": "91b2f282-8734-4b36-a76b-5ff2b5ec8513" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sharded matrix type: \n", "Sharding spec: NamedSharding(mesh=Mesh('x': 2, 'y': 2), spec=PartitionSpec('x', 'y'), memory_kind=device)\n" ] }, { "data": { "text/html": [ "
                        \n",
              "                        \n",
              "   TPU 0       TPU 1    \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "   TPU 2       TPU 3    \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mTPU 3\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Define the partition spec - shard along the first dimension\n", "partition_spec = P('x', 'y')\n", "\n", "# Create a shardings object\n", "from jax.sharding import NamedSharding\n", "shardings = NamedSharding(mesh_2d, partition_spec)\n", "\n", "# Create the sharded array\n", "sharded_matrix = jax.device_put(x=matrix, device=shardings)\n", "\n", "print(f\"Sharded matrix type: {type(sharded_matrix)}\")\n", "print(f\"Sharding spec: {shardings}\")\n", "\n", "jax.debug.visualize_array_sharding(sharded_matrix)" ] }, { "cell_type": "markdown", "metadata": { "id": "Aa9Xp22Pu3sx" }, "source": [ "### 2.2 Inspecting Sharded Matrices\n", "\n", "We can inspect how the array is distributed across devices:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BPvuoHw96fHZ", "outputId": "8cbb4bc0-078d-4ac2-db04-05e99b6c6857" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Global matrix shape (8192, 8192)\n", "Shapes of Matrix Shards:\n", "TPU_0(process=0,(0,0,0,0)) (slice(0, 4096, None), slice(0, 4096, None)) (4096, 4096)\n", "TPU_1(process=0,(0,0,0,1)) (slice(0, 4096, None), slice(4096, 8192, None)) (4096, 4096)\n", "TPU_2(process=0,(1,0,0,0)) (slice(4096, 8192, None), slice(0, 4096, None)) (4096, 4096)\n", "TPU_3(process=0,(1,0,0,1)) (slice(4096, 8192, None), slice(4096, 8192, None)) (4096, 4096)\n" ] } ], "source": [ "print(\"Global matrix shape\", sharded_matrix.shape)\n", "\n", "# Get the local arrays on each device\n", "print(\"Shapes of Matrix Shards:\")\n", "for shard in sharded_matrix.addressable_shards:\n", " print(shard.device, shard.index, shard.data.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "ra5f9aOe9936" }, "source": [ "![Array Sharding Visualization](https://jax-ml.github.io/scaling-book/assets/img/sharding-colored4.png)" ] }, { "cell_type": "markdown", "metadata": { "id": "KUtX2m3S9937" }, "source": [ " Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) " ] }, { "cell_type": "markdown", "metadata": { "id": "v1aOZjpYu3sx" }, "source": [ "### 2.3 Performance Benefits\n", "\n", "Let's compare the performance of element-wise operations on sharded vs. non-sharded matrices (we will matrix-level operations in a bit)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c3y0rJzMu3sx", "outputId": "47f990e6-3dd6-410b-a0f1-2e6790cf1c23" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The slowest run took 10.40 times longer than the fastest. This could mean that an intermediate result is being cached.\n", "27.9 ms ± 36.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], "source": [ "# `matrix` is present on a single device\n", "%timeit -n 5 -r 5 jnp.sin(matrix).block_until_ready()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "hZpzir6m_D6u", "outputId": "a899263a-1269-4c8c-97d8-9418f77afec0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The slowest run took 29.52 times longer than the fastest. This could mean that an intermediate result is being cached.\n", "21.4 ms ± 36.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], "source": [ "# `sharded_matrix` is distributed across 4 devices\n", "%timeit -n 5 -r 5 jnp.sin(sharded_matrix).block_until_ready()" ] }, { "cell_type": "markdown", "metadata": { "id": "9PFgKkSJu3sy" }, "source": [ "### 2.4 Sharding Strategies\n", "\n", "Each axis of the matrix can be sharded across each possible axis in the device mesh. This gives rise to a combinatorial number of possible shardings." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3UCpiympBJB6" }, "outputs": [], "source": [ "# A small helper function to define a sharding mesh\n", "default_mesh = jax.make_mesh((2, 2), ('x', 'y'))\n", "\n", "def mesh_sharding(\n", " pspec: P, mesh: Optional[Mesh] = None,\n", " ) -> NamedSharding:\n", " if mesh is None:\n", " mesh = default_mesh\n", " return NamedSharding(mesh, pspec)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 215 }, "id": "dc-8P5xVu3sy", "outputId": "9602fcc0-cbc1-41c7-9c26-e1e2532335be" }, "outputs": [ { "data": { "text/html": [ "
                        \n",
              "                        \n",
              "   TPU 0       TPU 1    \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "   TPU 2       TPU 3    \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mTPU 3\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Shard first axis of matrix along mesh axis 'x', second axis of matrix along mesh axis 'y'\n", "ix_jy_sharding = jax.device_put(matrix, mesh_sharding(P('x', 'y')))\n", "jax.debug.visualize_array_sharding(ix_jy_sharding)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 215 }, "id": "Pn0Ihnt4BdgE", "outputId": "721fb4b7-8313-4a18-b810-5a78470357c5" }, "outputs": [ { "data": { "text/html": [ "
                        \n",
              "                        \n",
              "   TPU 0       TPU 2    \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "   TPU 1       TPU 3    \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mTPU 3\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Shard first axis of matrix along mesh axis 'y', second axis of matrix along mesh axis 'x'\n", "iy_jx_sharding = jax.device_put(matrix, mesh_sharding(P('x', 'y')))\n", "jax.debug.visualize_array_sharding(iy_jx_sharding)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 199 }, "id": "bk-iDwLkB1HB", "outputId": "d5a1ed04-0dcc-441a-8950-2157cc187e53" }, "outputs": [ { "data": { "text/html": [ "
┌───────────────────────┐\n",
              "│                       │\n",
              "│        TPU 0,1        │\n",
              "│                       │\n",
              "│                       │\n",
              "├───────────────────────┤\n",
              "│                       │\n",
              "│        TPU 2,3        │\n",
              "│                       │\n",
              "│                       │\n",
              "└───────────────────────┘\n",
              "
\n" ], "text/plain": [ "┌───────────────────────┐\n", "│ │\n", "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │\n", "│ │\n", "│ │\n", "├───────────────────────┤\n", "│ │\n", "│ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", "│ │\n", "│ │\n", "└───────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Shard first axis of matrix along mesh axis 'x', replicate second axis of matrix along each shard\n", "ix_j_sharding = jax.device_put(matrix, mesh_sharding(P('x', None)))\n", "jax.debug.visualize_array_sharding(ix_j_sharding, use_color=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 199 }, "id": "qglM7sfHCFd6", "outputId": "eb07b182-38d9-45a8-a931-136fdb0951e1" }, "outputs": [ { "data": { "text/html": [ "
┌──────────┬──────────┐\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│ TPU 0,2  │ TPU 1,3  │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "└──────────┴──────────┘\n",
              "
\n" ], "text/plain": [ "┌──────────┬──────────┐\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "└──────────┴──────────┘\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Partition second axis of x over second mesh axis 'y', replicate first axis of matrix along each shard\n", "i_jy_shard = jax.device_put(matrix, mesh_sharding(P(None, 'y')))\n", "jax.debug.visualize_array_sharding(i_jy_shard, use_color=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "EK_qrIQqIODG" }, "source": [ "For a 2D matrix being sharded along a 2D device mesh, here are all the possible sharding strategies" ] }, { "cell_type": "markdown", "metadata": { "id": "Sr3T40pqCdwX" }, "source": [ "![Possible Array Shardings](https://jax-ml.github.io/scaling-book/assets/img/sharding-colored5.png)" ] }, { "cell_type": "markdown", "metadata": { "id": "rJy0WY50CdwY" }, "source": [ " Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) " ] }, { "cell_type": "markdown", "metadata": { "id": "UKtWvBtju3sy" }, "source": [ "### 2.5 Choosing the Right Sharding Strategy\n", "\n", "As we discuss collectives (next section) and different parallelism strategies in (next tutorials), we will slowly do a deeper dive into how to chose sharding strategies based on your use case." ] }, { "cell_type": "markdown", "metadata": { "id": "wvNt9kOLu3sy" }, "source": [ "## 3. Collective Operations\n", "\n", "Collective operations are essential for distributed algorithms where devices need to share or aggregate information. In deep learning, these are utilized in performing computations with sharded arrays." ] }, { "cell_type": "markdown", "metadata": { "id": "eYxP1wOzEGu1" }, "source": [ "### 3.1 Matrix Operations With Sharded Arrays\n", "\n", "If we want to perform matrix operations on sharded arrays, we need to think through some overheads involved in moving data between devices.\n", "\n", "In deep learning, these two operations are often used:\n", "\n", "- **Element-wise operations (e.g. ReLU)**: Operations that can be performed independently on each element of the array. These operations can be performed in parallel across aray shards without needing to communicate between them.\n", "- **Matrix-multiplication (e.g. Linear Layer, Attention etc)**: A more complex operation that requires communication between devices to compute the result. This is where collective operations come into play." ] }, { "cell_type": "markdown", "metadata": { "id": "Q4FlzP66EJLF" }, "source": [ "#### 3.1.1 Block Matrix Multiplication\n", "\n", "We can think of a matrix as being composed of smaller blocks, which can be processed independently.\n", "\n", "\\begin{equation} \\begin{pmatrix} a_{00} & a_{01} & a_{02} & a_{03} \\\\ a_{10} & a_{11} & a_{12} & a_{13} \\\\ a_{20} & a_{21} & a_{22} & a_{23} \\\\ a_{30} & a_{31} & a_{32} & a_{33} \\end{pmatrix} = \\left( \\begin{matrix} \\begin{bmatrix} a_{00} & a_{01} \\\\ a_{10} & a_{11} \\end{bmatrix} \\\\ \\begin{bmatrix} a_{20} & a_{21} \\\\ a_{30} & a_{31} \\end{bmatrix} \\end{matrix} \\begin{matrix} \\begin{bmatrix} a_{02} & a_{03} \\\\ a_{12} & a_{13} \\end{bmatrix} \\\\ \\begin{bmatrix} a_{22} & a_{23} \\\\ a_{32} & a_{33} \\end{bmatrix} \\end{matrix} \\right) = \\begin{pmatrix} \\mathbf{A_{00}} & \\mathbf{A_{01}} \\\\ \\mathbf{A_{10}} & \\mathbf{A_{11}} \\end{pmatrix} \\end{equation}\n", "\n", "Matrix multiplication carries this nice property that the product of two matrices can be written in terms of block matmuls.\n", "\n", "\\begin{equation} \\begin{pmatrix} A_{00} & A_{01} \\\\ A_{10} & A_{11} \\end{pmatrix} \\cdot \\begin{pmatrix} B_{00} & B_{01} \\\\ B_{10} & B_{11} \\end{pmatrix} = \\begin{pmatrix} A_{00}B_{00} + A_{01}B_{10} & A_{00}B_{01} + A_{01}B_{11} \\\\ A_{10}B_{00} + A_{11}B_{10} & A_{10}B_{01} + A_{11}B_{11} \\end{pmatrix} \\end{equation}\n", "\n", "So we can compute distributed matrix multiplication by computing the block matmuls in parallel. The question is what communication is required to compute the final result, when to do it, and how expensive it is to perform." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "weLAPiSB16Gv" }, "outputs": [], "source": [ "# Create a 2x2 mesh with 4 devices to visualize matmul examples\n", "mesh_2x2 = jax.make_mesh((2, 2), ('x', 'y'))" ] }, { "cell_type": "markdown", "metadata": { "id": "42gwCTSPERfY" }, "source": [ "### 3.2 Case 1: No Sharded Contracting Dimension" ] }, { "cell_type": "markdown", "metadata": { "id": "gH7ecu0pR06m" }, "source": [ "Consider the matrix multiplication of two sharded matrices $\\mathbf{A}[I_X, J]$ and $\\mathbf{B}[J, K_Y]$. Note that the contracting dimension $J$ is not sharded. Thus we have:" ] }, { "cell_type": "markdown", "metadata": { "id": "yYl_aup2R06m" }, "source": [ "$$ \\mathbf{A}[I_X, J] \\cdot \\mathbf{B}[J, K_Y] \\rightarrow \\mathbf{C}[I_X, K_Y] $$" ] }, { "cell_type": "markdown", "metadata": { "id": "EdzfIGC-1Nal" }, "source": [ "#### Example: Case 1 - No Sharded Contracting Dimension\n", "\n", "Let's demonstrate this with 4x4 matrices on a 2x2 mesh where the contracting dimension is not sharded:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 737 }, "id": "gjX2eend1Nal", "outputId": "e6390eb9-4a80-453d-dd3c-12f4d20a1c1e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Matrix A sharding (rows split along X):\n" ] }, { "data": { "text/html": [ "
                         \n",
              "                         \n",
              "         TPU 0,1         \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "         TPU 2,3         \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0,1\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 2,3\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Matrix B sharding (columns split along Y):\n" ] }, { "data": { "text/html": [ "
                        \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "  TPU 0,2     TPU 1,3   \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0,2\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 1,3\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Result C sharding (both dimensions sharded):\n" ] }, { "data": { "text/html": [ "
                        \n",
              "                        \n",
              "   TPU 0       TPU 1    \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "   TPU 2       TPU 3    \n",
              "                        \n",
              "                        \n",
              "                        \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214mTPU 1\u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;222;158;214m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74mTPU 2\u001b[0m\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107mTPU 3\u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;173;73;74m \u001b[0m\u001b[38;2;0;0;0;48;2;181;207;107m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Correct result: True\n", "C shape: (4, 4)\n" ] } ], "source": [ "# Case 1: No sharded contracting dimension\n", "# A[I_X, J] @ B[J, K_Y] -> C[I_X, K_Y]\n", "\n", "# Create two small 4x4 matrices\n", "key = jax.random.PRNGKey(42)\n", "A = jax.random.uniform(key, (4, 4))\n", "B = jax.random.uniform(key, (4, 4)) + 1\n", "\n", "# Shard A along first dimension (rows) on X axis, B along second dimension (columns) on Y axis\n", "A_sharded = jax.device_put(A, NamedSharding(mesh_2x2, P('x', None)))\n", "B_sharded = jax.device_put(B, NamedSharding(mesh_2x2, P(None, 'y')))\n", "\n", "print(\"Matrix A sharding (rows split along X):\")\n", "jax.debug.visualize_array_sharding(A_sharded)\n", "print(\"\\nMatrix B sharding (columns split along Y):\")\n", "jax.debug.visualize_array_sharding(B_sharded)\n", "\n", "# Direct multiplication works without any collective operations\n", "C = A_sharded @ B_sharded\n", "print(\"\\nResult C sharding (both dimensions sharded):\")\n", "jax.debug.visualize_array_sharding(C)\n", "\n", "# Verify the result\n", "C_expected = A @ B\n", "print(f\"\\nCorrect result: {jnp.allclose(C, C_expected)}\")\n", "print(f\"C shape: {C.shape}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "XwrnJc3OEMGS" }, "source": [ "We can multiply each local shard without any communication between devices. Each device computes its local shard of the result matrix $\\mathbf{C}$ independently. Each of the following possible sharded matrix multiplications can be performed without any communication:\n", "\n", "\\begin{align*} \\mathbf{A}[I, J] \\cdot \\mathbf{B}[J, K] \\rightarrow &\\ \\mathbf{C}[I, K] \\\\ \\mathbf{A}[I_X, J] \\cdot \\mathbf{B}[J, K] \\rightarrow &\\ \\mathbf{C}[I_X, K]\\\\ \\mathbf{A}[I, J] \\cdot \\mathbf{B}[J, K_Y] \\rightarrow &\\ \\mathbf{C}[I, K_Y]\\\\ \\mathbf{A}[I_X, J] \\cdot \\mathbf{B}[J, K_Y] \\rightarrow &\\ \\mathbf{C}[I_X, K_Y] \\end{align*}" ] }, { "cell_type": "markdown", "metadata": { "id": "tbfYuMEau3sy" }, "source": [ "### 3.3 Case 2 (All-Gather): One matrix has a sharded contracting dimension\n", "\n", "Consider the case of a distributed matrix multiplication where one of the matrices has a sharded contracting dimension. For example,\n", "\n", "$$ \\mathbf{A}[I, J_X] \\cdot \\mathbf{B}[J, K] \\rightarrow \\mathbf{C}[I, K] $$" ] }, { "cell_type": "markdown", "metadata": { "id": "AvopF6cTR06t" }, "source": [ "Now, we cannot directly multiply the local shards of $\\mathbf{A}$ and $\\mathbf{B}$ without communication. Each device needs to gather the shards of $\\mathbf{B}$ across all devices to compute its local shard of $\\mathbf{C}$. This is done using an all-gather operation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 755 }, "id": "SwffjJwW1Naf", "outputId": "39cf075c-b15a-4186-eda6-e76067f17830" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Matrix A sharding (columns split along X - contracting dimension):\n" ] }, { "data": { "text/html": [ "
┌──────────┬──────────┐\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│ TPU 0,1  │ TPU 2,3  │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "└──────────┴──────────┘\n",
              "
\n" ], "text/plain": [ "┌──────────┬──────────┐\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "└──────────┴──────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Matrix B (replicated on all devices):\n" ] }, { "data": { "text/html": [ "
┌───────────────────────┐\n",
              "│                       │\n",
              "│                       │\n",
              "│                       │\n",
              "│                       │\n",
              "│         TPU 0         │\n",
              "│                       │\n",
              "│                       │\n",
              "│                       │\n",
              "│                       │\n",
              "└───────────────────────┘\n",
              "
\n" ], "text/plain": [ "┌───────────────────────┐\n", "│ │\n", "│ │\n", "│ │\n", "│ │\n", "│ TPU \u001b[1;36m0\u001b[0m │\n", "│ │\n", "│ │\n", "│ │\n", "│ │\n", "└───────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "What each device sees locally:\n", "\n", "Device 0 has A columns slice(0, 2, None):\n", "[[ 0. 1.]\n", " [ 4. 5.]\n", " [ 8. 9.]\n", " [12. 13.]]\n", "But needs ALL columns of A to multiply with B!\n", "\n", "Device 2 has A columns slice(2, 4, None):\n", "[[ 2. 3.]\n", " [ 6. 7.]\n", " [10. 11.]\n", " [14. 15.]]\n", "But needs ALL columns of A to multiply with B!\n", "\n", "⚠️ Cannot multiply directly - each device only has partial columns of A\n" ] } ], "source": [ "# Case 2: Show why direct multiplication fails\n", "# A[I, J_X] @ B[J, K] - contracting dimension J is sharded in A\n", "\n", "# Create matrices\n", "A = jnp.arange(16).reshape(4, 4).astype(jnp.float32)\n", "B = jnp.arange(16, 32).reshape(4, 4).astype(jnp.float32)\n", "\n", "# Shard A's columns (contracting dimension) along X axis\n", "A_sharded = jax.device_put(A, NamedSharding(mesh_2x2, P(None, 'x')))\n", "B_unsharded = B # B is replicated\n", "\n", "print(\"Matrix A sharding (columns split along X - contracting dimension):\")\n", "jax.debug.visualize_array_sharding(A_sharded, use_color=False)\n", "print(\"\\nMatrix B (replicated on all devices):\")\n", "jax.debug.visualize_array_sharding(B_unsharded, use_color=False)\n", "\n", "# Show what each device sees locally\n", "print(\"\\nWhat each device sees locally:\")\n", "for shard in A_sharded.addressable_shards[:3:2]: # Show each group of devices\n", " print(f\"\\nDevice {shard.device.id} has A columns {shard.index[1]}:\")\n", " print(shard.data)\n", " print(\"But needs ALL columns of A to multiply with B!\")\n", "\n", "# This would fail or give incorrect results if done locally\n", "print(\"\\n⚠️ Cannot multiply directly - each device only has partial columns of A\")" ] }, { "cell_type": "markdown", "metadata": { "id": "7iMCiPVgR06t" }, "source": [ "![All Gather](https://github.com/sshkhr/MinText/blob/main/docs/_static/all-gather-vis.png?raw=1)\n", "\n", "All Gather is a collective operation that gathers data from all devices and distributes it to all devices. In this case, each device gathers the shards of an array across all devices and gets a fully replicated array.\n", "\n", "Here is a simple example of how to perform an All-Gather operation in JAX:\n", "\n", "![All Gather](https://github.com/sshkhr/MinText/blob/main/docs/_static/all-gather.png?raw=1)" ] }, { "cell_type": "markdown", "metadata": { "id": "JZJTjdf9R06t" }, "source": [ " Image Source: [JAX documentation](https://docs.jax.dev/en/latest/notebooks/shard_map.html#all-gather) " ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "jLYZOeZimxZQ" }, "outputs": [], "source": [ "mesh1d = Mesh(jax.devices()[:4], ('i',))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SGdkWPXPR06t", "outputId": "070dfd1c-6e65-40dc-849b-8bb20dcb9c30" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BEFORE: On TPU_0(process=0,(0,0,0,0)) at mesh coordinates (i,) = (0,):\n", "[3]\n", "\n", "On TPU_1(process=0,(0,0,0,1)) at mesh coordinates (i,) = (1,):\n", "[9]\n", "\n", "On TPU_2(process=0,(1,0,0,0)) at mesh coordinates (i,) = (2,):\n", "[5]\n", "\n", "On TPU_3(process=0,(1,0,0,1)) at mesh coordinates (i,) = (3,):\n", "[2]\n", "\n", "AFTER: On TPU_0(process=0,(0,0,0,0)) at mesh coordinates (i,) = (0,):\n", "[3 9 5 2]\n", "\n", "On TPU_1(process=0,(0,0,0,1)) at mesh coordinates (i,) = (1,):\n", "[3 9 5 2]\n", "\n", "On TPU_2(process=0,(1,0,0,0)) at mesh coordinates (i,) = (2,):\n", "[3 9 5 2]\n", "\n", "On TPU_3(process=0,(1,0,0,1)) at mesh coordinates (i,) = (3,):\n", "[3 9 5 2]\n", "\n", "FINAL RESULT: [3 9 5 2 3 9 5 2 3 9 5 2 3 9 5 2]\n" ] } ], "source": [ "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def all_gather(x_block):\n", " print('BEFORE:', x_block)\n", " y_block = jax.lax.all_gather(x_block, 'i', tiled=True)\n", " print('AFTER:', y_block)\n", " return y_block\n", "\n", "x = jnp.array([3, 9, 5, 2])\n", "y = all_gather(x)\n", "print('FINAL RESULT:', y)" ] }, { "cell_type": "markdown", "metadata": { "id": "ujqRqxoVR06t" }, "source": [ "To perform a matrix multiplication using the AllGather operation, we can follow these steps:\n", "\n", "1. **AllGather** the first matrix across all devices.\n", "$$\\textbf{AllGather}_X[I, J_X] \\rightarrow \\mathbf{A}[I, J]$$\n", "\n", "2. **Multiply** the gathered matrix with the second matrix.\n", "$$\\mathbf{A}[I, J] \\cdot \\mathbf{B}[J, K] \\rightarrow \\mathbf{C}[I, K]$$" ] }, { "cell_type": "markdown", "metadata": { "id": "j5-VoDz21Nal" }, "source": [ "#### Example: Case 2 - One Matrix has Sharded Contracting Dimension" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 395 }, "id": "L8uumJii1Nal", "outputId": "744b7da6-c964-4a80-dafa-1bc33319b9d3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Before all-gather, A_shard shape: (4, 2)\n", "After all-gather, A_full shape: (4, 4)\n", "\n", "Result C (replicated on all devices):\n" ] }, { "data": { "text/html": [ "
                         \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "       TPU 0,1,2,3       \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0,1,2,3\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Correct result: True\n", "Result shape: (4, 4)\n", "\n", "First few elements of result:\n", "[[152. 158.]\n", " [504. 526.]]\n" ] } ], "source": [ "# Case 2: Solution using All-Gather\n", "# Step 1: All-gather A to reconstruct full matrix on each device\n", "# Step 2: Multiply with B\n", "\n", "@partial(shard_map, mesh=mesh_2x2,\n", " in_specs=(P(None, 'X'), P(None, None)),\n", " out_specs=P(None, None),\n", " check_rep=False)\n", "def matmul_with_allgather(A_shard, B_shard):\n", " print(f\"Before all-gather, A_shard shape: {A_shard.shape}\")\n", "\n", " # All-gather along the X axis to get full A matrix\n", " A_full = jax.lax.all_gather(A_shard, 'X', axis=1, tiled=True)\n", " print(f\"After all-gather, A_full shape: {A_full.shape}\")\n", "\n", " # Now we can multiply\n", " C = A_full @ B_shard\n", " return C\n", "\n", "# Execute the multiplication with all-gather\n", "try:\n", " C = matmul_with_allgather(A_sharded, B_unsharded)\n", "except Exception as e:\n", " print(\"An error occurred:\", e)\n", "\n", "print(\"\\nResult C (replicated on all devices):\")\n", "jax.debug.visualize_array_sharding(C)\n", "\n", "# Verify correctness\n", "C_expected = A @ B\n", "print(f\"\\nCorrect result: {jnp.allclose(C, C_expected)}\")\n", "print(f\"Result shape: {C.shape}\")\n", "print(f\"\\nFirst few elements of result:\\n{C[:2, :2]}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "C7CHmsnbR06u" }, "source": [ "#### 3.3.1 How is an all-gather performed?\n", "\n", "![All Gather Operation](https://jax-ml.github.io/scaling-book/assets/img/all-gather.gif)" ] }, { "cell_type": "markdown", "metadata": { "id": "P7oUFy2ZR06u" }, "source": [ " Image Source: [How To Scale Your Model](https://jax-ml.github.io/scaling-book) " ] }, { "cell_type": "markdown", "metadata": { "id": "57IU22SOu3sy" }, "source": [ "### 3.4 Case 3 (All-Reduce): Both Matrices have sharded contracting dimensions\n", "\n", "Consider the case where both matrices to be multiplied are sharded on their contracting dimensions, along the same mesh axes.\n", "\n", "$$\\textbf{A}[I, J_X] \\cdot \\textbf{B}[J_X, K] \\rightarrow C[I, K]$$" ] }, { "cell_type": "markdown", "metadata": { "id": "50txPaeFR06u" }, "source": [ "In this case, we can multiply the local shards of the matrices, however each shard will only contain a partial result.\n", "\n", "$$\\textbf{A}[I, J_X] \\cdot_\\text{LOCAL} \\textbf{B}[J_X, K] \\rightarrow C[I, K] \\{\\ U_X \\}$$\n", "\n", "The notation $\\{\\ U_X \\}$ here refers to the fact that the matrix $C$ is unreduced along the mesh axis $X$." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 559 }, "id": "dlwNc81q1Naf", "outputId": "e6125b9a-1300-4d8b-ffe8-3d2861c331d4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Matrix A sharding (columns split along X):\n" ] }, { "data": { "text/html": [ "
┌──────────┬──────────┐\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│ TPU 0,1  │ TPU 2,3  │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "└──────────┴──────────┘\n",
              "
\n" ], "text/plain": [ "┌──────────┬──────────┐\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "└──────────┴──────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Matrix B sharding (rows split along X):\n" ] }, { "data": { "text/html": [ "
┌───────────────────────┐\n",
              "│                       │\n",
              "│        TPU 0,1        │\n",
              "│                       │\n",
              "│                       │\n",
              "├───────────────────────┤\n",
              "│                       │\n",
              "│        TPU 2,3        │\n",
              "│                       │\n",
              "│                       │\n",
              "└───────────────────────┘\n",
              "
\n" ], "text/plain": [ "┌───────────────────────┐\n", "│ │\n", "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │\n", "│ │\n", "│ │\n", "├───────────────────────┤\n", "│ │\n", "│ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", "│ │\n", "│ │\n", "└───────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Local multiplication on each device:\n", "Device 0: A[:, 0:2] @ B[0:2, :] = partial result\n", "Device 1: A[:, 2:4] @ B[2:4, :] = partial result\n", "\n", "Each device computes only PART of the final result!\n", "Need to SUM all partial results to get the correct answer\n" ] } ], "source": [ "# Case 3: Show why we need all-reduce\n", "# A[I, J_X] @ B[J_X, K] - contracting dimension J is sharded in both\n", "\n", "# Create matrices\n", "A = jnp.arange(16).reshape(4, 4).astype(jnp.float32)\n", "B = jnp.arange(16, 32).reshape(4, 4).astype(jnp.float32)\n", "\n", "# Shard contracting dimensions along X axis\n", "A_sharded = jax.device_put(A, NamedSharding(mesh_2x2, P(None, 'x')))\n", "B_sharded = jax.device_put(B, NamedSharding(mesh_2x2, P('x', None)))\n", "\n", "print(\"Matrix A sharding (columns split along X):\")\n", "jax.debug.visualize_array_sharding(A_sharded, use_color=False)\n", "print(\"\\nMatrix B sharding (rows split along X):\")\n", "jax.debug.visualize_array_sharding(B_sharded, use_color=False)\n", "\n", "# Show what happens with local multiplication\n", "print(\"\\nLocal multiplication on each device:\")\n", "print(\"Device 0: A[:, 0:2] @ B[0:2, :] = partial result\")\n", "print(\"Device 1: A[:, 2:4] @ B[2:4, :] = partial result\")\n", "print(\"\\nEach device computes only PART of the final result!\")\n", "print(\"Need to SUM all partial results to get the correct answer\")" ] }, { "cell_type": "markdown", "metadata": { "id": "AfD6WMW1R06u" }, "source": [ "![All-Reduce](https://github.com/sshkhr/MinText/blob/main/docs/_static/all-reduce-vis.png?raw=1)\n", "\n", "To obtain the final result, we need to perform an All-Reduce operation on the shards of $C$. Since the all-reduce sum operation is very common, jax provides the `jax.lax.psum` function to perform this operation efficiently.\n", "\n", "![All-Reduce in JAX](https://github.com/sshkhr/MinText/blob/main/docs/_static/psum.png?raw=1)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "q3AHaoIlu3sy", "outputId": "15b686fc-fce9-4014-8ff9-de120ce80771" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BEFORE:\n", " On TPU_0(process=0,(0,0,0,0)) at mesh coordinates (i,) = (0,):\n", "[3 1 4 1]\n", "\n", "On TPU_1(process=0,(0,0,0,1)) at mesh coordinates (i,) = (1,):\n", "[5 9 2 6]\n", "\n", "On TPU_2(process=0,(1,0,0,0)) at mesh coordinates (i,) = (2,):\n", "[5 3 5 8]\n", "\n", "On TPU_3(process=0,(1,0,0,1)) at mesh coordinates (i,) = (3,):\n", "[9 7 1 2]\n", "\n", "AFTER:\n", " On TPU_0(process=0,(0,0,0,0)) at mesh coordinates (i,) = (0,):\n", "[22 20 12 17]\n", "\n", "On TPU_1(process=0,(0,0,0,1)) at mesh coordinates (i,) = (1,):\n", "[22 20 12 17]\n", "\n", "On TPU_2(process=0,(1,0,0,0)) at mesh coordinates (i,) = (2,):\n", "[22 20 12 17]\n", "\n", "On TPU_3(process=0,(1,0,0,1)) at mesh coordinates (i,) = (3,):\n", "[22 20 12 17]\n", "\n", "FINAL RESULT:\n", " [22 20 12 17]\n" ] } ], "source": [ "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P(None))\n", "def psum(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum(x_block, 'i')\n", " print('AFTER:\\n', y_block)\n", " return y_block\n", "\n", "x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])\n", "y = psum(x)\n", "print('FINAL RESULT:\\n', y)" ] }, { "cell_type": "markdown", "metadata": { "id": "Nx1pyIcT1Nam" }, "source": [ "#### Example: Case 3 - Both Matrices have Sharded Contracting Dimensions\n", "\n", "When both matrices have their contracting dimensions sharded along the same axis, we can multiply locally but need to sum the partial results:" ] }, { "cell_type": "markdown", "metadata": { "id": "4CT76fByR06u" }, "source": [ "In order to perform the matrix multiplication $C = A \\cdot B$ using the AllReduce operation, we can break down the process into two main steps.\n", "\n", "1. **Local Matrix Multiplication** of input matrix shards on each device.\n", "$$A[I, J_X] \\cdot_\\text{LOCAL} B[J_X, K] \\rightarrow C[I, K] \\{ U_X \\}$$\n", "\n", "2. **AllReduce** the partial results across all devices.\n", "$$\\textbf{AllReduce}_X C[I, K] \\{ U_X \\} \\rightarrow C[I, K]$$" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 466 }, "id": "Bv8FDz6j1Nam", "outputId": "ee86396b-c9e9-44ab-d8d7-5484f15cca7d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Local shard shapes: A=(4, 2), B=(2, 4)\n", "Partial result shape: (4, 4)\n", "After all-reduce shape: (4, 4)\n", "\n", "Result C (replicated on all devices after all-reduce):\n" ] }, { "data": { "text/html": [ "
                         \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "       TPU 0,1,2,3       \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "                         \n",
              "
\n" ], "text/plain": [ "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121mTPU 0,1,2,3\u001b[0m\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n", "\u001b[38;2;255;255;255;48;2;57;59;121m \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Correct result: True\n", "Result shape: (4, 4)\n", "\n", "Computation breakdown:\n", "A[:, 0:2] @ B[0:2, :] + A[:, 2:4] @ B[2:4, :] = Full result\n", "\n", "First few elements of result:\n", "[[152. 158.]\n", " [504. 526.]]\n" ] } ], "source": [ "# Case 3: Solution using All-Reduce (psum)\n", "# Step 1: Local matrix multiplication\n", "# Step 2: All-reduce (sum) the partial results\n", "\n", "@partial(shard_map, mesh=mesh_2x2,\n", " in_specs=(P(None, 'x'), P('x', None)),\n", " out_specs=P(None, None))\n", "def matmul_with_allreduce(A_shard, B_shard):\n", " print(f\"Local shard shapes: A={A_shard.shape}, B={B_shard.shape}\")\n", "\n", " # Step 1: Local multiplication (each device computes partial result)\n", " C_partial = A_shard @ B_shard\n", " print(f\"Partial result shape: {C_partial.shape}\")\n", "\n", " # Step 2: All-reduce (sum) across X axis\n", " C_full = jax.lax.psum(C_partial, 'x')\n", " print(f\"After all-reduce shape: {C_full.shape}\")\n", "\n", " return C_full\n", "\n", "# Execute the multiplication with all-reduce\n", "C = matmul_with_allreduce(A_sharded, B_sharded)\n", "\n", "print(\"\\nResult C (replicated on all devices after all-reduce):\")\n", "jax.debug.visualize_array_sharding(C)\n", "\n", "# Verify correctness\n", "C_expected = A @ B\n", "print(f\"\\nCorrect result: {jnp.allclose(C, C_expected)}\")\n", "print(f\"Result shape: {C.shape}\")\n", "\n", "# Show the computation breakdown\n", "print(\"\\nComputation breakdown:\")\n", "print(f\"A[:, 0:2] @ B[0:2, :] + A[:, 2:4] @ B[2:4, :] = Full result\")\n", "print(f\"\\nFirst few elements of result:\\n{C[:2, :2]}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "x04ltG32R06u" }, "source": [ "#### 3.4.1 All-Reduce as ReduceScatter + AllGather" ] }, { "cell_type": "markdown", "metadata": { "id": "BNYZsPPqnJDq" }, "source": [ "![AllReduce breakdown](https://engineering.fb.com/wp-content/uploads/2021/07/FSDP-graph-2a.png)" ] }, { "cell_type": "markdown", "metadata": { "id": "cMxpvlhnnMYl" }, "source": [ "We can express AllReduce as two different collectives, a Reduce Scatter followed by AllGather. A Reduce Scatter operation is visualized below" ] }, { "cell_type": "markdown", "metadata": { "id": "-X9v41O2R06u" }, "source": [ "![Reduce Scatter in JAX](https://github.com/sshkhr/MinText/blob/main/docs/_static/psum_scatter.png?raw=1)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UfKDbqXFnZWF", "outputId": "778d8c75-5670-4a75-e282-e6ab1819d199" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BEFORE:\n", " On TPU_0(process=0,(0,0,0,0)) at mesh coordinates (i,) = (0,):\n", "[3 1 4 1]\n", "\n", "On TPU_1(process=0,(0,0,0,1)) at mesh coordinates (i,) = (1,):\n", "[5 9 2 6]\n", "\n", "On TPU_2(process=0,(1,0,0,0)) at mesh coordinates (i,) = (2,):\n", "[5 3 5 8]\n", "\n", "On TPU_3(process=0,(1,0,0,1)) at mesh coordinates (i,) = (3,):\n", "[9 7 1 2]\n", "\n", "AFTER:\n", " On TPU_0(process=0,(0,0,0,0)) at mesh coordinates (i,) = (0,):\n", "[22]\n", "\n", "On TPU_1(process=0,(0,0,0,1)) at mesh coordinates (i,) = (1,):\n", "[20]\n", "\n", "On TPU_2(process=0,(1,0,0,0)) at mesh coordinates (i,) = (2,):\n", "[12]\n", "\n", "On TPU_3(process=0,(1,0,0,1)) at mesh coordinates (i,) = (3,):\n", "[17]\n", "\n", "FINAL RESULT:\n", " [22 20 12 17]\n" ] } ], "source": [ "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def scatter_gather_sum(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.psum_scatter(x_block, 'i', tiled=True)\n", " print('AFTER:\\n', y_block)\n", " return y_block\n", "\n", "x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])\n", "y = scatter_gather_sum(x)\n", "print('FINAL RESULT:\\n', y)" ] }, { "cell_type": "markdown", "metadata": { "id": "4zGyIDphR06u" }, "source": [ "### 3.5 Case 4 (All-Gather): Both Matrices have non-contracting dimensions sharded along the same mesh axes\n", "\n", "**Whenever we shard a tensor, each mesh dimension can appear AT MOST ONCE.** Consider the case where both matrices to be multiplied are sharded on their non-contracting dimensions, along the same mesh axes.\n", "\n", "$$\\textbf{A}[I_X, J] \\cdot \\textbf{B}[J, K_X] \\rightarrow C[I_X, K_X]$$\n", "\n", "Such a sharding is **not allowed**, as there is not enough information along each shards to reconstruct the full matrix. In this case, we need to change the sharding of at least one of the matrices before multiplication." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 559 }, "id": "myzARm_N1Naf", "outputId": "91ef3ed5-266b-4293-ca42-1f7f6e5219c2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Matrix A sharding (rows split along X):\n" ] }, { "data": { "text/html": [ "
┌───────────────────────┐\n",
              "│                       │\n",
              "│        TPU 0,1        │\n",
              "│                       │\n",
              "│                       │\n",
              "├───────────────────────┤\n",
              "│                       │\n",
              "│        TPU 2,3        │\n",
              "│                       │\n",
              "│                       │\n",
              "└───────────────────────┘\n",
              "
\n" ], "text/plain": [ "┌───────────────────────┐\n", "│ │\n", "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │\n", "│ │\n", "│ │\n", "├───────────────────────┤\n", "│ │\n", "│ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", "│ │\n", "│ │\n", "└───────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Matrix B sharding (columns split along X):\n" ] }, { "data": { "text/html": [ "
┌──────────┬──────────┐\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│ TPU 0,1  │ TPU 2,3  │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "└──────────┴──────────┘\n",
              "
\n" ], "text/plain": [ "┌──────────┬──────────┐\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "└──────────┴──────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "⚠️ Problem: Each device needs data from OTHER devices!\n", "Device 0 has: A[0:2, :] and B[:, 0:2]\n", "But to compute C[0:2, 0:2], it needs ALL of A[0:2, :] @ ALL of B!\n", "Similarly, to compute C[0:2, 2:4], it needs A[0:2, :] @ B[:, 2:4]\n", "\n", "Cannot compute the result without resharding!\n" ] } ], "source": [ "# Case 4: Show the problem with both non-contracting dimensions sharded\n", "# A[I_X, J] @ B[J, K_X] - trying to get C[I_X, K_X] is problematic\n", "\n", "# Create matrices\n", "A = jnp.arange(16).reshape(4, 4).astype(jnp.float32)\n", "B = jnp.arange(16, 32).reshape(4, 4).astype(jnp.float32)\n", "\n", "# Shard non-contracting dimensions on same axis X\n", "A_sharded = jax.device_put(A, NamedSharding(mesh_2x2, P('x', None)))\n", "B_sharded = jax.device_put(B, NamedSharding(mesh_2x2, P(None, 'x')))\n", "\n", "print(\"Matrix A sharding (rows split along X):\")\n", "jax.debug.visualize_array_sharding(A_sharded, use_color=False)\n", "print(\"\\nMatrix B sharding (columns split along X):\")\n", "jax.debug.visualize_array_sharding(B_sharded, use_color=False)\n", "\n", "print(\"\\n⚠️ Problem: Each device needs data from OTHER devices!\")\n", "print(\"Device 0 has: A[0:2, :] and B[:, 0:2]\")\n", "print(\"But to compute C[0:2, 0:2], it needs ALL of A[0:2, :] @ ALL of B!\")\n", "print(\"Similarly, to compute C[0:2, 2:4], it needs A[0:2, :] @ B[:, 2:4]\")\n", "print(\"\\nCannot compute the result without resharding!\")" ] }, { "cell_type": "markdown", "metadata": { "id": "wwik9QxP1ZBF" }, "source": [ "We have two options:\n", "\n", "1. **All-Gather the sharding dimension of matrix A** to have the non-contracting dimension unsharded.\n", "$$\\begin{align*} \\textbf{AllGather}_X A[I_X, J] \\rightarrow &\\ A[I, J] \\\\ A[I, J] \\cdot B[J, K_X] \\rightarrow &\\ C[I, K_X] \\end{align*}$$\n", "2. **All-Gather the sharding dimension of matrix B** to have the non-contracting dimension unsharded.\n", "$$\\begin{align*} \\textbf{AllGather}_X B[J, K_X] \\rightarrow &\\ B[J, K] \\\\ A[I_X, J] \\cdot B[J, K] \\rightarrow &\\ C[I_X, K] \\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "id": "VZ_6NCwQ1Nam" }, "source": [ "#### Example: Both Matrices have non-contracting dimensions sharded along the same mesh axes" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 594 }, "id": "FEDkHT5O1Nam", "outputId": "b50ff55c-d902-4ac1-8e1c-fbb25f3fbea0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Before all-gather: A=(2, 4), B=(4, 2)\n", "After all-gather B: B_full=(4, 4)\n", "Result shape: (2, 4)\n", "\n", "Result C sharding (rows split along X):\n" ] }, { "data": { "text/html": [ "
┌───────────────────────┐\n",
              "│                       │\n",
              "│        TPU 0,1        │\n",
              "│                       │\n",
              "│                       │\n",
              "├───────────────────────┤\n",
              "│                       │\n",
              "│        TPU 2,3        │\n",
              "│                       │\n",
              "│                       │\n",
              "└───────────────────────┘\n",
              "
\n" ], "text/plain": [ "┌───────────────────────┐\n", "│ │\n", "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │\n", "│ │\n", "│ │\n", "├───────────────────────┤\n", "│ │\n", "│ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", "│ │\n", "│ │\n", "└───────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Correct result: True\n", "Result shape: (4, 4)\n", "\n", "--- Alternative: All-gather A instead of B ---\n", "Result with A gathered (columns split along X):\n" ] }, { "data": { "text/html": [ "
┌──────────┬──────────┐\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│ TPU 0,1  │ TPU 2,3  │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "│          │          │\n",
              "└──────────┴──────────┘\n",
              "
\n" ], "text/plain": [ "┌──────────┬──────────┐\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "│ │ │\n", "└──────────┴──────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Also correct: True\n" ] } ], "source": [ "# Case 4: Solution - All-gather one of the matrices\n", "# Option 1: All-gather B to remove column sharding\n", "\n", "@partial(shard_map, mesh=mesh_2x2,\n", " in_specs=(P('x', None), P(None, 'x')),\n", " out_specs=P('x', None))\n", "def matmul_case4_allgather_B(A_shard, B_shard):\n", " print(f\"Before all-gather: A={A_shard.shape}, B={B_shard.shape}\")\n", "\n", " # All-gather B along X axis to get full B on each device\n", " B_full = jax.lax.all_gather(B_shard, 'x', axis=1, tiled=True)\n", " print(f\"After all-gather B: B_full={B_full.shape}\")\n", "\n", " # Now multiply: each device computes its rows of C\n", " C_shard = A_shard @ B_full\n", " print(f\"Result shape: {C_shard.shape}\")\n", "\n", " return C_shard\n", "\n", "# Execute the multiplication\n", "C = matmul_case4_allgather_B(A_sharded, B_sharded)\n", "\n", "print(\"\\nResult C sharding (rows split along X):\")\n", "jax.debug.visualize_array_sharding(C, use_color=False)\n", "\n", "# Verify correctness\n", "C_expected = A @ B\n", "print(f\"\\nCorrect result: {jnp.allclose(C, C_expected)}\")\n", "print(f\"Result shape: {C.shape}\")\n", "\n", "# Alternative: All-gather A instead\n", "print(\"\\n--- Alternative: All-gather A instead of B ---\")\n", "\n", "@partial(shard_map, mesh=mesh_2x2,\n", " in_specs=(P('x', None), P(None, 'x')),\n", " out_specs=P(None, 'x'))\n", "def matmul_case4_allgather_A(A_shard, B_shard):\n", " # All-gather A along X axis\n", " A_full = jax.lax.all_gather(A_shard, 'x', axis=0, tiled=True)\n", " # Multiply to get column-sharded result\n", " C_shard = A_full @ B_shard\n", " return C_shard\n", "\n", "C_alt = matmul_case4_allgather_A(A_sharded, B_sharded)\n", "print(\"Result with A gathered (columns split along X):\")\n", "jax.debug.visualize_array_sharding(C_alt, use_color=False)\n", "print(f\"Also correct: {jnp.allclose(C_alt, C_expected)}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DF4DVnDiFH00" }, "source": [ "### 3.6 Other Collectives" ] }, { "cell_type": "markdown", "metadata": { "id": "SL5QXIWlu3sy" }, "source": [ "#### 3.6.2 All-to-All\n", "\n", "All-to-all exchanges slices of data between all devices. This is useful for operations like matrix transposition or redistributing data with a different sharding. It is not used in the three parallelism strategies we will discuss, but is used in other parallelism (mentioned in tutorial 4)." ] }, { "cell_type": "markdown", "metadata": { "id": "n2aiHk1lNAif" }, "source": [ "![Reduce Scatter in JAX](https://github.com/sshkhr/MinText/blob/main/docs/_static/all-to-all.png?raw=1)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wLFnWWJ2u3sy", "outputId": "4f000973-d28f-421a-992c-7775990d19d9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BEFORE:\n", " On TPU_0(process=0,(0,0,0,0)) at mesh coordinates (i,) = (0,):\n", "[3 1 4 1]\n", "\n", "On TPU_1(process=0,(0,0,0,1)) at mesh coordinates (i,) = (1,):\n", "[5 9 2 6]\n", "\n", "On TPU_2(process=0,(1,0,0,0)) at mesh coordinates (i,) = (2,):\n", "[5 3 5 8]\n", "\n", "On TPU_3(process=0,(1,0,0,1)) at mesh coordinates (i,) = (3,):\n", "[9 7 1 2]\n", "\n", "AFTER:\n", " On TPU_0(process=0,(0,0,0,0)) at mesh coordinates (i,) = (0,):\n", "[3 5 5 9]\n", "\n", "On TPU_1(process=0,(0,0,0,1)) at mesh coordinates (i,) = (1,):\n", "[1 9 3 7]\n", "\n", "On TPU_2(process=0,(1,0,0,0)) at mesh coordinates (i,) = (2,):\n", "[4 2 5 1]\n", "\n", "On TPU_3(process=0,(1,0,0,1)) at mesh coordinates (i,) = (3,):\n", "[1 6 8 2]\n", "\n", "FINAL RESULT:\n", " [3 5 5 9 1 9 3 7 4 2 5 1 1 6 8 2]\n" ] } ], "source": [ "@partial(shard_map, mesh=mesh1d, in_specs=P('i'), out_specs=P('i'))\n", "def all_to_all(x_block):\n", " print('BEFORE:\\n', x_block)\n", " y_block = jax.lax.all_to_all(x_block, 'i', split_axis=0, concat_axis=0,\n", " tiled=True)\n", " print('AFTER:\\n', y_block)\n", " return y_block\n", "\n", "x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5, 8, 9, 7, 1, 2])\n", "y = all_to_all(x)\n", "print('FINAL RESULT:\\n', y)" ] }, { "cell_type": "markdown", "metadata": { "id": "739BfzUPu3sy" }, "source": [ "## 4. Notes on JAX Sharding" ] }, { "cell_type": "markdown", "metadata": { "id": "LSWH770vUJUI" }, "source": [ "| Mode | View? | Explicit sharding? | Explicit Collectives? |\n", "|---|---|---|---|\n", "| Auto | Global | ❌ | ❌ |\n", "| Explicit | Global | ✅ | ❌ |\n", "| Manual | Per-device | ✅ | ✅ |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "JAX provides three levels of granularity for sharding and parallelism, and they are summarized in the table above. For the purpose of these tutorials, we will focus mostly on the **Auto** and **Explicit** mode, which allows us to at most specify the sharding of arrays and not have to worry about collectives. The examples covered are simple enough that manual sharding is not required, and the JAX compiler can take care of a lot of the optimizations for us. But as we will see in Tutorial 4, for things beyond Tensor Parallelism you might have to explicitly specify your network/data/gradient sharding and collectives." ] }, { "cell_type": "markdown", "metadata": { "id": "93CxXrqgu3sz" }, "source": [ "## 5. Conclusion\n", "\n", "In this tutorial, we've explored JAX's powerful capabilities for distributed computation using sharded matrices and collective operations. We've covered:\n", "\n", "1. **Setting up a device mesh** for organizing available devices\n", "2. **Creating and using sharded matrices** to distribute data across devices\n", "3. **Different sharding strategies** and when to use them\n", "4. **Collective operations** for efficient device-to-device communication\n", "\n", "JAX's sharding capabilities enable efficient scaling of numerical computations across multiple devices, making it a powerful tool for large-scale machine learning and scientific computing. We will cover these in the next few tutorials." ] }, { "cell_type": "markdown", "metadata": { "id": "9nidiY1eT29A" }, "source": [ "### 5.1 References\n", "\n", "- [JAX Documentation](https://jax.readthedocs.io/)\n", "- [JAX Docs: Distributed arrays and automatic parallelization](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)\n", "- [JAX Docs: Manual parallelism with `shard_map`](https://docs.jax.dev/en/latest/notebooks/shard_map.html)\n", "- [How to Scale Your Model: Sharded Matrices and How to Multiply Them](https://jax-ml.github.io/scaling-book/sharding/)" ] } ], "metadata": { "accelerator": "TPU", "colab": { "gpuType": "V28", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 0 }