Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    BPE 算法详解

    语音识别与语义处理领域
    1
    1
    108
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 155****7220
      155****7220 last edited by

      Byte Pair Encoding

      在NLP模型中,输入通常是一个句子,例如"I went to New York last week.",一句话中包含很多单词(token)。传统的做法是将这些单词以空格进行分隔,例如['i', 'went', 'to', 'New', 'York', 'last', 'week']。然而这种做法存在很多问题,例如模型无法通过old, older, oldest之间的关系学到smart, smarter, smartest之间的关系。如果我们能使用将一个token分成多个subtokens,上面的问题就能很好的解决。本文将详述目前比较常用的subtokens算法——BPE(Byte-Pair Encoding)

      现在性能比较好一些的NLP模型,例如GPT、BERT、RoBERTa等,在数据预处理的时候都会有WordPiece的过程,其主要的实现方式就是BPE(Byte-Pair Encoding)。具体来说,例如['loved', 'loving', 'loves']这三个单词。其实本身的语义都是"爱"的意思,但是如果我们以词为单位,那它们就算不一样的词,在英语中不同后缀的词非常的多,就会使得词表变的很大,训练速度变慢,训练的效果也不是太好。BPE算法通过训练,能够把上面的3个单词拆分成["lov","ed","ing","es"]几部分,这样可以把词的本身的意思和时态分开,有效的减少了词表的数量。算法流程如下:

      1. 设定最大subwords个数VVV
      2. 将所有单词拆分为单个字符,并在最后添加一个停止符</w>,同时标记出该单词出现的次数。例如,"low"这个单词出现了5次,那么它将会被处理为{'l o w </w>': 5}
      3. 统计每一个连续字节对的出现频率,选择最高频者合并成新的subword
      4. 重复第3步直到达到第1步设定的subwords词表大小或下一个最高频的字节对出现频率为1

      例如

      {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}
      

      出现最频繁的字节对是**e和s**,共出现了6+3=9次,因此将它们合并

      {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3}
      

      出现最频繁的字节对是**es和t**,共出现了6+3=9次,因此将它们合并

      {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3}
      

      出现最频繁的字节对是**est和</w>**,共出现了6+3=9次,因此将它们合并

      {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
      

      出现最频繁的字节对是**l和o**,共出现了5+2=7次,因此将它们合并

      {'lo w </w>': 5, 'lo w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
      

      出现最频繁的字节对是**lo和w**,共出现了5+2=7次,因此将它们合并

      {'low </w>': 5, 'low e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
      

      …继续迭代直到达到预设的subwords词表大小或下一个最高频的字节对出现频率为1。这样我们就得到了更加合适的词表,这个词表可能会出现一些不是单词的组合,但是其本身有意义的一种形式

      停止符</w>的意义在于表示subword是词后缀。举例来说:st不加</w>可以出现在词首,如st ar;加了</w>表明改字词位于词尾,如wide st</w>,二者意义截然不同

      BPE实现

      import re, collections
      
      def get_vocab(filename):
          vocab = collections.defaultdict(int)
          with open(filename, 'r', encoding='utf-8') as fhand:
              for line in fhand:
                  words = line.strip().split()
                  for word in words:
                      vocab[' '.join(list(word)) + ' </w>'] += 1
          return vocab
      
      def get_stats(vocab):
          pairs = collections.defaultdict(int)
          for word, freq in vocab.items():
              symbols = word.split()
              for i in range(len(symbols)-1):
                  pairs[symbols[i],symbols[i+1]] += freq
          return pairs
      
      def merge_vocab(pair, v_in):
          v_out = {}
          bigram = re.escape(' '.join(pair))
          p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
          for word in v_in:
              w_out = p.sub(''.join(pair), word)
              v_out[w_out] = v_in[word]
          return v_out
      
      def get_tokens(vocab):
          tokens = collections.defaultdict(int)
          for word, freq in vocab.items():
              word_tokens = word.split()
              for token in word_tokens:
                  tokens[token] += freq
          return tokens
      
      vocab = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}
      
      # Get free book from Gutenberg
      # wget http://www.gutenberg.org/cache/epub/16457/pg16457.txt
      # vocab = get_vocab('pg16457.txt')
      
      print('==========')
      print('Tokens Before BPE')
      tokens = get_tokens(vocab)
      print('Tokens: {}'.format(tokens))
      print('Number of tokens: {}'.format(len(tokens)))
      print('==========')
      
      num_merges = 5
      for i in range(num_merges):
          pairs = get_stats(vocab)
          if not pairs:
              break
          best = max(pairs, key=pairs.get)
          vocab = merge_vocab(best, vocab)
          print('Iter: {}'.format(i))
          print('Best pair: {}'.format(best))
          tokens = get_tokens(vocab)
          print('Tokens: {}'.format(tokens))
          print('Number of tokens: {}'.format(len(tokens)))
          print('==========')
      

      输出如下

      ==========
      Tokens Before BPE
      Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 16, 'e': 17, 'r': 2, 'n': 6, 's': 9, 't': 9, 'i': 3, 'd': 3})
      Number of tokens: 11
      ==========
      Iter: 0
      Best pair: ('e', 's')
      Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 16, 'e': 8, 'r': 2, 'n': 6, 'es': 9, 't': 9, 'i': 3, 'd': 3})
      Number of tokens: 11
      ==========
      Iter: 1
      Best pair: ('es', 't')
      Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 16, 'e': 8, 'r': 2, 'n': 6, 'est': 9, 'i': 3, 'd': 3})
      Number of tokens: 10
      ==========
      Iter: 2
      Best pair: ('est', '</w>')
      Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 7, 'e': 8, 'r': 2, 'n': 6, 'est</w>': 9, 'i': 3, 'd': 3})
      Number of tokens: 10
      ==========
      Iter: 3
      Best pair: ('l', 'o')
      Tokens: defaultdict(<class 'int'>, {'lo': 7, 'w': 16, '</w>': 7, 'e': 8, 'r': 2, 'n': 6, 'est</w>': 9, 'i': 3, 'd': 3})
      Number of tokens: 9
      ==========
      Iter: 4
      Best pair: ('lo', 'w')
      Tokens: defaultdict(<class 'int'>, {'low': 7, '</w>': 7, 'e': 8, 'r': 2, 'n': 6, 'w': 9, 'est</w>': 9, 'i': 3, 'd': 3})
      Number of tokens: 9
      ==========
      

      编码和解码

      编码

      在之前的算法中,我们已经得到了subword的词表,对该词表按照字符个数由多到少排序。编码时,对于每个单词,遍历排好序的子词词表寻找是否有token是当前单词的子字符串,如果有,则该token是表示单词的tokens之一

      我们从最长的token迭代到最短的token,尝试将每个单词中的子字符串替换为token。 最终,我们将迭代所有tokens,并将所有子字符串替换为tokens。 如果仍然有子字符串没被替换但所有token都已迭代完毕,则将剩余的子词替换为特殊token,如<unk>

      例如

      # 给定单词序列
      ["the</w>", "highest</w>", "mountain</w>"]
      
      # 排好序的subword表
      # 长度 6         5           4        4         4       4          2
      ["errrr</w>", "tain</w>", "moun", "est</w>", "high", "the</w>", "a</w>"]
      
      # 迭代结果
      "the</w>" -> ["the</w>"]
      "highest</w>" -> ["high", "est</w>"]
      "mountain</w>" -> ["moun", "tain</w>"]
      
      解码

      将所有的tokens拼在一起即可,例如

      # 编码序列
      ["the</w>", "high", "est</w>", "moun", "tain</w>"]
      
      # 解码序列
      "the</w> highest</w> mountain</w>"
      

      编码和解码实现

      import re, collections
      
      def get_vocab(filename):
          vocab = collections.defaultdict(int)
          with open(filename, 'r', encoding='utf-8') as fhand:
              for line in fhand:
                  words = line.strip().split()
                  for word in words:
                      vocab[' '.join(list(word)) + ' </w>'] += 1
      
          return vocab
      
      def get_stats(vocab):
          pairs = collections.defaultdict(int)
          for word, freq in vocab.items():
              symbols = word.split()
              for i in range(len(symbols)-1):
                  pairs[symbols[i],symbols[i+1]] += freq
          return pairs
      
      def merge_vocab(pair, v_in):
          v_out = {}
          bigram = re.escape(' '.join(pair))
          p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
          for word in v_in:
              w_out = p.sub(''.join(pair), word)
              v_out[w_out] = v_in[word]
          return v_out
      
      def get_tokens_from_vocab(vocab):
          tokens_frequencies = collections.defaultdict(int)
          vocab_tokenization = {}
          for word, freq in vocab.items():
              word_tokens = word.split()
              for token in word_tokens:
                  tokens_frequencies[token] += freq
              vocab_tokenization[''.join(word_tokens)] = word_tokens
          return tokens_frequencies, vocab_tokenization
      
      def measure_token_length(token):
          if token[-4:] == '</w>':
              return len(token[:-4]) + 1
          else:
              return len(token)
      
      def tokenize_word(string, sorted_tokens, unknown_token='</u>'):
          
          if string == '':
              return []
          if sorted_tokens == []:
              return [unknown_token]
      
          string_tokens = []
          for i in range(len(sorted_tokens)):
              token = sorted_tokens[i]
              token_reg = re.escape(token.replace('.', '[.]'))
      
              matched_positions = [(m.start(0), m.end(0)) for m in re.finditer(token_reg, string)]
              if len(matched_positions) == 0:
                  continue
              substring_end_positions = [matched_position[0] for matched_position in matched_positions]
      
              substring_start_position = 0
              for substring_end_position in substring_end_positions:
                  substring = string[substring_start_position:substring_end_position]
                  string_tokens += tokenize_word(string=substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)
                  string_tokens += [token]
                  substring_start_position = substring_end_position + len(token)
              remaining_substring = string[substring_start_position:]
              string_tokens += tokenize_word(string=remaining_substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)
              break
          return string_tokens
      
      # vocab = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}
      
      vocab = get_vocab('pg16457.txt')
      
      print('==========')
      print('Tokens Before BPE')
      tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
      print('All tokens: {}'.format(tokens_frequencies.keys()))
      print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
      print('==========')
      
      num_merges = 10000
      for i in range(num_merges):
          pairs = get_stats(vocab)
          if not pairs:
              break
          best = max(pairs, key=pairs.get)
          vocab = merge_vocab(best, vocab)
          print('Iter: {}'.format(i))
          print('Best pair: {}'.format(best))
          tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
          print('All tokens: {}'.format(tokens_frequencies.keys()))
          print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
          print('==========')
      
      # Let's check how tokenization will be for a known word
      word_given_known = 'mountains</w>'
      word_given_unknown = 'Ilikeeatingapples!</w>'
      
      sorted_tokens_tuple = sorted(tokens_frequencies.items(), key=lambda item: (measure_token_length(item[0]), item[1]), reverse=True)
      sorted_tokens = [token for (token, freq) in sorted_tokens_tuple]
      
      print(sorted_tokens)
      
      word_given = word_given_known 
      
      print('Tokenizing word: {}...'.format(word_given))
      if word_given in vocab_tokenization:
          print('Tokenization of the known word:')
          print(vocab_tokenization[word_given])
          print('Tokenization treating the known word as unknown:')
          print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))
      else:
          print('Tokenizating of the unknown word:')
          print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))
      
      word_given = word_given_unknown 
      
      print('Tokenizing word: {}...'.format(word_given))
      if word_given in vocab_tokenization:
          print('Tokenization of the known word:')
          print(vocab_tokenization[word_given])
          print('Tokenization treating the known word as unknown:')
          print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))
      else:
          print('Tokenizating of the unknown word:')
          print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))
      

      输出如下

      Tokenizing word: mountains</w>...
      Tokenization of the known word:
      ['mountains</w>']
      Tokenization treating the known word as unknown:
      ['mountains</w>']
      Tokenizing word: Ilikeeatingapples!</w>...
      Tokenizating of the unknown word:
      ['I', 'like', 'ea', 'ting', 'app', 'l', 'es!</w>']
      

      Reference

      • 3 subword algorithms help to improve your NLP model performance
      • Tokenizers: How machines read
      • Overview of tokenization algorithms in NLP
      • 一文读懂BERT中的WordPiece
      • Byte Pair Encoding
      • 深入理解NLP Subword算法:BPE、WordPiece、ULM
      1 Reply Last reply Reply Quote 4
      • First post
        Last post