MXNet (3) Gluon - nn.Block
MXNet 的 Gluon module 引入了類似 PyTorch 的接口,使得我們可以輕鬆地架構深度神經網路。
在 Gluon 中定義了 nn.Block 和 nn.HybridBlock ,今天先從 nn.Block 這個 class 談起。
1. 建構一個簡單的網路架構:
使用 Gluon 建立網路架構主要需要注意兩件事:
- 使類別繼承 nn.Block 或 nn.HybridBlock
- 繼承 nn.Block 要實做 forward(),繼承 nn.HybridBlock 要實做hybrid_forward。
我們定義一個簡單的捲積神經網路 SimpleNet 繼承了 nn.Block 並實做 forward():
from mxnet import nd from mxnet.gluon import nn class SimpleNet(nn.Block): def __init__(self, **kwargs): super(SimpleNet, self).__init__(**kwargs) self.body = nn.Sequential() self.body.add( nn.Conv2D(20, 3, activation="relu"), nn.Conv2D(36, 3, activation="relu"), ) self.output = nn.Dense(10, activation="relu") def forward(self, x): x = self.body(x) return self.output(x)
這邊值得一提的地方是 MXNet 採用 lazy evaluation,所以我們不需要定義輸入的 channel size。
以 nn.Conv2D(20, 3, activation="relu") 為例,20 是輸出的 channel size,3 是 kernel size。
2. 測試 forward 運算:
實例化/初始化 SimpleNet。
net = SimpleNet() net.initialize()
在計算前 net 其實不知道任何 input channel size,每一層的 input channel size 是在 forward 後自動推導出來。
我們可以設個 x 來測試 forward(),MXNet 的 input 順序是 [N,C,H,W] (註: N, C, H, W 分別是 個數, channel, 高, 寬)。
x = nd.ones((10, 1, 28, 28)) out = net(x) print(out.shape) # [N, C] => [10, 10] # print(out)
留言
張貼留言