[PyTorch] PyTorch Basic Tutorial
파이토치 기본
A. torch **
- 벡터
torch.tensor([1,2,3])
tensor([1, 2, 3])
- 벡터의 덧셈
torch.tensor([1,2,3]) + torch.tensor([2,2,2])
tensor([3, 4, 5])
- 브로드캐스팅
torch.tensor([1,2,3]) + 2
tensor([3, 4, 5])
B. 벡터와 매트릭스
- $3 \times 2$ matrix
torch.tensor([[1,2],[3,4],[5,6]])
tensor([[1, 2],
[3, 4],
[5, 6]])
- $3 \times 1$ matrix = $3 \times 1$ column vector
torch.tensor([[1],[3],[5]])
tensor([[1],
[3],
[5]])
- $1 \times 2$ matrix = $1 \times 2$ row vector
torch.tensor([[1,2]])
tensor([[1, 2]])
- 더하기
브로드캐스팅(편한거)
torch.tensor([[1,2],[3,4],[5,6]]) - 1
tensor([[0, 1],
[2, 3],
[4, 5]])
torch.tensor([[1,2],[3,4],[5,6]]) + torch.tensor([[-1],[-3],[-5]])
tensor([[0, 1],
[0, 1],
[0, 1]])
torch.tensor([[1,2],[3,4],[5,6]]) + torch.tensor([[-1,-2]])
tensor([[0, 0],
[2, 2],
[4, 4]])
잘못된 브로드캐스팅
torch.tensor([[1,2],[3,4],[5,6]]) + torch.tensor([[-1,-3,-5]])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[11], line 1
----> 1 torch.tensor([[1,2],[3,4],[5,6]]) + torch.tensor([[-1,-3,-5]])
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
torch.tensor([[1,2],[3,4],[5,6]]) + torch.tensor([[-1],[-2]])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[12], line 1
----> 1 torch.tensor([[1,2],[3,4],[5,6]]) + torch.tensor([[-1],[-2]])
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0
이상한 것
torch.tensor([[1,2],[3,4],[5,6]]) + torch.tensor([-1,-2])
tensor([[0, 0],
[2, 2],
[4, 4]])
torch.tensor([[1,2],[3,4],[5,6]]) + torch.tensor([-1,-3,-5])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[14], line 1
----> 1 torch.tensor([[1,2],[3,4],[5,6]]) + torch.tensor([-1,-3,-5])
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
- 행렬곱
정상적인 행렬곱
torch.tensor([[1,2],[3,4],[5,6]]) @ torch.tensor([[1],[2]])
tensor([[ 5],
[11],
[17]])
torch.tensor([[1,2,3]]) @ torch.tensor([[1,2],[3,4],[5,6]])
tensor([[22, 28]])
잘못된 행렬곱
torch.tensor([[1,2],[3,4],[5,6]]) @ torch.tensor([[1,2]])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[17], line 1
----> 1 torch.tensor([[1,2],[3,4],[5,6]]) @ torch.tensor([[1,2]])
RuntimeError: mat1 and mat2 shapes cannot be multiplied (3x2 and 1x2)
torch.tensor([[1],[2],[3]]) @ torch.tensor([[1,2],[3,4],[5,6]])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[18], line 1
----> 1 torch.tensor([[1],[2],[3]]) @ torch.tensor([[1,2],[3,4],[5,6]])
RuntimeError: mat1 and mat2 shapes cannot be multiplied (3x1 and 3x2)
이상한 것
torch.tensor([[1,2],[3,4],[5,6]]) @ torch.tensor([1,2]) # 이게 왜 가능..
tensor([ 5, 11, 17])
torch.tensor([1,2,3]) @ torch.tensor([[1,2],[3,4],[5,6]]) # 이건 왜 가능?
tensor([22, 28])
C. transpose, reshape
- transpose
torch.tensor([[1,2],[3,4]]).T
tensor([[1, 3],
[2, 4]])
torch.tensor([[1],[3]]).T
tensor([[1, 3]])
torch.tensor([[1,2]]).T
tensor([[1],
[2]])
- reshape
일반적인 사용
torch.tensor([[1,2],[3,4],[5,6]]).reshape(2,3)
tensor([[1, 2, 3],
[4, 5, 6]])
torch.tensor([[1,2],[3,4],[5,6]])
tensor([[1, 2],
[3, 4],
[5, 6]])
torch.tensor([[1,2],[3,4],[5,6]]).reshape(1,6)
tensor([[1, 2, 3, 4, 5, 6]])
torch.tensor([[1,2],[3,4],[5,6]]).reshape(6)
tensor([1, 2, 3, 4, 5, 6])
편한 것
torch.tensor([[1,2],[3,4],[5,6]]).reshape(2,-1)
tensor([[1, 2, 3],
[4, 5, 6]])
torch.tensor([[1,2],[3,4],[5,6]]).reshape(6,-1)
tensor([[1],
[2],
[3],
[4],
[5],
[6]])
torch.tensor([[1,2],[3,4],[5,6]]).reshape(-1,6)
tensor([[1, 2, 3, 4, 5, 6]])
torch.tensor([[1,2],[3,4],[5,6]]).reshape(-1)
tensor([1, 2, 3, 4, 5, 6])
D. concat, stack $(\star\star\star)$
- concat
a = torch.tensor([[1],[3],[5]])
b = torch.tensor([[2],[4],[6]])
print(a.shape)
print(b.shape)
torch.concat([a,b],axis=1)
torch.Size([3, 1])
torch.Size([3, 1])
tensor([[1, 2],
[3, 4],
[5, 6]])
torch.concat([a,b])
tensor([[1],
[3],
[5],
[2],
[4],
[6]])
torch.concat([a,b],axis=0)
tensor([[1],
[3],
[5],
[2],
[4],
[6]])
torch.concat([a,b],axis=1)
tensor([[1, 2],
[3, 4],
[5, 6]])
torch.concat([a,b],axis=2)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[36], line 1
----> 1 torch.concat([a,b],axis=2)
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
aa = torch.tensor([[1,2],[3,4],[5,6]]) # 3x2
bb = torch.tensor([[7,8],[9,0],[10,2]]) # 3x2
torch.concat([aa,bb]) # 6x2
tensor([[ 1, 2],
[ 3, 4],
[ 5, 6],
[ 7, 8],
[ 9, 0],
[10, 2]])
torch.concat([aa,bb]).shape
torch.Size([6, 2])
print(torch.concat([aa,bb],axis=0).shape)
print(torch.concat([aa,bb],axis=1).shape)
print(torch.concat([aa,bb],axis=2).shape)
torch.Size([6, 2])
torch.Size([3, 4])
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[39], line 3
1 print(torch.concat([aa,bb],axis=0).shape)
2 print(torch.concat([aa,bb],axis=1).shape)
----> 3 print(torch.concat([aa,bb],axis=2).shape)
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
aaa = torch.tensor([[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]]) # 2 x 2 x 3
bbb = torch.tensor([[[13,14,15],[16,17,18]],[[19,20,21],[22,23,24]]]) # 2 x 2 x 3
print(aaa.shape)
print(bbb.shape)
torch.Size([2, 2, 3])
torch.Size([2, 2, 3])
print(torch.concat([aaa,bbb],axis=0).shape)
print(torch.concat([aaa,bbb],axis=1).shape)
print(torch.concat([aaa,bbb],axis=2).shape)
torch.Size([4, 2, 3])
torch.Size([2, 4, 3])
torch.Size([2, 2, 6])
- stack
a = torch.tensor([1,3,5])
b = torch.tensor([2,4,6])
print(a.shape)
print(b.shape)
torch.stack([a,b],axis=1)
torch.Size([3])
torch.Size([3])
tensor([[1, 2],
[3, 4],
[5, 6]])
torch.concat([a.reshape(3,1),b.reshape(3,1)],axis=1)
tensor([[1, 2],
[3, 4],
[5, 6]])
note: np.concatenate은 축의 총 갯수를 유지하면서 결합, np.stack은 축의 갯수를 하나 증가시키면서 결합
:::{.callout-warning}
concat과 stack을 지금 처음본다면 아래를 복습하시는게 좋습니다.
https://guebin.github.io/PP2024/posts/06wk-2.html#numpy와-축axis :::
Leave a comment