Skip to content

Commit 4ec4d9c

Browse files
authored
Merge pull request #1354 from mi804/low_vram_training_ds
low vram training with deepspeed zero3
1 parent 7a80f10 commit 4ec4d9c

File tree

7 files changed

+346
-14
lines changed

7 files changed

+346
-14
lines changed

diffsynth/core/gradient/gradient_checkpoint.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,50 @@
11
import torch
22

33

4+
try:
5+
import deepspeed
6+
_HAS_DEEPSPEED = True
7+
except ModuleNotFoundError:
8+
_HAS_DEEPSPEED = False
9+
10+
411
def create_custom_forward(module):
512
def custom_forward(*inputs, **kwargs):
613
return module(*inputs, **kwargs)
714
return custom_forward
815

916

17+
def create_custom_forward_use_reentrant(module):
18+
def custom_forward(*inputs):
19+
return module(*inputs)
20+
return custom_forward
21+
22+
23+
def judge_args_requires_grad(*args):
24+
for arg in args:
25+
if isinstance(arg, torch.Tensor) and arg.requires_grad:
26+
return True
27+
return False
28+
29+
1030
def gradient_checkpoint_forward(
1131
model,
1232
use_gradient_checkpointing,
1333
use_gradient_checkpointing_offload,
1434
*args,
1535
**kwargs,
1636
):
37+
if use_gradient_checkpointing and _HAS_DEEPSPEED and deepspeed.checkpointing.is_configured():
38+
all_args = args + tuple(kwargs.values())
39+
if not judge_args_requires_grad(*all_args):
40+
# get the first grad_enabled tensor from un_checkpointed forward
41+
model_output = model(*args, **kwargs)
42+
else:
43+
model_output = deepspeed.checkpointing.checkpoint(
44+
create_custom_forward_use_reentrant(model),
45+
*all_args,
46+
)
47+
return model_output
1748
if use_gradient_checkpointing_offload:
1849
with torch.autograd.graph.save_on_cpu():
1950
model_output = torch.utils.checkpoint.checkpoint(

diffsynth/diffusion/runner.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def launch_training_task(
2929
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
3030
model.to(device=accelerator.device)
3131
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
32-
32+
initialize_deepspeed_gradient_checkpointing(accelerator)
3333
for epoch_id in range(num_epochs):
3434
for data in tqdm(dataloader):
3535
with accelerator.accumulate(model):
@@ -70,3 +70,19 @@ def launch_data_process_task(
7070
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
7171
data = model(data)
7272
torch.save(data, save_path)
73+
74+
75+
def initialize_deepspeed_gradient_checkpointing(accelerator: Accelerator):
76+
if getattr(accelerator.state, "deepspeed_plugin", None) is not None:
77+
ds_config = accelerator.state.deepspeed_plugin.deepspeed_config
78+
if "activation_checkpointing" in ds_config:
79+
import deepspeed
80+
act_config = ds_config["activation_checkpointing"]
81+
deepspeed.checkpointing.configure(
82+
mpu_=None,
83+
partition_activations=act_config.get("partition_activations", False),
84+
checkpoint_in_cpu=act_config.get("cpu_checkpointing", False),
85+
contiguous_checkpointing=act_config.get("contiguous_memory_optimization", False)
86+
)
87+
else:
88+
print("Do not find activation_checkpointing config in deepspeed config, skip initializing deepspeed gradient checkpointing.")

docs/en/Pipeline_Usage/Model_Training.md

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ Similar to [model loading during inference](../Pipeline_Usage/Model_Inference.md
123123
124124
<details>
125125
126-
<details>
127126
128127
<summary>Load models from local file paths</summary>
129128
@@ -244,4 +243,119 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera
244243
* The training framework does not support batch size > 1. The reasons are complex. See [Q&A: Why doesn't the training framework support batch size > 1?](../QA.md#why-doesnt-the-training-framework-support-batch-size--1)
245244
* Some models contain redundant parameters. For example, the text encoding part of the last layer of Qwen-Image's DiT part. When training these models, `--find_unused_parameters` needs to be set to avoid errors in multi-GPU training. For compatibility with community models, we do not intend to remove these redundant parameters.
246245
* The loss function value of Diffusion models has little relationship with actual effects. Therefore, we do not record loss function values during training. We recommend setting `--num_epochs` to a sufficiently large value, testing while training, and manually closing the training program after the effect converges.
247-
* `--use_gradient_checkpointing` is usually enabled unless GPU VRAM is sufficient; `--use_gradient_checkpointing_offload` is enabled as needed. See [`diffsynth.core.gradient`](../API_Reference/core/gradient.md) for details.
246+
* `--use_gradient_checkpointing` is usually enabled unless GPU VRAM is sufficient; `--use_gradient_checkpointing_offload` is enabled as needed. See [`diffsynth.core.gradient`](../API_Reference/core/gradient.md) for details.
247+
248+
## Low VRAM Training
249+
250+
If you want to complete LoRA model training on GPU with low vram, you can combine [Two-Stage Split Training](../Training/Split_Training.md) with `deepspeed_zero3_offload` training. First, split the preprocessing steps into the first stage and store the computed results onto the hard disk. Second, read these results from the disk and train the denoising model. By using `deepspeed_zero3_offload`, the training parameters and optimizer states are offloaded to the CPU or disk. We provide examples for some models, primarily by specifying the `deepspeed` configuration via `--config_file`.
251+
252+
Please note that the `deepspeed_zero3_offload` mode is incompatible with PyTorch's native gradient checkpointing mechanism. To address this, we have adapted the `checkpointing` interface of `deepspeed`. Users need to fill the `activation_checkpointing` field in the `deepspeed` configuration to enable gradient checkpointing.
253+
254+
Below is the script for low VRAM model training for the Qwen-Image model:
255+
256+
```shell
257+
accelerate launch examples/qwen_image/model_training/train.py \
258+
--dataset_base_path data/example_image_dataset \
259+
--dataset_metadata_path data/example_image_dataset/metadata.csv \
260+
--max_pixels 1048576 \
261+
--dataset_repeat 1 \
262+
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
263+
--learning_rate 1e-4 \
264+
--num_epochs 5 \
265+
--remove_prefix_in_ckpt "pipe.dit." \
266+
--output_path "./models/train/Qwen-Image_lora-splited-cache" \
267+
--lora_base_model "dit" \
268+
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
269+
--lora_rank 32 \
270+
--task "sft:data_process" \
271+
--use_gradient_checkpointing \
272+
--dataset_num_workers 8 \
273+
--find_unused_parameters
274+
275+
accelerate launch --config_file examples/qwen_image/model_training/special/low_vram_training/deepspeed_zero3_cpuoffload.yaml examples/qwen_image/model_training/train.py \
276+
--dataset_base_path "./models/train/Qwen-Image_lora-splited-cache" \
277+
--max_pixels 1048576 \
278+
--dataset_repeat 50 \
279+
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
280+
--learning_rate 1e-4 \
281+
--num_epochs 5 \
282+
--remove_prefix_in_ckpt "pipe.dit." \
283+
--output_path "./models/train/Qwen-Image_lora" \
284+
--lora_base_model "dit" \
285+
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
286+
--lora_rank 32 \
287+
--task "sft:train" \
288+
--use_gradient_checkpointing \
289+
--dataset_num_workers 8 \
290+
--find_unused_parameters \
291+
--initialize_model_on_cpu
292+
```
293+
294+
The configurations for `accelerate` and `deepspeed` are as follows:
295+
296+
```yaml
297+
compute_environment: LOCAL_MACHINE
298+
debug: true
299+
deepspeed_config:
300+
deepspeed_config_file: examples/qwen_image/model_training/special/low_vram_training/ds_z3_cpuoffload.json
301+
zero3_init_flag: true
302+
distributed_type: DEEPSPEED
303+
downcast_bf16: 'no'
304+
enable_cpu_affinity: false
305+
machine_rank: 0
306+
main_training_function: main
307+
num_machines: 1
308+
num_processes: 1
309+
rdzv_backend: static
310+
same_network: true
311+
tpu_env: []
312+
tpu_use_cluster: false
313+
tpu_use_sudo: false
314+
use_cpu: false
315+
```
316+
317+
```json
318+
{
319+
"fp16": {
320+
"enabled": "auto",
321+
"loss_scale": 0,
322+
"loss_scale_window": 1000,
323+
"initial_scale_power": 16,
324+
"hysteresis": 2,
325+
"min_loss_scale": 1
326+
},
327+
"bf16": {
328+
"enabled": "auto"
329+
},
330+
"zero_optimization": {
331+
"stage": 3,
332+
"offload_optimizer": {
333+
"device": "cpu",
334+
"pin_memory": true
335+
},
336+
"offload_param": {
337+
"device": "cpu",
338+
"pin_memory": true
339+
},
340+
"overlap_comm": false,
341+
"contiguous_gradients": true,
342+
"sub_group_size": 1e9,
343+
"reduce_bucket_size": 5e7,
344+
"stage3_prefetch_bucket_size": 5e7,
345+
"stage3_param_persistence_threshold": 1e5,
346+
"stage3_max_live_parameters": 1e8,
347+
"stage3_max_reuse_distance": 1e8,
348+
"stage3_gather_16bit_weights_on_model_save": true
349+
},
350+
"activation_checkpointing": {
351+
"partition_activations": false,
352+
"cpu_checkpointing": false,
353+
"contiguous_memory_optimization": false
354+
},
355+
"gradient_accumulation_steps": "auto",
356+
"gradient_clipping": "auto",
357+
"train_batch_size": "auto",
358+
"train_micro_batch_size_per_gpu": "auto",
359+
"wall_clock_breakdown": false
360+
}
361+
```

docs/zh/Pipeline_Usage/Model_Training.md

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,116 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera
243243
* 少数模型包含冗余参数,例如 Qwen-Image 的 DiT 部分最后一层的文本编码部分,在训练这些模型时,需设置 `--find_unused_parameters` 避免在多 GPU 训练中报错。出于对开源社区模型兼容性的考虑,我们不打算删除这些冗余参数。
244244
* Diffusion 模型的损失函数值与实际效果的关系不大,因此我们在训练过程中不会记录损失函数值。我们建议把 `--num_epochs` 设置为足够大的数值,边训边测,直至效果收敛后手动关闭训练程序。
245245
* `--use_gradient_checkpointing` 通常是开启的,除非 GPU 显存足够;`--use_gradient_checkpointing_offload` 则按需开启,详见 [`diffsynth.core.gradient`](../API_Reference/core/gradient.md)
246+
247+
## 低显存训练
248+
如果想在低显存显卡上完成 LoRA 模型训练,可以同时采用 [两阶段拆分训练](../Training/Split_Training.md)`deepspeed_zero3_offload` 训练。 首先,将前处理过程拆分到第一阶段,将计算结果存储到硬盘中。其次,在第二阶段从硬盘中读取这些结果并进行去噪模型的训练,训练通过采用 `deepspeed_zero3_offload`,将训练参数和优化器状态 offload 到 cpu 或者 disk 上。我们为部分模型提供了样例,主要是通过 `--config_file` 指定 `deepspeed` 配置。
249+
250+
需要注意的是,`deepspeed_zero3_offload` 模式与 `pytorch` 原生的梯度检查点机制不兼容,我们为此对 `deepspeed``checkpointing` 接口做了适配。用户需要在 `deepspeed` 配置中填写 `activation_checkpointing` 字段以启用梯度检查点。
251+
252+
以下为 Qwen-Image 模型的低显存模型训练脚本:
253+
```shell
254+
accelerate launch examples/qwen_image/model_training/train.py \
255+
--dataset_base_path data/example_image_dataset \
256+
--dataset_metadata_path data/example_image_dataset/metadata.csv \
257+
--max_pixels 1048576 \
258+
--dataset_repeat 1 \
259+
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
260+
--learning_rate 1e-4 \
261+
--num_epochs 5 \
262+
--remove_prefix_in_ckpt "pipe.dit." \
263+
--output_path "./models/train/Qwen-Image_lora-splited-cache" \
264+
--lora_base_model "dit" \
265+
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
266+
--lora_rank 32 \
267+
--task "sft:data_process" \
268+
--use_gradient_checkpointing \
269+
--dataset_num_workers 8 \
270+
--find_unused_parameters
271+
272+
accelerate launch --config_file examples/qwen_image/model_training/special/low_vram_training/deepspeed_zero3_cpuoffload.yaml examples/qwen_image/model_training/train.py \
273+
--dataset_base_path "./models/train/Qwen-Image_lora-splited-cache" \
274+
--max_pixels 1048576 \
275+
--dataset_repeat 50 \
276+
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
277+
--learning_rate 1e-4 \
278+
--num_epochs 5 \
279+
--remove_prefix_in_ckpt "pipe.dit." \
280+
--output_path "./models/train/Qwen-Image_lora" \
281+
--lora_base_model "dit" \
282+
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
283+
--lora_rank 32 \
284+
--task "sft:train" \
285+
--use_gradient_checkpointing \
286+
--dataset_num_workers 8 \
287+
--find_unused_parameters \
288+
--initialize_model_on_cpu
289+
```
290+
291+
其中,`accelerate``deepspeed` 的配置文件如下:
292+
293+
```yaml
294+
compute_environment: LOCAL_MACHINE
295+
debug: true
296+
deepspeed_config:
297+
deepspeed_config_file: examples/qwen_image/model_training/special/low_vram_training/ds_z3_cpuoffload.json
298+
zero3_init_flag: true
299+
distributed_type: DEEPSPEED
300+
downcast_bf16: 'no'
301+
enable_cpu_affinity: false
302+
machine_rank: 0
303+
main_training_function: main
304+
num_machines: 1
305+
num_processes: 1
306+
rdzv_backend: static
307+
same_network: true
308+
tpu_env: []
309+
tpu_use_cluster: false
310+
tpu_use_sudo: false
311+
use_cpu: false
312+
```
313+
314+
```json
315+
{
316+
"fp16": {
317+
"enabled": "auto",
318+
"loss_scale": 0,
319+
"loss_scale_window": 1000,
320+
"initial_scale_power": 16,
321+
"hysteresis": 2,
322+
"min_loss_scale": 1
323+
},
324+
"bf16": {
325+
"enabled": "auto"
326+
},
327+
"zero_optimization": {
328+
"stage": 3,
329+
"offload_optimizer": {
330+
"device": "cpu",
331+
"pin_memory": true
332+
},
333+
"offload_param": {
334+
"device": "cpu",
335+
"pin_memory": true
336+
},
337+
"overlap_comm": false,
338+
"contiguous_gradients": true,
339+
"sub_group_size": 1e9,
340+
"reduce_bucket_size": 5e7,
341+
"stage3_prefetch_bucket_size": 5e7,
342+
"stage3_param_persistence_threshold": 1e5,
343+
"stage3_max_live_parameters": 1e8,
344+
"stage3_max_reuse_distance": 1e8,
345+
"stage3_gather_16bit_weights_on_model_save": true
346+
},
347+
"activation_checkpointing": {
348+
"partition_activations": false,
349+
"cpu_checkpointing": false,
350+
"contiguous_memory_optimization": false
351+
},
352+
"gradient_accumulation_steps": "auto",
353+
"gradient_clipping": "auto",
354+
"train_batch_size": "auto",
355+
"train_micro_batch_size_per_gpu": "auto",
356+
"wall_clock_breakdown": false
357+
}
358+
```

examples/qwen_image/model_training/special/low_vram_training/Qwen-Image-LoRA.sh

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,32 @@ accelerate launch examples/qwen_image/model_training/train.py \
44
--max_pixels 1048576 \
55
--dataset_repeat 1 \
66
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
7-
--fp8_models "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
87
--learning_rate 1e-4 \
98
--num_epochs 5 \
109
--remove_prefix_in_ckpt "pipe.dit." \
11-
--output_path "./models/train/Qwen-Image-LoRA-lowvram-cache" \
10+
--output_path "./models/train/Qwen-Image_lora-splited-cache" \
1211
--lora_base_model "dit" \
1312
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
1413
--lora_rank 32 \
14+
--task "sft:data_process" \
1515
--use_gradient_checkpointing \
16-
--use_gradient_checkpointing_offload \
1716
--dataset_num_workers 8 \
18-
--find_unused_parameters \
19-
--task "sft:data_process"
17+
--find_unused_parameters
2018

21-
accelerate launch examples/qwen_image/model_training/train.py \
22-
--dataset_base_path "./models/train/Qwen-Image-LoRA-lowvram-cache" \
19+
accelerate launch --config_file examples/qwen_image/model_training/special/low_vram_training/deepspeed_zero3_cpuoffload.yaml examples/qwen_image/model_training/train.py \
20+
--dataset_base_path "./models/train/Qwen-Image_lora-splited-cache" \
2321
--max_pixels 1048576 \
2422
--dataset_repeat 50 \
2523
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
26-
--fp8_models "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
2724
--learning_rate 1e-4 \
2825
--num_epochs 5 \
2926
--remove_prefix_in_ckpt "pipe.dit." \
30-
--output_path "./models/train/Qwen-Image-LoRA-lowvram" \
27+
--output_path "./models/train/Qwen-Image_lora" \
3128
--lora_base_model "dit" \
3229
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
3330
--lora_rank 32 \
31+
--task "sft:train" \
3432
--use_gradient_checkpointing \
35-
--use_gradient_checkpointing_offload \
3633
--dataset_num_workers 8 \
3734
--find_unused_parameters \
38-
--task "sft:train"
35+
--initialize_model_on_cpu
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: true
3+
deepspeed_config:
4+
deepspeed_config_file: examples/qwen_image/model_training/special/low_vram_training/ds_z3_cpuoffload.json
5+
zero3_init_flag: true
6+
distributed_type: DEEPSPEED
7+
downcast_bf16: 'no'
8+
enable_cpu_affinity: false
9+
machine_rank: 0
10+
main_training_function: main
11+
num_machines: 1
12+
num_processes: 1
13+
rdzv_backend: static
14+
same_network: true
15+
tpu_env: []
16+
tpu_use_cluster: false
17+
tpu_use_sudo: false
18+
use_cpu: false

0 commit comments

Comments
 (0)