first commit
This commit is contained in:
14
user/123
Normal file
14
user/123
Normal file
@@ -0,0 +1,14 @@
|
||||
#定又训练亟
|
||||
def train():
|
||||
model.train()
|
||||
for i, data in enumerate(train loader):
|
||||
#获得一个batch的数据和对应的标签
|
||||
inputs,labels = data
|
||||
inputs,labels =inputs.to(device),labels.to(device)#将数据移到GPU或CPU上
|
||||
#数据正向传播,(64,8)需要补全以下代码!!!
|
||||
#计算代价(误差),out(batch,c),labels(batch)
|
||||
需要补全以下代码!!
|
||||
# 梯度凊θ
|
||||
optimizer.zero grad()#误差反向传播---需要补全以下代码!!
|
||||
#更新网络模型爹数
|
||||
optimizer.step()
|
||||
Reference in New Issue
Block a user