Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    【记录】rotary-embedding-torch

    技术交流
    1
    1
    72
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 183****0229
      183****0229 last edited by

      旋转位置编码
      地址:https://github.com/lucidrains/rotary-embedding-torch
      安装:pip install rotary-embedding-torch

      代码(可以发现这里使用到了rearrange库)

      from inspect import isfunction
      from math import pi, log
      
      import torch
      from torch import nn, einsum
      
      from einops import rearrange, repeat
      
      # helper functions
      
      def exists(val):
          return val is not None
      
      def broadcat(tensors, dim = -1):
          num_tensors = len(tensors)
          shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
          assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
          shape_len = list(shape_lens)[0]
      
          dim = (dim + shape_len) if dim < 0 else dim
          dims = list(zip(*map(lambda t: list(t.shape), tensors)))
      
          expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
          assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
          max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
          expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
          expanded_dims.insert(dim, (dim, dims[dim]))
          expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
          tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
          return torch.cat(tensors, dim = dim)
      
      # rotary embedding helper functions
      
      def rotate_half(x):
          x = rearrange(x, '... (d r) -> ... d r', r = 2)
          x1, x2 = x.unbind(dim = -1)
          x = torch.stack((-x2, x1), dim = -1)
          return rearrange(x, '... d r -> ... (d r)')
      
      def apply_rotary_emb(freqs, t, start_index = 0):
          rot_dim = freqs.shape[-1]
          end_index = start_index + rot_dim
          assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
          t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
          t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
          return torch.cat((t_left, t, t_right), dim = -1)
      
      # learned rotation helpers
      
      def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
          if exists(freq_ranges):
              rotations = einsum('..., f -> ... f', rotations, freq_ranges)
              rotations = rearrange(rotations, '... r f -> ... (r f)')
      
          rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
          return apply_rotary_emb(rotations, t, start_index = start_index)
      
      # classes
      
      class RotaryEmbedding(nn.Module):
          def __init__(
              self,
              dim,
              custom_freqs = None,
              freqs_for = 'lang',
              theta = 10000,
              max_freq = 10,
              num_freqs = 1,
              learned_freq = False
          ):
              super().__init__()
              if exists(custom_freqs):
                  freqs = custom_freqs
              elif freqs_for == 'lang':
                  freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
              elif freqs_for == 'pixel':
                  freqs = torch.logspace(0., log(max_freq / 2) / log(2), dim // 2, base = 2) * pi
              elif freqs_for == 'constant':
                  freqs = torch.ones(num_freqs).float()
              else:
                  raise ValueError(f'unknown modality {freqs_for}')
      
              self.cache = dict()
      
              if learned_freq:
                  self.freqs = nn.Parameter(freqs)
              else:
                  self.register_buffer('freqs', freqs)
      
          def forward(self, t, cache_key = None):
              if exists(cache_key) and cache_key in self.cache:
                  return self.cache[cache_key]
      
              if isfunction(t):
                  t = t()
      
              freqs = self.freqs
      
              freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
              freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
      
              if exists(cache_key):
                  self.cache[cache_key] = freqs
      
              return freqs
      

      END

      这里的实现更为优雅,比我之前实现的要美观多了!!!

      1 Reply Last reply Reply Quote 1
      • First post
        Last post