TNewsDataModule和TNewsModel定义
一、TNewsDataModule定义
import argparse
import pytorch_lightning as pl
from datasets import load_dataset
# 导入DataLoader
from torch.utils.data import DataLoader
# 导入transformers 的动态padding的collate_fn
from transformers import AutoTokenizer, DataCollatorWithPadding
class TNewsDataModule(pl.LightningDataModule):
def __init__(self, args, **kwargs):
super().__init__()
for k, v in kwargs.items():
setattr(args, k, v)
self.args = args
# 加载tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.args.model_name_or_path)
def load_data(self):
# 模型类型
model_type = self.args.model_name_or_path.split("/")[-1]
# 加载数据集
raw_datasets = load_dataset(
"utils/custom_dataset.py", "tnews", cache_dir="cache")
# 数据集预处理函数
def preprocess_function(example):
result = self.tokenizer(example["sentence"], example["keywords"],
padding=self.args.padding, max_length=self.args.max_length, truncation=True)
result["labels"] = example["label"]
return result
# 数据集预处理
self.processed_datasets = raw_datasets.map(
preprocess_function,
batched=True,
remove_columns=raw_datasets["train"].column_names,
cache_file_names={
"train": f"{self.args.cached_dir}/tnews-cached-tokenized-train-{self.args.max_length}-{model_type}.arrow", "validation": f"{self.args.cached_dir}/tnews-cached-tokenized-validation-{self.args.max_length}-{model_type}.arrow"},
desc="Running tokenizer on dataset",
)
self.processed_datasets.set_format(type='torch', columns=[
'input_ids', 'token_type_ids', 'attention_mask', 'labels'])
self.num_labels = raw_datasets["train"].features["label"].num_classes
# 配置阶段
def setup(self, stage=None):
self.load_data()
if stage == 'fit' or stage is None:
self.train_ds = self.processed_datasets["train"]
self.dev_ds = self.processed_datasets["validation"]
else:
self.dev_ds = self.processed_datasets["validation"]
@property
def collate_fn(self):
return DataCollatorWithPadding(self.tokenizer)
# train loader
def train_dataloader(self):
return DataLoader(dataset=self.train_ds,
batch_size=self.args.train_batch_size,
num_workers=self.args.num_workers,
collate_fn=self.collate_fn)
# val loader
def val_dataloader(self):
return DataLoader(dataset=self.dev_ds,
batch_size=self.args.eval_batch_size,
num_workers=self.args.num_workers,
collate_fn=self.collate_fn)
@staticmethod
def add_dataset_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser],
add_help=False)
parser.add_argument("--max_length",
type=int,
default=256,
help="max length of dataset")
parser.add_argument("--cached_dir",
type=str,
default="data")
parser.add_argument("--padding", action='store_true')
parser.add_argument("--train_batch_size", default=8, type=int)
parser.add_argument("--eval_batch_size", default=32, type=int)
parser.add_argument("--num_workers", default=0, type=int)
return parser
二、TNewsModel定义
import argparse
import pytorch_lightning as pl
from torchmetrics import Accuracy
from utils.dm import TNewsDataModule
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import (AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup,
get_cosine_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup)
arg_to_scheduler = {
"linear": get_linear_schedule_with_warmup,
"cosine": get_cosine_schedule_with_warmup,
"polynomial": get_polynomial_decay_schedule_with_warmup,
}
class TNewsModel(pl.LightningModule):
def __init__(self, args, **kwargs):
super().__init__()
for k, v in kwargs.items():
setattr(args, k, v)
self.save_hyperparameters(args)
# 加载预训练模型
self.model = AutoModelForSequenceClassification.from_pretrained(
self.hparams.model_name_or_path, num_labels=self.hparams.num_labels)
# 定义指标(使用torchmetrics)
self.accuracy = Accuracy()
@property
def num_training_steps(self):
# copy from https://github.com/PyTorchLightning/lightning-transformers/blob/1c5c91f5b4962f9162bcbd41fee7a7ac5eae00a9/lightning_transformers/core/model.py#L56
"""Total training steps inferred from datamodule and devices."""
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.limit_train_batches != 0:
dataset_size = self.trainer.limit_train_batches
elif isinstance(self.trainer.limit_train_batches, float):
dataset_size = len(self.trainer.datamodule.train_dataloader())
dataset_size = int(dataset_size * self.trainer.limit_train_batches)
else:
dataset_size = len(self.trainer.datamodule.train_dataloader())
num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes)
if self.trainer.tpu_cores:
num_devices = max(num_devices, self.trainer.tpu_cores)
effective_batch_size = self.trainer.accumulate_grad_batches * num_devices
max_estimated_steps = (
dataset_size // effective_batch_size) * self.trainer.max_epochs
if self.trainer.max_steps and self.trainer.max_steps < max_estimated_steps:
return self.trainer.max_steps
return max_estimated_steps
# 定义优化器
def configure_optimizers(self):
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in self.model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay":
self.hparams.weight_decay,
},
{
"params": [
p for n, p in self.model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay":
0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters,
lr=self.hparams.lr,
eps=self.hparams.adam_epsilon)
scheduler = arg_to_scheduler[self.hparams.scheduler](
optimizer,
num_warmup_steps=int(self.num_training_steps *
self.hparams.warmup_prob),
num_training_steps=self.num_training_steps)
scheduler = {
"scheduler": scheduler,
"interval": "step",
"frequency": 1
}
return [optimizer], [scheduler]
# 由于该模型继承自nn.Module,可以实现forward方法
def forward(self,
input_ids=None,
token_type_ids=None,
attention_mask=None,
labels=None):
return self.model(input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
labels=labels)
# 定义训练step
def training_step(self, batch, batch_idx):
# 调用self.forward
outputs = self(**batch)
# log lr信息
self.log("lr",
self.trainer.lr_schedulers[0]["scheduler"].get_last_lr()[-1],
prog_bar=True)
# log train loss
self.log("train_loss", outputs.loss)
return outputs.loss
# 定义验证 step
def validation_step(self, batch, batch_idx):
outputs = self(**batch)
# 取得每行logits最大值的位置
preds = outputs.logits.argmax(dim=-1)
# 计算准确率
self.accuracy(preds, batch["labels"])
# log验证loss
self.log("val_loss", outputs.loss, prog_bar=True)
# 重置accuracy
def on_validation_epoch_start(self):
self.accuracy.reset()
# 计算epoch结束后验证集总的accuracy
def on_validation_epoch_end(self):
acc = self.accuracy.compute()
self.log('val_acc', acc, prog_bar=True)