以二維numpy矩陣為例
import torch
import numpy as np
K=3 #取每行最小3個(gè)值的索引
data=np.random.rand(4,7)
print(data)
data=torch.from_numpy(data)
a, idx = torch.sort(data, descending=False)
lists=idx[:,:K]
print(lists)
運(yùn)行結(jié)果如下:
results.jpg
import torch
import numpy as np
K=3 #取每行最小3個(gè)值的索引
data=np.random.rand(4,7)
print(data)
data=torch.from_numpy(data)
a, idx = torch.sort(data, descending=False)
lists=idx[:,:K]
print(lists)