MinText
  • Tutorial 1: Parallelization Basics
  • Tutorial 2: Data Parallel and Fully Sharded Data Parallel Training
  • Tutorial 3: Tensor Parallel and Transformers Scaling
  • Tutorial 4: Up Next
MinText
  • Welcome to MinText’s documentation!
  • View page source

Welcome to MinText’s documentation!

MinText is a minimalistic 3D-parallelism distributed training and inference framework for LLMs in JAX

Tutorials

  • Tutorial 1: Parallelization Basics
    • 0. Why JAX?
    • 1. Multi-Device Computation
    • 2. Sharded Matrices
    • 3. Collective Operations
    • 4. Notes on JAX Sharding
    • 5. Conclusion
  • Tutorial 2: Data Parallel and Fully Sharded Data Parallel Training
    • 0. Background
    • 1. Data Parallel
    • 2. Data Parallel Algorithm
    • 3. Example: 8-way Data Parallel Training with Plain JAX
    • 4. Data Parallel Training with Flax NNX
    • 5. Fully Sharded Data Parallelism (FSDP)
    • 6. Fully Sharded Data Parallel (FSDP) Training with Flax NNX
  • Tutorial 3: Tensor Parallel and Transformers Scaling
    • 0. Setup
    • 1. Tensor Parallelism
    • 2. Combining Parallelism Techniques
    • 3. Transformer Scaling with Tensor Parallelism
    • 4. Scaling Transformers in Flax
  • Tutorial 4: Up Next
    • 1. Other Forms of Parallelism in Distributed Deep Learning
    • 2. What to Read Next

Indices and tables

  • Index

  • Module Index

  • Search Page

Next

© Copyright 2025, Shashank Shekhar.

Built with Sphinx using a theme provided by Read the Docs.