网易首页 > 网易号 > 正文 申请入驻

SimMIM:一种更简单的MIM方法

0
分享至

自从何恺明的(←点击蓝字查看文章详情)出来之后,基于MIM(Masked Image Modeling)的无监督学习方法越来越受到关注。这里介绍一篇和MAE同期的工作:SimMIM: A Simple Framework for Masked Image Modeling,研究团队是微软亚研院。

SimMIM和MAE有很多相似的设计和结论,而且效果也比较接近,比如基于ViT-B的模型无监督训练后再finetune可以ImageNet数据集达到83.8%的top1 accuray(MAE为83.6%)。不过相比MAE,SimMIM更加简单,而且也可以用来无监督训练金字塔结构的vision transformer模型如swin transformer等。目前SimMIM实现代码已经开源,本文将基于论文和源码对SimMIM方法进行解读。

算法原理

SimMIM采用最简单的MIM方法:随机mask掉输入图像的一部分patch,然后通过encoder-decoder来预测masked patchs的原始像素值。算法原理图如上图所示,从设计方面和MAE基本一致。SimMIM的主要结论如下:

  • 直接对图像采用简单的random mask是非常简单有效的方法;

  • 直接回归原始的像素的RGB值不比BEiT采用的分类效果差;

  • decoder采用轻量级的设计(直接采用一个线性层)也能得到很好的效果;

这些结论也是在MAE论文中得到了验证。那么SimMIM和MAE的区别在哪里呢?主要有以下两点:

  • SimMIM的encoder同时处理visible tokens和masked tokens,而MAE的encoder只处理visible tokens;

  • SimMIM的decoder只采用一个线性层来回归像素值,而MAE的decoder采用transformer结构;

第2个差异带来的影响相对很小,因为两个论文都证明了decoder设计对性能影响较小。主要的差异点是第一个,MAE训练时只处理visible tokens一方面可以加速训练(减少了计算量),同时也可以减少pre-training和deploy之间的gap(deploy时输入是非masked的图像,无masked token),MAE实验也证明只处理visible tokens可以提升linear probing性能:73.5% vs 59.6%。

而SimMIM是处理所有的tokens,从实验结果上看也符合MAE的结论,SimMIM方法得到的ViT-B模型的linear probing只有56.7%,不过这不并不会影响finetune后的性能,关于这点MAE论文也论证了。不过SimMIM这样做带来的一个好处是可以用来训练其它非“同质结构”模型,比如swin transformer,由于它各个stage间要对patch进行merge操作,所以token并不是像ViT那样一成不变的。下面我们具体介绍SimMIM的各个部分,这里默认实验都是以Swin-B为encoder,为了减少实验成本,输入图像大小为192x192(原来是224),window size设置为6(原来是7),预训练epoch为100。

Masking Strategy

SimMIM的masking策略按照一定mask ratio随机mask掉一部分patch。在MAE中,masked patch size和ViT的patch size是一致的,比如ViT-B/16模型,masked patch size就要设计为16x16,然后用一个可学习的masked token来代替。但是对于SimMIM,其设计masked patch size不一定等于模型的patch size,比如ViT模型masked patch size可以是32x32,理论上mask patch size只要是ViT模型patch size的整数倍就可以,因此此时每个mask掉的patch可以整分成和模型patch一样大小的若干个patch。

对于金字塔结构的swin transformer,每个stage的patch size是不同的,比如第一个stage的patch size是4x4,而最后一个stage的patch size是32x32,此时设计的mask patch size只需要是第一个stage的patch size整数就好。无论是ViT还是swin transformer,masked token对应的patch size都是其patch embedding层对应的patch size,对于ViT就是16x16,而对于swin transformer就是4x4,而mask patch size只需要是masked token的patch size的整数倍即可。

所以SimMIM采用更灵活的mask patch size,不同mask patch size的可视化效果如下图所示。对于ViT和swin transformer,SimMIM都默认采用:mask ratio=0.6,mask patch size=32x32。

