Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    【2】使用pytorch_lightning+transformers+torchmetric+datasets进行文本分类

    技术交流
    1
    1
    68
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 183****0229
      183****0229 last edited by 183****0229

      datasets使用

      一、数据准备

      本文自定义了一个utils/custom_dataset.py文件,进行了数据集下载、划分等操作。

      随后我们可以通过命令行raw_datasets = load_dataset("utils/custom_dataset.py", "tnews", cache_dir="cache")加载自定义数据集。

      # utils/custom_dataset.py
      import os
      import json
      import datasets
      
      _CITATION = "None"
      
      _DESCRIPTION = "None"
      
      _HOMEPAGE = "None"
      
      _LICENSE = "MIT"
      
      # 这里定义了原数据集的下载地址
      _URLs = {"tnews": "https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip"}
      
      # 这里定义了_LABELS_MAP的映射
      _LABELS_MAP = {'100': 'news_story', '101': 'news_culture', '102': 'news_entertainment', '103': 'news_sports', '104': 'news_finance', '106': 'news_house', '107': 'news_car',
                     '108': 'news_edu', '109': 'news_tech', '110': 'news_military', '112': 'news_travel', '113': 'news_world', '114': 'news_stock', '115': 'news_agriculture', '116': 'news_game'}
      
      
      class Tnews(datasets.GeneratorBasedBuilder):
          """Tnews"""
      
          VERSION = datasets.Version("1.0.0")
      	
          # 这里定义了CONFIG配置信息
          BUILDER_CONFIGS = [
              datasets.BuilderConfig(
                  name="tnews", version=VERSION, description="Tnews"
              ),
      
          ]
      
          DEFAULT_CONFIG_NAME = "tnews"
      
          def _info(self):
              # 定义数据集的特征
              features = datasets.Features(
                  {
                      # sentence字段是string类型
                      "sentence": datasets.Value("string"),
                      # keywords字段也是string 类型
                      "keywords": datasets.Value("string"),
                      # label字段可以定义成ClassLabel类型,也可以定义成string类型
                      "label": datasets.ClassLabel(names=['news_story',
                                                          'news_culture',
                                                          'news_entertainment',
                                                          'news_sports',
                                                          'news_finance',
                                                          'news_house',
                                                          'news_car',
                                                          'news_edu',
                                                          'news_tech',
                                                          'news_military',
                                                          'news_travel',
                                                          'news_world',
                                                          'news_stock',
                                                          'news_agriculture',
                                                          'news_game'])
                  })
              return datasets.DatasetInfo(
                  # This is the description that will appear on the datasets page.
                  description=_DESCRIPTION,
                  # This defines the different columns of the dataset and their types
                  # Here we define them above because they are different between the two configurations
                  features=features,
                  # If there's a common (input, target) tuple from the features,
                  # specify them here. They'll be used if as_supervised=True in
                  # builder.as_dataset.
                  supervised_keys=None,
                  # Homepage of the dataset for documentation
                  homepage=_HOMEPAGE,
                  # License for the dataset if available
                  license=_LICENSE,
                  # Citation for the dataset
                  citation=_CITATION,
              )
      
          def _split_generators(self, dl_manager):
              """Returns SplitGenerators."""
              # 根据不同的config下载不同的数据集
              my_urls = _URLs[self.config.name]
              data_dir = dl_manager.download_and_extract(my_urls)
              # 划分数据集
              outputs = [
                  datasets.SplitGenerator(
                      name=datasets.Split.TRAIN,
                      # These kwargs will be passed to _generate_examples
                      gen_kwargs={
                          # train文件
                          "filepath": os.path.join(data_dir, "train.json"),
                          "split": "train",
                      },
                  ),
                  datasets.SplitGenerator(
                      name=datasets.Split.VALIDATION,
                      # These kwargs will be passed to _generate_examples
                      gen_kwargs={
                          # dev文件
                          "filepath": os.path.join(data_dir, "dev.json"),
                          "split": "dev",
                      },
                  ),
              ]
      
              return outputs
      
          def _generate_examples(
              # method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
              self, filepath, split
          ):
              """Yields examples as (key, example) tuples."""
              # 这里主要是yield每一行数据
              with open(filepath, "r", encoding="utf8") as f:
                  for _id, line in enumerate(f):
                      # 由于该数据集是jsonl的格式,也就是每一行是一个json字符串,需要用json加载
                      line = json.loads(line)
                      yield _id, {
                          "sentence": line["sentence"],
                          "keywords": line["keywords"],
                          "label": _LABELS_MAP[line["label"]],
                      }
      
      from datasets import load_dataset
      raw_datasets = load_dataset("utils/custom_dataset.py", "tnews", cache_dir="cache")
      print(raw_datasets)
      print(raw_datasets["train"][0])
      

      二、数据处理

      # 使用bert tokenizer
      from transformers import AutoTokenizer
      tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
      # 定义数据预处理函数
      def preprocess_function(example):
          result = tokenizer(example["sentence"], example["keywords"],
                                  padding=False, max_length=256, truncation=True)
          result["labels"] = example["label"]
          return result
      

      # 使用map方法对原数据集进行预处理
      processed_datasets = raw_datasets.map(
          preprocess_function,
          # 批量进行预处理
          batched=True,
          # 删除没用的列
          remove_columns=raw_datasets["train"].column_names,
          # 缓存处理好的数据 
          cache_file_names={
              "train": "features/tnews-cached-tokenized-train-256.arrow", "validation": "features/tnews-cached-tokenized-validation-256.arrow"},
          desc="Running tokenizer on dataset",
      )
      # 将['input_ids', 'token_type_ids', 'attention_mask', 'labels']转化为torch.tensor
      processed_datasets.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
      # 查看train第一条数据
      print(processed_datasets["train"][0])
      

      from torch.utils.data import DataLoader
      from transformers import DataCollatorWithPadding
      dl = DataLoader(processed_datasets["train"],batch_size=2,collate_fn=DataCollatorWithPadding(tokenizer))
      next(iter(dl))
      

      1 Reply Last reply Reply Quote 2
      • First post
        Last post