【记录】pytorch_scatter工具
-
地址:https://github.com/rusty1s/pytorch_scatter
pytorch_scatter包是一个高度优化的稀疏update操作(scatter 和 segment)的小型扩展库。(PyTorch主包中缺少这些操作)。安装:
pip
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+${CUDA}.html
conda
conda install pytorch-scatter -c rusty1s
例子:
import torch from torch_scatter import scatter_max src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) out, argmax = scatter_max(src, index, dim=-1) print(out) tensor([[0, 0, 4, 3, 2, 0], [2, 4, 3, 0, 0, 0]]) print(argmax) tensor([[5, 5, 3, 4, 0, 1] [1, 4, 3, 5, 5, 5]])
END:
使用了这个方法,可以帮助我们快速对最终的结果进行一些sao操作。