网易首页 > 网易号 > 正文 申请入驻

使用Modulated Convolutions修改 StarGAN V2

0
分享至

在本教程中,我们将替换 StarGAN V2 模型中的自适应实例归一化(AdaIN)层,并在分辨率为 512x512 像素的图像上对其进行训练。

今天,有许多模型可以生成高质量的图像。具体来说,对于属性交换任务(2021 年),最好的质量是由 StyleGAN 的进一步发展的模型提供的,或者是通过对其进行提炼而获得的,这需要大量的计算时间来训练新的数据域。在单个 Google Colab GPU 上从头开始训练 24 小时后,所提出的模型会生成文章开头所示的图像。

StarGAN V2[1] 是一个图像到图像模型,它使用由条件编码器管理的 AdaIN 层来传输图像风格。它分别使用有关对象结构及其纹理的信息,从而允许用户获得组合图像。

StarGAN 与图像生成相关的部分如下图所示。它们包括 ResNet-like[2] 编码器——用绿色标记,解码器带有 AdainResBlk 模块(将在下面描述)——紫色,以及一组具有共享头层的条件相关样式信息编码器(灰蓝色)——用绿松石标记。

StarGAN 的工作原理如下。 一开始,风格编码器从图像中提取低级特征。 然后生成器编码对象的几何信息并将其提供给 AdainResBlk 模块的金字塔。

每个 AdainResBlk 块都包含 StyleGAN 的自适应实例归一化(AdaIN)模块 [3],它通过从样式编码器接收到的信息来调制抽象对象的几何表示。

让我们开始我们的项目,用来自 StyleGAN 2[4] 的调制卷积替换 AdaIN 归一化。

首先,我们需要原始 StarGAN 的 repo:git clone https://github.com/clovaai/stargan-v2.git。

AdainResBlk 的源代码位于 core/model.py 文件中。 代码如下所示。

class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=0,
actv=nn.LeakyReLU(0.2), upsample=False):
# ...
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
if self.w_hpf == 0:
out = (out + self._shortcut(x)) / math.sqrt(2)
return out

现在,我们用 lucidrains StyleGAN 2 模块 [5] 替换了 AdainResBlk。 类似于 AdainResBlk 的功能在 GeneratorBlock 类(文件 stylegan2_pytorch.py)中实现。 让我们将这个类及其依赖项——Conv2DMod、Blur 和 RGBBlock 复制到我们的仓库中。

生成器块的最终版本如下所示。

from modulated_convolution import Conv2DMod, RGBBlock
class GenResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=64, fade_num_channels=4, fade_num_hidden=32,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.actv = actv
self.upsample = upsample
self.needSkipConvolution = dim_in != dim_out
self.conv1 = Conv2DMod(dim_in, dim_out, 3, stride=1, dilation=1)
self.conv2 = Conv2DMod(dim_out, dim_out, 3, stride=1, dilation=1)
self.style1 = nn.Linear(style_dim, dim_in)
self.style2 = nn.Linear(style_dim, dim_out)
if self.needSkipConvolution:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
self.toRGB = RGBBlock(style_dim, dim_out, upsample, 3)
def forward(self, x, rgb, s):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
if self.needSkipConvolution:
x_ = self.conv1x1(x)
else:
x_ = x
s1 = self.style1(s)
x = self.conv1(x, s1)
x = self.actv(x)
s2 = self.style2(s)
x = self.conv2(x, s2)
x = self.actv(x + x_)
rgb = self.toRGB(x, rgb, s)
return x, rgb

为简单起见,我们不会改变 StyleGAN 使用两个流——特征流和 RGB 图像流的原始概念,因此有必要修改生成器的前向方法。

替换最近的几行:

def forward(self, x, s, masks=None):
x = self.from_rgb(x)
cache = {}
for block in self.encode:
if (masks is not None) and (x.size(2) in [32, 64, 128]):
cache[x.size(2)] = x
x = block(x)
for block in self.decode:
x = block(x, s)
if (masks is not None) and (x.size(2) in [32, 64, 128]):
mask = masks[0] if x.size(2) in [32] else masks[1]
mask = F.interpolate(mask, size=x.size(2), mode='bilinear')
x = x + self.hpf(mask * cache[x.size(2)])
return self.to_rgb(x)

