介绍
这篇博客主要是关于pytorch分布式ddp(DistributedDataParallel)的介绍和大概的食用(这不是错别字)教程。
数据并行DistributedDataParallel指的是在数据集层面进行多进程的切分,对于模型参数和训练状态等其他部分切分。
首先会介绍一下通信。在pytorch分布式ddp中,各个进程的代码是单独运行的。彼此之间在没有显式通信的时候,是不知道对方的的信息的。因此分布式的重点,所以了解通信的情况,也就了解了分布式的原理和使用的方法。
主要通信的方式有:
- 环境变量
- 同步
- tensor操作
然后会介绍一下在数据并行下数据如何进行切分。
最后介绍整体pytorch分布式大概的流程和使用方法。
通信
环境变量
常用的环境变量有:
- WORLD_SIZE 全局进程数
- RANK 当前进程全局标识符
- LOCAL_RANK 在单个节点中的进程标识符
- MASTER_ADDR 主节点IP地址
- MASTER_PORT 主节点端口
常用的为前三个,还有一些使用更加少的暂时没有罗列。
环境变量的获取可以:
import os
os.environ["path"]
os.environ.get('KEY_THAT_MIGHT_EXIST')
os.getenv('KEY_THAT_MIGHT_EXIST', default_value) # 推荐
同步
引入pytorch分布式包
import torch.distributed as dist
同步所有进程进度
dist.barrier()
在tensor通信的时候,也会起到同步进程的作用。很容易理解,不同步的话tensor的值都没有求得。
tensor通信
广播broadcast,收集gather,分发scatter,全收集all-gather,规约reduce,全规约all-reduce,全对称all-to-all,批量广播broadcast_object_list
数据
在ddp中,只考虑的数据的剪切。那么对于某个进程,只需要计算部分数据即可。某个进程根据LOCAL_RANK获取自己所需的数据的方法有两种:
- 数据集定义中加入offset,根据offset获取自己只需要的数据,那么进程只能看到自己的数据,比如
for i, segment in enumerate(open(file)): if i % n_gpus != offset: continue
- 通过设置DataLoader中的sampler控制数据集采样实现数据切分,比如:
from torch.utils.data import DataLoader, DistributedSampler sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)
使用方法
在通信前,需要进行初始化操作(如果init_process_group不指定部分参数,也会自动从环境变量中获取):
# 初始化分布式进程
def setup():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", rank=rank, world_size=world_size)
设置使用的GPU(可以灵活设置,比如每若干个进程共享GPU):
rank = int(os.environ["RANK"])
torch.cuda.set_device(rank) # 设置默认GPU
device = torch.device(f"cuda:{rank}") # 显式指定设备
使用的GPU还会收到环境变量CUDA_VISIBLE_DEVICES的限制
设置默认GPU可以让部分CUDA操作默认在该设备执行
然后包装模型,隐式的进行
tensor
的同步和通信(在模型之外计算某些量(如精度、损失值等),可能需要同步):
from torch.nn.parallel import DistributedDataParallel as DDP
model = SimpleCNN().to(device)
model = DDP(model, device_ids=[rank])
当然,这里也可以切换成别的过程,比如如果不是模型的训练和推理,也可以进行tensor别的计算方法,但是需要手动的进行通信等。
对于一些多个进程只需要完成一次的操作,比如保存模型或者日志记录等,只需要一个进程一般是主进程完成即可:
if dist.get_rank() == 0:
torch.save(model.state_dict(), "model_checkpoint.pth")
代码执行完需要进程组的销毁:
def cleanup():
dist.destroy_process_group()
代码执行
如果执行代码直接使用python,那么需要使用pytorch的包启动多进程:
import torch.multiprocessing as mp
mp.spawn(train, nprocs=world_size, join=True)
如果直接使用 torchrun
命令执行代码,则不需要使用
torch.multiprocessing
,但需要在命令里添加部分参数,等于调用
torch.multiprocessing
的任务交给
torchrun
完成:
torchrun --nproc_per_node=4 your_script.py