def step(self, closure=None): """Performs a single optimization step.""" loss = None if closure is not None: loss = closure()
for group in self.param_groups: for p in group['params']: if p.grad is None: continue d_p = if group['momentum'] > 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() buf.mul_(group['momentum']) else: buf = param_state['momentum_buffer'] buf.mul_(group['momentum']).add_(d_p, alpha=1 - group['momentum']) d_p = buf
if group['weight_decay'] != 0: d_p.add_(, alpha=group['weight_decay'])
def get_lr(self): """Calculate the learning rate at a given step.""" return [base_lr * self.gamma ** (self.last_step // self.step_size) for base_lr in self.base_lrs]
def step(self, epoch=None): """Update the learning rate at the end of the given epoch.""" if epoch is None: self.last_step += 1 else: self.last_step = epoch + 1 self._last_lr = self.get_lr() for param_group, lr in zip(self.optimizer.param_groups, self._last_lr): param_group['lr'] = lr
def train_model(model,criterion,optimizer,scheduler,num_epochs): since=time.time() best_model_wts=copy.deepcopy(model.state_dict()) best_acc=0.0 #每个epoch for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch,num_epochs-1)) print('-'*10) # 分为训练或者测试阶段 for phase in ['train','val']: if phase=='train': model.train() else: model.eval() running_loss=0.0 running_corrects=0 # 每个批次进行计算损失和反向梯度 for inputs,labels in dataloaders[phase]: optimizer.zero_grad() with torch.set_grad_enabled(phase=='train'): outputs=model(inputs) _,preds=torch.max(outputs,1) loss=criterion(outputs,labels) if phase=='train': loss.backward() optimizer.step() running_loss+=loss.item()*inputs.size(0) running_corrects+=torch.sum( epoch_loss=running_loss/dataset_sizes[phase] epoch_acc=running_corrects/dataset_sizes[phase] print('{} Loss :{:.4f} Acc:{:.4f}'.format(phase,epoch_loss,epoch_acc)) if phase=='val' and epoch_acc>best_acc: best_acc=epoch_acc best_model_wts=copy.deepcopy(model.state_dict()) scheduler.step() print() time_elapsed=time.time()-since print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed//60,time_elapsed%60)) model.load_state_dict(best_model_wts) return model
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#数据集 dataset #数据迭代器 dataloader #模型 model #损失函数 criterion #优化器 optimizer #学习率优化器 lr_scheduler #训练 train #保存 save #预测 predict
Since the comment system relies on GitHub's Discussions feature, by default, commentators will receive all notifications. You can click "unsubscribe" in the email to stop receiving them, and you can also manage your notifications by clicking on the following repositories:
Since the comment system relies on GitHub's Discussions feature, by default, commentators will receive all notifications. You can click "unsubscribe" in the email to stop receiving them, and you can also manage your notifications by clicking on the following repositories: bg51717/Hexo-Blogs-comments