使用下一个代码块:

def forward(self, x, s, masks=None):
x = self.from_rgb(x)
cache = {}
for block in self.encode:
if (masks is not None) and (x.size(2) in [32, 64, 128]):
cache[x.size(2)] = x
x = block(x)
rgb = None
for block in self.decode:
x, rgb = block(x, rgb, s)
if (masks is not None) and (x.size(2) in [32, 64, 128]):
mask = masks[0] if x.size(2) in [32] else masks[1]
mask = F.interpolate(mask, size=x.size(2), mode='bilinear')
x = x + self.hpf(mask * cache[x.size(2)])
return rgb

为避免测试调用时出现 OOM,请在 debug_image 函数(文件 utils.py)中注释“潜在引导图像合成”和“参考引导图像合成”块。

对于 512x512 图像的训练模型,我们必须将批量大小减少到 1。为了稳定训练过程,我们将使用假图像缓冲区(来自 pytorch-CycleGAN-and-pix2pix repo),它允许我们使用以下方法更新鉴别器的权重 生成数据的历史记录,而不是最新的假输出。

如果您将在 Colab 环境中训练模型,您可以修改 _save_checkpoint 和 _load_checkpoint 函数中的步骤参数(在任何情况下,记得Google Drive 创建备份)并在将当前模型复制到 Drive 的训练函数中添加下一行:

# save model checkpoints
if (i+1) % args.save_every == 0:
self._save_checkpoint(step=i+1)
print("Saving on GDrive...")
import subprocess
subprocess.run(f"cp --force -R ./expr/ /content/drive/MyDrive/stargan_animals_expr/", shell=True, check=True)

将 AFHQ 放入 data/ 文件夹后,我们就可以开始训练了。

可以通过以下方式开始对大小为 256x256 的图像进行训练:

python main.py --img_size 256 --resume_iter 100 --mode train --num_domains 3 --w_hpf 0 \ --lambda_reg 1 --lambda_sty 1 --lambda_ds 2 --lambda_cyc 1 \ --train_img_dir data/afhq/train --val_img_dir data/afhq/val \ --batch_size 4 --sample_every 100 --save_every 500

要在 512x512px 分辨率上进行训练,请运行:

python main.py --img_size 256 --resume_iter 100 --mode train --num_domains 3 --w_hpf 0 \ --lambda_reg 1 --lambda_sty 1 --lambda_ds 2 --lambda_cyc 1 \ --train_img_dir data/afhq/train --val_img_dir data/afhq/val \ --batch_size 4 --sample_every 100 --save_every 500

本文的源代码:https://github.com/Hramchenko/modulated_stargan

本文作者:Vitaliy Hramchenko

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
4.4女单四强确定:孙颖莎王曼昱温特申裕斌晋级,美桥本无缘

4.4女单四强确定:孙颖莎王曼昱温特申裕斌晋级,美桥本无缘

富贵体坛说
2026-04-04 23:22:10
美国人通告全球,美不“护台”,话音刚落,国台办再将台当局一军

美国人通告全球,美不“护台”,话音刚落,国台办再将台当局一军

潋滟晴方DAY
2026-04-05 02:02:42
香菇再次被关注!医生发现:癌症患者吃香菇,不用多久或有5改善

香菇再次被关注!医生发现:癌症患者吃香菇,不用多久或有5改善

读懂世界历史
2026-02-12 21:48:53
中产心头好!国产咖啡机一哥泡出4.49亿,顺德女老板冲刺第一股

中产心头好!国产咖啡机一哥泡出4.49亿,顺德女老板冲刺第一股

品牌观察官
2026-04-04 20:56:20
难怪咸丰帝31岁去世,你看他逃往热河干了啥?每天都做4件致命事

难怪咸丰帝31岁去世,你看他逃往热河干了啥?每天都做4件致命事

铭记历史呀
2026-03-21 17:44:21
毛岸英究竟怎么牺牲的?2020年彭德怀发的绝密电报公开,写了什么

毛岸英究竟怎么牺牲的?2020年彭德怀发的绝密电报公开,写了什么

楚风说历史
2026-02-18 07:25:03
退休后,存款高于“这个数”,你的家庭就很有底气了,说话都硬气

退休后,存款高于“这个数”,你的家庭就很有底气了,说话都硬气

