메인 콘텐츠로 건너뛰기
PyTorch Lightning은 PyTorch 코드를 체계적으로 구성하고 분산 트레이닝이나 16비트 정밀도 같은 고급 기능을 쉽게 추가할 수 있도록 해주는 가벼운 래퍼를 제공합니다. W&B는 ML 실험을 로깅하기 위한 가벼운 래퍼를 제공합니다. 하지만 이 둘을 직접 함께 구성할 필요는 없습니다. W&B는 WandbLogger를 통해 PyTorch Lightning 라이브러리에 직접 통합되어 있습니다.

Lightning와 통합하기

from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer

wandb_logger = WandbLogger(log_model="all")
trainer = Trainer(logger=wandb_logger)

wandb.log() 사용: WandbLogger는 Trainer의 global_step을 사용해 W&B에 로깅합니다. 코드에서 wandb.log()를 직접 추가로 호출하는 경우, wandb.log()에서 step 인수를 사용하지 마세요.대신 다른 메트릭과 마찬가지로 Trainer의 global_step을 로깅하세요:
wandb.log({"accuracy":0.99, "trainer/global_step": step})
대화형 대시보드

가입하고 API 키 생성하기

API 키는 머신을 W&B에 인증하는 데 사용됩니다. 사용자 프로필에서 API 키를 생성할 수 있습니다.
더 간편하게 하려면 User Settings로 바로 이동해 API 키를 생성하세요. 새로 생성한 API 키는 즉시 복사해 비밀번호 관리자와 같은 안전한 위치에 저장하세요.
  1. 오른쪽 상단에서 사용자 프로필 아이콘을 클릭합니다.
  2. User Settings를 선택한 다음 API Keys 섹션으로 스크롤합니다.

wandb 라이브러리 설치 및 로그인

로컬 환경에 wandb 라이브러리를 설치하고 로그인하려면 다음 단계를 따르세요.
  1. WANDB_API_KEY 환경 변수에 API 키를 설정합니다.
    export WANDB_API_KEY=<your_api_key>
    
  2. wandb 라이브러리를 설치하고 로그인합니다.
    pip install wandb
    
    wandb login
    

PyTorch Lightning의 WandbLogger 사용

PyTorch Lightning에는 메트릭, 모델 가중치, 미디어 등을 로깅할 수 있는 WandbLogger 클래스가 여러 개 있습니다. Lightning와 통합하려면 WandbLogger를 인스턴스화한 후 Lightning의 Trainer 또는 Fabric에 전달하세요.
trainer = Trainer(logger=wandb_logger)

일반적인 로거 인수

다음은 WandbLogger에서 가장 자주 사용되는 매개변수입니다. 모든 로거 인수에 대한 자세한 내용은 PyTorch Lightning 문서를 참고하세요.
ParameterDescription
project로깅할 wandb 프로젝트를 지정합니다
namewandb run의 이름을 지정합니다
log_modellog_model="all"이면 모든 모델을 로깅하고, log_model=True이면 트레이닝이 끝날 때 로깅합니다
save_dir데이터가 저장되는 경로

하이퍼파라미터를 로깅하세요

class LitModule(LightningModule):
    def __init__(self, *args, **kwarg):
        self.save_hyperparameters()

추가 설정 매개변수 로깅

# 파라미터 하나 추가
wandb_logger.experiment.config["key"] = value

# 여러 파라미터 추가
wandb_logger.experiment.config.update({key1: val1, key2: val2})

# wandb 모듈 직접 사용
wandb.config["key"] = value
wandb.config.update()

그라디언트, 파라미터 히스토그램 및 모델 토폴로지 로깅

트레이닝하는 동안 모델의 그라디언트와 파라미터를 모니터링하려면 모델 객체를 wandblogger.watch()에 전달할 수 있습니다. PyTorch Lightning WandbLogger 문서를 참조하세요

메트릭 로깅

WandbLogger를 사용할 때는 LightningModule 내에서(예: training_step 또는 validation_step 메서드) self.log('my_metric_name', metric_vale)를 호출해 메트릭을 W&B에 로깅할 수 있습니다.아래 code snippet은 메트릭과 LightningModule 하이퍼파라미터를 로깅하도록 LightningModule을 정의하는 방법을 보여줍니다. 이 예제에서는 메트릭을 계산하기 위해 torchmetrics 라이브러리를 사용합니다.
import torch
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from lightning.pytorch import LightningModule


