روشهای ساده برای درک آسان GAN در PyTorch برای مبتدیان
I. مقدمهای بر شبکههای Adversarial Generation (GAN) A. تعریف و اجزای کلیدی GAN ها
- GAN ها یک کلاس از مدلهای یادگیری ماشین هستند که شامل دو شبکه عصبی، یک مولد و یک تمییزکننده است، که به صورت رقابتی آموزش میبینند.
- شبکه مولد مسئول تولید نمونههای واقعگرایانه (مانند تصاویر، متن، صدا) از یک فضای ورودی پنهان است.
- شبکه تمییزکننده برای تشخیص بین نمونههای واقعی از مجموعه داده و نمونههای جعلی توسط مولد آموزش میبیند.
- این دو شبکه به صورت رقابتی و آنتاگونیستی آموزش میبینند، به طوری که مولد سعی در فریب دادن تمیزکننده دارد و تمیزکننده سعی میکند نمونههای واقعی و جعلی را به درستی دستهبندی کند.
B. تاریخچه مختصر و تکامل GAN ها
- GAN ها در سال 2014 توسط ایان گودفلو و همکاران به عنوان یک رویکرد نوین برای مدلسازی تولید معرفی شدند.
- از زمان معرفی آنها، GAN ها پیشرفتهای قابل توجهی کرده و در زمینههای گستردهای مانند تولید تصاویر، تولید متن و حتی سنتز صدا مورد استفاده قرار گرفتهاند.
- برخی از دستاوردهای کلیدی در تکامل GAN ها شامل معرفی GAN های شرطی (cGANs)، GAN های کانوالوشنی عمیق (DCGANs)، GAN های واسرشتین (WGANs) و GAN های رشد تدریجی (PGGANs) هستند.
II. راه اندازی محیط پایتورچ A. نصب پایتورچ
- پایتورچ یک کتابخانه محبوب یادگیری ماشین منبع باز است که یک چارچوب قابل انعطاف و کارآمد برای ساخت و آموزش مدلهای یادگیری عمیق، از جمله GAN ها، ارائه میکند.
- برای نصب پایتورچ، می توانید دستورالعمل نصب رسمی را که در وب سایت پایتورچ (https://pytorch.org/get-started/locally/ (opens in a new tab)) ارائه شده است، دنبال کنید.
- فرآیند نصب ممکن است بسته به سیستم عامل، نسخه پایتون و نسخه CUDA (در صورت استفاده از گرافیک) شما، متفاوت باشد.
B. وارد کردن کتابخانهها و ماژولهای لازم
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
III. درک ساختار GAN A. شبکه مولد
-
ساختار ورودی و خروجی
- شبکه مولد یک بردار ورودی پنهان (مانند یک بردار داده تصادفی) را میگیرد و یک نمونه تولید شده را (مانند یک تصویر) تولید میکند.
- اندازه بردار ورودی پنهان و نمونه خروجی، بسته به مساله خاص و خروجی مورد نظر مشخص میشود.
-
لایهها و توابع فعالسازی شبکه
- شبکه مولد معمولاً شامل یک سری لایههای کاملاً متصل یا کانوالوشنی است، بسته به حوزه مسئله.
- توابع فعالسازی مانند ReLU، Leaky ReLU یا tanh به طور معمول در شبکه مولد استفاده میشوند.
-
بهینه کردن مولد
- شبکه مولد آموزش میبیند تا نمونههایی تولید کند که میتوانند تمیزکننده را گول بزنند.
- تابع خطا برای مولد برای بیشینه کردن احتمال تمیزکننده در نادرست دستهبندی کردن نمونههای تولید شده به عنوان واقعی طراحی شده است.
B. شبکه تمییزکننده
-
ساختار ورودی و خروجی
- شبکه تمییزکننده یک نمونه (به طور واقعی از مجموعه داده یا تولید شده توسط مولد) را میگیرد و احتمال واقعی بودن نمونه را خروجی میدهد.
- اندازه ورودی تمیزکننده بستگی به اندازه نمونهها (مانند اندازه تصویر) دارد و خروجی یک مقدار عددی بین 0 و 1 است.
-
لایهها و توابع فعالسازی شبکه
- شبکه تمیزکننده معمولاً شامل یک سری لایههای کانوالوشنی یا کاملاً متصل است، بسته به حوزه مسئله.
- توابع فعالسازی مانند Leaky ReLU یا sigmoid به طور معمول در شبکه تمیزکننده استفاده میشوند.
-
بهینهسازی تمیزکننده
- شبکه تمیزکننده آموزش دیده است تا نمونههای واقعی را به درستی به عنوان واقعی و نمونههای تولید شده را به عنوان جعلی دستهبندی کند.
- تابع خطا برای تمیزکننده طراحی شده است تا بیشینه احتمال درست دستهبندی کردن نمونههای واقعی و جعلی را داشته باشد.
C. فرآیند آموزش رقابتی
-
توابع خطا برای مولد و تمیزکننده
- تابع خطا مولد برای بیشینه کردن احتمال تمیزکننده در نادرست دستهبندی کردن نمونههای تولید شده به عنوان واقعی طراحی شده است.
- تابع خطا تمیزکننده برای بیشینه کردن احتمال درست دستهبندی کردن نمونههای واقعی و جعلی طراحی شده است.
-
بهینهسازی متناوب مولد و تمیزکننده
- فرآیند آموزش شامل تناوب بین بهروزرسانی شبکههای مولد و تمیزکننده است.
- در ابتدا، تمیزکننده آموزش میبیند تا قدرت خود در تشخیص نمونههای واقعی و جعلی را بهبود ببخشد.
- سپس، مولد آموزش میبیند تا قدرت خود در تولید نمونههایی که تمیزکننده را فریب بدهند، بهبود ببخشد.
- این فرآیند آموزش رقابتی تا زمانیکه مولد و تمیزکننده به تعادل برسند، ادامه مییابد.
IV. پیادهسازی یک GAN ساده در PyTorch A. تعریف مدلهای مولد و تمیزکننده
-
ساخت شبکه مولد
class Generator(nn.Module): def __init__(self, latent_dim, img_shape): super(Generator, self).__init__() self.latent_dim = latent_dim self.img_shape = img_shape self.model = nn.Sequential( nn.Linear(self.latent_dim, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 1024), nn.LeakyReLU(0.2, inplace=True), nn.Linear(1024, np.prod(self.img_shape)), nn.Tanh() ) def forward(self, z): img = self.model(z) img = img.view(img.size(0), *self.img_shape) return img
-
ساخت شبکه تمیزکننده
class Discriminator(nn.Module): def __init__(self, img_shape): super(Discriminator, self).__init__() self.img_shape = img_shape self.model = nn.Sequential( nn.Linear(np.prod(self.img_shape), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, img): img_flat = img.view(img.size(0), -1) validity = self.model(img_flat) return validity
B. تنظیم حلقه آموزش
-
شروع مولد و تمیزکننده
latent_dim = 100 img_shape = (1, 28, 28) # مثالی برای مجموعه داده MNIST generator = Generator(latent_dim, img_shape) discriminator = Discriminator(img_shape)
-
تعریف تابع خطاها
adversarial_loss = nn.BCELoss() def generator_loss(fake_output): return adversarial_loss(fake_output, torch.ones_like(fake_output)) def discriminator_loss(real_output, fake_output): real_loss = adversarial_loss(real_output, torch.ones_like(real_output)) fake_loss = adversarial_loss(fake_output, torch.zeros_like(fake_output)) return (real_loss + fake_loss) / 2
-
تناوب بهینهسازی مولد و تمیزکننده
num_epochs = 200 batch_size = 64 # بهینهسازها generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) for epoch in range(num_epochs): # آموزش تمیزکننده discriminator.zero_grad() real_samples = next(iter(dataloader))[0] real_output = discriminator(real_samples) fake_noise = torch.randn(batch_size, latent_dim) fake_samples = generator(fake_noise) fake_output = discriminator(fake_samples.detach()) d_loss = discriminator_loss(real_output, fake_output) d_loss.backward() discriminator_optimizer.step() # آموزش مولد generator.zero_grad() fake_noise = torch.randn(batch_size, latent_dim) fake_samples = generator(fake_noise) fake_output = discriminator(fake_samples) g_loss = generator_loss(fake_output) g_loss.backward() generator_optimizer.step()
C. نظارت بر پیشرفت آموزش
-
تصویرسازی نمونههای تولیدی
# تولید نمونهها و نمایش آنها fake_noise = torch.randn(64, latent_dim) fake_samples = generator(fake_noise) plt.figure(figsize=(8, 8)) plt.axis("off") plt.imshow(np.transpose(vutils.make_grid(fake_samples.detach()[:64], padding=2, normalize=True), (1, 2, 0))) plt.show()
-
ارزیابی عملکرد GAN
- ارزیابی عملکرد یک GAN میتواند چالشبرانگیز باشد، زیرا هیچ معیاری وجود ندارد که تمام جنبههای نمونههای تولیدی را در بر بگیرد.
- معیارهای معمولاً استفاده شده شامل Inception Score (IS) و Fréchet Inception Distance (FID) میباشند که کیفیت و گستردگی نمونههای تولیدی را اندازهگیری میکنند.
V. شبکههای Adversarial Generation شرطی (cGANs) A. دلایل و کاربردهای GAN های شرطی- GANهای شرطی (cGANها) توسعه ای برای چارچوب معمول GAN هستند که امکان تولید نمونه هایی را با توجه به اطلاعات ورودی خاص، مانند برچسب های کلاس، توصیفات متنی یا داده های کمکی دیگر فراهم می کنند.
- cGANها در برنامه هایی که می خواهید نمونه هایی با ویژگی ها یا ویژگی های خاص تولید کنید، مانند تولید تصاویری از یک کلاس جسم خاص یا ترجمه متن به تصویر، مفید است.
B. تغییر معماری GAN برای تولید شرطی
-
یکپارچه کردن اطلاعات برچسب در Generator و Discriminator
- در یک cGAN، شبکه های generator و discriminator طوری تغییر یافته اند که ورودی اضافی ای داشته باشند که اطلاعات شرطی (مانند برچسب کلاس، توصیف متنی) را شامل می شود.
- این می تواند با اتصال ورودی شرطی با ورودی مخفی برای generator و با نمونه واقعی / تقلبی برای discriminator انجام شود.
-
تعریف تابع های خسارت برای cGAN ها
- تابع های خسارت برای generator و discriminator در یک cGAN شبیه به GAN استاندارد هستند، اما همچنین اطلاعات شرطی را در نظر می گیرند.
- به عنوان مثال، هدف از تابع خسارت discriminator این است که نمونه های واقعی و نمونه های تقلبی را با توجه به اطلاعات برچسب ارائه شده درست طبقه بندی کند.
C. پیاده سازی یک cGAN در PyTorch
- تعریف مدل های cGAN
class ConditionalGenerator(nn.Module): def __init__(self, latent_dim, num_classes, img_shape): super(ConditionalGenerator, self).__init__() self.latent_dim = latent_dim self.num_classes = num_classes self.img_shape = img_shape self.model = nn.Sequential( nn.Linear(self.latent_dim + self.num_classes, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 1024), nn.LeakyReLU(0.2, inplace=True), nn.Linear(1024, np.prod(self.img_shape)), nn.Tanh() ) def forward(self, z, labelsدر این آموزش، شما درباره اجزای اصلی فرایند آموزش مدلهای یادگیری عمیق مانند بهینهسازها، توابع خطا، معیارهای ارزیابی، تکنیکهای تنظیم مجازی سازی و ذخیره و بارگیری مدلها آموختید. با درک این مفاهیم و استفاده از آنها در پروژههای یادگیری عمیق خود، به راحتی میتوانید مدلهایی با عملکرد بالا ایجاد و آموزش دهید که قادر به حل یک دامنه گسترده از مسائل باشند.
به خاطر داشته باشید که یادگیری عمیق یک حوزه در حال تکامل است و همیشه مطالب جدیدی برای یادگیری وجود دارد. بهرهبرداری از این فیلد را ادامه دهید، آزمایشها را انجام دهید و با پیشرفتهای جدیدترین مطالعات در این حوزه، خوششانس باشید. موفقیت در پیشهیادگیری عمیق خودتان را آرزومندیم!