Perceptual Lossとは?VAEのぼやけを改善する方法をCIFAR-10で分かりやすく解説

カテゴリ: AI・機械学習, Python実装

画像系タスクでVAE(Variational Autoencoder、変分オートエンコーダー)を使うと、出力画像がぼやけやすいという問題があります。これは、再構成誤差として画素ごとの差をそのまま最小化すると、複数のあり得る細部を平均したような出力になりやすいためです。特に、形状や背景のバリエーションが大きい自然画像では、この傾向が目立ちます。こうした問題を改善する方法の1つが、Perceptual Lossの導入です。Perceptual Lossは、画素値そのものではなく、事前学習済みの画像分類モデルが捉えた特徴の差を使って画像を比較する考え方です。実際に、Perceptual Lossは画素ベースの損失よりも細部を保ちやすく、VAEへ応用した研究でも見た目の自然さや知覚品質の改善が報告されています。

本記事では、Perceptual Lossの考え方を説明し、それをCIFAR-10を用いたVAEの学習に応用する流れを整理します。VAEの学習は分かってきたものの、自然画像で再構成がぼやけると感じている方にとって、次の一歩として理解しやすい内容を目指します。

Perceptual Lossとは

VAEで自然画像を扱うと、複数のあり得る細部を平均したような出力になりやすく、輪郭が甘くなったり、質感が失われたりすることがあります。そこで重要になるのが、「画素値がどれだけ一致しているか」だけでなく、「見た目としてどれだけ似ているか」を評価することです。Johnsonらは、画像変換や超解像において、画素ベースの損失ではなく、事前学習済みネットワークの特徴空間で画像同士の差を測るPerceptual Lossを用いることで、より細部を保った結果が得られることを示しました。

超解像の例。Johnson, J., Alahi, A., & Fei-Fei, L. (2016). Perceptual Losses for Real-Time Style Transfer and Super-Resolution. In European Conference on Computer Vision (ECCV 2016).

Perceptual Lossでは、一般にVGG-Netのような事前学習済みCNNを使います。元画像と生成画像の両方をVGG-Netに入力し、中間層の出力を取り出して比較します。浅い層ではエッジや局所的なテクスチャのような特徴が、より深い層ではより広い受容野に基づく高次の特徴が表現されやすいため、中間層の差を小さくすることで、単なる画素一致よりも見た目の自然さを保ちやすくなります。深層特徴は、人間の知覚に近い類似度指標として有効であることも報告されています。

PVAE: 通常のVAE、VAE-123とVAE-345はperceptual lossを用いたモデル。Hou, X., Shen, L., Sun, K., & Qiu, G. (2016). Deep Feature Consistent Variational Autoencoder. arXiv preprint arXiv:1610.00291.

Perceptual Lossを用いたVAEの誤差関数

Perceptual LossをVAEに導入する場合、基本的な考え方はシンプルです。もともとのVAEの損失関数である再構成誤差KLダイバージェンスに加えて、Perceptual Lossを足します。

式で書くと、次のような形になります。

L=Lrec+βLKL+λLperc​

ここで、

  • LrecL_{\mathrm{rec}}​:画素ベースの再構成誤差
  • LKLL_{\mathrm{KL}}​:潜在分布を整えるためのKLダイバージェンス
  • LpercL_{\mathrm{perc}}​:VGG-Netなどの中間層特徴の差
  • β,λ\beta, \lambda:各損失の重み

です。

Perceptual Lossそのものは、たとえばVGGの複数層の特徴マップを使って、元画像 xxx と再構成画像 x^\hat{x}x^ の差を測る形で書けます。

Lperc=∑l∥ϕl(x)−ϕl(x^)∥1​

ここで ϕl(⋅)\phi_l(\cdot)は、事前学習済みCNNの ll 層目の出力です。L1ノルムでもL2ノルムでも実装できますが、重要なのは、画素値ではなく特徴量を比較するという点です。Johnsonらはこの考え方を画像変換に用い、HouらはVAEの学習へ応用しています。 

実装例

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

import torchvision
import torchvision.transforms as transforms
from torchvision.models import vgg16, VGG16_Weights

import matplotlib.pyplot as plt
import numpy as np
import json


# ---------- 計算機の確認 ----------
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# ---------- データセットの準備 ----------
transform = transforms.ToTensor()

train_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)

test_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transform
)

# CIFAR-10のクラス確認
print("classes:", train_dataset.classes)
# ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# automobile のラベル番号
car_label = train_dataset.class_to_idx["horse"]
print("car_label:", car_label)

