VAEとは?MNISTとCIFAR-10で学ぶ画像生成の基礎

カテゴリ: AI・機械学習, Python実装

VAE(変分オートエンコーダー、Variational Autoencoder)は、入力データを圧縮して特徴を抽出し、その情報をもとにデータを再構成する機械学習モデルです。次元圧縮、画像生成、異常検知、データ補完など、さまざまな用途で使われています。

本記事では、手書き文字データセットであるMNISTと、車や馬などの自然画像を含むCIFAR-10を用いて、VAEによる画像生成の基本を分かりやすく説明します。また、なぜCIFAR-10では画像がぼやけやすいのか、そしてその改善方法についても解説します。

1. はじめに通常のAE(Autoencoder、オートエンコーダー)の説明

VAEを理解するには、まず通常のAE(Autoencoder)を知っておくと分かりやすくなります。

AEは、入力データをいったん圧縮し、そこから元のデータを再構成するモデルです。構造は大きく次の2つに分かれます。

  • Encoder:入力データを圧縮して潜在変数に変換する
  • Decoder:潜在変数から元のデータを再構成する

たとえば画像を入力すると、Encoderが画像の特徴を少数の値にまとめ、Decoderがその値から元画像に近い画像を再現します。

この仕組みによって、AEは画像の特徴をうまく圧縮できます。しかし、通常のAEには弱点があります。潜在空間が単に「圧縮された特徴の置き場」になりやすく、潜在変数を少し変えたときに、意味のある変化が起きるとは限らないのです。つまり、画像生成に使いやすい滑らかな潜在空間にならないことがあります。

この弱点を改善したのがVAEです。

2. VAE(変分オートエンコーダー、Variational Autoencoder)の説明

VAEは、AEを発展させたモデルで、潜在空間に連続性を持たせるように設計されています。通常のAEでは、Encoderは入力画像をそのまま1つの潜在ベクトルに変換しますが、VAEでは少し考え方が異なります。

VAEのEncoderは、潜在変数を1点として出力するのではなく、平均 μ(ミュー)と分散 σ²(または標準偏差)を出力し、その周辺の分布として潜在空間を表現します。つまり、「この画像は潜在空間のこのあたりにありそうだ」と確率的に表すわけです。

このようにすると、潜在空間が滑らかになり、潜在変数を少しずつ変化させたときにも自然に画像が変わりやすくなります。これが、VAEが画像生成に向いている理由の1つです。

Reparameterization Trickとは何か

ただし、ここで1つ問題があります。VAEでは分布から乱数をサンプリングする必要がありますが、そのままでは誤差逆伝播による学習がしにくくなります。

そこで使われるのがReparameterization Trickです。

VAEでは、潜在変数 z を直接ランダムにサンプリングする代わりに、次のように表します。

z=μ+σϵ

ここで、ϵは標準正規分布 N(0,1)からサンプリングした乱数です。こうすることで、ランダム性を ϵに分離しつつ、μとσに対して勾配を流せるようになります。これにより、VAEを通常のニューラルネットワークと同じように学習できます。

また、VAEの損失関数は大きく2つの項からなります。

  • 再構成誤差:入力画像をどれだけうまく再現できたか
  • KLダイバージェンス:潜在分布が標準正規分布に近くなるようにする項

この2つのバランスによって、VAEは「再構成のうまさ」と「潜在空間の扱いやすさ」を両立しようとします。

3. VAEをMNISTで学習する

まず、VAEをMNISTで学習させることを考えます。MNISTは0から9までの手書き数字画像からなるデータセットで、白黒かつ構造が比較的単純です。そのため、VAEの入門に非常に向いています。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np


# ---------- 計算機の確認 ----------
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# ---------- データセットの準備 ----------
transform = transforms.ToTensor()

train_dataset = torchvision.datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

test_dataset = torchvision.datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print("学習データ数:", len(train_dataset))
print("テストデータ数:", len(test_dataset))


# ---------- データの確認 ----------
images, labels = next(iter(train_loader))