不同的mask type,mask patch size和mask ratio对模型效果(finetune)的影响如下表所示,可以看到不同的设置均可以取得类似的效果,其中random+masked patch size=32x32+mask ratio=0.5可取得最优的效果83.0%。

从表中可以看出,采用较小的masked patch size(4x4,8x8,16x16),模型效果随着mask ratio的增加而提升,而对于更大的masked patch size(64x64),需要采用较小的mask ratio才能得到较好的结果。masked patch size和mask ratio影响的是MIM任务的难度,两者越大,MIM任务越难,要想取得较好的模型训练效果,MIM任务的难度要适当大一些。

论文也提出了AvgDist指标来进一步分析masked patch size和mask ratio对模型finetune效果的影响,这里AvgDist指标计算的是所有masked pixels到最近的visible pixels的平均欧式距离,它综合了masked patch size和mask ratio对MIM任务的影响。从下图可以看出,AvgDist随着mask ratio的增加而增加,对于较小的masked patch size,其AvgDist在较大的mask ratio下依然较小,而较大的masked patch size,其AvgDist在较小的mask ratio下就比较大。从右图可以看出,AvgDist在[10, 20]区间内都可以取得较好的finetune效果,这个可以用来指导选择不同masked patch size和mask ratio组合。

采用不同的masked patch size,其预测的图像效果如下所示,可以看到masked patch size越小,图像还原度越高,这也比较合理。但是MIM本身并不是为了更好地恢复图像,而是希望encoder学习到好的特征以迁移到下游任务。

随机mask策略的实现比较简单,在对每个图像进行数据增强后,同时随机生成一个mask;在模型forward时,将masked patch用mask token来替换,注意由于masked patch size和model_patch_size不一定相等,所以要将随机生成mask转换成和model_patch_size一致的mask。具体实现代码如下所示:

class MaskGenerator:
def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
self.input_size = input_size # 输入图像大小
self.mask_patch_size = mask_patch_size # masked patch大小
self.model_patch_size = model_patch_size # 模型patch embed层的patch大小
self.mask_ratio = mask_ratio

assert self.input_size % self.mask_patch_size == 0
assert self.mask_patch_size % self.model_patch_size == 0

self.rand_size = self.input_size // self.mask_patch_size
self.scale = self.mask_patch_size // self.model_patch_size

self.token_count = self.rand_size ** 2
self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))

def __call__(self):
mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
mask = np.zeros(self.token_count, dtype=int)
mask[mask_idx] = 1

# 要转换成和model_patch size一致的mask
mask = mask.reshape((self.rand_size, self.rand_size))
mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)

return mask

class SimMIMTransform:
def __init__(self, config):
self.transform_img = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
])

if config.MODEL.TYPE == 'swin':
model_patch_size=config.MODEL.SWIN.PATCH_SIZE
elif config.MODEL.TYPE == 'vit':
model_patch_size=config.MODEL.VIT.PATCH_SIZE
else:
raise NotImplementedError

self.mask_generator = MaskGenerator(
input_size=config.DATA.IMG_SIZE,
mask_patch_size=config.DATA.MASK_PATCH_SIZE,
model_patch_size=model_patch_size,
mask_ratio=config.DATA.MASK_RATIO,
)

def __call__(self, img):
img = self.transform_img(img) # 图像数据增强
mask = self.mask_generator() # 生成mask

return img, mask

class SwinTransformerForSimMIM(SwinTransformer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

assert self.num_classes == 0

self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
trunc_normal_(self.mask_token, mean=0., std=.02)

def forward(self, x, mask):
x = self.patch_embed(x)

assert mask is not None
B, L, _ = x.shape

mask_tokens = self.mask_token.expand(B, L, -1)
w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
x = x * (1. - w) + mask_tokens * w

if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)

for layer in self.layers:
x = layer(x)
x = self.norm(x)

x = x.transpose(1, 2)
B, C, L = x.shape
H = W = int(L ** 0.5)
x = x.reshape(B, C, H, W)
return x

