【记录】x-transformers库
-
地址:https://github.com/lucidrains/x-transformers
安装:pip install x-transformers
优点:- 这里包含了很多transformer的变种模型,感觉很全。
- 可以通过这里的代码学习一下pytorch和python知识。
- 可以学习一下rearrange库的使用。
缺点:
- 无法加载预训权重,因为改起来会需要费精力。
例子:
Full encoder / decoder
import torch from x_transformers import XTransformer model = XTransformer( dim = 512, enc_num_tokens = 256, enc_depth = 6, enc_heads = 8, enc_max_seq_len = 1024, dec_num_tokens = 256, dec_depth = 6, dec_heads = 8, dec_max_seq_len = 1024, tie_token_emb = True # tie embeddings of encoder and decoder ) src = torch.randint(0, 256, (1, 1024)) src_mask = torch.ones_like(src).bool() tgt = torch.randint(0, 256, (1, 1024)) tgt_mask = torch.ones_like(tgt).bool() loss = model(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask) # (1, 1024, 512) loss.backward()
GPT
import torch from x_transformers import TransformerWrapper, Decoder model = TransformerWrapper( num_tokens = 20000, max_seq_len = 1024, attn_layers = Decoder( dim = 512, depth = 12, heads = 8 ) ).cuda() x = torch.randint(0, 256, (1, 1024)).cuda() model(x) # (1, 1024, 20000)
END
更多的内容可以看下这个作者的其他仓库,感觉写的都非常好!!!