print("images shape:", images.shape)
print("labels shape:", labels.shape)
print("最初のラベル:", labels[0].item())

plt.figure(figsize=(10, 4))
for i in range(8):
    plt.subplot(2, 4, i + 1)
    plt.imshow(images[i].squeeze(), cmap="gray")
    plt.title(f"label: {labels[i].item()}")
    plt.axis("off")
plt.tight_layout()
plt.show()


# ---------- VAEモデルの作成 ----------
class CNNVAE(nn.Module):
    def __init__(self, latent_dim=16):
        super().__init__()
        self.latent_dim = latent_dim

        # ===== Encoder =====
        self.enc_conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)   # 28x28 -> 28x28
        self.enc_pool1 = nn.MaxPool2d(2)                              # 28x28 -> 14x14

        self.enc_conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # 14x14 -> 14x14
        self.enc_pool2 = nn.MaxPool2d(2)                              # 14x14 -> 7x7

        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)

        # ===== Decoder =====
        self.fc_dec = nn.Linear(latent_dim, 64 * 7 * 7)

        self.up1 = nn.Upsample(scale_factor=2, mode="nearest")        # 7x7 -> 14x14
        self.dec_conv1 = nn.Conv2d(64, 32, kernel_size=3, padding=1)

        self.up2 = nn.Upsample(scale_factor=2, mode="nearest")        # 14x14 -> 28x28
        self.dec_conv2 = nn.Conv2d(32, 16, kernel_size=3, padding=1)

        self.dec_conv3 = nn.Conv2d(16, 1, kernel_size=3, padding=1)

    def encode(self, x):
        # 入力: [N, 1, 28, 28]
        x = F.relu(self.enc_conv1(x))       # [N, 32, 28, 28]
        x = self.enc_pool1(x)               # [N, 32, 14, 14]

        x = F.relu(self.enc_conv2(x))       # [N, 64, 14, 14]
        x = self.enc_pool2(x)               # [N, 64, 7, 7]

        x = torch.flatten(x, start_dim=1)   # [N, 64*7*7]
        mu = self.fc_mu(x)                  # [N, latent_dim(default=16)]
        logvar = self.fc_logvar(x)          # [N, latent_dim(default=16)]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def decode(self, z):
        # 入力: [N, latent_dim(default=16)]
        x = self.fc_dec(z)                  # [N, 64*7*7]
        x = x.view(-1, 64, 7, 7)            # [N, 64, 7, 7]

        x = self.up1(x)                     # [N, 64, 14, 14]
        x = F.relu(self.dec_conv1(x))       # [N, 32, 14, 14]

        x = self.up2(x)                     # [N, 32, 28, 28]
        x = F.relu(self.dec_conv2(x))       # [N, 1, 28, 28]

        x = torch.sigmoid(self.dec_conv3(x))
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

# ---------- 学習モデルの作成 ----------
latent_dim = 8 # 潜在次元
model = CNNVAE(latent_dim=latent_dim).to(device)
print(model)


# ---------- 損失誤差(再構成誤差 + KLダイバージェンス) ----------
def vae_loss_function(recon_x, x, mu, logvar):
    # 再構成誤差
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction="sum")

    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    total_loss = recon_loss + kl_loss
    return total_loss, recon_loss, kl_loss


# ---------- オプティマイザー ----------
optimizer = optim.Adam(model.parameters(), lr=0.001)


# ---------- 学習 ----------
num_epochs = 10

for epoch in range(num_epochs):
    model.train()

    train_loss = 0.0
    train_recon = 0.0
    train_kl = 0.0

    for images, _ in train_loader:
        images = images.to(device)

        optimizer.zero_grad()

        recon, mu, logvar = model(images)
        loss, recon_loss, kl_loss = vae_loss_function(recon, images, mu, logvar)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_recon += recon_loss.item()
        train_kl += kl_loss.item()

    avg_loss = train_loss / len(train_dataset)
    avg_recon = train_recon / len(train_dataset)
    avg_kl = train_kl / len(train_dataset)

    print(
        f"Epoch [{epoch+1}/{num_epochs}] "
        f"Loss: {avg_loss:.4f} | Recon: {avg_recon:.4f} | KL: {avg_kl:.4f}"
    )


