TinyBERT 蒸馏速度实现加速小记
-
最近做的一个 project 需要复现 EMNLP 2020 Findings 的 TinyBERT,这篇文章就是在复现过程对踩到坑,以及对应的解决方案和实现加速的一个记录。
Overview of TinyBERT
BERT 效果虽好,其较大内存消耗和较长的推理延时会对其上线部署造成一定挑战。内存消耗方面,一系列知识蒸馏的工作,例如 DistilBERT、BERT-PKD 和 TinyBERT 被提出来来降低模型的参数(主要是层数)以及相应地减少时间;推理加速方面,也有例如 DeeBERT、FastBERT 以及 CascadeBERT 等方案来动态地根据样本难度进行模型的执行从而提升推理效率。其中比较具备代表性便是 TinyBERT,其核心框架如下:
分为两个阶段:- General Distillation:在通用的语料,例如 BookCorpus, EnglishWiki 上进行知识蒸馏,目标函数包括 Transformer Layer Attention 矩阵以及 Layer Hidden States 的对齐;
- Task Distillation:在具体的任务数据集上进行蒸馏,又被进一步分成两个步骤:
- Task Transformer Disitllation: 在任务数据集上对齐 Student 和已经 fine-tuned Teacher model 的 attention map 和 hidden states;
- Task Prediction Distillation:在任务数据集上对 student model 和 teacher model 的 output distritbuion 利用 KL loss / MSE loss 进行对齐。
TinyBERT 提供了经过 General Distillation 阶段的 checkpoint,可以认为是一个小的 BERT,包括了 6L786H 版本以及 4L312H 版本。而我们后续的复现就是基于 4L312H v2 版本的。值得注意的是,TinyBERT 对任务数据集进行了数据增强操作,通过基于 Glove 的 Embedding Distance 的相近词替换以及 BERT MLM 预测替换,会将原本的数据集扩增到 20 倍。而我们遇到的第一个 bug 就是在数据增强阶段。
Bug in Data Augmentation
我们可以按照官方给出的代码对数据进行增强操作,但是在 QNLI 上会报错:
Index Error: index 514 is out of dimension 1 with size 512
造成数据增强到一半程序就崩溃了,为什么呢?
很简单,因为数据增强代码 BERT MLM 换词模块对于超长(> 512)的句子没有特殊处理,造成下标越界,具体可以参考 #Issue50。
在对应的函数中进行边界的判断即可:
def _masked_language_model(self, sent, word_pieces, mask_id): if mask_id > 511: # if mask id is longer than max length return [] tokenized_text = self.tokenizer.tokenize(sent) tokenized_text = ['[CLS]'] + tokenized_text tokenized_len = len(tokenized_text) tokenized_text = word_pieces + ['[SEP]'] + tokenized_text[1:] + ['[SEP]'] segments_ids = [0] * (tokenized_len + 1) + [1] * (len(tokenized_text) - tokenized_len - 1) if len(tokenized_text) > 512: # truncation tokenized_text = tokenized_text[:512] segments_ids = segments_ids[:512] token_ids = self.tokenizer.convert_tokens_to_ids(tokenized_text) tokens_tensor = torch.tensor([token_ids]).to(device) segments_tensor = torch.tensor([segments_ids]).to(device) self.model.to(device) predictions = self.model(tokens_tensor, segments_tensor) word_candidates = torch.argsort(predictions[0, mask_id], descending=True)[:self.M].tolist() word_candidates = self.tokenizer.convert_ids_to_tokens(word_candidates) return list(filter(lambda x: x.find("##"), word_candidates))
Acceleration of Data Parallel
当我们费劲愉快地完成数据增强之后,下一步就是要进行 Task Specific 蒸馏里的 Step 1,General Distillation 了。对于一些小数据集像 MRPC,增广 20 倍之后的数据量依旧是 80k 不到,因此训练速度还是很快的,20 轮单卡大概半天也能跑完。但是对于像 MNLI 这样 GLUE 中最大的数据集(390k),20 倍增广后的数据集(增广就花费了大约 2 天时间),如果用单卡训练个 10 轮那可能得跑上半个月了,到时候怕不是黄花菜都凉咯。遂打算用多卡训练,一看,官方的实现就通过 nn.DataParallel 支持了多卡。好嘛,直接 CUDA_VISIBLE_DEVICES=“0,1,2,3” 来上 4 块卡。不跑不知道,加载数据(tokenize, padding )花费 1小时,好不容易跑起来了,一开 nvidia-smi 吓一跳,GPU 的利用率都在 50% 左右,再一看预估时间,大约 21h 一轮,10 epoch 那四舍五入就是一个半礼拜。好家伙,这我还做不做实验了?这时候就去翻看 PyTorch 文档,发现 PyTorch 现在都不再推荐使用 nn.DataParallel 了,为什么呢?主要原因在于 DataParallel 的实现是单进程的,每次都是有一块主卡读入数据再发给其他卡,这一部分不进带来了额外的计算开销,而且会造成主卡的 GPU 显存占用会显著高于其他卡,进而造成潜在的 batch size 限制;此外,这种模式下,其他 GPU 算完之后要传回主卡进行同步,这一步又会受限于 Python 的线程之间的 GIL(global interpreter lock),进一步降低了效率。此外,还有多机以及模型切片等 DataParallel 不支持,但是另一个 DistributedDataParallel 模块支持的功能。所以,废话少说,得把原先 TinyBERT DataParallel(DP)改成 DistributedDataParallel(DDP)。那么,请问,把 DP 改成 DDP 需要几步?答:大概,就那么多步。核心的代码就是做一下初始化,以及用 DDP 替换掉 DP:
from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist # 给 parser 增加一个 local rank 参数来在启动的时候传入 rank parser.add_argument('--local_rank', type=int, default=-1) # ... # 初始化 logger.info("Initializing Distributed Environment") torch.cuda.set_device(args.local_rank) dist.init_process_group(backend="nccl") # 设置 devicec local_rank = args.local_rank torch.cuda.set_device(local_rank) # ... # 初始化模型 并且 放到 device 上 student_model = TinyBertForSequenceClassification.from_pretrained(args.student_model, num_labels=num_labels).to(device) teacher_model = TinyBertForSequenceClassification.from_pretrained(args.teacher_model, num_labels=num_labels).to(device) # 用 DDP 包裹模型 student_model = DDP(student_model, device_ids=[local_rank], output_device=local_rank) teacher_model = DDP(teacher_model, device_ids=[local_rank], output_device=local_rank) # .. # 用 DistributedSampler 替换原来的 Random Sampler train_sampler = torch.utils.data.DistributedSampler(train_data)
然后,大功告成,一键启动:
GPU=”0,1,2,3”
CUDA_VISIBLE_DEVICEES=$GPU python -m torch.distributed.launch –n_proc_per_node 4 task_disti.py
启动成功了吗?模型又开始处理数据….One hours later,机器突然卡住,程序的 log 也停了,打开 htop 一看,好家伙,256G 的内存都满了,程序都是 D 状态,咋回事?
Acceleration of Data Loading
我先试了少量数据,降采样到 10k,程序运行没问题, DDP 速度很快;我再尝试了单卡加载,虽然又 load 了一个小时,但是 ok,程序还是能跑起来,那么,问题是如何发生的呢?单卡的时候我看了一眼加载全量数据完毕之后的内存占用,大约在 60G 左右,考虑到 DDP 是多进程的,因此,每个进程都要独立地加载数据,4 块卡 4个进程,大约就是 250 G 的内存,因此内存爆炸,到后面数据的 io 就卡住了(没法从磁盘 load 到内存),所以造成了程序 D 状态。看了下组里的机器,最大的也就是 250 G 内存,也就是说,如果我只用 3 块卡,那么是能够跑的,但是万一有别的同学上来开程序吃了一部分内存,那么就很可能爆内存,然后就是大家的程序都同归于尽的局面,不太妙。一种不太优雅的解决方案就是,把数据切块,然后读完一小块训练完,再读下一块,再训练,再读。咨询了一下组里资深的师兄,还有一种办法就是实现一种把数据存在磁盘上,每次要用的时候才 load 到内存的数据读取方案,这样就能够避免爆内存的问题。行吧,那就干吧,但是总不能从头造轮子吧?脸折师兄提到 huggingface(yyds) 的 datasets 能够支持这个功能,check 了一下文档,发现他是基于 pyarrow 的实现了一个 memory map 的数据读取,以我的 huggingface transformers 的经验,似乎是能够实现这个功能的,所以摩拳擦掌,准备动手。
首先,要把增广的数据 load 进来,datasets 提供的 load_dataset 函数最接近的就是 load_dataset(‘csv’, data_file),然后我们就可以逐个 column 的拿到数据并且进行预处理了。写了一会,发现总是报读取一部分数据后 columns 数目不对的错误,猜测可能原始 MNLI 数据集就不太能保证每个列都是在的,检查了一下 MnliProcessor 里处理的代码,发现其写死了 line[8] 和 line[9] 作为 sentence_a 和 sentence_b。无奈之下,只能采取最粗暴地方式,用 text mode 读进来,每一行是一个数据,再 split:
from datasets import processor = processors[task_name]() output_mode = output_modes[task_name] label_list = processor.get_labels() num_labels = len(label_list) tokenizer = BertTokenizer.from_pretrained(args.student_model, do_lower_case=args.do_lower_case) # 用 text mnli_datasets = load_dataset("text", data_files=os.path.join(args.data_dir, "train_aug.tsv")) label_classes = processor.get_labels() label_map = {label: i for i, label in enumerate(label_classes)} def preprocess_func(examples, max_seq_length=args.max_seq_length): splits = [e.split('\t') for e in examples['text']] # split # tokenize for sent1 & sent2 tokens_s1 = [tokenizer.tokenize(e[8]) for e in splits] tokens_s2 = [tokenizer.tokenize(e[9]) for e in splits] for t1, t2 in zip(tokens_s1, tokens_s2): truncate_seq_pair(t1, t2, max_length=max_seq_length - 3) input_ids_list = [] input_mask_list = [] segment_ids_list = [] seq_length_list = [] labels_list = [] labels = [e[-1] for e in splits] # last column is label column for token_a, token_b, l in zip(tokens_s1, tokens_s2, labels): # zip(tokens_as, tokens_bs): tokens = ["[CLS]"] + token_a + ["[SEP]"] segment_ids = [0] * len(tokens) tokens += token_b + ["[SEP]"] segment_ids += [1] * (len(token_b) + 1) input_ids = tokenizer.convert_tokens_to_ids(tokens) # tokenize to id input_mask = [1] * len(input_ids) seq_length = len(input_ids) padding = [0] * (max_seq_length - len(input_ids)) input_ids += padding input_mask += padding segment_ids += padding assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length input_ids_list.append(input_ids) input_mask_list.append(input_mask) segment_ids_list.append(segment_ids) seq_length_list.append(seq_length) labels_list.append(label_map[l]) results = {"input_ids": input_ids_list, "input_mask": input_mask_list, "segment_ids": segment_ids_list, "seq_length": seq_length_list, "label_ids": labels_list} return results # map datasets mnli_datasets = mnli_datasets.map(preprocess_func, batched=True) # remove column train_data = mnli_datasets['train'].remove_columns('text')
写完这个 preprocess_func ,我觉得胜利在望,但还有几个小坑需要解决:
-
map 完之后,返回的还是一个 DatasetDict,得手动取一下 train set;
-
对于原先存在的列,map 函数并不会去除掉,所以如果不用的列,需要手动 .remove_columns()
-
在配合 DDP 使用的时候,因为 DistributedSample 取数据的维度是在第一维取的,所以取到的数据可能是个 seq_len 长的列表,里面的 tensor 是 [bsz] 形状的,需要在交给 model 之前 stack 一下:
inputs = {} for k, v in batch.items(): if isinstance(v, torch.Tensor): inputs[k] = v.to(device) elif isinstance(v, List): inputs[k] = torch.stack(v, dim=1).to(device)
至此,只要把之前代码的 train_data 都换成现在的版本即可。
此外,为了进一步加速,我还把混合精度也整合了进来,现在 Pytorch 以及自带对混合精度的支持,代码量也很少,但是有个坑就是loss 的计算必须被 auto() 包裹住,同时,所有模型的输出都要参与到 loss 的计算,这对于只做 prediction 或者是 hidden state 对齐的 loss 很不友好,所以只能手动再额外计算一项为系数为 0 的 loss 项(这样他参与到训练但是不会影响梯度)。
-
这篇是真真真·干货