# automobile だけの index を取得
train_car_indices = [i for i, label in enumerate(train_dataset.targets) if label == car_label]
test_car_indices = [i for i, label in enumerate(test_dataset.targets) if label == car_label]

# Subset で絞り込み
train_car_dataset = Subset(train_dataset, train_car_indices)
test_car_dataset = Subset(test_dataset, test_car_indices)

train_loader = DataLoader(train_car_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_car_dataset, batch_size=64, shuffle=False)

print("学習データ数:", len(train_car_dataset))
print("テストデータ数:", len(test_car_dataset))


# ---------- データの確認 ----------
images, labels = next(iter(train_loader))

print("images shape:", images.shape)
print("labels shape:", labels.shape)
print("最初のラベル:", labels[0].item())

plt.figure(figsize=(10, 4))
for i in range(8):
    plt.subplot(2, 4, i + 1)
    plt.imshow(images[i].numpy().transpose((1, 2, 0)))
    plt.title(f"label: {labels[i].item()}")
    plt.axis("off")
plt.tight_layout()
plt.show()


# ---------- VAEモデルの作成 ----------
class CNNVAE(nn.Module):
    def __init__(self, latent_dim=16):
        super().__init__()
        self.latent_dim = latent_dim

        # ===== Encoder =====
        self.enc_conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)   # 32x32 -> 32x32
        self.enc_conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)   # 32x32 -> 32x32
        self.enc_pool1 = nn.MaxPool2d(2)                              # 32x32 -> 16x16

        self.enc_conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # 16x16 -> 16x16
        self.enc_conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)  # 16x16 -> 16x16
        self.enc_pool2 = nn.MaxPool2d(2)                              # 16x16 -> 8x8

        self.fc_mu = nn.Linear(64 * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(64 * 8 * 8, latent_dim)

        # ===== Decoder =====
        self.fc_dec = nn.Linear(latent_dim, 64 * 8 * 8)

        self.up1 = nn.Upsample(scale_factor=2, mode="nearest")        # 8x8 -> 16x16
        self.dec_conv1 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.dec_conv2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)

        self.up2 = nn.Upsample(scale_factor=2, mode="nearest")        # 16x16 -> 32x32
        self.dec_conv3 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.dec_conv4 = nn.Conv2d(32, 16, kernel_size=3, padding=1)

        self.dec_conv5 = nn.Conv2d(16, 3, kernel_size=3, padding=1)

    def encode(self, x):
        # 入力: [N, 1, 28, 28]
        x = F.relu(self.enc_conv1(x))       # [N, 32, 32, 32]
        x = F.relu(self.enc_conv2(x))       # [N, 32, 32, 32]
        x = self.enc_pool1(x)               # [N, 32, 16, 16]

        x = F.relu(self.enc_conv3(x))       # [N, 64, 16, 16]
        x = F.relu(self.enc_conv4(x))       # [N, 64, 16, 16]
        x = self.enc_pool2(x)               # [N, 64, 8, 8]

        x = torch.flatten(x, start_dim=1)   # [N, 64*8*8]
        mu = self.fc_mu(x)                  # [N, latent_dim(default=16)]
        logvar = self.fc_logvar(x)          # [N, latent_dim(default=16)]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def decode(self, z):
        # 入力: [N, latent_dim(default=16)]
        x = self.fc_dec(z)                  # [N, 64*8*8]
        x = x.view(-1, 64, 8, 8)            # [N, 64, 8, 8]

        x = self.up1(x)                     # [N, 64, 16, 16]
        x = F.relu(self.dec_conv1(x))       # [N, 32, 16, 16]
        x = F.relu(self.dec_conv2(x))       # [N, 32, 16, 16]

        x = self.up2(x)                     # [N, 32, 32, 32]
        x = F.relu(self.dec_conv3(x))       # [N, 32, 32, 32]
        x = F.relu(self.dec_conv4(x))       # [N, 32, 32, 32]

        x = torch.sigmoid(self.dec_conv5(x))
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar


