Caffe2 (1) Net

1. Packages:

from caffe2.python import (
    workspace,
    model_helper,
    net_drawer,
)
import numpy as np

2. Data and Label:

將假 data 和 label 做成 blob 放到 workspace 中。data 的維度是 [16, 100],即 16 筆大小為 100 的資料。label 是 16 筆 0~9 之間的數。

# Create the input data, 16 is the size of the batch
data = np.random.rand(16, 100).astype(np.float32)

# Create labels for the data as integers [0, 9].
label = (np.random.rand(16) * 10).astype(np.int32)

workspace.FeedBlob("data", data)
workspace.FeedBlob("label", label)

3. 建立 model:

ModelHelper 可以幫我們建立 model,model 有兩個重要的成員 param_init_net 和 net。param_init_net 用於初始化參數,net 定義運算過程。

# Create model using a model helper
m = model_helper.ModelHelper(name="my first net")
weight = m.param_init_net.XavierFill([], 'fc_w', shape=[10, 100])
bias = m.param_init_net.ConstantFill([], 'fc_b', shape=[10, ])

fc_1 = m.net.FC(["data", "fc_w", "fc_b"], "fc1")
pred = m.net.Sigmoid(fc_1, "pred")
softmax, loss = m.net.SoftmaxWithLoss([pred, "label"], ["softmax", "loss"])

一系列的動作目的是產生 protobuf 文件。caffe 的風格就是編寫 protobuf 解析再執行,caffe2 也類似 caffe。

print(m.net.Proto())
print(m.param_init_net.Proto())

將 protobuf 可視化:

graph = net_drawer.GetPydotGraph(m.net, rankdir="BT")
graph.write_png("hello.png")

4. 加入梯度:

把梯度和修正的計算也加入模型中。

前面提過 param_init_net 用於初始化參數,net 定義運算過程,我們先 init 再 create network。

m.AddGradientOperators([loss])  # add gradient
# print(m.net.Proto())          # observe gradient

workspace.RunNetOnce(m.param_init_net)
workspace.CreateNet(m.net)

for ii in range(100):
    data = np.random.rand(16, 100).astype(np.float32)
    label = (np.random.rand(16) * 10).astype(np.int32)

    workspace.FeedBlob("data", data)
    workspace.FeedBlob("label", label)

    workspace.RunNet(m.name, 10)   # run for 10 times
    # print("Run: ", ii)

# save the model with grad
graph = net_drawer.GetPydotGraph(m.net, rankdir="BT")
graph.write_png("hello_with_grad.png")

由圖可知我們把梯度傳播也加入計算的傳遞中,計算完梯度後會再反向傳遞修正參數。註: 雖然名為反向傳遞,但在圖的上半部箭頭方向並未改變,可以理解為計算進行的方向。


Reference:

[1] https://caffe2.ai/docs/intro-tutorial.html

留言

熱門文章