Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    【4】使用pytorch_lightning+transformers+torchmetric+datasets进行文本分类

    技术交流
    1
    1
    59
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 183****0229
      183****0229 last edited by 183****0229

      TNewsDataModule和TNewsModel定义

      一、TNewsDataModule定义

      import argparse
      import pytorch_lightning as pl
      
      from datasets import load_dataset
      # 导入DataLoader
      from torch.utils.data import DataLoader
      # 导入transformers 的动态padding的collate_fn
      from transformers import AutoTokenizer, DataCollatorWithPadding
      
      
      class TNewsDataModule(pl.LightningDataModule):
          def __init__(self, args, **kwargs):
              super().__init__()
              for k, v in kwargs.items():
                  setattr(args, k, v)
              self.args = args
              # 加载tokenizer 
              self.tokenizer = AutoTokenizer.from_pretrained(
                  self.args.model_name_or_path)
      
          def load_data(self):
              # 模型类型
              model_type = self.args.model_name_or_path.split("/")[-1]
              # 加载数据集
              raw_datasets = load_dataset(
                  "utils/custom_dataset.py", "tnews", cache_dir="cache")
              # 数据集预处理函数
              def preprocess_function(example):
                  result = self.tokenizer(example["sentence"], example["keywords"],
                                          padding=self.args.padding, max_length=self.args.max_length, truncation=True)
                  result["labels"] = example["label"]
                  return result
              # 数据集预处理
              self.processed_datasets = raw_datasets.map(
                  preprocess_function,
                  batched=True,
                  remove_columns=raw_datasets["train"].column_names,
                  cache_file_names={
                      "train": f"{self.args.cached_dir}/tnews-cached-tokenized-train-{self.args.max_length}-{model_type}.arrow", "validation": f"{self.args.cached_dir}/tnews-cached-tokenized-validation-{self.args.max_length}-{model_type}.arrow"},
                  desc="Running tokenizer on dataset",
              )
              self.processed_datasets.set_format(type='torch', columns=[
                  'input_ids', 'token_type_ids', 'attention_mask', 'labels'])
              self.num_labels = raw_datasets["train"].features["label"].num_classes
          
          # 配置阶段
          def setup(self, stage=None):
              self.load_data()
              if stage == 'fit' or stage is None:
                  self.train_ds = self.processed_datasets["train"]
                  self.dev_ds = self.processed_datasets["validation"]
      
              else:
                  self.dev_ds = self.processed_datasets["validation"]
      
          @property
          def collate_fn(self):
              return DataCollatorWithPadding(self.tokenizer)
          
          # train loader
          def train_dataloader(self):
              return DataLoader(dataset=self.train_ds,
                                batch_size=self.args.train_batch_size,
                                num_workers=self.args.num_workers,
                                collate_fn=self.collate_fn)
          # val loader
          def val_dataloader(self):
              return DataLoader(dataset=self.dev_ds,
                                batch_size=self.args.eval_batch_size,
                                num_workers=self.args.num_workers,
                                collate_fn=self.collate_fn)
      
          @staticmethod
          def add_dataset_specific_args(parent_parser):
              parser = argparse.ArgumentParser(parents=[parent_parser],
                                               add_help=False)
              parser.add_argument("--max_length",
                                  type=int,
                                  default=256,
                                  help="max length of dataset")
              parser.add_argument("--cached_dir",
                                  type=str,
                                  default="data")
              parser.add_argument("--padding", action='store_true')
              parser.add_argument("--train_batch_size", default=8, type=int)
              parser.add_argument("--eval_batch_size", default=32, type=int)
              parser.add_argument("--num_workers", default=0, type=int)
      
              return parser
      

      二、TNewsModel定义

      import argparse
      import pytorch_lightning as pl
      
      from torchmetrics import Accuracy
      from utils.dm import TNewsDataModule
      from pytorch_lightning.callbacks import ModelCheckpoint
      from transformers import (AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup,
                                get_cosine_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup)
      
      
      arg_to_scheduler = {
          "linear": get_linear_schedule_with_warmup,
          "cosine": get_cosine_schedule_with_warmup,
          "polynomial": get_polynomial_decay_schedule_with_warmup,
      }
      
      
      class TNewsModel(pl.LightningModule):
          def __init__(self, args, **kwargs):
              super().__init__()
              for k, v in kwargs.items():
                  setattr(args, k, v)
              self.save_hyperparameters(args)
              #  加载预训练模型
              self.model = AutoModelForSequenceClassification.from_pretrained(
                  self.hparams.model_name_or_path, num_labels=self.hparams.num_labels)
              # 定义指标(使用torchmetrics)
              self.accuracy = Accuracy()
      
          @property
          def num_training_steps(self):
              # copy from https://github.com/PyTorchLightning/lightning-transformers/blob/1c5c91f5b4962f9162bcbd41fee7a7ac5eae00a9/lightning_transformers/core/model.py#L56
              """Total training steps inferred from datamodule and devices."""
              if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
                  dataset_size = self.trainer.limit_train_batches
              elif isinstance(self.trainer.limit_train_batches, float):
                  dataset_size = len(self.trainer.datamodule.train_dataloader())
                  dataset_size = int(dataset_size * self.trainer.limit_train_batches)
              else:
                  dataset_size = len(self.trainer.datamodule.train_dataloader())
      
              num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
              if self.trainer.tpu_cores:
                  num_devices = max(num_devices, self.trainer.tpu_cores)
      
              effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
              max_estimated_steps = (
                  dataset_size // effective_batch_size) * self.trainer.max_epochs
      
              if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
                  return self.trainer.max_steps
              return max_estimated_steps
         
          # 定义优化器
          def configure_optimizers(self):
              no_decay = ["bias", "LayerNorm.weight"]
              optimizer_grouped_parameters = [
                  {
                      "params": [
                          p for n, p in self.model.named_parameters()
                          if not any(nd in n for nd in no_decay)
                      ],
                      "weight_decay":
                      self.hparams.weight_decay,
                  },
                  {
                      "params": [
                          p for n, p in self.model.named_parameters()
                          if any(nd in n for nd in no_decay)
                      ],
                      "weight_decay":
                      0.0,
                  },
              ]
      
              optimizer = AdamW(optimizer_grouped_parameters,
                                lr=self.hparams.lr,
                                eps=self.hparams.adam_epsilon)
      
              scheduler = arg_to_scheduler[self.hparams.scheduler](
                  optimizer,
                  num_warmup_steps=int(self.num_training_steps *
                                       self.hparams.warmup_prob),
                  num_training_steps=self.num_training_steps)
              scheduler = {
                  "scheduler": scheduler,
                  "interval": "step",
                  "frequency": 1
              }
      
              return [optimizer], [scheduler]
          
          # 由于该模型继承自nn.Module,可以实现forward方法
          def forward(self,
                      input_ids=None,
                      token_type_ids=None,
                      attention_mask=None,
                      labels=None):
              return self.model(input_ids=input_ids,
                                token_type_ids=token_type_ids,
                                attention_mask=attention_mask,
                                labels=labels)
          # 定义训练step
          def training_step(self, batch, batch_idx):
              # 调用self.forward
              outputs = self(**batch)
              # log lr信息
              self.log("lr",
                       self.trainer.lr_schedulers[0]["scheduler"].get_last_lr()[-1],
                       prog_bar=True)
              # log train loss
              self.log("train_loss", outputs.loss)
              return outputs.loss
         
          # 定义验证 step
          def validation_step(self, batch, batch_idx):
              outputs = self(**batch)
              # 取得每行logits最大值的位置
              preds = outputs.logits.argmax(dim=-1)
              # 计算准确率
              self.accuracy(preds, batch["labels"])
              # log验证loss
              self.log("val_loss", outputs.loss, prog_bar=True)
         
          # 重置accuracy
          def on_validation_epoch_start(self):
              self.accuracy.reset()
      
          # 计算epoch结束后验证集总的accuracy
          def on_validation_epoch_end(self):
              acc = self.accuracy.compute()
              self.log('val_acc', acc, prog_bar=True)
      
      
      
      1 Reply Last reply Reply Quote 2
      • First post
        Last post