# ---------- 画像の再構成テスト ----------
model.eval()

images, _ = next(iter(test_loader))
images = images[:8].to(device)

with torch.no_grad():
    recon, mu, logvar = model(images)

images = images.cpu()
recon = recon.cpu()

plt.figure(figsize=(12, 4))
for i in range(8):
    # 元画像
    plt.subplot(2, 8, i + 1)
    plt.imshow(images[i].squeeze(), cmap="gray")
    plt.title("Original")
    plt.axis("off")

    # 再構成画像
    plt.subplot(2, 8, 8 + i + 1)
    plt.imshow(recon[i].squeeze(), cmap="gray")
    plt.title("Recon")
    plt.axis("off")

plt.tight_layout()
plt.show()


# ---------- 潜在ベクトルから画像生成 ----------
model.eval()

with torch.no_grad():
    z = torch.randn(16, latent_dim).to(device)
    samples = model.decode(z).cpu()

plt.figure(figsize=(8, 8))
for i in range(16):
    plt.subplot(4, 4, i + 1)
    plt.imshow(samples[i].squeeze(), cmap="gray")
    plt.axis("off")
plt.tight_layout()
plt.show()

MNISTでVAEを学習すると、潜在空間の次元が2次元でも、ある程度は画像を再構成できることが分かります。さらに4次元、8次元と潜在次元を増やしていくと、数字の輪郭や細部もより安定して再構成されるようになります。

上: 元画像、下: 再構成画像

これは、MNISTが比較的単純なデータセットだからです。数字の種類は10パターンで、背景もほぼ一定です。そのため、潜在空間に必要な情報量が比較的少なく、VAEでも特徴を捉えやすいのです。

4. VAEをCIFAR-10で学習する

次に、VAEをCIFAR-10で学習させます。CIFAR-10は、自動車、馬、鳥、猫などを含む自然画像データセットです。MNISTと同じく10クラスですが、画像の難しさは大きく異なります。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np


# ---------- 計算機の確認 ----------
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# ---------- データセットの準備 ----------
transform = transforms.ToTensor()

train_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

test_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print("学習データ数:", len(train_dataset))
print("テストデータ数:", len(test_dataset))


# ---------- データの確認 ----------
images, labels = next(iter(train_loader))

print("images shape:", images.shape)
print("labels shape:", labels.shape)
print("最初のラベル:", labels[0].item())

plt.figure(figsize=(10, 4))
for i in range(8):
    plt.subplot(2, 4, i + 1)
    plt.imshow(images[i].numpy().transpose((1, 2, 0)))
    plt.title(f"label: {labels[i].item()}")
    plt.axis("off")
plt.tight_layout()
plt.show()


