Python/알면 쓸모있는 잡다한 코드

torch.max , torch.argmax 차이

joannekim0420 2022. 11. 11. 17:10
728x90

 

 

1. torch.max

input tensor들 중 가장 최댓값을 output함.

import torch

a = torch.randn(1, 3)
print(a)
print(torch.max(a))

 

1x3 tensor에서 최대값만 결과로 나옴. 

 

 

2차원 배열에서는,,

a = torch.randn(3, 4)
print(a)
print(torch.max(a))

단순히 torch.max을 print 하면, 2차원 배열에서 최댓값을 output

 

하지만, 이런식으로 dim 자리에 차원을 설정해줄 수 있다. 

value, idx = torch.max(a,1)

print("value",value)
print("idx",idx)

output = (values, idx)

 

 

 

 

2. torch.argmax

입력은 torch.max와 다르지 않지만 output에서 차이. 

 

import torch

a = torch.randn(3, 4)
print(a)
idx = torch.argmax(a,1)

print(idx)

idx 값들만 output함.