# ---------- Perceptual lossの作成 ----------
class VGGPerceptualLoss(nn.Module):
    def __init__(self, resize_to_224=True):
        super().__init__()
        self.resize_to_224 = resize_to_224

        vgg = vgg16(weights=VGG16_Weights.DEFAULT).features

        # 使いやすい浅〜中層を3段
        self.blocks = nn.ModuleList([
            vgg[:4].eval(),    # relu1_2 付近
            vgg[4:9].eval(),   # relu2_2 付近
            vgg[9:16].eval(),  # relu3_3 付近
        ])

        for block in self.blocks:
            for p in block.parameters():
                p.requires_grad = False

        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def preprocess(self, x):
        if self.resize_to_224:
            x = F.interpolate(x, size=(224, 224), mode="bilinear", align_corners=False)
        x = (x - self.mean) / self.std
        return x

    def forward(self, pred, target):
        # target側へは勾配不要
        target = target.detach()

        pred = self.preprocess(pred)
        target = self.preprocess(target)

        loss = 0.0
        x = pred
        y = target

        for block in self.blocks:
            x = block(x)
            y = block(y)
            loss = loss + F.l1_loss(x, y)

        return loss


# ---------- 損失誤差(再構成誤差 + KLダイバージェンス + Perceptual) ----------
def vae_perceptual_loss(recon_x, x, mu, logvar, perceptual_fn, beta=1e-3, lambda_perc=0.1):
    # pixel-level reconstruction
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction="sum")

    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # perceptual loss
    perc_loss = perceptual_fn(recon_x, x)

    total_loss = recon_loss + beta * kl_loss + lambda_perc * perc_loss
    return total_loss, recon_loss, kl_loss, perc_loss


# ---------- 学習モデルの作成 ----------
latent_dim = 64

model = CNNVAE(latent_dim=latent_dim).to(device)
perceptual_fn = VGGPerceptualLoss(resize_to_224=False).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

print(model)


# ---------- 学習 ----------
train_losses = []
train_recon_losses = []
train_kl_losses = []

num_epochs = 100
beta = 1.0
lambda_perc = 10000

for epoch in range(num_epochs):
    model.train()

    train_loss = 0.0
    train_recon = 0.0
    train_kl = 0.0

    for images, _ in train_loader:
        images = images.to(device)

        optimizer.zero_grad()

        recon, mu, logvar = model(images)
        
        loss, recon_loss, kl_loss, perc_loss = vae_perceptual_loss(
            recon, images, mu, logvar,
            perceptual_fn=perceptual_fn,
            beta=beta,
            lambda_perc=lambda_perc
        )

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_recon += recon_loss.item()
        train_kl += kl_loss.item()

    avg_loss = train_loss / len(train_dataset)
    avg_recon = train_recon / len(train_dataset)
    avg_kl = train_kl / len(train_dataset)

    train_losses.append(avg_loss)
    train_recon_losses.append(avg_recon)
    train_kl_losses.append(avg_kl)

    print(
        f"Epoch [{epoch+1}/{num_epochs}] "
        f"Loss: {avg_loss:.4f} | Recon: {avg_recon:.4f} | KL: {avg_kl:.4f}"
    )


# ---------- 潜在ベクトルから画像生成 ----------
model.eval()

images, _ = next(iter(test_loader))
images = images[:8].to(device)

with torch.no_grad():
    recon, mu, logvar = model(images)

images = images.cpu()
recon = recon.cpu()

plt.figure(figsize=(12, 4))
for i in range(8):
    # 元画像
    plt.subplot(2, 8, i + 1)
    plt.imshow(images[i].numpy().transpose((1, 2, 0)))
    plt.title("Original")
    plt.axis("off")

    # 再構成画像
    plt.subplot(2, 8, 8 + i + 1)
    plt.imshow(recon[i].numpy().transpose((1, 2, 0)))
    plt.title("Recon")
    plt.axis("off")

plt.tight_layout()
plt.show()


# ---------- 画像の連続性の確認 ----------
mu1, logvar1 = model.encode(images[0:1].to('cuda'))
mu2, logvar2 = model.encode(images[1:2].to('cuda'))

z1 = model.reparameterize(mu1, logvar1)
z2 = model.reparameterize(mu2, logvar2)

N = 6
plt.figure(figsize=(15, 4))
for i in range(N):
    z = z1*(i/(N-1)) + z2*(1-i/(N-1))
    y = model.decode(z)
    y = y.cpu().detach()[0]

    plt.subplot(1,N,i+1)
    plt.imshow(y.numpy().transpose((1, 2, 0)))
    plt.axis("off")


# ---------- ランダムな潜在ベクトルからの画像生成 ----------
N = 6
plt.figure(figsize=(12, 4))
for i in range(N):
    z = torch.randn(1,64).cuda()
    y = model.decode(z)
    y = y.cpu().detach()[0]

    plt.subplot(2,N,i+1)
    plt.imshow(y.numpy().transpose((1, 2, 0)))
    plt.axis("off")