# 基于swinT的SimMIM
class SwinTransformerForSimMIM(SwinTransformer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

assert self.num_classes == 0

# 定义可学习的masked token
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
trunc_normal_(self.mask_token, mean=0., std=.02)

def forward(self, x, mask):
x = self.patch_embed(x)

assert mask is not None
B, L, _ = x.shape

mask_tokens = self.mask_token.expand(B, L, -1)
w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
x = x * (1. - w) + mask_tokens * w # 用masked token替换masked patch对应的patch embedding

if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)

for layer in self.layers:
x = layer(x)
x = self.norm(x)

x = x.transpose(1, 2)
B, C, L = x.shape
H = W = int(L ** 0.5)
x = x.reshape(B, C, H, W)
return x

Prediction Head

这里的prediction head指的就是decoder,用来预测masked patch的原始像素值。论文发现采用一个非常轻量级的decoder(只用1个linear层)就非常有效。采用更复杂的head,效果没有提升,反而会增加训练成本。MAE也指出decoder的设计对finetune性能影响较小,但是却会影响linear probing效果,如果采用较轻的decoder,那么encoder的后面部分层就要承担一部分像素预测任务(无监督训练代理任务),但这个却不是图像分类任务所需要的,所以会带来linear probing的下降,所以如果要想得到比较好的linear probing效果,就需要设计一个适当的decoder以将预测任务集中在decoder上。

SimMIM默认采用单个linear层来预测像素值,在实现上采用一个1x1卷积层。对于swin transformer,其得到的特征图(恢复成hxw)是原来图像的1/32大小,那么卷积层输出channels等于3072=32x32x3,每个特征点预测32x32个pixels的RGB值。

Prediction Tragets

SimMIM是直接回归masked patch的原始像素值,所以target就是原始图像的RGB值,而回归损失采用L1 loss,注意这里和MAE一样,只计算masked pixels的损失,论文也发现如果对所有pixels计算loss,效果会下降(82.8% -> 81.7%),prediction而不是reconstruction能更好地让encoder学习到更强的特征。另外一个参数是prediction resolution,SimMIM默认的prediction resolution是原始图像大小,但也可以对原始图像进行下采样,从而降低prediction resolution,从实验结果来看,采用不同的prediction resolution均能得到较好的结果,除了1/32表现相对差一些(图像损失比较严重):

论文也对比了其它类型的targets,比如像BEiT那样用dVAE将回归变成分类任务,或者像IGPT那样采用color clustering。从下表的对比结果可以看到直接回归像素值并不比这些更复杂的设计差。

loss计算部分的实现也比较简单,具体的代码如下所示(注意这里回归的像素值是归一化后的像素值):

class SimMIM(nn.Module):
def __init__(self, encoder, encoder_stride):
super().__init__()
self.encoder = encoder
self.encoder_stride = encoder_stride

# 定义encoder
self.decoder = nn.Sequential(
nn.Conv2d(
in_channels=self.encoder.num_features,
out_channels=self.encoder_stride ** 2 * 3, kernel_size=1), # 1x1 conv等价于linear
nn.PixelShuffle(self.encoder_stride), # [B, 3*r*r, h, w] -> [B, 3, h*r, w*r]
)

self.in_chans = self.encoder.in_chans
self.patch_size = self.encoder.patch_size

def forward(self, x, mask):
z = self.encoder(x, mask) # encoder提取特征
x_rec = self.decoder(z) # decoder预测图像

# mask转变成和原始图像一样大小
mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
loss_recon = F.l1_loss(x, x_rec, reduction='none') # L1 loss
loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans # 只计算masked pixels并取mean
return loss

实验设置及对比结果

