pytorch_model

文章目录
  1. 1. 介绍
  2. 2. dataset.py
  3. 3. models.py
  4. 4. criterion.py
  5. 5. optimizer.py
  6. 6. lr_scheduler.py
  7. 7. train.py
  8. 8. main.py
  9. 9. 可选优化
  10. 10. 其他
  11. 11. 参考资料

介绍

作为深度学习的基本模板使用,方便使用的时候作为骨架

许多文件可以考虑添加argparsesh来引入外部配置来抽象过程,增强代码重用性

dataset.py

这个文件提供了各种数据集的定义,自定义数据集需要实习三个主要函数

1
2
3
4
5
6
7
8
9
10
11
class MyDataset(torch.utils.data.Dataset):
def __init__(self):
super().__init__()
#todo

def __getitem__(self,idx):
#todo

def __len__(self):
#todo

models.py

这个文件负责提供各种模型的定义,可以是完全自定义的模型或者预训练模型

1
2
3
4
5
6
7
8
9
10
class MyModel(torch.nn.Module):
def __init__(self)
super().__init__()
#todo
def forward(self,input):
#todo

def get_model(config):
#todo
return model

criterion.py

这个文件负责提供各种损失函数的定义,可以是完全自定义的损失函数或者框架提供的损失函数

1
2
3
4
5
6
7
8
9
10
11
12
13
class CustomLoss(nn.Module):
def __init__(self):
super(CustomLoss, self).__init__()
# 在这里初始化你的参数,如果有的话

def forward(self, input, target):
# 计算损失的逻辑
# 例如,这里我们使用简单的均方误差作为示例
loss = torch.mean((input - target) ** 2)
return loss

def get_criterion()
#todo

optimizer.py

这个文件负责提供各种优化器的定义,可以是完全自定义的优化器或者框架提供的优化器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class CustomOptimizer(torch.optim.Optimizer):
def __init__(self, params, lr=0.01, momentum=0.5, weight_decay=0, learning_rate_decay=0.9):
defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, learning_rate_decay=learning_rate_decay)
super(CustomOptimizer, self).__init__(params, defaults)

def step(self, closure=None):
"""Performs a single optimization step."""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if group['momentum'] > 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
buf.mul_(group['momentum'])
else:
buf = param_state['momentum_buffer']
buf.mul_(group['momentum']).add_(d_p, alpha=1 - group['momentum'])
d_p = buf

if group['weight_decay'] != 0:
d_p.add_(p.data, alpha=group['weight_decay'])

p.data.add_(d_p, alpha=-lr)

return loss

def get_Optimizer():
#todo

lr_scheduler.py

这个文件负责提供各种学习率调度器的定义,可以是完全自定义的学习率调度器或者框架提供的学习率调度器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

class MyLRScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, step_size=10, gamma=0.1):
self.step_size = step_size
self.gamma = gamma
super(CustomLRScheduler, self).__init__(optimizer)

def get_lr(self):
"""Calculate the learning rate at a given step."""
return [base_lr * self.gamma ** (self.last_step // self.step_size)
for base_lr in self.base_lrs]

def step(self, epoch=None):
"""Update the learning rate at the end of the given epoch."""
if epoch is None:
self.last_step += 1
else:
self.last_step = epoch + 1
self._last_lr = self.get_lr()
for param_group, lr in zip(self.optimizer.param_groups, self._last_lr):
param_group['lr'] = lr

def get_lr_scheduler()
#todo

train.py

这个文件负责提供各种训练方法和过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def train_model(model,criterion,optimizer,scheduler,num_epochs):
since=time.time()
best_model_wts=copy.deepcopy(model.state_dict())
best_acc=0.0
#每个epoch
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch,num_epochs-1))
print('-'*10)
# 分为训练或者测试阶段
for phase in ['train','val']:
if phase=='train':
model.train()
else:
model.eval()
running_loss=0.0
running_corrects=0
# 每个批次进行计算损失和反向梯度
for inputs,labels in dataloaders[phase]:
inputs=inputs.to(device)
labels=labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase=='train'):
outputs=model(inputs)
_,preds=torch.max(outputs,1)
loss=criterion(outputs,labels)
if phase=='train':
loss.backward()
optimizer.step()
running_loss+=loss.item()*inputs.size(0)
running_corrects+=torch.sum(preds==labels.data)
epoch_loss=running_loss/dataset_sizes[phase]
epoch_acc=running_corrects/dataset_sizes[phase]
print('{} Loss :{:.4f} Acc:{:.4f}'.format(phase,epoch_loss,epoch_acc))
if phase=='val' and epoch_acc>best_acc:
best_acc=epoch_acc
best_model_wts=copy.deepcopy(model.state_dict())
scheduler.step()
print()

time_elapsed=time.time()-since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed//60,time_elapsed%60))
model.load_state_dict(best_model_wts)
return model

main.py

主要负责对于各个文件部分的引用,整合代码,基本逻辑为

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#数据集
dataset
#数据迭代器
dataloader
#模型
model
#损失函数
criterion
#优化器
optimizer
#学习率优化器
lr_scheduler
#训练
train
#保存
save
#预测
predict

可选优化

  • 梯度裁剪 torch.nn.utils.clip_grad_norm_
  • 加载最优参数
  • ...

其他

可以去pytorch官网找找一些关于模型的优化。

参考资料

由于评论系统依托于Github的Discuss存在,因此默认评论者会收到所有通知。可以在邮件里点击"unsubscribe"停止接受,后续也可以点击下列仓库进行通知管理: bg51717/Hexo-Blogs-comments
Since the comment system relies on GitHub's Discussions feature, by default, commentators will receive all notifications. You can click "unsubscribe" in the email to stop receiving them, and you can also manage your notifications by clicking on the following repositories: bg51717/Hexo-Blogs-comments