# ---------- VAEモデルの作成 ----------
class CNNVAE(nn.Module):
    def __init__(self, latent_dim=16):
        super().__init__()
        self.latent_dim = latent_dim

        # ===== Encoder =====
        self.enc_conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)   # 32x32 -> 32x32
        self.enc_pool1 = nn.MaxPool2d(2)                              # 32x32 -> 16x16

        self.enc_conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # 16x16 -> 16x16
        self.enc_pool2 = nn.MaxPool2d(2)                              # 16x16 -> 8x8

        self.fc_mu = nn.Linear(64 * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(64 * 8 * 8, latent_dim)

        # ===== Decoder =====
        self.fc_dec = nn.Linear(latent_dim, 64 * 8 * 8)

        self.up1 = nn.Upsample(scale_factor=2, mode="nearest")        # 8x8 -> 16x16
        self.dec_conv1 = nn.Conv2d(64, 32, kernel_size=3, padding=1)

        self.up2 = nn.Upsample(scale_factor=2, mode="nearest")        # 16x16 -> 32x32
        self.dec_conv2 = nn.Conv2d(32, 16, kernel_size=3, padding=1)

        self.dec_conv3 = nn.Conv2d(16, 3, kernel_size=3, padding=1)

    def encode(self, x):
        # 入力: [N, 1, 28, 28]
        x = F.relu(self.enc_conv1(x))       # [N, 32, 32, 32]
        x = self.enc_pool1(x)               # [N, 32, 16, 16]

        x = F.relu(self.enc_conv2(x))       # [N, 64, 16, 16]
        x = self.enc_pool2(x)               # [N, 64, 8, 8]

        x = torch.flatten(x, start_dim=1)   # [N, 64*8*8]
        mu = self.fc_mu(x)                  # [N, latent_dim(default=16)]
        logvar = self.fc_logvar(x)          # [N, latent_dim(default=16)]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

    def decode(self, z):
        # 入力: [N, latent_dim(default=16)]
        x = self.fc_dec(z)                  # [N, 64*8*8]
        x = x.view(-1, 64, 8, 8)            # [N, 64, 8, 8]

        x = self.up1(x)                     # [N, 64, 16, 16]
        x = F.relu(self.dec_conv1(x))       # [N, 32, 16, 16]

        x = self.up2(x)                     # [N, 32, 32, 32]
        x = F.relu(self.dec_conv2(x))       # [N, 3, 32, 32]

        x = torch.sigmoid(self.dec_conv3(x))
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

# ---------- 学習モデルの作成 ----------
latent_dim = 64
model = CNNVAE(latent_dim=latent_dim).to(device)
print(model)


# ---------- 損失誤差(再構成誤差 + KLダイバージェンス)----------
def vae_loss_function(recon_x, x, mu, logvar):
    # 再構成誤差
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction="sum")

    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    total_loss = recon_loss + kl_loss
    return total_loss, recon_loss, kl_loss


# ---------- オプティマイザー ----------
optimizer = optim.Adam(model.parameters(), lr=0.001)


# ---------- 学習 ----------
train_losses = []
train_recon_losses = []
train_kl_losses = []

num_epochs = 100

for epoch in range(num_epochs):
    model.train()

    train_loss = 0.0
    train_recon = 0.0
    train_kl = 0.0

    for images, _ in train_loader:
        images = images.to(device)

        optimizer.zero_grad()

        recon, mu, logvar = model(images)
        loss, recon_loss, kl_loss = vae_loss_function(recon, images, mu, logvar)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_recon += recon_loss.item()
        train_kl += kl_loss.item()

    avg_loss = train_loss / len(train_dataset)
    avg_recon = train_recon / len(train_dataset)
    avg_kl = train_kl / len(train_dataset)

    train_losses.append(avg_loss)
    train_recon_losses.append(avg_recon)
    train_kl_losses.append(avg_kl)

    print(
        f"Epoch [{epoch+1}/{num_epochs}] "
        f"Loss: {avg_loss:.4f} | Recon: {avg_recon:.4f} | KL: {avg_kl:.4f}"
    )


# ---------- 画像の再構成テスト ----------
model.eval()

images, _ = next(iter(test_loader))
images = images[:8].to(device)

with torch.no_grad():
    recon, mu, logvar = model(images)

images = images.cpu()
recon = recon.cpu()

plt.figure(figsize=(12, 4))
for i in range(8):
    # 元画像
    plt.subplot(2, 8, i + 1)
    plt.imshow(images[i].numpy().transpose((1, 2, 0)))
    plt.title("Original")
    plt.axis("off")

    # 再構成画像
    plt.subplot(2, 8, 8 + i + 1)
    plt.imshow(recon[i].numpy().transpose((1, 2, 0)))
    plt.title("Recon")
    plt.axis("off")

plt.tight_layout()
plt.show()

CIFAR-10でVAEを学習すると、潜在空間の次元を増やすにつれて、たしかに画像の再構成精度は向上します。しかし、ある程度以上は改善が頭打ちになり、画像は全体的にぼやけやすいという結果になりがちです。

上: 元画像、下: 再構成画像。

潜在次元=64、エンコーダーCNN層深さ=2、デコーダーCNN層深さ=2のベストなもの

