Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    【记录】pytorch_scatter工具

    技术交流
    1
    1
    106
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 183****0229
      183****0229 last edited by

      地址:https://github.com/rusty1s/pytorch_scatter
      pytorch_scatter包是一个高度优化的稀疏update操作(scatter 和 segment)的小型扩展库。(PyTorch主包中缺少这些操作)。

      安装:

      pip

      pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+${CUDA}.html

      conda

      conda install pytorch-scatter -c rusty1s

      例子:

      import torch
      from torch_scatter import scatter_max
      
      src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
      index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
      
      out, argmax = scatter_max(src, index, dim=-1)
      
      print(out)
      tensor([[0, 0, 4, 3, 2, 0],
              [2, 4, 3, 0, 0, 0]])
      
      print(argmax)
      tensor([[5, 5, 3, 4, 0, 1]
              [1, 4, 3, 5, 5, 5]])
      

      END:

      使用了这个方法,可以帮助我们快速对最终的结果进行一些sao操作。

      1 Reply Last reply Reply Quote 1
      • First post
        Last post