for i in range(N):
    z = torch.randn(1,64).cuda()
    y = model.decode(z)
    y = y.cpu().detach()[0]

    plt.subplot(2,N,i+7)
    plt.imshow(y.numpy().transpose((1, 2, 0)))
    plt.axis("off")

まず、通常どおりEncoderとDecoderを持つVAEを用意します。次に、事前学習済みのVGG16またはVGG19を読み込み、こちらは学習させずに固定します。そのうえで、元画像とVAEの再構成画像の両方をVGGに通し、選んだ中間層の特徴マップを取り出します。最後に、その特徴マップ同士の差をPerceptual Lossとして計算し、VAEの損失へ加えます。

実装上のポイントは、次の3つです。

  • VGG側の重みは更新しない
  • 入力画像はVGGが想定する形式に合わせて正規化する
  • Perceptual Lossの重み λ\lambda を大きくしすぎない

λ\lambdaλ が大きすぎると、再構成画像が「特徴は似ているが色や全体バランスが崩れる」ことがあります。逆に小さすぎると、通常のVAEとの差がほとんど出ません。そのため、再構成誤差、KLダイバージェンス、Perceptual Lossの重みのバランス調整が重要になります。 

結果

はじめに、CIFAR-10に含まれる10クラスすべての画像をまとめて学習した結果を示します。Perceptual Lossなしのパターンと比較すると、輪郭や形状のまとまりが改善し、ぼやけが明確に軽減されたことが分かります。これは、画素の一致だけでなく、特徴空間での一致も同時に学習した効果と考えられます。

上: Perceptual lossなし。下: Perceptual lossあり

次に、潜在空間の連続性を確かめるために、2画像間の潜在ベクトルを補間した結果と、ランダムな潜在ベクトルから画像を生成した結果を確認します。2画像間の補間では、画像がなめらかに変化しており、潜在空間の連続性は保たれていることが分かりました。一方で、ランダムな潜在ベクトルから生成した画像は、形状はある程度出ているものの、まだ不明瞭なものが残りました。

CIFAR-10の全クラスを用いた実験。

上: 潜在空間の連続性の確認。下: ランダムな潜在画像からの生成

この理由としては、まず潜在ベクトルが64次元と比較的小さく、モデル自体もシンプルであることが考えられます。CIFAR-10は自然画像であり、同じクラスの中でも見た目のばらつきが大きいため、限られた表現力で全体をカバーするのは簡単ではありません。また、飛行機やカエルのように、クラス間の見た目の差が大きいデータを1つの連続潜在空間でまとめて扱うため、中間的な潜在ベクトルを取ったときに、やや曖昧な画像になりやすいと考えられます。

そこで、「馬」や「車」など、クラスを絞ったドメイン内でVAEを学習させた結果も確認しました。この場合は、画像がより鮮明に再構成され、ランダムな潜在ベクトルから生成した画像でも、形状をある程度識別できる結果が得られました。これは、学習対象のばらつきが小さくなり、潜在空間が表現すべき変動が限定されたためだと考えられます。

CIFAR-10の馬と車での実験。

上: 潜在空間の連続性の確認。下: ランダムな潜在画像からの生成

まとめ

本記事では、VAEで自然画像を扱うときに起こりやすいぼやけの問題に対して、Perceptual Lossを導入する考え方を説明しました。

Perceptual Lossは、元画像と生成画像を事前学習済みCNNに通し、その中間層の特徴の差を小さくするように学習する方法です。これにより、単なる画素一致では捉えにくいエッジや形状の情報を保ちやすくなり、VAEの再構成画像の見た目を改善しやすくなります。Perceptual Lossは画像変換や超解像で有効性が示されており、VAEへ応用した研究でも、より自然な見た目と高い知覚品質が報告されています。

今回の結果からも、CIFAR-10全体のようにばらつきの大きい自然画像では依然として難しさが残る一方で、Perceptual Lossを導入することで、通常のVAEよりもぼやけを改善しやすいことが分かります。特に、クラスを絞った学習では改善がより分かりやすく、研究用途でまず試しやすい拡張として有力です。

VAEのぼやけに悩んでいる場合は、まずPerceptual Lossを追加することを検討するとよいでしょう。実装の難易度は比較的低く、それでいて見た目の改善効果が得られやすい方法です。