def step(self, closure=None): """Performs a single optimization step.""" loss = None if closure is not None: loss = closure()
for group in self.param_groups: for p in group['params']: if p.grad is None: continue d_p = p.grad.data if group['momentum'] > 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() buf.mul_(group['momentum']) else: buf = param_state['momentum_buffer'] buf.mul_(group['momentum']).add_(d_p, alpha=1 - group['momentum']) d_p = buf
if group['weight_decay'] != 0: d_p.add_(p.data, alpha=group['weight_decay'])
def get_lr(self): """Calculate the learning rate at a given step.""" return [base_lr * self.gamma ** (self.last_step // self.step_size) for base_lr in self.base_lrs]
def step(self, epoch=None): """Update the learning rate at the end of the given epoch.""" if epoch is None: self.last_step += 1 else: self.last_step = epoch + 1 self._last_lr = self.get_lr() for param_group, lr in zip(self.optimizer.param_groups, self._last_lr): param_group['lr'] = lr
def train_model(model,criterion,optimizer,scheduler,num_epochs): since=time.time() best_model_wts=copy.deepcopy(model.state_dict()) best_acc=0.0 #每个epoch for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch,num_epochs-1)) print('-'*10) # 分为训练或者测试阶段 for phase in ['train','val']: if phase=='train': model.train() else: model.eval() running_loss=0.0 running_corrects=0 # 每个批次进行计算损失和反向梯度 for inputs,labels in dataloaders[phase]: inputs=inputs.to(device) labels=labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase=='train'): outputs=model(inputs) _,preds=torch.max(outputs,1) loss=criterion(outputs,labels) if phase=='train': loss.backward() optimizer.step() running_loss+=loss.item()*inputs.size(0) running_corrects+=torch.sum(preds==labels.data) epoch_loss=running_loss/dataset_sizes[phase] epoch_acc=running_corrects/dataset_sizes[phase] print('{} Loss :{:.4f} Acc:{:.4f}'.format(phase,epoch_loss,epoch_acc)) if phase=='val' and epoch_acc>best_acc: best_acc=epoch_acc best_model_wts=copy.deepcopy(model.state_dict()) scheduler.step() print() time_elapsed=time.time()-since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed//60,time_elapsed%60)) model.load_state_dict(best_model_wts) return model
main.py
主要负责对于各个文件部分的引用,整合代码,基本逻辑为
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#数据集 dataset #数据迭代器 dataloader #模型 model #损失函数 criterion #优化器 optimizer #学习率优化器 lr_scheduler #训练 train #保存 save #预测 predict
由于评论系统依托于Github的Discuss存在,因此默认评论者会收到所有通知。可以在邮件里点击"unsubscribe"停止接受,后续也可以点击下列仓库进行通知管理:
bg51717/Hexo-Blogs-comments
Since the comment system relies on GitHub's Discussions feature, by default, commentators will receive all notifications. You can click "unsubscribe" in the email to stop receiving them, and you can also manage your notifications by clicking on the following repositories:
bg51717/Hexo-Blogs-comments
Since the comment system relies on GitHub's Discussions feature, by default, commentators will receive all notifications. You can click "unsubscribe" in the email to stop receiving them, and you can also manage your notifications by clicking on the following repositories: bg51717/Hexo-Blogs-comments