Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    【炼丹保姆】early stopping 的pytorch实现

    技术分享📚有奖励
    1
    1
    173
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 173****7719
      173****7719 last edited by

      什么是early stopping?

      一种防止模型过拟合的小技巧,能清楚的告诉我们什么时候训练应该结束了

      核心思想

      若模型经过连续n个epochs的训练,验证数据集的损失值仍然没有实现向下突破,则我们认为模型训练遇到了瓶颈,再往下训练则模型可能存在过拟合的风险,此时应该提前终止训练,而模型本应该在N个epochs后停止训练,所以叫做early stopping。

      代码

      import torch
      import numpy as np
      
      EPOCHS = # 定义最大epoch训练数
      model = # 调用模型
      data_loaders = # 封装训练和验证数据集
      optimizer = # 定义优化器
      criterion = # 定义损失函数
      best_loss = np.inf # 最小验证集损失值初始化
      
      for epoch in range(1, EPOCHS+1):
          losses = {}    
          for phase in ['train', 'val']:
              train_flag = phase == 'train'
              model.train() if train_flag else model.eval()
              losses[phase] = 0.0
              data_n = 0.0
              for _, (inputs,targets) in enumerate(data_loaders[phase]):
                  inputs=inputs.to(device)
                  targets=targets.to(device)
                  optimizer.zero_grad(set_to_none=True)
                  with torch.set_grad_enabled(train_flag):
                      outputs = model(inputs)
                      loss = criterion(outputs, targets)
                      if train_flag:
                          loss.backward()
                          optimizer.step()
                  losses[phase] += loss.item() * inputs.size(0)
                  data_n += inputs.size(0)
              losses[phase] = losses[phase] / data_n
          # early stopping 的实现
          if losses['val'] < best_loss:
              best_loss = val_loss
              early_stopping_counter = 0 # 一旦最小验证集损失值找到,开始重新计数
              # 一般在此时保存训练模型
              torch.save({
                  'epoch': epoch, # 当前epoch
                  'model_state_dict': model.state_dict(), # 模型参数
                  'optimizer_state_dict': optimizer.state_dict(), #优化器参数
                  'train_loss': losses['train'], # 当前训练集损失值
                  'val_loss': losses['val'], # 当前验证集损失值
                  })
          else:
             early_stopping_counter += 1
          if early_stopping_counter > 10: # 若验证集损失值连续10个epochs仍然没有实现向下突破,则宣布训练结束了
              break
      
      
      1 Reply Last reply Reply Quote 1
      • Referenced by  Alice_恒源云 Alice_恒源云 
      • First post
        Last post