**How to train a model on 10k H100 GPUs?** A quick note summarizing common knowledge among the large-scale training cohort Oct 2nd, 2024 [https://soumith.ch/blog.html](https://soumith.ch/blog.html)
My friend Francois Fleuret asked the above. I quickly jotted down what I think is fairly common knowledge among engineers working on large-scale training. There's three parts. 1. Fitting as large of a network and as large of a batch-size as possible onto the 10k H100s -- parallelizing and using memory-saving tricks. 2. Communicating state between these GPUs as quickly as possible 3. Recovering from failures (hardware, software, etc.) as quickly as possible # Fitting as large of a network and as large of a batch-size as possible onto the 10k H100s. ## Parallelizing: 1. parallelize over batches (data parallel) 2. parallelize over layers (i.e. split a layer across GPUs) 3. parallelize across layers (i.e. 1 to N are on GPU1, N+1th layer to N+10th layer are on GPU2) Keep parallelizing until you are able to use all GPUs well, with maximum utilization. ## Checkpointing / Compute vs memorize: * You need to save certain terms from forward to compute the backprop (save_for_backward). However, if the network is sufficiently large, it is more profitable to free these terms in order to fit a larger batch-size, and recompute them again when you need them to compute the backprop. * Tricks like FSDP discard parts of weights that are held in one GPU (to save memory), and ask for the shards of weights from other GPUs right before they need them. # Communicating state between these GPUs as quickly as possible ## Communication overlap: When you need to communicate among GPUs, try to start communication as soon as you can: * Example: when Nth layer is done with backward, while N-1th layer is computing backward, all GPUs with an Nth layer can all-reduce their gradients) ## Discover and leverage the underlying networking topology: Communicating large amounts of state (gradients, optimizer state) across multiple nodes is complicated. with Sync SGD, you have to communicate this state in a burst, as quickly as you can. we might have multiple layers of switches, and have RDMA (ability to copy GPU memory directly to NIC, bypassing CPU ram entirely), and have frontend and backend NICs (frontend connects to storage like NFS, backend connects GPUs to other GPUs in cluster). So, it's important to leverage all this info when running communication collectives like all-reduce or scatter/gather. All-reduce for example can be done algorithmically in log(n) if you tree-reduce; and the constant factors that change based on the type of fiber connecting one node to another in the tree of networking fiber is important to reduce overall time and latency. Libraries like NCCL do sophisticated discovery of the underlying networking topology and leverage them when we run all-reduce and other collectives. At this scale, we also have to adjust the actual packet routing algorithms in our switches and NICs, to be able to load-balance well. Did you know switches have to have significant HBM memory as well (not just GPUs) because as the packets queue up, they have to queue up somewhere without getting dropped -- and that's switch-level HBM. That's a whole level of sophistication that's also super cool. # Recovering from failures (hardware, software, etc.) as quickly as possible At 10k GPU scale, things fail all the time -- GPUs, NICs, cables, etc. Some of these failures are easy to detect quickly, some of them you can only detect because one node isn't replying back in time (say a NCCL all-reduce is stuck). We build various tools to monitor and detect fleet health, and remove failed nodes from the fleet as quickly as possible. This is quite hard. Separately, at this large of a scale you can have silent data corruptions from memory bits flipping randomly (due to basic physics and amplifying the probability at this scale), and you suddenly have loss-explosions for no reason other than this random phenomenon. These happen at small-scale too, but very very infrequently so you barely notice. This is very hard to detect before-hand in software. Some hardware has hardware circuitry that does built-in checksums after it computes things -- this way if bit-flips occur the hardware can throw an interrupt. H100s and previous NVIDIA GPUs don't have this feature. To counter all these failures, you would want to save your model state as frequently and as quickly as you can; and when a failure occurs, you want to recover and continue as quickly as you can. Usually, we save model state really quickly to CPU memory in a separate thread and in the background we save from CPU memory to disk or remote storage. We also save model state in shards (this is torch.distributed's checkpointing feature), i.e. not every GPU needs to save all of the model weights; each GPU only needs to save a portion of weights -- and they can recover the other part of weights from other GPU shard checkpoints. # Further Reading Here's some resources to read further: 1. The [llama3 paper](https://arxiv.org/abs/2407.21783) has a great infrastructure section 2. Watch Meta's [AI Infra @ Scale talks](https://atscaleconference.com/events/ai-infra-scale-2024/) 3. Watch the [Faster Than Fast: Networking and Communication Optimizations for Llama 3](https://atscaleconference.com/videos/faster-than-fast-networking-and-communication-optimizations-for-llama-3/) talk from Networking @ Scale 4. The [torchtitan codebase](https://github.com/pytorch/torchtitan) is easy to read and has advanced parallelism and various memory-saving techniques implemented.I realize I have absolutely no clue how you train a single model with e.g. 10k h100s.
— François Fleuret (@francoisfleuret) September 30, 2024
What is where, updated when, with what?