AI & GPU
روش‌های ساده برای درک آسان GAN در PyTorch برای مبتدیان

روش‌های ساده برای درک آسان GAN در PyTorch برای مبتدیان

I. مقدمه‌ای بر شبکه‌های Adversarial Generation (GAN) A. تعریف و اجزای کلیدی GAN ها

  • GAN ها یک کلاس از مدل‌های یادگیری ماشین هستند که شامل دو شبکه عصبی، یک مولد و یک تمییزکننده است، که به صورت رقابتی آموزش می‌بینند.
  • شبکه مولد مسئول تولید نمونه‌های واقع‌گرایانه (مانند تصاویر، متن، صدا) از یک فضای ورودی پنهان است.
  • شبکه تمییزکننده برای تشخیص بین نمونه‌های واقعی از مجموعه داده و نمونه‌های جعلی توسط مولد آموزش می‌بیند.
  • این دو شبکه به صورت رقابتی و آنتاگونیستی آموزش می‌بینند، به طوری که مولد سعی در فریب دادن تمیزکننده دارد و تمیزکننده سعی می‌کند نمونه‌های واقعی و جعلی را به درستی دسته‌بندی کند.

B. تاریخچه مختصر و تکامل GAN ها

  • GAN ها در سال 2014 توسط ایان گودفلو و همکاران به عنوان یک رویکرد نوین برای مدلسازی تولید معرفی شدند.
  • از زمان معرفی آنها، GAN ها پیشرفت‌های قابل توجهی کرده و در زمینه‌های گسترده‌ای مانند تولید تصاویر، تولید متن و حتی سنتز صدا مورد استفاده قرار گرفته‌اند.
  • برخی از دستاوردهای کلیدی در تکامل GAN ها شامل معرفی GAN های شرطی (cGANs)، GAN های کانوالوشنی عمیق (DCGANs)، GAN های واسرشتین (WGANs) و GAN های رشد تدریجی (PGGANs) هستند.