class My_LitModule(LightningModule):
    def __init__(self, n_classes=10, n_layer_1=128, n_layer_2=256, lr=1e-3):
        """모델 파라미터를 정의하는 데 사용하는 메서드"""
        super().__init__()

        # mnist 이미지는 (1, 28, 28) (채널, 너비, 높이)입니다
        self.layer_1 = Linear(28 * 28, n_layer_1)
        self.layer_2 = Linear(n_layer_1, n_layer_2)
        self.layer_3 = Linear(n_layer_2, n_classes)

        self.loss = CrossEntropyLoss()
        self.lr = lr

        # 하이퍼파라미터를 self.hparams에 저장합니다(W&B에서 자동 로깅됨)
        self.save_hyperparameters()

    def forward(self, x):
        """추론 시 입력에서 출력까지 처리하는 데 사용하는 메서드"""

        # (b, 1, 28, 28) -> (b, 1*28*28)
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        # 3 x (linear + relu)를 수행합니다
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        x = self.layer_3(x)
        return x

    def training_step(self, batch, batch_idx):
        """단일 배치의 loss를 반환해야 합니다"""
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # loss와 메트릭 로깅
        self.log("train_loss", loss)
        self.log("train_accuracy", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        """메트릭 로깅에 사용됩니다"""
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # loss와 메트릭 로깅
        self.log("val_loss", loss)
        self.log("val_accuracy", acc)
        return preds

    def configure_optimizers(self):
        """모델 옵티마이저를 정의합니다"""
        return Adam(self.parameters(), lr=self.lr)

    def _get_preds_loss_accuracy(self, batch):
        """train/valid/test 단계가 유사하므로 사용하는 편의 함수입니다"""
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)
        acc = accuracy(preds, y)
        return preds, loss, acc

메트릭의 최소/최대값 로깅

wandb의 define_metric 함수를 사용하면 해당 메트릭에 대해 W&B summary 메트릭에 최소값, 최대값, 평균값 또는 최적값 중 무엇을 표시할지 지정할 수 있습니다. define_metric _를 사용하지 않으면 마지막으로 로깅된 값이 summary 메트릭에 표시됩니다. 자세한 내용은 define_metric 레퍼런스 문서가이드를 참조하세요. W&B summary 메트릭에서 최대 검증 정확도를 추적하도록 하려면, 트레이닝 시작 시 wandb.define_metric()을 한 번만 호출하세요:
class My_LitModule(LightningModule):
    ...

    def validation_step(self, batch, batch_idx):
        if trainer.global_step == 0:
            wandb.define_metric("val_accuracy", summary="max")

        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # loss와 메트릭 로깅
        self.log("val_loss", loss)
        self.log("val_accuracy", acc)
        return preds

모델 체크포인트 저장

모델 체크포인트를 W&B Artifacts로 저장하려면 Lightning ModelCheckpoint 콜백을 사용하고 WandbLogger에서 log_model 인자를 설정하세요.
python trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])
W&B Artifact에서 모델 체크포인트를 쉽게 조회할 수 있도록 latestbest alias가 자동으로 설정됩니다:
# Artifacts 패널에서 레퍼런스를 확인할 수 있습니다
# "VERSION"은 버전(예: "v2") 또는 별칭("latest" 또는 "best")일 수 있습니다
checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION"
python # download checkpoint locally (if not already cached) wandb_logger.download_artifact(checkpoint_reference, artifact_type="model")
# 체크포인트 로드
model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
로깅한 모델 체크포인트는 W&B Artifacts UI에서 확인할 수 있으며, 모델의 전체 리니지를 포함합니다(UI에서 모델 체크포인트 예시는 여기를 참조하세요). 최고의 모델 체크포인트를 북마크하고 팀 전체에서 중앙 집중식으로 관리하려면 W&B Model 레지스트리에 연결할 수 있습니다. 여기에서 작업별로 최고의 모델을 정리하고, 모델 라이프사이클을 관리하고, ML 라이프사이클 전반에 걸쳐 추적 및 감사를 쉽게 수행할 수 있으며, 웹훅이나 작업으로 후속 액션을 자동화할 수 있습니다.

이미지, 텍스트 등 로깅하기

WandbLogger에는 미디어를 로깅하는 log_image, log_text, log_table 메서드가 있습니다. 또한 wandb.log() 또는 trainer.logger.experiment.log()를 직접 호출해 오디오, Molecules, Point Clouds, 3D 객체 등 다른 미디어 유형도 로깅할 수 있습니다.
# 텐서, numpy 배열 또는 PIL 이미지 사용
wandb_logger.log_image(key="samples", images=[img1, img2])

# 캡션 추가

wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])

# 파일 경로 사용

wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])

# trainer에서 .log 사용

trainer.logger.experiment.log(
{"samples": [wandb.Image(img, caption=caption) for (img, caption) in my_images]},
step=current_trainer_global_step,
)

Lightning의 Callbacks 시스템을 사용하면 WandbLogger를 통해 W&B에 로깅할 시점을 제어할 수 있습니다. 이 예제에서는 검증 이미지와 예측 샘플을 로깅합니다:
import torch
import wandb
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger

# 또는
# from wandb.integration.lightning.fabric import WandbLogger


