14 lines
580 B
Plaintext
14 lines
580 B
Plaintext
#定又训练亟
|
||
def train():
|
||
model.train()
|
||
for i, data in enumerate(train loader):
|
||
#获得一个batch的数据和对应的标签
|
||
inputs,labels = data
|
||
inputs,labels =inputs.to(device),labels.to(device)#将数据移到GPU或CPU上
|
||
#数据正向传播,(64,8)需要补全以下代码!!!
|
||
#计算代价(误差),out(batch,c),labels(batch)
|
||
需要补全以下代码!!
|
||
# 梯度凊θ
|
||
optimizer.zero grad()#误差反向传播---需要补全以下代码!!
|
||
#更新网络模型爹数
|
||
optimizer.step() |