II. راه اندازی محیط پایتورچ A. نصب پایتورچ

  • پایتورچ یک کتابخانه محبوب یادگیری ماشین منبع باز است که یک چارچوب قابل انعطاف و کارآمد برای ساخت و آموزش مدل‌های یادگیری عمیق، از جمله GAN ها، ارائه می‌کند.
  • برای نصب پایتورچ، می توانید دستورالعمل نصب رسمی را که در وب سایت پایتورچ (https://pytorch.org/get-started/locally/ (opens in a new tab)) ارائه شده است، دنبال کنید.
  • فرآیند نصب ممکن است بسته به سیستم عامل، نسخه پایتون و نسخه CUDA (در صورت استفاده از گرافیک) شما، متفاوت باشد.

B. وارد کردن کتابخانه‌ها و ماژول‌های لازم

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

III. درک ساختار GAN A. شبکه مولد

  1. ساختار ورودی و خروجی

    • شبکه مولد یک بردار ورودی پنهان (مانند یک بردار داده تصادفی) را می‌گیرد و یک نمونه تولید شده را (مانند یک تصویر) تولید می‌کند.
    • اندازه بردار ورودی پنهان و نمونه خروجی، بسته به مساله خاص و خروجی مورد نظر مشخص می‌شود.
  2. لایه‌ها و توابع فعال‌سازی شبکه

    • شبکه مولد معمولاً شامل یک سری لایه‌های کاملاً متصل یا کانوالوشنی است، بسته به حوزه مسئله.
    • توابع فعال‌سازی مانند ReLU، Leaky ReLU یا tanh به طور معمول در شبکه مولد استفاده می‌شوند.
  3. بهینه کردن مولد

    • شبکه مولد آموزش می‌بیند تا نمونه‌هایی تولید کند که می‌توانند تمیزکننده را گول بزنند.
    • تابع خطا برای مولد برای بیشینه کردن احتمال تمیزکننده در نادرست دسته‌بندی کردن نمونه‌های تولید شده به عنوان واقعی طراحی شده است.

B. شبکه تمییزکننده

  1. ساختار ورودی و خروجی

    • شبکه تمییزکننده یک نمونه (به طور واقعی از مجموعه داده یا تولید شده توسط مولد) را می‌گیرد و احتمال واقعی بودن نمونه را خروجی می‌دهد.
    • اندازه ورودی تمیزکننده بستگی به اندازه نمونه‌ها (مانند اندازه تصویر) دارد و خروجی یک مقدار عددی بین 0 و 1 است.
  2. لایه‌ها و توابع فعال‌سازی شبکه

    • شبکه تمیزکننده معمولاً شامل یک سری لایه‌های کانوالوشنی یا کاملاً متصل است، بسته به حوزه مسئله.
    • توابع فعال‌سازی مانند Leaky ReLU یا sigmoid به طور معمول در شبکه تمیزکننده استفاده می‌شوند.
  3. بهینه‌سازی تمیزکننده

    • شبکه تمیزکننده آموزش دیده است تا نمونه‌های واقعی را به درستی به عنوان واقعی و نمونه‌های تولید شده را به عنوان جعلی دسته‌بندی کند.
    • تابع خطا برای تمیزکننده طراحی شده است تا بیشینه احتمال درست دسته‌بندی کردن نمونه‌های واقعی و جعلی را داشته باشد.

C. فرآیند آموزش رقابتی

  1. توابع خطا برای مولد و تمیزکننده

    • تابع خطا مولد برای بیشینه کردن احتمال تمیزکننده در نادرست دسته‌بندی کردن نمونه‌های تولید شده به عنوان واقعی طراحی شده است.
    • تابع خطا تمیزکننده برای بیشینه کردن احتمال درست دسته‌بندی کردن نمونه‌های واقعی و جعلی طراحی شده است.
  2. بهینه‌سازی متناوب مولد و تمیزکننده

    • فرآیند آموزش شامل تناوب بین به‌روزرسانی شبکه‌های مولد و تمیزکننده است.
    • در ابتدا، تمیزکننده آموزش می‌بیند تا قدرت خود در تشخیص نمونه‌های واقعی و جعلی را بهبود ببخشد.
    • سپس، مولد آموزش می‌بیند تا قدرت خود در تولید نمونه‌هایی که تمیزکننده را فریب بدهند، بهبود ببخشد.
    • این فرآیند آموزش رقابتی تا زمانیکه مولد و تمیزکننده به تعادل برسند، ادامه می‌یابد.

IV. پیاده‌سازی یک GAN ساده در PyTorch A. تعریف مدل‌های مولد و تمیزکننده

  1. ساخت شبکه مولد

    class Generator(nn.Module):
        def __init__(self, latent_dim, img_shape):
            super(Generator, self).__init__()
            self.latent_dim = latent_dim
            self.img_shape = img_shape
     
            self.model = nn.Sequential(
                nn.Linear(self.latent_dim, 256),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(256, 512),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(512, 1024),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(1024, np.prod(self.img_shape)),
                nn.Tanh()
            )
     
        def forward(self, z):
            img = self.model(z)
            img = img.view(img.size(0), *self.img_shape)
            return img
  2. ساخت شبکه تمیزکننده

    class Discriminator(nn.Module):
        def __init__(self, img_shape):
            super(Discriminator, self).__init__()
            self.img_shape = img_shape
     
            self.model = nn.Sequential(
                nn.Linear(np.prod(self.img_shape), 512),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(512, 256),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(256, 1),
                nn.Sigmoid()
            )
     
        def forward(self, img):
            img_flat = img.view(img.size(0), -1)
            validity = self.model(img_flat)
            return validity

B. تنظیم حلقه آموزش

  1. شروع مولد و تمیزکننده

    latent_dim = 100
    img_shape = (1, 28, 28)  # مثالی برای مجموعه داده MNIST
     
    generator = Generator(latent_dim, img_shape)
    discriminator = Discriminator(img_shape)
  2. تعریف تابع خطاها

    adversarial_loss = nn.BCELoss()
     
    def generator_loss(fake_output):
        return adversarial_loss(fake_output, torch.ones_like(fake_output))
     
    def discriminator_loss(real_output, fake_output):
        real_loss = adversarial_loss(real_output, torch.ones_like(real_output))
        fake_loss = adversarial_loss(fake_output, torch.zeros_like(fake_output))
        return (real_loss + fake_loss) / 2
  3. تناوب بهینه‌سازی مولد و تمیزکننده

    num_epochs = 200
    batch_size = 64
     
    # بهینه‌سازها
    generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
     
    for epoch in range(num_epochs):
        # آموزش تمیزکننده
        discriminator.zero_grad()
        real_samples = next(iter(dataloader))[0]
        real_output = discriminator(real_samples)
        fake_noise = torch.randn(batch_size, latent_dim)
        fake_samples = generator(fake_noise)
        fake_output = discriminator(fake_samples.detach())
        d_loss = discriminator_loss(real_output, fake_output)
        d_loss.backward()
        discriminator_optimizer.step()
     
        # آموزش مولد
        generator.zero_grad()
        fake_noise = torch.randn(batch_size, latent_dim)
        fake_samples = generator(fake_noise)
        fake_output = discriminator(fake_samples)
        g_loss = generator_loss(fake_output)
        g_loss.backward()
        generator_optimizer.step()

C. نظارت بر پیشرفت آموزش

  1. تصویرسازی نمونه‌های تولیدی

    # تولید نمونه‌ها و نمایش آنها
    fake_noise = torch.randn(64, latent_dim)
    fake_samples = generator(fake_noise)
    plt.figure(figsize=(8, 8))
    plt.axis("off")
    plt.imshow(np.transpose(vutils.make_grid(fake_samples.detach()[:64], padding=2, normalize=True), (1, 2, 0)))
    plt.show()
  2. ارزیابی عملکرد GAN

    • ارزیابی عملکرد یک GAN می‌تواند چالش‌برانگیز باشد، زیرا هیچ معیاری وجود ندارد که تمام جنبه‌های نمونه‌های تولیدی را در بر بگیرد.
    • معیارهای معمولاً استفاده شده شامل Inception Score (IS) و Fréchet Inception Distance (FID) می‌باشند که کیفیت و گستردگی نمونه‌های تولیدی را اندازه‌گیری می‌کنند.

V. شبکه‌های Adversarial Generation شرطی (cGANs) A. دلایل و کاربردهای GAN های شرطی- GANهای شرطی (cGANها) توسعه ای برای چارچوب معمول GAN هستند که امکان تولید نمونه هایی را با توجه به اطلاعات ورودی خاص، مانند برچسب های کلاس، توصیفات متنی یا داده های کمکی دیگر فراهم می کنند.

  • cGANها در برنامه هایی که می خواهید نمونه هایی با ویژگی ها یا ویژگی های خاص تولید کنید، مانند تولید تصاویری از یک کلاس جسم خاص یا ترجمه متن به تصویر، مفید است.

B. تغییر معماری GAN برای تولید شرطی

  1. یکپارچه کردن اطلاعات برچسب در Generator و Discriminator

    • در یک cGAN، شبکه های generator و discriminator طوری تغییر یافته اند که ورودی اضافی ای داشته باشند که اطلاعات شرطی (مانند برچسب کلاس، توصیف متنی) را شامل می شود.
    • این می تواند با اتصال ورودی شرطی با ورودی مخفی برای generator و با نمونه واقعی / تقلبی برای discriminator انجام شود.
  2. تعریف تابع های خسارت برای cGAN ها

    • تابع های خسارت برای generator و discriminator در یک cGAN شبیه به GAN استاندارد هستند، اما همچنین اطلاعات شرطی را در نظر می گیرند.
    • به عنوان مثال، هدف از تابع خسارت discriminator این است که نمونه های واقعی و نمونه های تقلبی را با توجه به اطلاعات برچسب ارائه شده درست طبقه بندی کند.

C. پیاده سازی یک cGAN در PyTorch

  1. تعریف مدل های cGAN
    class ConditionalGenerator(nn.Module):
        def __init__(self, latent_dim, num_classes, img_shape):
            super(ConditionalGenerator, self).__init__()
            self.latent_dim = latent_dim
            self.num_classes = num_classes
            self.img_shape = img_shape
     
            self.model = nn.Sequential(
                nn.Linear(self.latent_dim + self.num_classes, 256),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(256, 512),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(512, 1024),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(1024, np.prod(self.img_shape)),
                nn.Tanh()
            )
     
        def forward(self, z, labelsدر این آموزش، شما درباره اجزای اصلی فرایند آموزش مدل‌های یادگیری عمیق مانند بهینه‌سازها، توابع خطا، معیارهای ارزیابی، تکنیک‌های تنظیم مجازی سازی و ذخیره و بارگیری مدل‌ها آموختید. با درک این مفاهیم و استفاده از آنها در پروژه‌های یادگیری عمیق خود، به راحتی می‌توانید مدل‌هایی با عملکرد بالا ایجاد و آموزش دهید که قادر به حل یک دامنه گسترده از مسائل باشند.

به خاطر داشته باشید که یادگیری عمیق یک حوزه در حال تکامل است و همیشه مطالب جدیدی برای یادگیری وجود دارد. بهره‌برداری از این فیلد را ادامه دهید، آزمایش‌ها را انجام دهید و با پیشرفت‌های جدیدترین مطالعات در این حوزه، خوش‌شانس باشید. موفقیت در پیشه‌یادگیری عمیق خودتان را آرزومندیم!