Offload Training
This document introduces the Offload Training feature in DiffSynth-Studio, which significantly reduces GPU memory usage during training by moving model weights layer-by-layer between CPU and GPU.
Note: Offload Training currently supports single-GPU training only and is not compatible with multi-GPU (DDP) setups.
What is Offload Training
When training large-scale models (e.g., Qwen-Image with 60 layers, Wan2.1-14B with 40 layers), all layer weights must reside on the GPU simultaneously, consuming tens of GB of memory for weights alone. The core idea of Offload Training is: at any given moment, only load the weights of the currently computing module onto the GPU, and immediately offload them back to CPU after computation, reducing memory usage from O(N × params_per_layer) to O(1 × params_per_layer).
This feature is implemented via PyTorch’s Module Hook mechanism and requires no modifications to model code.
How It Works
Core Mechanism
OffloadTrainingManager scans the model and registers 4 hooks for each managed module:
forward_pre_hook → Load module weights from CPU to GPU (onload)
module.forward() → Normal forward computation
forward_hook → Offload module weights from GPU back to CPU (offload)
backward_pre_hook → Reload module weights from CPU to GPU (onload)
module.backward() → Compute gradients
backward_hook → Offload module weights back to CPU (offload)
Parameter and Buffer Classification
Different offload strategies are applied depending on whether parameters are trainable and for buffer types:
| Type | Offloader Class | Behavior |
|---|---|---|
Non-trainable (requires_grad=False) |
StaticParamOffloader |
Copies weights to pre-allocated pinned memory at init, maintaining a permanent CPU copy, and replaces param.data with an empty GPU placeholder (freeing GPU memory); onload asynchronously copies from CPU to GPU, offload reassigns param.data to the placeholder (no PCIe transfer back) |
Trainable + enable_optimizer_cpu_offload=True |
TrainableParamOffloader |
Weights change during training, so no static copy is kept; onload/offload via param.data.to(device) with actual data transfer; also moves param.grad to CPU after backward |
Trainable + enable_optimizer_cpu_offload=False |
AlwaysOnGPUParamOffloader |
Moves parameters to GPU at init and never offloads; suitable for LoRA training (small number of trainable params) |
Module Buffers (e.g., BatchNorm's running_mean/running_var) |
BufferOffloader |
Similar to StaticParamOffloader: copies buffer to pinned memory at init; onload asynchronously copies from CPU to GPU, offload reassigns module._buffers[name] back to the CPU copy |
Pinned Memory Pool
StaticParamOffloader and BufferOffloader need to allocate a pinned memory copy on CPU for each non-trainable parameter/buffer (pinned memory enables asynchronous non-blocking CPU→GPU transfers, much faster than regular pageable memory).
Problem: PyTorch’s pin_memory() allocates memory through CachingHostAllocator, which rounds up each allocation size to the next power of two. For example, a 17MB tensor actually allocates 32MB. Large models have thousands of parameter tensors, and allocating each independently via pin_memory() leads to massive memory waste (measured inflation of 50%~100%).
Solution: PinnedArenaPool pre-allocates a few large blocks of pinned memory (i.e., arenas — large pre-allocated memory regions from which all small objects are carved out), then uses bump-pointer allocation to compactly carve out space for each tensor, avoiding the per-tensor rounding waste:
from_model()scans all non-trainable parameters and buffers in the model, computing total sizeDecomposes total size into several power-of-two sized chunks (each chunk is a
PinnedBuffer)Allocation sequentially probes chunks for remaining space; bump-pointer advances to complete allocation (only 64-byte alignment, no rounding waste)
Automatically grows new chunks when space is insufficient
Falls back to per-tensor
pin_memory()on exceptions
Gradient Checkpointing Compatibility
Gradient Checkpointing re-executes forward during backward (recomputing activations), which re-triggers forward_hook. This is solved via the _in_recompute set:
First forward: normal offload, module added to
_in_recomputeRecomputed forward (during backward): detects module in
_in_recompute, skips offload, keeps weights on GPU for backwardWhen
after_backward()is called: clears_in_recompute, preparing for the next step
Hook Registration Granularity
OffloadTrainingManager registers hooks at leaf module granularity by default (nn.Linear, nn.LayerNorm, etc.), meaning each leaf module is independently onloaded/offloaded. Additionally, “orphan parameters” and “orphan buffers” not managed by any leaf module are automatically collected and hooked.
Experimental: The cpu_offload_split_threshold parameter (unit: MB) allows adjusting hook registration granularity. When set, modules with total parameters exceeding the threshold are recursively split into children, while modules below the threshold are hooked as a whole. This feature may not be compatible with all model architectures in the current version and is disabled by default.
Training Loop Integration
Execution flow in runner.py:
# When enable_model_cpu_offload=True:
# 1. Model does NOT call model.to(device), stays on CPU
# 2. Only prepare optimizer, dataloader, scheduler (model is NOT prepared)
# 3. Create OffloadTrainingManager, which auto-registers hooks on the model
# Training loop:
loss = model(data)
accelerator.backward(loss)
offload_manager.after_backward() # Clear recompute marks + move gradients to CPU
optimizer.step()
optimizer.zero_grad()
Usage
Parameters
| Parameter | Default | Description |
|---|---|---|
--enable_model_cpu_offload |
False | Enable layer-wise offload training |
--enable_optimizer_cpu_offload |
False | Used with --enable_model_cpu_offload; moves trainable params and optimizer to CPU |
--cpu_offload_split_threshold |
None | Experimental (unit: MB); modules above this threshold are recursively split |
Parameter Combinations
| Scenario | --enable_model_cpu_offload |
--enable_optimizer_cpu_offload |
Effect |
|---|---|---|---|
| Default training | ❌ | ❌ | All weights and optimizer on GPU |
| Offload non-trainable params | ✅ | ❌ | Non-trainable params offloaded layer-by-layer; trainable params and optimizer stay on GPU |
| Offload all params | ✅ | ✅ | All params offloaded layer-by-layer; gradients and optimizer run on CPU |
Example
Simply add --enable_model_cpu_offload to your existing training command. Example with Qwen-Image LoRA training:
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_dataset \
--dataset_metadata_path data/example_dataset/metadata.json \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters \
--enable_model_cpu_offload
For full offload (optimizer also on CPU), add --enable_optimizer_cpu_offload:
--enable_model_cpu_offload \
--enable_optimizer_cpu_offload
Compatibility
| Feature | Compatible | Notes |
|---|---|---|
| Gradient Checkpointing | ✅ | _in_recompute mechanism handles recomputation |
| Accelerate DDP (multi-GPU) | ⚠️ | In enable_model_cpu_offload mode, model is not wrapped by DDP (no accelerator.prepare(model)), so gradient allreduce is not performed. Multi-GPU training compatibility is not guaranteed; each GPU trains independently without gradient synchronization |
| Split Training | ✅ | launch_data_process_task also supports --enable_model_cpu_offload |
| DeepSpeed | ❌ | ZeRO's parameter gathering conflicts with hooks |
Notes
With
--enable_model_cpu_offloadenabled, the model never callsmodel.to(device); weights are managed entirely by hooksTraining speed decreases due to CPU↔GPU transfers (typically 2-10x slower); larger models see greater slowdown; suitable for memory-constrained scenarios
Recommended to use with
--use_gradient_checkpointingto further reduce activation memory--enable_optimizer_cpu_offloadonly supports gradient accumulation steps of 1 (--gradient_accumulation_steps 1)
Integrating Offload Training Module in Other Codebases
The Offload Training module is relatively independent, so developers can integrate it into other codebases. Below is a code example with 4GB VRAM usage.
import torch
from tqdm import tqdm
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ModuleList(torch.nn.Linear(4096, 4096) for _ in range(10))
def forward(self, x):
for layer in self.layers:
x = x + layer(torch.nn.functional.layer_norm(x, (4096,)))
return x
model = ToyModel().to("cuda")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
pbar = tqdm(range(100))
for i in pbar:
x = torch.randn((512, 4096), device="cuda")
y = x + 1
y_pred = model(x)
loss = torch.nn.functional.mse_loss(y_pred, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
pbar.set_postfix(loss=f"{loss.item():.4f}")
With Offload Training enabled, VRAM usage drops to 1.4GB:
import torch
from tqdm import tqdm
from diffsynth.core import OffloadTrainingManager
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ModuleList(torch.nn.Linear(4096, 4096) for _ in range(10))
def forward(self, x):
for layer in self.layers:
x = x + layer(torch.nn.functional.layer_norm(x, (4096,)))
return x
model = ToyModel().to("cpu")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
offload_manager = OffloadTrainingManager(model, target_device="cuda", enable_optimizer_cpu_offload=True)
pbar = tqdm(range(100))
for i in pbar:
x = torch.randn((512, 4096), device="cuda")
y = x + 1
y_pred = model(x)
loss = torch.nn.functional.mse_loss(y_pred, y)
loss.backward()
offload_manager.after_backward()
optimizer.step()
optimizer.zero_grad()
pbar.set_postfix(loss=f"{loss.item():.4f}")