PyTorch (6-1) Distributed Applications with PyTorch (p2p)

本文來自官方教學 https://pytorch.org/tutorials/intermediate/dist_tuto.html

1. Packages:

本篇主要使用 torch.distributed 和 torch.multiprocessing 這兩個 packages 來演示 IPC (Inter-Process Communication) 流程。

import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process

2.Blocking p2p communication:

我們要開 4 個 Processes (Rank 0~3),並由 0 傳遞 tensor 給 3。

一次看整段程式碼會比較清楚,順序是 main => init_processes => run

  1. main: 開 4 個 processes,並註冊 start 的 callback 為 init_processes,等其他 processes join 回來
  2. init_processes: 設定環境變數,初始化後執行給定函式 fn。world_size 為 processes 的數目。
  3. run: teson 運算和 IPC。
def run(rank, size):
    """Blocking point-to-point communication."""
    tensor = torch.zeros(1)
    if rank == 0:
        tensor += 1
        # Send the tensor to process 3
        dist.send(tensor=tensor, dst=3)
    elif rank == 3:
        # Receive tensor from process 0
        dist.recv(tensor=tensor, src=0)
    print('Rank ', rank, ' has data ', tensor)


def init_processes(rank, size, fn, backend='tcp'):
    """ Initialize the distributed environment. """
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)


if __name__ == "__main__":
    size = 4
    processes = []
    for rank in range(size):
        p = Process(target=init_processes, args=(rank, size, run))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

3. Non-blocking p2p communication:

Non-blocking 和 Blocking 類似,要使用 isend 和 irecv 並且要呼叫 req.wait 確定傳輸完成。

def run(rank, size):
    """Non-blocking point-to-point communication."""
    tensor = torch.zeros(1)
    req = None
    if rank == 0:
        tensor += 1
        # Send the tensor to process 3
        req = dist.isend(tensor=tensor, dst=3)
        print('Rank 0 started sending')
    elif rank == 3:
        # Receive tensor from process 0
        req = dist.irecv(tensor=tensor, src=0)
        print('Rank 1 started receiving')
    if req is not None:
        req.wait()
    print('Rank ', rank, ' has data ', tensor[0])

留言

熱門文章