PyTorch (6-2) Distributed Applications with PyTorch (Collective Communication)

本文來自官方教學 https://pytorch.org/tutorials/intermediate/dist_tuto.html

所有的溝通方式分成下面六種,而溝通時還要注意 backends。

1. Collective Communication:

用到的 packages:

import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process

建立 3 個 processes,將 0,1 形成一個 group。這樣就可以透過 group 限制溝通的對象。

def run(rank, size):
   group = dist.new_group([0, 1])
   tensor = torch.ones(1)
   dist.all_reduce(tensor, op=dist.reduce_op.SUM, group=group)
   print('Rank ', torch.distributed.get_rank(), ' has data ', tensor[0])


def init_processes(rank, size, fn, backend="tcp"):
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)


if __name__ == "__main__":
    size = 3
    process = []
    for rank in range(size):
        p = Process(target=init_processes, args=(rank, size, run))
        p.start()
        process.append(p)

    for p in process:
        p.join()

執行結果如下:

Rank  1  has data  tensor(2.)
Rank  0  has data  tensor(2.)
Rank  2  has data  tensor(1.)

一般來說執行分散式訓練時我們會把資料切成幾個 partitions,將 model 和 data 分散到各 gpu 上。 各 gpu 分開對各自的資料計算 gradient,再將 gradient 加總取平均並同步更新到所有 models。透過這種方式將原本要跑的 batch 分散出去來加速訓練的過程。 dist.reduce_op.SUM 可以用於加總各 processes 的 gradient。


留言

熱門文章