Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    【5】使用pytorch_lightning+transformers+torchmetric+datasets进行文本分类

    技术交流
    1
    1
    45
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 183****0229
      183****0229 last edited by

      开始训练

      # 定义通用的parser 
      def get_parser():
          parser = argparse.ArgumentParser(description="Training Args")
          parser.add_argument("--model_name_or_path",
                              default="bert-base-chinese",
                              type=str,
                              help="model_name_or_path")
          parser.add_argument("--seed", default=42, type=int, help="random seed")
          parser.add_argument("--scheduler",
                              choices=["linear", "cosine", "polynomial"],
                              default="linear",
                              help="scheduler type")
          parser.add_argument("--lr", type=float, default=3e-5, help="learning rate")
          parser.add_argument("--weight_decay",
                              default=0.02,
                              type=float,
                              help="Weight decay if we apply some")
          parser.add_argument("--warmup_prob",
                              default=0.,
                              type=float,
                              help="Warmup steps used for scheduler")
          parser.add_argument("--adam_epsilon",
                              default=1e-8,
                              type=float,
                              help="Epsilon for Adam optimizer")
      
          return parser
      
      
      def main():
          # 获取通用的parser 
          parser = get_parser()
          # 给parser添加TNewsDataModule特定的参数
          parser = TNewsDataModule.add_dataset_specific_args(parser)
          # 给parser添加pytorch_lightning的Trainer特定的参数
          parser = pl.Trainer.add_argparse_args(parser)
          args = parser.parse_args()
      
          # 固定随机种子,便于复现结果
          pl.seed_everything(args.seed)
      
          # 定义保存模型的callback,根据最好的val_acc保存模型参数
          ckpt_callback = ModelCheckpoint(filename='{epoch}-{step}-{val_acc:.4f}',
                                          monitor='val_acc',
                                          mode="max",
                                          save_top_k=5,
                                          save_weights_only=True)
          # 初始化TNewsDataModule
          tnews_dm = TNewsDataModule(args)
          tnews_dm.prepare_data()
          args.num_labels = tnews_dm.num_labels
      
          # 初始化TNewsModel
          model = TNewsModel(args)
          # 定义trainer 
          trainer = pl.Trainer.from_argparse_args(args, callbacks=[ckpt_callback])
          # 开始训练
          trainer.fit(model=model, datamodule=tnews_dm)
      
      
      if __name__ == '__main__':
          main()
      
      

      训练命令

      使用roformer_chinese_base

      export TOKENIZERS_PARALLELISM=false
      python train.py \
          --gpus=1 \
          --default_root_dir /tf_logs/ \
          --model_name_or_path junnyu/roformer_chinese_base \
          --lr 3e-5 \
          --scheduler linear \
          --warmup_prob 0.15 \
          --max_epochs 5 \
          --max_length 256 \
          --train_batch_size 64 \
          --eval_batch_size 128 \
          --num_workers 6 \
          --gradient_clip_val 5.0 \
          --seed 42 \
          --precision 16
      
      

      使用chinese-roberta-wwm-ext

      export TOKENIZERS_PARALLELISM=false
      python train.py \
          --gpus=1 \
          --default_root_dir /tf_logs/ \
          --model_name_or_path hfl/chinese-roberta-wwm-ext \
          --lr 3e-5 \
          --scheduler linear \
          --warmup_prob 0.15 \
          --max_epochs 5 \
          --max_length 256 \
          --train_batch_size 64 \
          --eval_batch_size 128 \
          --num_workers 6 \
          --gradient_clip_val 5.0 \
          --seed 42 \
          --precision 16
      
      1 Reply Last reply Reply Quote 2
      • First post
        Last post