开始训练
# 定义通用的parser
def get_parser():
parser = argparse.ArgumentParser(description="Training Args")
parser.add_argument("--model_name_or_path",
default="bert-base-chinese",
type=str,
help="model_name_or_path")
parser.add_argument("--seed", default=42, type=int, help="random seed")
parser.add_argument("--scheduler",
choices=["linear", "cosine", "polynomial"],
default="linear",
help="scheduler type")
parser.add_argument("--lr", type=float, default=3e-5, help="learning rate")
parser.add_argument("--weight_decay",
default=0.02,
type=float,
help="Weight decay if we apply some")
parser.add_argument("--warmup_prob",
default=0.,
type=float,
help="Warmup steps used for scheduler")
parser.add_argument("--adam_epsilon",
default=1e-8,
type=float,
help="Epsilon for Adam optimizer")
return parser
def main():
# 获取通用的parser
parser = get_parser()
# 给parser添加TNewsDataModule特定的参数
parser = TNewsDataModule.add_dataset_specific_args(parser)
# 给parser添加pytorch_lightning的Trainer特定的参数
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# 固定随机种子,便于复现结果
pl.seed_everything(args.seed)
# 定义保存模型的callback,根据最好的val_acc保存模型参数
ckpt_callback = ModelCheckpoint(filename='{epoch}-{step}-{val_acc:.4f}',
monitor='val_acc',
mode="max",
save_top_k=5,
save_weights_only=True)
# 初始化TNewsDataModule
tnews_dm = TNewsDataModule(args)
tnews_dm.prepare_data()
args.num_labels = tnews_dm.num_labels
# 初始化TNewsModel
model = TNewsModel(args)
# 定义trainer
trainer = pl.Trainer.from_argparse_args(args, callbacks=[ckpt_callback])
# 开始训练
trainer.fit(model=model, datamodule=tnews_dm)
if __name__ == '__main__':
main()
训练命令
使用roformer_chinese_base
export TOKENIZERS_PARALLELISM=false
python train.py \
--gpus=1 \
--default_root_dir /tf_logs/ \
--model_name_or_path junnyu/roformer_chinese_base \
--lr 3e-5 \
--scheduler linear \
--warmup_prob 0.15 \
--max_epochs 5 \
--max_length 256 \
--train_batch_size 64 \
--eval_batch_size 128 \
--num_workers 6 \
--gradient_clip_val 5.0 \
--seed 42 \
--precision 16
使用chinese-roberta-wwm-ext
export TOKENIZERS_PARALLELISM=false
python train.py \
--gpus=1 \
--default_root_dir /tf_logs/ \
--model_name_or_path hfl/chinese-roberta-wwm-ext \
--lr 3e-5 \
--scheduler linear \
--warmup_prob 0.15 \
--max_epochs 5 \
--max_length 256 \
--train_batch_size 64 \
--eval_batch_size 128 \
--num_workers 6 \
--gradient_clip_val 5.0 \
--seed 42 \
--precision 16