前面的实验都是以Swin-B为backbone,预训练的epoch为100,而最后的实验训练800个epoch,batch size为2048。在数据增强方面,只采用random resize croping:RandomResizedCrop(192, scale=(0.67, 1.),ratio=(3. / 4., 4. / 3.))以及水平翻转,和MAE一样属于轻量级的数据增强,这说明MIM方法确实不像对比学习那样需要较heavy的数据增强。对于ViT,预训练的图像大小是224,而SwinT采用的图像大小为192,对比结果如下表所示。可以看到:

  • 基于SimMIM训练的ViT-B优于BEiT方法(83.8 vs 83.2),训练成本也比较低,但是linear probing效果均比较差(56.7);

  • 基于SimMIM预训练的SwinT也优于有监督训练的模型,对于Swin-B,预训练800epoch相比100epoch有一定提升(82.8 vs 84.0),这里也包含SwinV2的实验,其中30亿参数的SwinV2-G的效果可达到90.2%。

下图是一些masked图像重建后的可视化,可以看出经过SimMIM训练后,模型能学习到一定的推理能力,比如mask掉一个物体或者人后,模型能学会补全背景。

小结

总结来看,SimMIM和MAE方法大致相同,两者的差异大概源自SimMIM是为Swin设计的,而MAE是为单纯的ViT结构设计的。一个缺憾是SimMIM方法虽然在SwinV2上做了验证,但是没有直接在下游检测和分割任务上的对比实验,而MAE方法在随后的工作Benchmarking Detection Transfer Learning with Vision Transformers中论证了其迁移到实例分割任务上的有效性。

参考文献

【1】Masked Autoencoders Are Scalable Vision Learners

【2】SimMIM: A Simple Framework for Masked Image Modeling

【3】https://link.zhihu.com/?target=https%3A//github.com/microsoft/SimMIM

来源:知乎

作者:小小将

深延科技|

深延科技成立于2018年,是深兰科技(DeepBlue)旗下的子公司,以“人工智能赋能企业与行业”为使命,助力合作伙伴降低成本、提升效率并挖掘更多商业机会,进一步开拓市场,服务民生。公司推出四款平台产品——深延智能数据标注平台、深延AI开发平台、深延自动化机器学习平台、深延AI开放平台,涵盖从数据标注及处理,到模型构建,再到行业应用和解决方案的全流程服务,一站式助力企业“AI”化。

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
美国移民与海关执法局副局长宣布离职

美国移民与海关执法局副局长宣布离职

每日经济新闻
2026-01-16 08:06:39
随着越南3-2,亚洲杯四强仅剩2席:中国队还未晋级

随着越南3-2,亚洲杯四强仅剩2席:中国队还未晋级

侧身凌空斩
2026-01-17 02:09:34
中国将迎人口死亡高峰!22年1041万,23年1100万,去年死亡多少?

中国将迎人口死亡高峰!22年1041万,23年1100万,去年死亡多少?

长歌侃娱
2026-01-16 07:55:03
订婚时新郎长相普通,新娘颜值爆表却望着他移不开眼:生理性喜欢

订婚时新郎长相普通,新娘颜值爆表却望着他移不开眼:生理性喜欢

梅子的小情绪
2026-01-16 21:09:56
69岁男星为还巨债「凌晨厦门卖水果」,网友心痛:看起来好累

69岁男星为还巨债「凌晨厦门卖水果」,网友心痛:看起来好累

粤睇先生
2026-01-16 00:35:43
多地现快递驿站转让帖,经营者直言太累,称每天工作十几个小时,每月只赚五六千,“不如去当保安”

多地现快递驿站转让帖,经营者直言太累,称每天工作十几个小时,每月只赚五六千,“不如去当保安”

大风新闻
2026-01-14 17:03:09
台湾黑帮大佬陈启礼:我宁可被共产党管着,也不让台湾被外人拿走

台湾黑帮大佬陈启礼:我宁可被共产党管着,也不让台湾被外人拿走

浩渺青史
2026-01-16 16:39:21
上海一汤泉店顾客称遗失13万余元劳力士!涉事门店:公安介入,已全面排查更衣柜门锁

上海一汤泉店顾客称遗失13万余元劳力士!涉事门店:公安介入,已全面排查更衣柜门锁