class LogPredictionSamplesCallback(Callback):
    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
    ):
        """검증 배치가 끝날 때 호출됩니다."""

        # `outputs`는 `LightningModule.validation_step`에서 전달되며,
        # 이 경우 모델 예측값에 해당합니다

        # 첫 번째 배치에서 샘플 이미지 예측값 20개를 로깅합니다
        if batch_idx == 0:
            n = 20
            x, y = batch
            images = [img for img in x[:n]]
            captions = [
                f"Ground Truth: {y_i} - Prediction: {y_pred}"
                for y_i, y_pred in zip(y[:n], outputs[:n])
            ]

            # 옵션 1: `WandbLogger.log_image`로 이미지를 로깅합니다
            wandb_logger.log_image(key="sample_images", images=images, caption=captions)

            # 옵션 2: 이미지와 예측값을 W&B Table로 로깅합니다
            columns = ["image", "ground truth", "prediction"]
            data = [
                [wandb.Image(x_i), y_i, y_pred] or x_i,
                y_i,
                y_pred in list(zip(x[:n], y[:n], outputs[:n])),
            ]
            wandb_logger.log_table(key="sample_table", columns=columns, data=data)


trainer = pl.Trainer(callbacks=[LogPredictionSamplesCallback()])

Lightning과 W&B에서 여러 GPU 사용하기

PyTorch Lightning은 DDP 인터페이스를 통해 멀티 GPU를 지원합니다. 하지만 PyTorch Lightning의 설계상 각 GPU를 어떻게 인스턴스화하는지 주의해야 합니다. Lightning은 트레이닝 루프의 각 GPU(또는 rank)가 동일한 초기 조건으로 정확히 같은 방식으로 인스턴스화되어야 한다고 가정합니다. 하지만 wandb.run 객체에 접근할 수 있는 것은 rank 0 프로세스뿐이며, 0이 아닌 rank 프로세스에서는 wandb.run = None입니다. 이로 인해 0이 아닌 프로세스가 실패할 수 있습니다. 이런 상황에서는 이미 중단된 0이 아닌 rank 프로세스들이 join하지 못한 채 rank 0 프로세스가 이들을 기다리게 되므로 데드락이 발생할 수 있습니다. 따라서 트레이닝 코드를 어떻게 구성할지 주의해야 합니다. 권장되는 방식은 코드가 wandb.run 객체와 독립적으로 동작하도록 작성하는 것입니다.
class MNISTClassifier(pl.LightningModule):
    def __init__(self):
        super(MNISTClassifier, self).__init__()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)

        self.log("train/loss", loss)
        return {"train_loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)

        self.log("val/loss", loss)
        return {"val_loss": loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


def main():
    # 모든 랜덤 시드를 같은 값으로 설정합니다.
    # 이는 분산 트레이닝 환경에서 중요합니다.
    # 각 rank는 고유한 초기 가중치 집합을 받습니다.
    # 이 값들이 서로 일치하지 않으면 그라디언트도 일치하지 않아
    # 트레이닝이 수렴하지 않을 수 있습니다.
    pl.seed_everything(1)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)

    model = MNISTClassifier()
    wandb_logger = WandbLogger(project="<project_name>")
    callbacks = [
        ModelCheckpoint(
            dirpath="checkpoints",
            every_n_train_steps=100,
        ),
    ]
    trainer = pl.Trainer(
        max_epochs=3, gpus=2, logger=wandb_logger, strategy="ddp", callbacks=callbacks
    )
    trainer.fit(model, train_loader, val_loader)

예제

Colab 노트북이 포함된 비디오 튜토리얼을 보면서 따라 할 수 있습니다.

자주 묻는 질문

W&B는 Lightning과 어떻게 통합되나요?

핵심 인테그레이션은 Lightning loggers API를 기반으로 하며, 이를 통해 로깅 코드의 상당 부분을 프레임워크에 구애받지 않는 방식으로 작성할 수 있습니다. LoggerLightning Trainer에 전달되며, 해당 API의 풍부한 hook-and-callback system에 따라 트리거됩니다. 이렇게 하면 연구 코드와 엔지니어링 및 로깅 코드를 깔끔하게 분리할 수 있습니다.

추가 코드 없이 인테그레이션은 무엇을 로깅하나요?

모델 체크포인트를 W&B에 저장하므로, 이를 확인하거나 이후 run에서 사용하기 위해 다운로드할 수 있습니다. 또한 GPU 사용량 및 네트워크 I/O와 같은 system metrics, 하드웨어 및 OS 정보와 같은 환경 정보, code state (git commit 및 diff patch, notebook 내용과 세션 이력 포함), 그리고 표준 출력에 출력되는 모든 내용을 캡처합니다.

wandb.run을 트레이닝 설정에서 사용해야 한다면 어떻게 해야 하나요?

직접 접근해야 하는 변수의 스코프를 넓혀야 합니다. 다시 말해, 모든 프로세스에서 초기 조건이 동일하도록 해야 합니다.
if os.environ.get("LOCAL_RANK", None) is None:
    os.environ["WANDB_DIR"] = wandb.run.dir
그 경우 os.environ["WANDB_DIR"]를 사용해 모델 체크포인트 디렉터리를 설정할 수 있습니다. 이렇게 하면 rank가 0이 아닌 모든 프로세스가 wandb.run.dir에 접근할 수 있습니다.