>>> q
tensor([[-1.0092, 1.1658, 2.0389],
[-1.1147, 0.8038, 0.4668],
[ 0.1952, -0.0108, 0.3505],
[-0.2510, 0.6763, -0.0758],
[ 0.1935, 0.3007, -0.4039]])
>>> selector
tensor([[1],
[2],
[0],
[0],
[1]])
>>> q.gather(-1, selector)
tensor([[ 1.1658],
[ 0.4668],
[ 0.1952],
[-0.2510],
[ 0.3007]])
>>> selector1
tensor([[1, 1],
[0, 2],
[2, 0],
[1, 1],
[1, 2]])
>>> q.gather(-1, selector1)
tensor([[ 1.1658, 1.1658],
[-1.1147, 0.4668],
[ 0.3505, 0.1952],
[ 0.6763, 0.6763],
[ 0.3007, -0.4039]])