빅데이타 & 머신러닝/Pytorch
파이토치 3. 파이토치라이트닝
Terry Cho
2024. 8. 16. 04:51
파이토치 라이트닝
파이토치 라이트닝은 파이토치를 한번 더 추상화하는 프레임웍으로, 파이토치 코딩의 복잡도를 낮춰주는 프레임이다. 텐서플로우의 복잡도를 케라스로 잡아주는 느낌이라고 할까? 한번더 추상화하는 만큼, 약간의 성능저하가 발생한다.
아래와 같은 간단한 선형 회귀 모델이 있다고 하자.
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import lightning as L
from torch.utils.data import TensorDataset
# 데이터
x_train = torch.FloatTensor([[1], [2], [3]])
y_train = torch.FloatTensor([[1], [2], [3]])
# 모델 초기화
class LinearRegressionModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
이 모델을 학습을 파이토치로 학습하는 코드를 만들어 보면 다음과 같다.
옵티마이저와, 손실 함수를 직접 적용하고, for 문을 이용해서 직접 루프를 돌리면서 손실함수를 계산하고, 그라디언트에 따라서 값을 이동 시키고, 백프로퍼게이션을 실행한다. 이 모든 것을 코드레벨에서 직접 컨트롤해야 한다.
model = LinearRegressionModel()
# optimizer 설정
optimizer = optim.SGD(model.parameters(), lr=0.01)
nb_epochs = 1000
cost_list = []
for epoch in range(nb_epochs + 1):
# H(x) 계산
prediction = model(x_train)
# cost 계산
cost = F.mse_loss(prediction, y_train)
cost_list.append(cost.item()) # cost 값을 리스트에 추가
# cost로 H(x) 개선
optimizer.zero_grad()
cost.backward()
optimizer.step()
반면 라이트닝을 사용할 경우, LightingModule을 이용하여, 옵티마이저, 데이터로더, 학습 스택들을 추상화한후에, 학습은 trainer.fit으로 간단하게 실행할 수 있다. 즉 앞의 코드와 차이점은, 보일러플레이트 (학습 스텝의 틀)을 제공함으로써 조금 더 구조화되고 간결한 트레이닝 루프를 구현할 수 있다는 장점이 있다.
# LightningModule 정의
class LinearRegressionLightningModule(L.LightningModule):
def __init__(self):
super().__init__()
self.model = LinearRegressionModel()
self.loss_fn = nn.MSELoss() # 손실 함수 정의
def forward(self, x):
return self.model(x)
def train_dataloader(self):
dataset = TensorDataset(x_train, y_train)
return DataLoader(dataset, batch_size=2) # 배치 크기 설정
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
return optimizer
def training_step(self, batch, batch_idx):
x, y = batch #trainin data & label
y_hat = self(x) #predictied 데이터
loss = self.loss_fn(y_hat, y)
self.log("train_loss", loss) # 로깅
# loss 값 출력 추가
print(f"Batch {batch_idx}, Loss: {loss.item()}")
return loss
# Trainer 설정 및 학습 시작
trainer = L.Trainer(max_epochs=50) # 에폭 설정
model = LinearRegressionLightningModule()
trainer.fit(model)
개발자 커뮤니티를 통해서 파이토치 개발 이코 시스템을 보면 많은 개발자들이 파이토치를 그대로 쓰는것 보다, 파이토치 라이트닝을 많이 사용하고 있다.