빅데이타 & 머신러닝/Pytorch

파이토치 3. 파이토치라이트닝

Terry Cho 2024. 8. 16. 04:51

파이토치 라이트닝

 

파이토치 라이트닝은 파이토치를 한번 더 추상화하는 프레임웍으로, 파이토치 코딩의 복잡도를 낮춰주는 프레임이다. 텐서플로우의 복잡도를 케라스로 잡아주는 느낌이라고 할까? 한번더 추상화하는 만큼, 약간의 성능저하가 발생한다.

 

파이토치와 파이토치 라이트닝 성능 비교 출처 : https://pytorch-lightning.readthedocs.io/en/1.2.10/benchmarking/benchmarks.html

 

아래와 같은 간단한 선형 회귀 모델이 있다고 하자.

import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import lightning as L
from torch.utils.data import TensorDataset

# 데이터
x_train = torch.FloatTensor([[1], [2], [3]])
y_train = torch.FloatTensor([[1], [2], [3]])
# 모델 초기화
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

이 모델을 학습을 파이토치로 학습하는 코드를 만들어 보면 다음과 같다. 

옵티마이저와, 손실 함수를 직접 적용하고, for 문을 이용해서 직접 루프를 돌리면서 손실함수를 계산하고, 그라디언트에 따라서 값을 이동 시키고, 백프로퍼게이션을 실행한다. 이 모든 것을 코드레벨에서 직접 컨트롤해야 한다. 

model = LinearRegressionModel()
# optimizer 설정
optimizer = optim.SGD(model.parameters(), lr=0.01)

nb_epochs = 1000
cost_list = []
for epoch in range(nb_epochs + 1):

    # H(x) 계산
    prediction = model(x_train)

    # cost 계산
    cost = F.mse_loss(prediction, y_train)
    cost_list.append(cost.item()) # cost 값을 리스트에 추가

    # cost로 H(x) 개선
    optimizer.zero_grad()
    cost.backward()
    optimizer.step()

 

반면 라이트닝을 사용할 경우, LightingModule을 이용하여, 옵티마이저, 데이터로더, 학습 스택들을 추상화한후에, 학습은 trainer.fit으로 간단하게 실행할 수 있다. 즉 앞의 코드와 차이점은, 보일러플레이트 (학습 스텝의 틀)을 제공함으로써 조금 더 구조화되고 간결한 트레이닝 루프를 구현할 수 있다는 장점이 있다. 

# LightningModule 정의
class LinearRegressionLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = LinearRegressionModel()
        self.loss_fn = nn.MSELoss()  # 손실 함수 정의

    def forward(self, x):
        return self.model(x)

    def train_dataloader(self):
        dataset = TensorDataset(x_train, y_train)
        return DataLoader(dataset, batch_size=2)  # 배치 크기 설정

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch #trainin data & label
        y_hat = self(x) #predictied 데이터
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss)  # 로깅
        # loss 값 출력 추가
        print(f"Batch {batch_idx}, Loss: {loss.item()}")
        return loss

# Trainer 설정 및 학습 시작
trainer = L.Trainer(max_epochs=50)  # 에폭 설정
model = LinearRegressionLightningModule()
trainer.fit(model)

 

개발자 커뮤니티를 통해서 파이토치 개발 이코 시스템을 보면 많은 개발자들이 파이토치를 그대로 쓰는것 보다, 파이토치 라이트닝을 많이 사용하고 있다.