MXNet (4) Gluon - nn.HybridBlock
上次我們介紹了 Gluon 的 nn.Block,這次我們要來介紹它的兄弟 nn.HybridBlock。
nn.HybridBlock 主要是希望把 nn.Block 定義的運算流程進行優化進而提升計算速度。
1. 使用方法:
類似 nn.Block,nn.HybridBlock 使用時需要遵循下列步驟:
- 使類別繼承 nn.HybridBlock
- 實做 hybrid_forward()
- 在實例初始化後呼叫 hybridize() 優化計算流程
from mxnet import nd from mxnet.gluon import nn class SimpleNet(nn.HybridBlock): # 1 使類別繼承 nn.HybridBlock def __init__(self, **kwargs): super(SimpleNet, self).__init__(**kwargs) self.body = nn.HybridSequential() self.body.add( nn.Conv2D(20, 3, activation="relu"), nn.Conv2D(36, 3, activation="relu"), ) self.output = nn.Dense(10, activation="relu") # 2 實做 hybrid_forward() def hybrid_forward(self, F, x, *args, **kwargs): x = self.body(x) return self.output(x) net = SimpleNet() net.initialize() net.hybridize() # 3 呼叫 hybridize() x = nd.ones((10, 3, 28, 28)) out = net(x) print(out.shape)
特別提一下使用 HybridBlock 需要注意的事項:
- nn.Sequential 是 nn.Block 的容器,nn.HybridBlock 要使用 nn.HybridSequential
- hybridize 後不接受 slice, broadcast 等運算,但可以透過 hybrid_forward 的 F 達成同樣效果。
- 不呼叫 hybridize() 程式會以 imperative 的模式執行。
2. 小小結論:
最後整理一下重點部份:
- Block, Sequential, forward 對應 HybridBlock, HybridSequential, hybrid_forward。
- hybrid_forward 中很多運算要靠 F 達成。
留言
張貼留言