Python/알면 쓸모있는 잡다한 코드
torch.max , torch.argmax 차이
joannekim0420
2022. 11. 11. 17:10
728x90
1. torch.max
input tensor들 중 가장 최댓값을 output함.
import torch
a = torch.randn(1, 3)
print(a)
print(torch.max(a))
1x3 tensor에서 최대값만 결과로 나옴.
2차원 배열에서는,,
a = torch.randn(3, 4)
print(a)
print(torch.max(a))
단순히 torch.max을 print 하면, 2차원 배열에서 최댓값을 output
하지만, 이런식으로 dim 자리에 차원을 설정해줄 수 있다.
value, idx = torch.max(a,1)
print("value",value)
print("idx",idx)
output = (values, idx)
2. torch.argmax
입력은 torch.max와 다르지 않지만 output에서 차이.
import torch
a = torch.randn(3, 4)
print(a)
idx = torch.argmax(a,1)
print(idx)
idx 값들만 output함.