Files

14 lines
580 B
Plaintext
Raw Permalink Normal View History

2025-11-24 11:20:33 +08:00
#定又训练亟
def train():
model.train()
for i, data in enumerate(train loader):
#获得一个batch的数据和对应的标签
inputs,labels = data
inputs,labels =inputs.to(device),labels.to(device)#将数据移到GPU或CPU上
#数据正向传播,(648)需要补全以下代码!!!
#计算代价(误差)out(batch,c),labels(batch)
需要补全以下代码!!
# 梯度凊θ
optimizer.zero grad()#误差反向传播---需要补全以下代码!!
#更新网络模型爹数
optimizer.step()