Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    【炼丹保姆】Pytorch在训练过程中自动实现学习率衰减

    技术分享📚有奖励
    1
    1
    90
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 173****7719
      173****7719 last edited by

      为什么要衰减学习率?

      当训练遇到瓶颈,适当减少学习率可能帮助模型突破训练瓶颈,从而找到更好的local minimal。

      如何实现?

      Pytorch中有很多不同的衰减策略,这里简单以ReduceLROnPlateau为例,先重点说一下比较重要的几个参数。

      • factor: 新的学习率 = 原学习率*factor
      • patience: 类似early stopping, 可以接受【patience】个epochs模型在原地踏步
      • threshold: 以Loss为例,看看【新旧最优loss之差】是否超过【threshold】,若没超过,则判定模型仍然在原地踏步
      • threshold_mode:决定 【新旧最优loss之差】的计算方法
      • eps: 最小衰减幅度,如果新旧学习率之差小于【eps】,则不再更新学习率

      代码

      import torch
      import numpy as np
      
      
      EPOCHS = # 定义最大epoch训练数
      model = # 调用模型
      data_loaders = # 封装训练和验证数据集
      optimizer = # 定义优化器
      criterion = # 定义损失函数
      best_loss = np.inf # 最小验证集损失值初始化
      
      scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=5, verbose=True, eps=1e-7) # 定义学习策略
      # 注意这里patience=5,配合early_stopping的patience=10,意思是如果5个epochs发现模型没有进步,则降低学习率,如果再过5个epochs仍然没有进步,则终止训练。
      
      for epoch in range(1, EPOCHS+1):
          losses = {}    
          for phase in ['train', 'val']:
              train_flag = phase == 'train'
              model.train() if train_flag else model.eval()
              losses[phase] = 0.0
              data_n = 0.0
              for _, (inputs,targets) in enumerate(data_loaders[phase]):
                  inputs=inputs.to(device)
                  targets=targets.to(device)
                  optimizer.zero_grad(set_to_none=True)
                  with torch.set_grad_enabled(train_flag):
                      outputs = model(inputs)
                      loss = criterion(outputs, targets)
                      if train_flag:
                          loss.backward()
                          optimizer.step()
                  losses[phase] += loss.item() * inputs.size(0)
                  data_n += inputs.size(0)
              losses[phase] = losses[phase] / data_n
          
          # 更新学习策略
          scheduler.step(losses['val'])
      
          # early stopping 的实现
          if losses['val'] < best_loss:
              best_loss = val_loss
              early_stopping_counter = 0 # 一旦最小验证集损失值找到,开始重新计数
              # 一般在此时保存训练模型
              torch.save({
                  'epoch': epoch, # 当前epoch
                  'model_state_dict': model.state_dict(), # 模型参数
                  'optimizer_state_dict': optimizer.state_dict(), #优化器参数
                  'train_loss': losses['train'], # 当前训练集损失值
                  'val_loss': losses['val'], # 当前验证集损失值
                  })
          else:
             early_stopping_counter += 1
          if early_stopping_counter > 10: # 若验证集损失值连续10个epochs仍然没有实现向下突破,则宣布训练结束了
              break
      
      1 Reply Last reply Reply Quote 1
      • Referenced by  Alice_恒源云 Alice_恒源云 
      • First post
        Last post