美食格物
2026-03-04 23:23:06
你敢地面入侵,我就派志愿军!伊朗迎来新帮手,海湾7国沉默不语

你敢地面入侵,我就派志愿军!伊朗迎来新帮手,海湾7国沉默不语

顾史
2026-04-03 23:27:25
原装进口!全时四驱SUV,硬派全能王,仅售19.98万,放弃坦克300

原装进口!全时四驱SUV,硬派全能王,仅售19.98万,放弃坦克300

隔壁说车老王
2026-04-04 07:00:59
两岸会谈前,洪秀柱重提《国家统一纲领》,郑丽文:此行挑战非凡

两岸会谈前,洪秀柱重提《国家统一纲领》,郑丽文:此行挑战非凡

锦年衍生烦愁
2026-04-03 15:19:06
因向中国大陆提供台湾机密文件,民进党官员被除名,面临5年徒刑

因向中国大陆提供台湾机密文件,民进党官员被除名,面临5年徒刑

混沌录
2026-04-03 22:59:28
盐城亭湖反腐新动态:新兴镇原财政“一把手”张爱明落马

盐城亭湖反腐新动态:新兴镇原财政“一把手”张爱明落马

飞鹤传媒
2026-04-03 20:45:37
越南教科书:广东,广西是越南故土,至今未收复,两千年抗北历史

越南教科书:广东,广西是越南故土,至今未收复,两千年抗北历史

长风文史
2026-03-19 20:48:02
加餐换自动铅笔后续:宝妈强势追责,同学赔偿道歉,方式太窒息了

加餐换自动铅笔后续:宝妈强势追责,同学赔偿道歉,方式太窒息了

阿纂看事
2026-04-02 13:47:58
男人胡子长得快,说明了什么?刮胡子频率与寿命有关?告诉你答案

男人胡子长得快,说明了什么?刮胡子频率与寿命有关?告诉你答案

熊猫医学社
2026-03-31 11:40:03
美国怕的不是伊朗,如果不是中国虎视眈眈,美伊战争或许早已结束

美国怕的不是伊朗,如果不是中国虎视眈眈,美伊战争或许早已结束

安安说
2026-03-29 13:42:09
成人版“抖*阴” ,终于还是凉凉了 !

成人版“抖*阴” ,终于还是凉凉了 !

肇庆之星
2021-04-23 08:33:36
内存一年涨四倍!国产手机厂商集体涨价:会持续多久?苹果会加入吗?

内存一年涨四倍!国产手机厂商集体涨价:会持续多久?苹果会加入吗?

澎湃新闻
2026-04-03 21:08:35
美以伊冲突持续,美媒:华盛顿通知东京,约400枚“战斧”导弹交付计划将受影响

美以伊冲突持续,美媒:华盛顿通知东京,约400枚“战斧”导弹交付计划将受影响

环球网资讯
2026-04-03 21:25:09
女子因桃花眼走红,订婚两年热度依旧,网友喊话:88号快回来上班

女子因桃花眼走红,订婚两年热度依旧,网友喊话:88号快回来上班

梅子的小情绪
2025-12-19 14:04:18
2026-04-05 03:24:49
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1966文章数 1461关注度
往期回顾 全部

科技要闻

内存一年涨四倍!国产手机厂商集体涨价

头条要闻

伊朗发动第七轮导弹袭击 耶路撒冷拦截导弹升空

头条要闻

伊朗发动第七轮导弹袭击 耶路撒冷拦截导弹升空

体育要闻

刹不住的泰格·伍兹,口袋里的两粒药丸

娱乐要闻

Q女士反击,否认逼宋宁峰张婉婷离婚

财经要闻

中微董事长,给半导体泼点冷水

汽车要闻

17万级海豹07EV 不仅续航长还有9分钟满电的快乐

态度原创

教育
艺术
亲子
游戏
军事航空

教育要闻

这些英国大学开始崩盘!

艺术要闻

你绝对不能错过的梦幻性感摄影作品!

亲子要闻

我这个00后舅舅怎么这么会带娃

好玩还上头!创新与传统并存的战棋黑马《永铃回响》值不值得玩?

军事要闻

美军又一架战机坠毁 此前F-15E被击落

无障碍浏览 进入关怀版