MXNet (3) Gluon - nn.Block

MXNet 的 Gluon module 引入了類似 PyTorch 的接口,使得我們可以輕鬆地架構深度神經網路。

在 Gluon 中定義了 nn.Block 和 nn.HybridBlock ,今天先從 nn.Block 這個 class 談起。


1. 建構一個簡單的網路架構:

使用 Gluon 建立網路架構主要需要注意兩件事:

  1. 使類別繼承 nn.Block 或 nn.HybridBlock
  2. 繼承 nn.Block 要實做 forward(),繼承 nn.HybridBlock 要實做hybrid_forward。

我們定義一個簡單的捲積神經網路 SimpleNet 繼承了 nn.Block 並實做 forward():

from mxnet import nd
from mxnet.gluon import nn


class SimpleNet(nn.Block):

    def __init__(self, **kwargs):
        super(SimpleNet, self).__init__(**kwargs)
        self.body = nn.Sequential()
        self.body.add(
            nn.Conv2D(20, 3, activation="relu"),
            nn.Conv2D(36, 3, activation="relu"),
        )
        self.output = nn.Dense(10, activation="relu")

    def forward(self, x):
        x = self.body(x)
        return self.output(x)

這邊值得一提的地方是 MXNet 採用 lazy evaluation,所以我們不需要定義輸入的 channel size

以 nn.Conv2D(20, 3, activation="relu") 為例,20 是輸出的 channel size,3 是 kernel size。


2. 測試 forward 運算:

實例化/初始化 SimpleNet。

net = SimpleNet()
net.initialize()

在計算前 net 其實不知道任何 input channel size,每一層的 input channel size 是在 forward 後自動推導出來。

我們可以設個 x 來測試 forward(),MXNet 的 input 順序是 [N,C,H,W] (註: N, C, H, W 分別是 個數, channel, 高, 寬)。

x = nd.ones((10, 1, 28, 28))
out = net(x)
print(out.shape) # [N, C] => [10, 10]
# print(out)

留言

熱門文章