Data Analysis

Python/머신러닝 & 딥러닝

[Pytorch] 경사하강법으로 이미지 복원해보기

Holy_Water 2023. 1. 26. 12:59

전제조건

  • 이미지 처리를 위해 만들어 두었던 weird_function() 함수에 실수로 버그가 들어가 100×100 픽셀의 오염된 이미지가 만들어졌습니다.
  • 이 오염된 이미지와 오염되기 전 원본 이미지를 동시에 파일로 저장하려고 했으나, 모종의 이유로 원본 이미지 파일은 삭제된 상황입니다.
  • 다행히도 weird_function()의 소스코드는 남아 있습니다.
  • 오염된 이미지와 weird_function()을 활용해 원본 이미지를 복원해봅시다.

 

문제 해결 방안

  1. 오염된 이미지와 크기가 같은 랜덤 텐서를 생성한다.(랜덤 텐서는 오염된 이미지와 크기가 같은 무작위 이미지 텐서입니다)
  2. 랜덤 텐서를 weird_function 함수에 입력해 똑같이 오염된 이미지를 가설이라고 부른다.
  • a. [사실] 원본 이미지가 weird_function() 함수에 입력되어 오염된 이미지를 출력했다
  • b. [사실] 인위적으로 생성한 무작위 이미지가 weird_function()함수에 입력되어 가설을 출력했다.
  1. 가설과 오염된 이미지가 같다면, 무작위 이미지와 원본 이미지도 같을 것이다.
  2. 그러므로 weird_function(random_tensor) = broken_image 관계가 성립하도록 만든다.

'가설'과 원본 이미지가 weird_function() 함수를 통해 오염되기 전의 이미지(정답) 사이의 거리를 오차라 하고, 이 오찻값이 최솟값이 되도록 랜덤 텐서(random_tensor)를 바꿔주는 것이 우리의 목표입니다.

 

 

필요 라이브러리 가져오기

import torch
import pickle
import matplotlib.pyplot as plt

 

오염된 이미지 불러오기

# 오염된 이미지를 파이토치의 텐서 형태로 읽습니다
shp_original_img = (100, 100)
broken_image =  torch.FloatTensor( pickle.load(open('broken_image_t.p', 'rb'),encoding='latin1' ) )

plt.imshow(broken_image.view(100,100)) 

 

문제에서 제공하는 weird_function 함수 구현

# 이미지를 오염시키는 weird_function()입니다. 머신러닝을 이용해 복원할 것이기 때문에 다음 함수를 이해할 필요는 없습니다!
def weird_function(x, n_iter=5):
    h = x    
    filt = torch.tensor([-1./3, 1./3, -1./3])
    for i in range(n_iter):
        zero_tensor = torch.tensor([1.0*0])
        h_l = torch.cat( (zero_tensor, h[:-1]), 0)
        h_r = torch.cat((h[1:], zero_tensor), 0 )
        h = filt[0] * h + filt[2] * h_l + filt[1] * h_r
        if i % 2 == 0:
            h = torch.cat( (h[h.shape[0]//2:],h[:h.shape[0]//2]), 0  )
    return h

 

오차를 구하는 함수 만들기

# 무작위 텐서(random_tensor)를 weird_function()함수에 입력해 얻은 가설텐서와 오염된 이미지 사이의 오차를 구하는 함수입니다
def distance_loss(hypothesis, broken_image):    
    return torch.dist(hypothesis, broken_image) 
#torch.dist()는 두 텐서 사이의 거리를 구하는 함수입니다. 이 예제에서는 단순 거리를 오차로 설정하였습니다

 

random_tensor, lr 만들기

# broken_image와 같은 모양과 랭크를 지니는 무작위 텐서를 생성합니다.([100,100]모양의 행렬이 [10000]모양의 벡터로 표현된 텐서입니다.)
random_tensor = torch.randn(10000, dtype = torch.float)

# 학습을 얼마나 급하게 진행하는가를 결정하는 매개변수를 학습률(learning rate)라고 합니다
# 예제에서는 learning rate를 0.8로 설정하였습니다
lr = 0.8

 

for문을 활용해 경사하강법 20000번 반복하여 broken_image와  유사한 random_tensor 구하기

#경사하강법의 for 반복문입니다.
for i in range(0,20000):
    random_tensor.requires_grad_(True) #오차함수를 random_tensor로 미분하기 위해 requrires_grad는 True로 설정합니다.
    hypothesis = weird_function(random_tensor)
    loss = distance_loss(hypothesis, broken_image)
    loss.backward()
    # random_tensor를 weird_function에 통과시켜 가설(hypothesis)을 구합니다.
    # 앞서 정의한 distance_loss 함수로 hypothesis와 broken_image의 오차를 계산합니다.
    # 이후 loss.backward() 함수를 호출해 loss를 random_tensor로 미분합니다.
    with torch.no_grad(): 
        # 이번 예제에서는 직접 경사하강법을 구현하므로 torch.no_grad()를 이용해 파이토치의 자동기울기 계산을 비활성화합니다.
        random_tensor = random_tensor - lr*random_tensor.grad
        # loss.backward()에서 계산한 loss의 기울기(loss가 최댓점이 되는 곳의 방향)와 만대쪽으로 random_tensor를 학습률(lr)만큼 이동시킵니다.
    if i % 1000 == 0:
        print('Loss at {} = {}'.format(i, loss.item()))
        # for문이 1000번 반복될 때마다 오차를 출력합니다

 

결과 확인해보기

plt.imshow(random_tensor.view(100,100).data)
# 반복문 실행 결과 random_tensor가 제대로 복원되었는지 확인합니다