たとえば潜在次元を増やしたり、層を深くしたりすると多少の改善は見られますが、MNISTのようにくっきりした画像を再現するのは簡単ではありません。

最適化結果。潜在次元が64以上は大きく改善しない。層の深さも大幅に寄与しない。

5. なぜCIFAR-10ではぼやけた画像になるのか

その理由は、MNISTとCIFAR-10ではタスクの難しさが大きく異なるためです。

  • MNIST:白黒で単純な手書き数字を再構成すればよい
  • CIFAR-10:物体の形、色、背景、視点が大きく異なる自然画像を再構成する必要がある

MNISTでは、画像のパターンがある程度限定されています。一方でCIFAR-10では、同じクラスでも見た目が大きく異なり、背景や構図もさまざまです。つまり、自然画像は情報量が非常に多いのです。

そのため、CIFAR-10のような複雑な画像を限られた次元の潜在変数に圧縮してから再構成しようとすると、どうしても細かな情報が落ちやすくなります。

また、VAEでは再構成誤差として画素ごとの差を用いることが多く、この場合、複数のあり得る細部を平均したような出力になりやすいという特徴があります。これも、画像がぼやける大きな理由です。

さらに、VAEでは潜在空間を滑らかで扱いやすい形に保つため、KLダイバージェンスによる制約も加えています。この制約は生成には有利ですが、再構成画像の細かさだけを見ると不利に働くことがあります。つまり、生成しやすさと鮮明さの間にはトレードオフがあるのです。

6. VAEのぼやけをどのように克服するか?

VAEのぼやけを改善する方法はいくつかあります。

  • β-VAEやKL項の重み調整:KLダイバージェンスの重みを調整し、潜在空間の性質と再構成のしやすさのバランスを取る
  • Perceptual Lossの導入:画素ごとの差だけでなく、画像の特徴の差も見ることで、見た目を自然にする
  • VQ-VAE:連続的な潜在空間ではなく、離散的な潜在表現を使って、よりはっきりした表現を学習しやすくする
  • Diffusionモデル:ノイズ除去を繰り返して高品質な画像を生成する手法で、VAEとは別系統の代表的な生成モデル

7. 研究者が使う画像系タスクでは何がおすすめか

研究用途でVAEを使う場合、特に化学や材料系の研究では、潜在変数を連続ベクトルとして扱いやすいことが大きな利点です。潜在空間の解析や、潜在変数を使った最適化を考えるなら、この性質は非常に便利です。

そのため、研究者がまず試しやすい方法としては、次の2つがおすすめです。

  • KL係数の調整:実装が簡単で、再構成とのバランスを取りやすい
  • Perceptual Lossの導入:見た目の自然さを改善しやすい

一方で、VQ-VAEやDiffusionモデルは高品質な画像生成には有力ですが、モデルの構造や運用が複雑になりやすく、潜在空間の解釈や操作を重視する研究では扱いにくい場面もあります。 

例 Perceptual lossを導入したCNN-VAE。明確に鮮明に再構成されている。

8. まとめ

本記事では、VAEをMNISTとCIFAR-10で学習させたときの違いを通して、VAEの基本と限界を説明しました。

MNISTでは画像の構造が比較的単純であるため、VAEでも少ない潜在次元でうまく再構成できます。一方、CIFAR-10のような自然画像では、形状や背景の多様性が大きく、限られた潜在空間に情報を押し込めるのが難しいため、画像がぼやけやすくなります。

これは単なる実装の問題ではなく、VAEの「圧縮しながら滑らかな潜在空間を作る」という性質そのものに由来する課題です。

その改善方法としては、KL係数の調整、Perceptual Lossの導入、VQ-VAE、Diffusionモデルなどがあります。特に研究用途では、実装のしやすさと潜在空間の扱いやすさのバランスから、まずはKL係数の調整とPerceptual Lossを試すのが現実的です。

VAEはシンプルで理解しやすい一方で、生成モデルとしての重要な考え方が多く詰まっています。画像生成AIを学ぶ最初の一歩として、今でも非常に有用なモデルです。