【2】使用pytorch_lightning+transformers+torchmetric+datasets进行文本分类
-
datasets使用
一、数据准备
本文自定义了一个utils/custom_dataset.py文件,进行了数据集下载、划分等操作。
随后我们可以通过命令行
raw_datasets = load_dataset("utils/custom_dataset.py", "tnews", cache_dir="cache")
加载自定义数据集。# utils/custom_dataset.py import os import json import datasets _CITATION = "None" _DESCRIPTION = "None" _HOMEPAGE = "None" _LICENSE = "MIT" # 这里定义了原数据集的下载地址 _URLs = {"tnews": "https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip"} # 这里定义了_LABELS_MAP的映射 _LABELS_MAP = {'100': 'news_story', '101': 'news_culture', '102': 'news_entertainment', '103': 'news_sports', '104': 'news_finance', '106': 'news_house', '107': 'news_car', '108': 'news_edu', '109': 'news_tech', '110': 'news_military', '112': 'news_travel', '113': 'news_world', '114': 'news_stock', '115': 'news_agriculture', '116': 'news_game'} class Tnews(datasets.GeneratorBasedBuilder): """Tnews""" VERSION = datasets.Version("1.0.0") # 这里定义了CONFIG配置信息 BUILDER_CONFIGS = [ datasets.BuilderConfig( name="tnews", version=VERSION, description="Tnews" ), ] DEFAULT_CONFIG_NAME = "tnews" def _info(self): # 定义数据集的特征 features = datasets.Features( { # sentence字段是string类型 "sentence": datasets.Value("string"), # keywords字段也是string 类型 "keywords": datasets.Value("string"), # label字段可以定义成ClassLabel类型,也可以定义成string类型 "label": datasets.ClassLabel(names=['news_story', 'news_culture', 'news_entertainment', 'news_sports', 'news_finance', 'news_house', 'news_car', 'news_edu', 'news_tech', 'news_military', 'news_travel', 'news_world', 'news_stock', 'news_agriculture', 'news_game']) }) return datasets.DatasetInfo( # This is the description that will appear on the datasets page. description=_DESCRIPTION, # This defines the different columns of the dataset and their types # Here we define them above because they are different between the two configurations features=features, # If there's a common (input, target) tuple from the features, # specify them here. They'll be used if as_supervised=True in # builder.as_dataset. supervised_keys=None, # Homepage of the dataset for documentation homepage=_HOMEPAGE, # License for the dataset if available license=_LICENSE, # Citation for the dataset citation=_CITATION, ) def _split_generators(self, dl_manager): """Returns SplitGenerators.""" # 根据不同的config下载不同的数据集 my_urls = _URLs[self.config.name] data_dir = dl_manager.download_and_extract(my_urls) # 划分数据集 outputs = [ datasets.SplitGenerator( name=datasets.Split.TRAIN, # These kwargs will be passed to _generate_examples gen_kwargs={ # train文件 "filepath": os.path.join(data_dir, "train.json"), "split": "train", }, ), datasets.SplitGenerator( name=datasets.Split.VALIDATION, # These kwargs will be passed to _generate_examples gen_kwargs={ # dev文件 "filepath": os.path.join(data_dir, "dev.json"), "split": "dev", }, ), ] return outputs def _generate_examples( # method parameters are unpacked from `gen_kwargs` as given in `_split_generators` self, filepath, split ): """Yields examples as (key, example) tuples.""" # 这里主要是yield每一行数据 with open(filepath, "r", encoding="utf8") as f: for _id, line in enumerate(f): # 由于该数据集是jsonl的格式,也就是每一行是一个json字符串,需要用json加载 line = json.loads(line) yield _id, { "sentence": line["sentence"], "keywords": line["keywords"], "label": _LABELS_MAP[line["label"]], }
from datasets import load_dataset raw_datasets = load_dataset("utils/custom_dataset.py", "tnews", cache_dir="cache") print(raw_datasets) print(raw_datasets["train"][0])
二、数据处理
# 使用bert tokenizer from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese") # 定义数据预处理函数 def preprocess_function(example): result = tokenizer(example["sentence"], example["keywords"], padding=False, max_length=256, truncation=True) result["labels"] = example["label"] return result
# 使用map方法对原数据集进行预处理 processed_datasets = raw_datasets.map( preprocess_function, # 批量进行预处理 batched=True, # 删除没用的列 remove_columns=raw_datasets["train"].column_names, # 缓存处理好的数据 cache_file_names={ "train": "features/tnews-cached-tokenized-train-256.arrow", "validation": "features/tnews-cached-tokenized-validation-256.arrow"}, desc="Running tokenizer on dataset", ) # 将['input_ids', 'token_type_ids', 'attention_mask', 'labels']转化为torch.tensor processed_datasets.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels']) # 查看train第一条数据 print(processed_datasets["train"][0])
from torch.utils.data import DataLoader from transformers import DataCollatorWithPadding dl = DataLoader(processed_datasets["train"],batch_size=2,collate_fn=DataCollatorWithPadding(tokenizer)) next(iter(dl))