使用simpletransformers库进行新闻文本多分类
-
1、环境信息
- pytorch 1.8.1
- python 3.8.1
- cuda 11.1
2、数据集信息
TNEWS’ 今日头条中文新闻(短文本)分类 Short Text Classificaiton for News
该数据集来自今日头条的新闻版块,共提取了15个类别的新闻,包括旅游,教育,金融,军事等。
数据量:训练集(53,360),验证集(10,000),测试集(10,000) 例子: {"label": "102", "label_des": "news_entertainment", "sentence": "江疏影甜甜圈自拍,迷之角度竟这么好看,美吸引一切事物"} 每一条数据有三个属性,从前往后分别是 分类ID,分类名称,新闻字符串(仅含标题)。
3、下载代码并安装依赖环境
# 切换路径 cd /hy-tmp # 代码已经上传到github git clone https://github.com/JunnYu/hy_tutorial.git # 如果下载失败请使用镜像地址git clone https://hub.fastgit.org/JunnYu/hy_tutorial.git # 切换路径 cd hy_tutorial # 解压缩文件 unzip tnews_classfication.zip # 切换路径 cd tnews_classfication # 安装所需的python依赖包 pip install -r requirements.txt
4、运行程序
python train.py
import pandas as pd from sklearn.metrics import accuracy_score from simpletransformers.classification import ClassificationModel # 读取tsv文件 train_df = pd.read_csv("dataset/processed/train.tsv", sep="\t") eval_df = pd.read_csv("dataset/processed/dev.tsv", sep="\t") # 定义label label_list = ['news_story', 'news_culture', 'news_entertainment', 'news_sports', 'news_finance', 'news_house', 'news_car', 'news_edu', 'news_tech', 'news_military', 'news_travel', 'news_world', 'news_stock', 'news_agriculture', 'news_game'] model = ClassificationModel( model_type="bert", model_name="hfl/chinese-roberta-wwm-ext", num_labels=15, args={ "reprocess_input_data": True, "overwrite_output_dir": True, "use_cached_eval_features": True, "do_lower_case": True, "evaluate_during_training": True, "labels_list": label_list, "max_seq_length": 256, "manual_seed": 42, "num_train_epochs": 10, "train_batch_size": 64, "eval_batch_size": 128, "save_optimizer_and_scheduler": False } ) # 训练模型 model.train_model(train_df=train_df, eval_df=eval_df, accuracy=accuracy_score) # 评估模型 result, model_outputs, wrong_predictions = model.eval_model( eval_df, accuracy=accuracy_score) print(result)
可以打开train.py修改num_train_epochs为自己想要的论数。
5、查看GPU使用率
watch -n 0.1 nvidia-smi
6、查看开发集效果
进入outputs文件夹,双击training_progress_scores.csv或eval_results.txt
最好结果accuracy 66.85
7、Reference
https://github.com/ThilinaRajapakse/simpletransformers
https://github.com/CLUEbenchmark/CLUE