14 lines
580 B
Plaintext
14 lines
580 B
Plaintext
|
|
#定又训练亟
|
|||
|
|
def train():
|
|||
|
|
model.train()
|
|||
|
|
for i, data in enumerate(train loader):
|
|||
|
|
#获得一个batch的数据和对应的标签
|
|||
|
|
inputs,labels = data
|
|||
|
|
inputs,labels =inputs.to(device),labels.to(device)#将数据移到GPU或CPU上
|
|||
|
|
#数据正向传播,(64,8)需要补全以下代码!!!
|
|||
|
|
#计算代价(误差),out(batch,c),labels(batch)
|
|||
|
|
需要补全以下代码!!
|
|||
|
|
# 梯度凊θ
|
|||
|
|
optimizer.zero grad()#误差反向传播---需要补全以下代码!!
|
|||
|
|
#更新网络模型爹数
|
|||
|
|
optimizer.step()
|