MXNet (4) Gluon - nn.HybridBlock

上次我們介紹了 Gluon 的 nn.Block,這次我們要來介紹它的兄弟 nn.HybridBlock。

nn.HybridBlock 主要是希望把 nn.Block 定義的運算流程進行優化進而提升計算速度。


1. 使用方法:

類似 nn.Block,nn.HybridBlock 使用時需要遵循下列步驟:

  1. 使類別繼承 nn.HybridBlock
  2. 實做 hybrid_forward()
  3. 在實例初始化後呼叫 hybridize() 優化計算流程
from mxnet import nd
from mxnet.gluon import nn


class SimpleNet(nn.HybridBlock): # 1 使類別繼承 nn.HybridBlock

    def __init__(self, **kwargs):
        super(SimpleNet, self).__init__(**kwargs)
        self.body = nn.HybridSequential()
        self.body.add(
            nn.Conv2D(20, 3, activation="relu"),
            nn.Conv2D(36, 3, activation="relu"),
        )
        self.output = nn.Dense(10, activation="relu")

    # 2 實做 hybrid_forward()
    def hybrid_forward(self, F, x, *args, **kwargs):
        x = self.body(x)
        return self.output(x)

net = SimpleNet()
net.initialize()
net.hybridize() # 3 呼叫 hybridize()

x = nd.ones((10, 3, 28, 28))
out = net(x)
print(out.shape)

特別提一下使用 HybridBlock 需要注意的事項:

  1. nn.Sequential 是 nn.Block 的容器,nn.HybridBlock 要使用 nn.HybridSequential
  2. hybridize 後不接受 slice, broadcast 等運算,但可以透過 hybrid_forward 的 F 達成同樣效果。
  3. 不呼叫 hybridize() 程式會以 imperative 的模式執行。

2. 小小結論:

最後整理一下重點部份:

  1. Block, Sequential, forward 對應 HybridBlock, HybridSequential, hybrid_forward。
  2. hybrid_forward 中很多運算要靠 F 達成。

留言

熱門文章