ldmae / VMAE /models_mae.py
isno0907's picture
Upload 115 files
6c49103 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.vision_transformer import PatchEmbed, Mlp, DropPath
from typing import Dict, Optional, Tuple, Union
import os, sys
sys.path.append("/data/projects/jeongwoo/ldmae_backup/mae")
from util.pos_embed import get_2d_sincos_pos_embed
############################## Docking Functions ##############################
from dataclasses import dataclass
from typing import Optional, Tuple
from diffusers.utils import BaseOutput
from einops import rearrange
from diffusers import ConfigMixin, ModelMixin
from util.misc import DiagonalGaussianDistribution
from PIL import Image
from torchvision import transforms
import numpy as np
from taming.modules.losses.lpips import LPIPS
class Config:
def __init__(self, scaling_factor):
self.scaling_factor = scaling_factor
@dataclass
class DecoderOutput(BaseOutput):
r"""
Output of decoding method.
Args:
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model.
"""
sample: torch.Tensor
commit_loss: Optional[torch.FloatTensor] = None
@dataclass
class EncoderOutput(BaseOutput):
r"""
Output of decoding method.
Args:
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model.
"""
latent: torch.Tensor
def sample(self):
return self.latent
def mode(self):
return self.latent
@dataclass
class MAEOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: Union[DiagonalGaussianDistribution, EncoderOutput] # noqa: F821
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
####################################################################################
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, return_attn_map = False):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if return_attn_map:
qk_attn = attn.clone().detach()
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x_ctxed = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x_ctxed)
x = self.proj_drop(x)
if return_attn_map:
return x, [qk_attn, x_ctxed]
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x, return_attn_map = False):
if return_attn_map:
x_tmp, qk_and_x = self.attn(self.norm1(x), return_attn_map = True)
# returned attn is Q @ K / scale, before softmax
else:
x_tmp = self.attn(self.norm1(x))
x = x + self.drop_path1(self.ls1(x_tmp))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
if return_attn_map:
return x, qk_and_x
return x
class Downsample(nn.Module): # from downblock in AutoencoderKL
def __init__(self, ch_in, ch_out):
super().__init__()
self.conv = nn.Conv2d(ch_in, ch_out, 3, stride=2)
def forward(self, x): # x : B N C
B, N, C = x.shape
H = int(N ** 0.5)
assert H * H == N, 'Size mismatch.'
x = x.reshape(B, H, H, C).permute(0,3,1,2) # B C H W
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x) # B C H W
x = x.reshape(B, C, -1).permute(0,2,1) # B N C
return x
class Upsample(nn.Module): # from upblock in AutoencoderKL
def __init__(self, ch_in, ch_out):
super().__init__()
self.conv = nn.Conv2d(ch_in, ch_out, 3, 1, 1)
def forward(self, x): # x : B N C
B, N, C = x.shape
H = int(N ** 0.5)
assert H * H == N, 'Size mismatch.'
x = x.reshape(B, H, H, C).permute(0,3,1,2) # B C H W
if x.shape[0] >= 64:
x = x.contiguous()
scale_factor = 2
if x.numel() * scale_factor > pow(2, 31):
x = x.contiguous()
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
x = x.reshape(B, C, -1).permute(0,2,1) # B N C
return x
class MLP_dim_resize(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(MLP_dim_resize, self).__init__()
self.layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.layers(x)
class conv_decoder_pred(nn.Module):
def __init__(self, decoder_embed_dim, patch_size, in_chans, pred_with_conv=True):
super(conv_decoder_pred, self).__init__()
self.p = patch_size
self.pred_with_conv = pred_with_conv
if self.pred_with_conv:
print('pred only with conv instead of previous linear')
self.conv_smoother = nn.Conv2d(decoder_embed_dim, patch_size**2 * in_chans, 2, stride=1, padding=0)
else:
print('conv on rgb')
self.linear_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True)
self.conv_smoother = nn.Conv2d(in_chans, in_chans, 3, 1, 1)
def forward(self, x):
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
if self.pred_with_conv:
B = x.shape[0]
x = x.reshape(B, h, w, -1).permute(0,3,1,2)
padding = (0, 1, 0, 1) # Pad 1 on the right (W) and 1 on the bottom (H)
# Apply padding
x = F.pad(x, padding, mode='constant', value=0)
x = self.conv_smoother(x) # B C H W
x = x.reshape(B, -1, h*w).permute(0,2,1) # B HW C
else:
x = self.linear_pred(x) # B HW p_size*p_size*3
x = x.reshape(shape=(x.shape[0], h, w, self.p, self.p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
x = x.reshape(shape=(x.shape[0], 3, h * self.p, w * self.p)) # B 3 256 256
x = self.conv_smoother(x)
x = x.reshape(x.shape[0], 3, h, self.p, w, self.p)
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(x.shape[0], h*w, self.p*self.p*3)) # B HW C
return x
class MaskedAutoencoderViT(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
latent_dim=32, ldmae_mode=False, scaling_factor=0.9654248952865601, no_cls=True,
gradual_resol=False, finetune_downsample_layer=None, down_nonlinear=False,
kl_loss_weight=None, smooth_output=False, pred_with_conv=False, perceptual_loss=None, perceptual_loss_ratio=1.0,
fixed_std=None):
super().__init__()
# --------------------------------------------------------------------------
# MAE for LDMAE settings
self.fixed_std = fixed_std
self.perceptual_loss = perceptual_loss
self.perceptual_loss_ratio = perceptual_loss_ratio
self.smooth_output = smooth_output
self.gradual_resol = gradual_resol
self.kl_loss_weight = kl_loss_weight
encoder_latent_dim = latent_dim
decoder_latent_dim = latent_dim
if self.kl_loss_weight is not None:
assert no_cls, 'There should be no class token to use KL loss.'
encoder_latent_dim = 2 * latent_dim
print(f'Use KL loss, encoder latent dim is {encoder_latent_dim} to predict mean & logvar')
if self.gradual_resol:
patch_size = patch_size // 2
print(f'patch size: {patch_size}')
if down_nonlinear:
print('Use MLP for latent embedding')
self.to_latent = MLP_dim_resize(embed_dim, latent_dim*4, encoder_latent_dim)
self.from_latent = MLP_dim_resize(decoder_latent_dim, latent_dim*4, embed_dim)
else:
self.to_latent = nn.Linear(embed_dim, encoder_latent_dim)
self.from_latent = nn.Linear(decoder_latent_dim, embed_dim)
self.config = Config(scaling_factor=scaling_factor)
self.ldmae_mode = ldmae_mode
self.img_size = img_size
self.patch_size = patch_size
self.latent_resolution = img_size // patch_size
self.tile_latent_min_size = self.latent_resolution
self.latent_dim = latent_dim
self.no_cls = no_cls
self.num_extra_tokens = 0
# if head_output_num is not None:
# self.head = torch.nn.ModuleList([
# torch.nn.LayerNorm(self.latent_dim),
# # torch.nn.BatchNorm1d(self.latent_dim, affine=False, eps=1e-6),
# torch.nn.Linear(in_features=self.latent_dim, out_features=head_output_num, bias=False)
# ])
# self.head_norm = torch.nn.LayerNorm(self.latent_dim)
# self.head_bn = torch.nn.BatchNorm1d(self.latent_dim, affine=False, eps=1e-6)
# self.head = torch.nn.Linear(in_features=self.latent_dim, out_features=head_output_num)
if not self.no_cls:
self.num_extra_tokens += 1
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
# MAE encoder specifics
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
if not self.no_cls:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_extra_tokens, embed_dim), requires_grad=False) # fixed sin-cos embedding
if self.gradual_resol:
blocks = []
downsize_time = depth // 2 if finetune_downsample_layer is None else finetune_downsample_layer
for i in range(depth):
blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer))
if i == downsize_time-1:
print(f"Add downsizing block in {i}th layer in encoder.")
blocks.append(Downsample(embed_dim, embed_dim))
self.blocks = nn.ModuleList(blocks)
else:
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
if not self.ldmae_mode:
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
if self.gradual_resol:
decoder_num_patches = num_patches // 4
assert decoder_num_patches * 4 == num_patches
else:
decoder_num_patches = num_patches
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, decoder_num_patches + self.num_extra_tokens, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
if self.gradual_resol:
decoder_blocks = []
upsize_time = decoder_depth - downsize_time
for i in range(decoder_depth):
decoder_blocks.append(Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer))
if i == upsize_time-1:
print(f"Add upsizing block in {i}th layer in decoder.")
decoder_blocks.append(Upsample(decoder_embed_dim, decoder_embed_dim))
self.decoder_blocks = nn.ModuleList(decoder_blocks)
else:
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(decoder_depth)])
self.decoder_norm = norm_layer(decoder_embed_dim)
if smooth_output:
assert no_cls, 'Should be no CLS token for smooth_output.'
print('Use conv in decoder pred.')
self.decoder_pred = conv_decoder_pred(decoder_embed_dim, patch_size, in_chans, pred_with_conv)
else:
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
# --------------------------------------------------------------------------
self.norm_pix_loss = norm_pix_loss
self.initialize_weights()
def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token= not self.no_cls)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
if self.gradual_resol:
decoder_num_patches = self.patch_embed.num_patches // 4
else:
decoder_num_patches = self.patch_embed.num_patches
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(decoder_num_patches**.5), cls_token= not self.no_cls)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
if not self.no_cls:
torch.nn.init.normal_(self.cls_token, std=.02)
if not self.ldmae_mode:
torch.nn.init.normal_(self.mask_token, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward_encoder(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
if self.no_cls:
x = x + self.pos_embed
else:
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
if not self.no_cls:
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
if not self.no_cls:
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
else:
x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
x = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
# remove cls token
if not self.no_cls:
x = x[:, 1:, :]
return x
def forward_encoder_with_mask(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
# if self.no_cls:
# x = x + self.pos_embed
# else:
# x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
# if not self.no_cls:
# x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
# x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# else:
x_ = torch.cat([x, mask_tokens], dim=1)
x = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# append cls token
if self.no_cls:
x = x + self.pos_embed
else:
x = x + self.pos_embed[:, 1:, :]
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
def forward_decoder_without_mask(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
# mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
# if not self.no_cls:
# x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
# x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# else:
# x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
# x = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
# remove cls token
if not self.no_cls:
x = x[:, 1:, :]
return x
# self.no_cls에 따라 다르게 encoding & decoding
# decoding은 이후에 unpatchfy 넣어주어야함
def ldmae_encoding(self, imgs, use_mode=False, return_kl=False):
# encoding
x = self.patch_embed(imgs)
if not self.no_cls:
x = x + self.pos_embed[:, 1:, :]
else:
x = x + self.pos_embed
if not self.no_cls:
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
# to latent
latent = self.to_latent(x)
if self.kl_loss_weight is not None:
latent = latent.permute(0,2,1) # B D HW
posterior = DiagonalGaussianDistribution(latent) # B D HW
if use_mode:
print('mode')
latent = posterior.mode() # B D HW
else:
latent = posterior.sample() # B D HW
latent = latent.permute(0,2,1)
if return_kl:
return latent, posterior.kl()
return latent
def ldmae_decoding(self, x):
# if self.smooth_output:
# with torch.no_grad():
# # from latent
# x = self.from_latent(x)
# x = self.decoder_embed(x)
# # if x.shape[1] != self.decoder_pos_embed.shape[1]:
# if not self.no_cls:
# decoder_pos_embed = self.decoder_pos_embed[:,1:,:]
# else:
# decoder_pos_embed = self.decoder_pos_embed
# x = x + decoder_pos_embed
# for blk in self.decoder_blocks:
# x = blk(x)
# x = self.decoder_norm(x)
# else:
# from latent
x = self.from_latent(x)
x = self.decoder_embed(x)
# if x.shape[1] != self.decoder_pos_embed.shape[1]:
if not self.no_cls:
decoder_pos_embed = self.decoder_pos_embed[:,1:,:]
else:
decoder_pos_embed = self.decoder_pos_embed
x = x + decoder_pos_embed
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
x = self.decoder_pred(x)
if not self.no_cls:
x = x[:, 1:, :]
return x
def reconstruct(self, imgs, use_mode=True, return_kl=False, mask_ratio=0.75):
# return it back
if mask_ratio == 0.0:
x = self.ldmae_encoding(imgs, use_mode=True, return_kl=False)
else:
with torch.no_grad():
if return_kl:
x, kl_val = self.ldmae_encoding(imgs, use_mode=use_mode, return_kl=return_kl)
else:
x = self.ldmae_encoding(imgs, use_mode=use_mode, return_kl=return_kl)
x = self.ldmae_decoding(x)
if return_kl:
return x, kl_val
return x
def linear_probe_seg(self, images):
with torch.no_grad():
x = self.ldmae_encoding(images)
if not self.no_cls:
x = x[:, 1:, :] # B HW D
B, N, D = x.shape
x = x.reshape(-1, D) # BHW D
for layer in self.head:
x = layer(x)
# x = self.head(x)
return x # BHW num_classes
def linear_probe(self, images):
with torch.no_grad():
x = self.ldmae_encoding(images)
# x = self.head_norm(x)
if self.no_cls:
x = x.mean(dim=1) # global pool, [B D]
else:
x = x[:, 1:, :].mean(dim=1) # global pool, [B D]
# x = self.head_bn(x)
for layer in self.head:
x = layer(x)
return x
def forward_loss(self, imgs, pred, mask, visible_loss_ratio=0.5):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
target = self.patchify(imgs)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [B, L], mean loss per patch
visible_loss = (loss * (1-mask)).sum() / (1-mask).sum()
mask_loss = (loss * mask).sum() / mask.sum() # now per pixel, scalar
loss = (1-visible_loss_ratio) * mask_loss + visible_loss_ratio * visible_loss
if self.perceptual_loss is not None:
p_loss = self.perceptual_loss(
imgs.contiguous(),
self.unpatchify(pred).contiguous()
) # B 1 1 1
p_loss = p_loss.mean()
loss = loss + self.perceptual_loss_ratio * p_loss
else:
p_loss = torch.zeros_like(loss)
# loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss, visible_loss, mask_loss, p_loss
def forward_vanilla(self, imgs, mask_ratio=0.75, visible_loss_ratio=0.5):
if self.gradual_resol:
latent, mask, ids_restore = self.forward_encoder_with_mask(imgs, mask_ratio)
else:
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
# --------------------------------------------------------------------------
# to latent
latent = self.to_latent(latent)
if self.kl_loss_weight is not None:
B, N, D = latent.shape
latent = latent.permute(0,2,1) # B D HW
posterior = DiagonalGaussianDistribution(latent, fixed_std=self.fixed_std) # B D HW
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] / N # per patch
latent = posterior.sample() # B D HW
latent = latent.permute(0,2,1)
latent = self.from_latent(latent)
# --------------------------------------------------------------------------
if self.gradual_resol:
pred = self.forward_decoder_without_mask(latent, ids_restore) # [N, L, p*p*3]
else:
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
loss, vis_loss, mask_loss, p_loss = self.forward_loss(imgs, pred, mask, visible_loss_ratio)
if self.kl_loss_weight is not None:
loss = loss + self.kl_loss_weight * kl_loss
else:
kl_loss = None
return loss, pred, mask, vis_loss, mask_loss, kl_loss, p_loss
def forward_ldmae(self, imgs, mask_ratio=0.75):
pred = self.reconstruct(imgs, use_mode=False, mask_ratio=mask_ratio)
# target = self.patchify(imgs)
vis_loss = (self.unpatchify(pred) - imgs) ** 2
if self.perceptual_loss is not None:
p_loss = self.perceptual_loss(
imgs.contiguous(),
self.unpatchify(pred).contiguous()
)
loss = vis_loss + self.perceptual_loss_ratio * p_loss
else:
loss = vis_loss
p_loss = torch.zeros_like(vis_loss)
loss = loss.mean()
return loss, pred, None, vis_loss.mean(), p_loss.mean(), None
def forward(self, imgs, mask_ratio=0.75, visible_loss_ratio=0.5):
if self.ldmae_mode:
return self.forward_ldmae(imgs, mask_ratio=mask_ratio)
else:
return self.forward_vanilla(imgs, mask_ratio=mask_ratio, visible_loss_ratio=visible_loss_ratio)
############################ Docking Functions ####################################
def _encode(self, x):
# encoding
x = self.patch_embed(x)
if not self.no_cls:
x = x + self.pos_embed[:, 1:, :]
else:
x = x + self.pos_embed
if not self.no_cls:
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
# to latent
x = self.to_latent(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=self.latent_resolution, w=self.latent_resolution)
return x
def encode_gflops(self, x, return_dict=True):
# encoding
x = self.patch_embed(x)
if not self.no_cls:
x = x + self.pos_embed[:, 1:, :]
else:
x = x + self.pos_embed
if not self.no_cls:
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
# to latent
x = self.to_latent(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=self.latent_resolution, w=self.latent_resolution)
return x
def encode(self, x, return_dict=True):
# encoding
x = self.patch_embed(x)
if not self.no_cls:
x = x + self.pos_embed[:, 1:, :]
else:
x = x + self.pos_embed
if not self.no_cls:
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
# to latent
x = self.to_latent(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=self.latent_resolution, w=self.latent_resolution)
if self.kl_loss_weight is not None:
p = DiagonalGaussianDistribution(x) # B c h w
else:
p = EncoderOutput(x)
if not return_dict:
return (p,)
return MAEOutput(latent_dist=p)
def decode(self, z, return_dict=True, generator=None):
# from latent
z = rearrange(z, 'b c h w -> b (h w) c')
x = self.from_latent(z)
x = self.decoder_embed(x)
# if x.shape[1] != self.decoder_pos_embed.shape[1]:
if not self.no_cls:
decoder_pos_embed = self.decoder_pos_embed[:,1:,:]
else:
decoder_pos_embed = self.decoder_pos_embed
x = x + decoder_pos_embed
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
x = self.decoder_pred(x)
img = self.unpatchify(x)
if not return_dict:
return (img,)
return DecoderOutput(sample=img)
@property
def device(self) -> torch.device:
"""
Returns:
torch.device: The torch device on which the model's parameters are located.
"""
# Check if the model has parameters that specify a device
for param in self.parameters():
return param.device
# Check if the model has buffers (like self.pos_embed or self.cls_token)
for buffer in self.buffers():
return buffer.device
# Fallback to CPU if no parameters or buffers are found
return torch.device("cpu")
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
Accelerate's module hooks.
"""
return self.device
@property
def dtype(self) -> torch.dtype:
"""
Returns:
torch.dtype: The torch data type on which the model's parameters are stored.
"""
# Check if the model has any parameters to infer dtype
for param in self.parameters():
return param.dtype
# Check if the model has buffers to infer dtype
for buffer in self.buffers():
return buffer.dtype
# Default to float32 if no parameters or buffers are found
return torch.float32
#################################################################################################
####################################### VAVAE Docking Functions #################################
def img_transform(self, p_hflip=0, img_size=None):
"""Image preprocessing transforms
Args:
p_hflip: Probability of horizontal flip
img_size: Target image size, use default if None
Returns:
transforms.Compose: Image transform pipeline
"""
img_size = img_size if img_size is not None else self.img_size
img_transforms = [
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, img_size)),
transforms.RandomHorizontalFlip(p=p_hflip),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
]
return transforms.Compose(img_transforms)
def encode_images(self, images):
"""Encode images to latent representations
Args:
images: Input image tensor
Returns:
torch.Tensor: Encoded latent representation
"""
with torch.no_grad():
posterior = self.encode(images.cuda(), return_dict=False)[0]
return posterior.sample()
def decode_to_images(self, z):
"""Decode latent representations to images
Args:
z: Latent representation tensor
Returns:
np.ndarray: Decoded image array
"""
with torch.no_grad():
images = self.decode(z.cuda(), return_dict=False)[0]
images = torch.clamp(127.5 * images + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
return images
############################################################################################
def mae_for_ldmae(**kwargs):
model = MaskedAutoencoderViT(
img_size=128, patch_size=8, embed_dim=192, depth=12, num_heads=12,
decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=32, **kwargs)
return model
# 위랑 같은건데 그냥 복붙해서 이름만 바꿈
def mae_for_ldmae_f8d32(**kwargs):
model = MaskedAutoencoderViT(
img_size=128, patch_size=8, embed_dim=192, depth=12, num_heads=12,
decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=32, **kwargs)
return model
def mae_for_ldmae_f8d16_prev(**kwargs):
model = MaskedAutoencoderViT(
patch_size=8, embed_dim=192, depth=12, num_heads=12,
decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, **kwargs)
return model
def mae_for_ldmae_f8d16_small(**kwargs):
model = MaskedAutoencoderViT(
patch_size=8, embed_dim=96, depth=12, num_heads=8,
decoder_embed_dim=96, decoder_depth=12, decoder_num_heads=8,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, **kwargs)
return model
def mae_for_ldmae_f8d16_asym_small(**kwargs):
model = MaskedAutoencoderViT(
patch_size=8, embed_dim=96, depth=12, num_heads=8,
decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, **kwargs)
return model
def mae_for_ldmae_f8d16_prev_large(**kwargs):
model = MaskedAutoencoderViT(
patch_size=8, embed_dim=384, depth=12, num_heads=16,
decoder_embed_dim=384, decoder_depth=12, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, **kwargs)
return model
def mae_for_ldmae_f8d16(**kwargs):
model = MaskedAutoencoderViT(
patch_size=8, embed_dim=192, depth=12, num_heads=12,
decoder_embed_dim=384, decoder_depth=12, decoder_num_heads=24,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, down_nonlinear=True,**kwargs)
return model
def mae_for_ldmae_f8d16_flexible(**kwargs):
model = MaskedAutoencoderViT(
patch_size=8, embed_dim=192, depth=12, num_heads=12,
decoder_embed_dim=384, decoder_depth=12, decoder_num_heads=24,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, down_nonlinear=True,**kwargs)
return model
def mae_for_ldmae_f16d32(**kwargs):
model = MaskedAutoencoderViT(
img_size=128, patch_size=16, embed_dim=192, depth=12, num_heads=12,
decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=32, **kwargs)
return model
def mae_for_ldmae_f16d32_large(**kwargs):
model = MaskedAutoencoderViT(
img_size=128, patch_size=16, embed_dim=384, depth=12, num_heads=12,
decoder_embed_dim=384, decoder_depth=12, decoder_num_heads=12,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=32, finetune_downsample_layer=4, **kwargs)
return model
def mae_for_ldmae_f8d32_flexible(**kwargs):
model = MaskedAutoencoderViT(
patch_size=8, embed_dim=192, depth=12, num_heads=12,
decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=32, **kwargs)
return model
def mae_for_ldmae_16d(**kwargs):
model = MaskedAutoencoderViT(
img_size=128, patch_size=8, embed_dim=192, depth=12, num_heads=12,
decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, **kwargs)
return model
def mae_vit_base_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mae_vit_base_patch16_dec128d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=128, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mae_vit_large_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def mae_vit_huge_patch14_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=14, embed_dim=1280, depth=32, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_base_patch16_128 = mae_vit_base_patch16_dec128d8b