具體流程參考教程:MindSpore快速入門 MindSpore 接口文檔
注:本文章記錄的是我在開發(fā)過程中的學(xué)習(xí)筆記,僅供參考學(xué)習(xí),歡迎討論,但不作為開發(fā)教程使用。
def datapipe(dataset, batch_size):
'''
數(shù)據(jù)處理流水線
'''
image_transform = [
vision.Rescale(1.0 / 255.0, 0), # 縮放 output = image * rescale + shift.
vision.Normalize(mean=(0.1307,), std=(0.3081,)), # 根據(jù)平均值和標(biāo)準(zhǔn)偏差對輸入圖像進(jìn)行歸一化
vision.HWC2CHW() # 轉(zhuǎn)換為NCHW格式
]
label_transform = transforms.TypeCast(mindspore.int32) # 轉(zhuǎn)為mindspore的int32格式
dataset = dataset.map(image_transform, 'image') # 對各個(gè)圖像按照流水線處理
dataset = dataset.map(label_transform, 'label') # 對各個(gè)標(biāo)簽轉(zhuǎn)換為int32
dataset = dataset.batch(batch_size)
return dataset
這段代碼中對輸入圖片進(jìn)行了縮放、歸一化和格式轉(zhuǎn)換三個(gè)操作,按照流水線運(yùn)行。
流水線操作數(shù)據(jù)流水線處理的介紹:【AI設(shè)計(jì)模式】03-數(shù)據(jù)處理-流水線(Pipeline)模式
總結(jié)而言,海量數(shù)據(jù)下,流水線模式可以實(shí)現(xiàn)高效的數(shù)據(jù)處理,當(dāng)然也會占用更多的CPU和內(nèi)存資源。
mindspore下dataset的map操作:第一個(gè)參數(shù)是處理函數(shù)列表,第二個(gè)參數(shù)是需要處理的列。
map函數(shù)會將數(shù)據(jù)集中第二個(gè)參數(shù)的指定的列作為輸入,調(diào)用第一個(gè)參數(shù)的處理函數(shù)執(zhí)行處理,如果有多個(gè)處理函數(shù),上一個(gè)函數(shù)的輸出作為下一個(gè)函數(shù)的輸入。
NCHW
缺點(diǎn):必須等所有通道輸入準(zhǔn)備好才能得到最終輸出結(jié)果,需要占用較大的臨時(shí)空間。
優(yōu)點(diǎn):是 Nvidia cuDNN 默認(rèn)格式,使用 GPU 加速時(shí)用 NCHW 格式速度會更快。(這個(gè)是什么原因呢?沒找到資料_(:з」∠)_)
NHWC
缺點(diǎn):GPU 加速較NCHW更慢
優(yōu)點(diǎn):訪存局部性更好(每三個(gè)輸入像素即可得到一個(gè)輸出像素)
參考文章:【深度學(xué)習(xí)框架輸入格式】NCHW還是NHWC?
為什么pytorch中transforms.ToTorch要把(H,W,C)的矩陣轉(zhuǎn)為(C,H,W)?
class Network(nn.Cell):
'''
Network model
'''
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.dense_relu_sequential = nn.SequentialCell(
nn.Dense(28*28, 512),
nn.ReLU(),
nn.Dense(512, 512),
nn.ReLU(),
nn.Dense(512, 10)
)
def construct(self, x):
x = self.flatten(x)
logits = self.dense_relu_sequential(x)
return logits
基類mindspore的模型基類是mindspore.nn.Cell
pytorch的模型基類是torch.nn.Module
mindspore的全連接層是mindspore.nn.Dense
pytorch的全連接層是torch.nn.Linear
mindspore的順序容器是mindspore.nn.SequentialCell
pytorch的順序容器是torch.nn.Sequential
mindspore的前向傳播函數(shù)(要執(zhí)行的計(jì)算邏輯)基類函數(shù)為construct(self, xxx)
pytorch的前向傳播函數(shù)基類函數(shù)為forward(self, xxx)
my_loss_fn = nn.CrossEntropyLoss()
my_optimizer = nn.SGD(model.trainable_params(), 1e-2)
交叉熵:把來自一個(gè)分布q的消息使用另一個(gè)分布p的最佳代碼傳達(dá)方式計(jì)算得到的平均消息長度,即為交叉熵。針對交叉熵,這個(gè)文章講的較好:損失函數(shù):交叉熵詳解
mindspore的交叉熵函數(shù)和pytorch類似:
前者是mindspore.nn.CrossEntropyLoss(),后者是torch.nn.CrossEntropyLoss()
def train(model_train, dataset, loss_fn, optimizer):
'''
訓(xùn)練函數(shù)
'''
# Define forward function
def forward_fn(data, label):
logits = model_train(data)
loss = loss_fn(logits, label)
return loss, logits
# Get gradient function
grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
# Define function of one-step training
def train_step(data, label):
(loss, _), grads = grad_fn(data, label)
loss = ops.depend(loss, optimizer(grads))
return loss
size = dataset.get_dataset_size()
model_train.set_train()
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
loss = train_step(data, label)
if batch % 100 == 0:
loss, current = loss.asnumpy(), batch
print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
value_and_grad官網(wǎng)對value_and_grad函數(shù)的介紹如下:mindspore.ops.value_and_grad
按照官網(wǎng)的描述,這個(gè)函數(shù)的作用是:生成求導(dǎo)函數(shù),用于計(jì)算給定函數(shù)的正向計(jì)算結(jié)果和梯度。
我們需要給這個(gè)函數(shù)傳入模型的正向傳輸函數(shù)和待求導(dǎo)的參數(shù)
其中模型的正向傳輸函數(shù)需要封裝一下,返回loss的計(jì)算, 用于后續(xù)優(yōu)化器的梯度計(jì)算;
待求導(dǎo)的參數(shù)可以寫為model.trainable_params(),也可以由優(yōu)化器提供(optimizer.parameters),因?yàn)閮?yōu)化器初始化時(shí)已經(jīng)傳入需要求導(dǎo)的參數(shù)。
總之,這個(gè)接口返回的是一個(gè)函數(shù),函數(shù)的作用是把正向傳播、反向傳播的整個(gè)流程走一遍,最后的輸出為正向傳輸函數(shù)的返回值+待求導(dǎo)參數(shù)的梯度值
在訓(xùn)練時(shí)使用到了depend算子,官網(wǎng)對Depend函數(shù)的介紹如下:mindspore.ops.Depend
# Define function of one-step training
def train_step(data, label):
(loss, _), grads = grad_fn(data, label)
loss = ops.depend(loss, optimizer(grads))
return loss
經(jīng)詢問分析,使用depend算子的原因是,在靜態(tài)圖模式下,函數(shù)執(zhí)行的先后順序可能會被優(yōu)化,這就可能存在loss在grad_fn(value_and_grad)之前就被返回使用的情況,導(dǎo)致返回的loss不正確。
因此通過使用depend算子,來保證loss的返回動作在optimizer之后執(zhí)行,而optimizer的輸入依賴grad_fn,因此optimizer一定在grad_fn之后執(zhí)行,這就保證了depend返回的loss確實(shí)是經(jīng)過grad_fn計(jì)算的最新結(jié)果。
當(dāng)然,mindspore也是支持動態(tài)圖模式的,只需加一行代碼:
# 設(shè)置為動態(tài)圖模式
mindspore.set_context(mode=mindspore.PYNATIVE_MODE)
# 設(shè)置為靜態(tài)圖模式
# mindspore.set_context(mode=mindspore.GRAPH_MODE)
model = Network()
print(model)
然后訓(xùn)練函數(shù)就可以這么寫:
def train_step(data, label):
(loss, _), grads = grad_fn(data, label)
optimizer(grads)
return loss
但是實(shí)測,動態(tài)圖模式下,訓(xùn)練速度相比靜態(tài)圖慢了很多。
關(guān)于mindspore動態(tài)圖和靜態(tài)圖模式的介紹,可看這個(gè)官方文檔:動靜態(tài)圖
各個(gè)文章在介紹梯度下降法時(shí),通常介紹的是批量梯度下降法,但是訓(xùn)練模型時(shí)用的最多的是小批量梯度下降法。這里先講下批量梯度下降、隨機(jī)梯度下降和小批量梯度下降的區(qū)別。
批量梯度下降批量梯度下降法的流程是:假設(shè)有1000個(gè)數(shù)據(jù),經(jīng)過正向計(jì)算,得到1000個(gè)計(jì)算結(jié)果,誤差函數(shù)的計(jì)算公式依賴這1000個(gè)計(jì)算結(jié)果;再對誤差函數(shù)進(jìn)行反向傳播求導(dǎo),得到模型里參數(shù)的梯度值;同樣地,對誤差函數(shù)求導(dǎo)得梯度,也依賴這1000個(gè)計(jì)算結(jié)果;最后基于學(xué)習(xí)率更新參數(shù),然后進(jìn)入下一輪訓(xùn)練。
因此,標(biāo)準(zhǔn)的批量梯度下降,需要每次計(jì)算出1000個(gè)數(shù)據(jù)的正向傳播結(jié)果,才可以得到參數(shù)梯度值,然后下一輪訓(xùn)練,重新計(jì)算1000個(gè)計(jì)算結(jié)果…這就存在大量的運(yùn)算量,使得訓(xùn)練容易變得非常耗時(shí)。
隨機(jī)梯度下降法的流程是,假設(shè)有1000個(gè)數(shù)據(jù),我們隨機(jī)取1個(gè)數(shù)據(jù),經(jīng)過正向計(jì)算,得到1個(gè)計(jì)算結(jié)果,誤差函數(shù)的計(jì)算公式就只依賴這1個(gè)計(jì)算結(jié)果;然后反向傳播求導(dǎo),得到基于1個(gè)計(jì)算結(jié)果的梯度值,最后基于學(xué)習(xí)率更新參數(shù),然后進(jìn)入下一輪訓(xùn)練。下一輪訓(xùn)練時(shí),隨機(jī)取另1個(gè)數(shù)據(jù),重復(fù)上述操作…
這種方法下,極大地降低了計(jì)算量,而且理論上,只要數(shù)據(jù)量夠大,數(shù)據(jù)足夠隨機(jī),最后也總會下降到所需極值點(diǎn),畢竟計(jì)算數(shù)據(jù)量小了很多,算得更快了,下降速度也會快很多。但是每次只依賴1個(gè)數(shù)據(jù),就使得梯度的下降方向在整體方向上不穩(wěn)定,容易到處飄,最后的結(jié)果可能不會是全局最優(yōu)。
小批量梯度下降法的流程是:假設(shè)有1000個(gè)數(shù)據(jù),我們隨機(jī)取100個(gè)數(shù)據(jù),經(jīng)過正向計(jì)算,得到100個(gè)計(jì)算結(jié)果,誤差函數(shù)的計(jì)算公式依賴這100個(gè)計(jì)算結(jié)果;然后反向傳播求導(dǎo),得到基于100個(gè)計(jì)算結(jié)果的梯度值,最后基于學(xué)習(xí)率更新參數(shù),然后進(jìn)入下一輪訓(xùn)練。下一輪訓(xùn)練時(shí),隨機(jī)取另100個(gè)數(shù)據(jù),重復(fù)上述操作…
可以看出,小批量梯度下降 結(jié)合了 批量梯度下降 和 隨機(jī)梯度下降 的優(yōu)缺點(diǎn),使得計(jì)算即不那么耗時(shí),又保證參數(shù)更新路徑和結(jié)果相對穩(wěn)定。
mindspore的這個(gè)例子用的是小批量梯度下降,train_step每次輸入64個(gè)數(shù)據(jù),然后前向傳播、計(jì)算梯度、更新參數(shù),再進(jìn)入下一個(gè)epoch,隨機(jī)取新的64個(gè)數(shù)據(jù),重復(fù)訓(xùn)練…
size = dataset.get_dataset_size()
model_train.set_train()
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
loss = train_step(data, label)
if batch % 100 == 0:
loss, current = loss.asnumpy(), batch
print(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
在將數(shù)據(jù)集進(jìn)行datapipe后,返回的train_dataset和test_dataset都是以batch_size=64個(gè)為一組進(jìn)行輸出的,此處dataset.get_dataset_size()返回的size是有多少組數(shù)據(jù)。測試集返回的size為938,表示一共有938組,每組64個(gè)圖片數(shù)據(jù)。實(shí)際上MNIST只有60000個(gè)測試集圖片,因此最后一組只有32個(gè)圖片。
運(yùn)行結(jié)果Epoch 1
-------------------------------
loss: 2.303684 [ 0/938]
loss: 2.291476 [100/938]
loss: 2.273411 [200/938]
loss: 2.212310 [300/938]
loss: 1.969760 [400/938]
loss: 1.600426 [500/938]
loss: 1.004380 [600/938]
loss: 0.735266 [700/938]
loss: 0.672223 [800/938]
loss: 0.578563 [900/938]
Test:
Accuracy: 85.3%, Avg loss: 0.528851
Epoch 2
-------------------------------
loss: 0.384008 [ 0/938]
loss: 0.453575 [100/938]
loss: 0.277697 [200/938]
loss: 0.317674 [300/938]
loss: 0.294471 [400/938]
loss: 0.519272 [500/938]
loss: 0.253794 [600/938]
loss: 0.389252 [700/938]
loss: 0.383196 [800/938]
loss: 0.334877 [900/938]
Test:
Accuracy: 90.2%, Avg loss: 0.334850
此處跑了兩輪訓(xùn)練,可以看出,第一輪的938組數(shù)據(jù)的訓(xùn)練過程中,參數(shù)快速調(diào)整至合理范圍(loss從2.3降低到0.5),但是第二輪的938組數(shù)據(jù)的訓(xùn)練過程中,loss出現(xiàn)了上下波動(0.3->0.4->0.2->0.3…),即模型參數(shù)向當(dāng)前數(shù)據(jù)組的梯度下降的方向走了一小步后,新的數(shù)據(jù)組算出的loss反而比之前還提高了。
這主要是因?yàn)楫?dāng)前數(shù)據(jù)組的梯度下降方向 無法代表 替他數(shù)據(jù)組/所有數(shù)據(jù)的梯度下降方向,當(dāng)然也可能是學(xué)習(xí)率(步長)太大導(dǎo)致跨過了最低點(diǎn),這個(gè)就具體問題具體分析了。
mindspore和pytorch在接口命名上存在區(qū)別,但是實(shí)際使用過程中,開發(fā)思路還是一致的。因此最關(guān)鍵的還是要熟悉深度學(xué)習(xí)的思路和流程,至于思路和代碼實(shí)現(xiàn)的映射,這就唯手熟爾。
你是否還在尋找穩(wěn)定的海外服務(wù)器提供商?創(chuàng)新互聯(lián)www.cdcxhl.cn海外機(jī)房具備T級流量清洗系統(tǒng)配攻擊溯源,準(zhǔn)確流量調(diào)度確保服務(wù)器高可用性,企業(yè)級服務(wù)器適合批量采購,新人活動首月15元起,快前往官網(wǎng)查看詳情吧
當(dāng)前題目:【mindspore】mindspore實(shí)現(xiàn)手寫數(shù)字識別-創(chuàng)新互聯(lián)
網(wǎng)頁路徑:http://bm7419.com/article40/gdsho.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供自適應(yīng)網(wǎng)站、商城網(wǎng)站、ChatGPT、微信小程序、品牌網(wǎng)站設(shè)計(jì)、面包屑導(dǎo)航
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請盡快告知,我們將會在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場,如需處理請聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容