Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    【记录】x-transformers库

    技术交流
    1
    1
    107
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 183****0229
      183****0229 last edited by 183****0229

      地址:https://github.com/lucidrains/x-transformers
      安装:pip install x-transformers
      优点:

      • 这里包含了很多transformer的变种模型,感觉很全。
      • 可以通过这里的代码学习一下pytorch和python知识。
      • 可以学习一下rearrange库的使用。

      缺点:

      • 无法加载预训权重,因为改起来会需要费精力。

      例子:

      Full encoder / decoder

      import torch
      from x_transformers import XTransformer
      
      model = XTransformer(
          dim = 512,
          enc_num_tokens = 256,
          enc_depth = 6,
          enc_heads = 8,
          enc_max_seq_len = 1024,
          dec_num_tokens = 256,
          dec_depth = 6,
          dec_heads = 8,
          dec_max_seq_len = 1024,
          tie_token_emb = True      # tie embeddings of encoder and decoder
      )
      
      src = torch.randint(0, 256, (1, 1024))
      src_mask = torch.ones_like(src).bool()
      tgt = torch.randint(0, 256, (1, 1024))
      tgt_mask = torch.ones_like(tgt).bool()
      
      loss = model(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask) # (1, 1024, 512)
      loss.backward()
      

      GPT

      import torch
      from x_transformers import TransformerWrapper, Decoder
      
      model = TransformerWrapper(
          num_tokens = 20000,
          max_seq_len = 1024,
          attn_layers = Decoder(
              dim = 512,
              depth = 12,
              heads = 8
          )
      ).cuda()
      
      x = torch.randint(0, 256, (1, 1024)).cuda()
      
      model(x) # (1, 1024, 20000)
      

      END

      更多的内容可以看下这个作者的其他仓库,感觉写的都非常好!!!

      1 Reply Last reply Reply Quote 1
      • First post
        Last post