Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    使用simpletransformers库进行新闻文本多分类

    技术交流
    1
    1
    89
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 183****0229
      183****0229 last edited by 183****0229

      1、环境信息

      • pytorch 1.8.1
      • python 3.8.1
      • cuda 11.1

      2、数据集信息

      TNEWS’ 今日头条中文新闻(短文本)分类 Short Text Classificaiton for News

      该数据集来自今日头条的新闻版块,共提取了15个类别的新闻,包括旅游,教育,金融,军事等。

           数据量:训练集(53,360),验证集(10,000),测试集(10,000)
           例子:
           {"label": "102", "label_des": "news_entertainment", "sentence": "江疏影甜甜圈自拍,迷之角度竟这么好看,美吸引一切事物"}
           每一条数据有三个属性,从前往后分别是 分类ID,分类名称,新闻字符串(仅含标题)。
      

      TNEWS’数据集下载

      3、下载代码并安装依赖环境

      # 切换路径
      cd /hy-tmp
      # 代码已经上传到github
      git clone https://github.com/JunnYu/hy_tutorial.git 
      # 如果下载失败请使用镜像地址git clone https://hub.fastgit.org/JunnYu/hy_tutorial.git
      # 切换路径
      cd hy_tutorial
      # 解压缩文件
      unzip tnews_classfication.zip
      # 切换路径
      cd tnews_classfication
      # 安装所需的python依赖包
      pip install -r requirements.txt
      

      4、运行程序

      python train.py
      
      import pandas as pd
      from sklearn.metrics import accuracy_score
      from simpletransformers.classification import ClassificationModel
      # 读取tsv文件
      train_df = pd.read_csv("dataset/processed/train.tsv", sep="\t")
      eval_df = pd.read_csv("dataset/processed/dev.tsv", sep="\t")
      # 定义label
      label_list = ['news_story',
                    'news_culture',
                    'news_entertainment',
                    'news_sports',
                    'news_finance',
                    'news_house',
                    'news_car',
                    'news_edu',
                    'news_tech',
                    'news_military',
                    'news_travel',
                    'news_world',
                    'news_stock',
                    'news_agriculture',
                    'news_game']
      
      model = ClassificationModel(
          model_type="bert",
          model_name="hfl/chinese-roberta-wwm-ext",
          num_labels=15,
          args={
              "reprocess_input_data": True,
              "overwrite_output_dir": True,
              "use_cached_eval_features": True,
              "do_lower_case": True,
              "evaluate_during_training": True,
              "labels_list": label_list,
              "max_seq_length": 256,
              "manual_seed": 42,
              "num_train_epochs": 10,
              "train_batch_size": 64,
              "eval_batch_size": 128,
              "save_optimizer_and_scheduler": False
          }
      )
      
      # 训练模型
      model.train_model(train_df=train_df, eval_df=eval_df,
                        accuracy=accuracy_score)
      
      # 评估模型
      result, model_outputs, wrong_predictions = model.eval_model(
          eval_df,  accuracy=accuracy_score)
      print(result)
      
      

      可以打开train.py修改num_train_epochs为自己想要的论数。

      5、查看GPU使用率

      watch -n 0.1 nvidia-smi
      

      6、查看开发集效果

      进入outputs文件夹,双击training_progress_scores.csv或eval_results.txt

      最好结果accuracy 66.85

      7、Reference

      https://github.com/ThilinaRajapakse/simpletransformers
      https://github.com/CLUEbenchmark/CLUE

      1 Reply Last reply Reply Quote 2
      • First post
        Last post