台州交通广播
2026-01-15 22:37:50
老公执意接3个侄子来借读,我转身出差,他打30通电话"孩子谁管"

老公执意接3个侄子来借读,我转身出差,他打30通电话"孩子谁管"

磊子讲史
2025-05-27 11:20:08
首次访华就给下马威?中国当即撤走仪仗队,表明态度:想好了再来

首次访华就给下马威?中国当即撤走仪仗队,表明态度:想好了再来

老媹古装影视解说
2025-12-24 18:15:45
医生调查发现:每天喝豆浆的糖尿病患者,不久后,身体或有6变化

医生调查发现:每天喝豆浆的糖尿病患者,不久后,身体或有6变化

读懂世界历史
2025-12-07 09:30:46
一夜3大震动!NBA局势风云突变

一夜3大震动!NBA局势风云突变

魔血獄苼
2026-01-17 01:03:26
《风筝》隐藏伏笔终揭晓:精明韩冰不知,识破她的人不是郑马二人

《风筝》隐藏伏笔终揭晓:精明韩冰不知,识破她的人不是郑马二人

磊子讲史
2026-01-15 09:36:26
国乒小将惊天逆转!0-2连扳3局胜日本世界冠军,国乒男单6人晋级

国乒小将惊天逆转!0-2连扳3局胜日本世界冠军,国乒男单6人晋级

全言作品
2026-01-17 00:53:04
3:0速胜!陈俊菘化身张本克星,多哈再燃复仇烽火

3:0速胜!陈俊菘化身张本克星,多哈再燃复仇烽火

阿晞体育
2026-01-16 13:26:11
故事:老知青病重托儿子去甘肃找初恋见最后一面,儿子见到后愣了

故事:老知青病重托儿子去甘肃找初恋见最后一面,儿子见到后愣了

红豆讲堂
2025-03-02 13:30:05
雷军直播官宣小米YU7新政策:首付4.99万起,全系送1.8万元零重力座椅

雷军直播官宣小米YU7新政策:首付4.99万起,全系送1.8万元零重力座椅

TechWeb
2026-01-16 10:53:03
一滴也别喝!科学家调查发现,2020年全球有74万人因饮酒患癌

一滴也别喝!科学家调查发现,2020年全球有74万人因饮酒患癌

蜉蝣说
2026-01-15 10:50:10
逆天运气!网友134元买到PS5 Pro 阿里跨境电商薅羊毛

逆天运气!网友134元买到PS5 Pro 阿里跨境电商薅羊毛

游民星空
2026-01-15 23:18:10
闫学晶风波再升级!一人已被行拘,婚变丑闻被扒,这回恐彻底栽了

闫学晶风波再升级!一人已被行拘,婚变丑闻被扒,这回恐彻底栽了

奇思妙想草叶君
2026-01-14 22:49:54
2026-01-17 04:00:49
深兰深延AI
深兰深延AI
让AI赋能更简单!
71文章数 0关注度
往期回顾 全部

科技要闻

贾国龙与罗永浩被禁言,微博CEO回应

头条要闻

罗永浩、贾国龙微博账号均被禁言

头条要闻

罗永浩、贾国龙微博账号均被禁言

体育要闻

全队身价=登贝莱,他们凭什么领跑法甲?

娱乐要闻

李湘翻车,早就有迹可循!

财经要闻

清流|酒店商家在携程和美团之间沦为炮灰

汽车要闻

方程豹品牌销量突破30万辆 2026年还将推出轿跑系列

态度原创

教育
亲子
家居
旅游
公开课

教育要闻

孩子千万不要早恋

亲子要闻

精神科医生:家长的“为你好”也可能对孩子造成创伤

家居要闻

岁月柔情 现代品质轻奢

旅游要闻

北境飘雪南园绽,青秀山用一场菊花展,藏住南宁冬日顶流密码!

公开课

李玫瑾:为什么性格比能力更重要?

无障碍浏览 进入关怀版