pytorch分布式-ddp

文章目录
  1. 1. 介绍
  2. 2. 通信
    1. 2.1. 环境变量
    2. 2.2. 同步
    3. 2.3. tensor通信
  3. 3. 数据
  4. 4. 使用方法
  5. 5. 代码执行
  6. 6. 参考资料

介绍

这篇博客主要是关于pytorch分布式ddp(DistributedDataParallel)的介绍和大概的食用(这不是错别字)教程。

数据并行DistributedDataParallel指的是在数据集层面进行多进程的切分,对于模型参数和训练状态等其他部分切分。

首先会介绍一下通信。在pytorch分布式ddp中,各个进程的代码是单独运行的。彼此之间在没有显式通信的时候,是不知道对方的的信息的。因此分布式的重点,所以了解通信的情况,也就了解了分布式的原理和使用的方法。

主要通信的方式有:

  • 环境变量
  • 同步
  • tensor操作

然后会介绍一下在数据并行下数据如何进行切分。

最后介绍整体pytorch分布式大概的流程和使用方法。

通信

环境变量

常用的环境变量有:

  • WORLD_SIZE 全局进程数
  • RANK 当前进程全局标识符
  • LOCAL_RANK 在单个节点中的进程标识符
  • MASTER_ADDR 主节点IP地址
  • MASTER_PORT 主节点端口

常用的为前三个,还有一些使用更加少的暂时没有罗列。

环境变量的获取可以:

1
2
3
4
import os
os.environ["path"]
os.environ.get('KEY_THAT_MIGHT_EXIST')
os.getenv('KEY_THAT_MIGHT_EXIST', default_value) # 推荐

同步

引入pytorch分布式包

1
import torch.distributed as dist

同步所有进程进度

1
dist.barrier()

在tensor通信的时候,也会起到同步进程的作用。很容易理解,不同步的话tensor的值都没有求得。

tensor通信

广播broadcast,收集gather,分发scatter,全收集all-gather,规约reduce,全规约all-reduce,全对称all-to-all,批量广播broadcast_object_list

数据

在ddp中,只考虑的数据的剪切。那么对于某个进程,只需要计算部分数据即可。某个进程根据LOCAL_RANK获取自己所需的数据的方法有两种:

  1. 数据集定义中加入offset,根据offset获取自己只需要的数据,那么进程只能看到自己的数据,比如
    1
    2
    3
    for i, segment in enumerate(open(file)):
    if i % n_gpus != offset:
    continue
  2. 通过设置DataLoader中的sampler控制数据集采样实现数据切分,比如:
    1
    2
    3
    from torch.utils.data import DataLoader, DistributedSampler
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)

使用方法

在通信前,需要进行初始化操作(如果init_process_group不指定部分参数,也会自动从环境变量中获取):

1
2
3
4
5
# 初始化分布式进程
def setup():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", rank=rank, world_size=world_size)

设置使用的GPU(可以灵活设置,比如每若干个进程共享GPU):

1
2
3
rank = int(os.environ["RANK"])
torch.cuda.set_device(rank) # 设置默认GPU
device = torch.device(f"cuda:{rank}") # 显式指定设备

使用的GPU还会收到环境变量CUDA_VISIBLE_DEVICES的限制

设置默认GPU可以让部分CUDA操作默认在该设备执行

然后包装模型,隐式的进行 tensor的同步和通信(在模型之外计算某些量(如精度、损失值等),可能需要同步):

1
2
3
from torch.nn.parallel import DistributedDataParallel as DDP
model = SimpleCNN().to(device)
model = DDP(model, device_ids=[rank])

当然,这里也可以切换成别的过程,比如如果不是模型的训练和推理,也可以进行tensor别的计算方法,但是需要手动的进行通信等。

对于一些多个进程只需要完成一次的操作,比如保存模型或者日志记录等,只需要一个进程一般是主进程完成即可:

1
2
if dist.get_rank() == 0:
torch.save(model.state_dict(), "model_checkpoint.pth")

代码执行完需要进程组的销毁:

1
2
def cleanup():
dist.destroy_process_group()

代码执行

如果执行代码直接使用python,那么需要使用pytorch的包启动多进程:

1
2
import torch.multiprocessing as mp
mp.spawn(train, nprocs=world_size, join=True)

如果直接使用 torchrun命令执行代码,则不需要使用 torch.multiprocessing,但需要在命令里添加部分参数,等于调用 torch.multiprocessing的任务交给 torchrun完成:

1
torchrun --nproc_per_node=4 your_script.py

参考资料

由于评论系统依托于Github的Discuss存在,因此默认评论者会收到所有通知。可以在邮件里点击"unsubscribe"停止接受,后续也可以点击下列仓库进行通知管理: bg51717/Hexo-Blogs-comments
Since the comment system relies on GitHub's Discussions feature, by default, commentators will receive all notifications. You can click "unsubscribe" in the email to stop receiving them, and you can also manage your notifications by clicking on the following repositories: bg51717/Hexo-Blogs-comments