【炼丹保姆】early stopping 的pytorch实现
-
什么是early stopping?
一种防止模型过拟合的小技巧,能清楚的告诉我们什么时候训练应该结束了
核心思想
若模型经过连续n个epochs的训练,验证数据集的损失值仍然没有实现向下突破,则我们认为模型训练遇到了瓶颈,再往下训练则模型可能存在过拟合的风险,此时应该提前终止训练,而模型本应该在N个epochs后停止训练,所以叫做early stopping。
代码
import torch import numpy as np EPOCHS = # 定义最大epoch训练数 model = # 调用模型 data_loaders = # 封装训练和验证数据集 optimizer = # 定义优化器 criterion = # 定义损失函数 best_loss = np.inf # 最小验证集损失值初始化 for epoch in range(1, EPOCHS+1): losses = {} for phase in ['train', 'val']: train_flag = phase == 'train' model.train() if train_flag else model.eval() losses[phase] = 0.0 data_n = 0.0 for _, (inputs,targets) in enumerate(data_loaders[phase]): inputs=inputs.to(device) targets=targets.to(device) optimizer.zero_grad(set_to_none=True) with torch.set_grad_enabled(train_flag): outputs = model(inputs) loss = criterion(outputs, targets) if train_flag: loss.backward() optimizer.step() losses[phase] += loss.item() * inputs.size(0) data_n += inputs.size(0) losses[phase] = losses[phase] / data_n # early stopping 的实现 if losses['val'] < best_loss: best_loss = val_loss early_stopping_counter = 0 # 一旦最小验证集损失值找到,开始重新计数 # 一般在此时保存训练模型 torch.save({ 'epoch': epoch, # 当前epoch 'model_state_dict': model.state_dict(), # 模型参数 'optimizer_state_dict': optimizer.state_dict(), #优化器参数 'train_loss': losses['train'], # 当前训练集损失值 'val_loss': losses['val'], # 当前验证集损失值 }) else: early_stopping_counter += 1 if early_stopping_counter > 10: # 若验证集损失值连续10个epochs仍然没有实现向下突破,则宣布训练结束了 break
-
Alice_恒源云