| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from functools import partial |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from timm.models.vision_transformer import PatchEmbed, Mlp, DropPath |
| | from typing import Dict, Optional, Tuple, Union |
| | import os, sys |
| |
|
| | sys.path.append("/data/projects/jeongwoo/ldmae_backup/mae") |
| | from util.pos_embed import get_2d_sincos_pos_embed |
| |
|
| | |
| | from dataclasses import dataclass |
| | from typing import Optional, Tuple |
| | from diffusers.utils import BaseOutput |
| | from einops import rearrange |
| | from diffusers import ConfigMixin, ModelMixin |
| | from util.misc import DiagonalGaussianDistribution |
| | from PIL import Image |
| | from torchvision import transforms |
| | import numpy as np |
| | from taming.modules.losses.lpips import LPIPS |
| |
|
| |
|
| | class Config: |
| | def __init__(self, scaling_factor): |
| | self.scaling_factor = scaling_factor |
| |
|
| | @dataclass |
| | class DecoderOutput(BaseOutput): |
| | r""" |
| | Output of decoding method. |
| | |
| | Args: |
| | sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): |
| | The decoded output sample from the last layer of the model. |
| | """ |
| |
|
| | sample: torch.Tensor |
| | commit_loss: Optional[torch.FloatTensor] = None |
| |
|
| | @dataclass |
| | class EncoderOutput(BaseOutput): |
| | r""" |
| | Output of decoding method. |
| | |
| | Args: |
| | sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): |
| | The decoded output sample from the last layer of the model. |
| | """ |
| | |
| | latent: torch.Tensor |
| |
|
| | def sample(self): |
| | return self.latent |
| | def mode(self): |
| | return self.latent |
| |
|
| | @dataclass |
| | class MAEOutput(BaseOutput): |
| | """ |
| | Output of AutoencoderKL encoding method. |
| | |
| | Args: |
| | latent_dist (`DiagonalGaussianDistribution`): |
| | Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. |
| | `DiagonalGaussianDistribution` allows for sampling latents from the distribution. |
| | """ |
| |
|
| | latent_dist: Union[DiagonalGaussianDistribution, EncoderOutput] |
| |
|
| | def center_crop_arr(pil_image, image_size): |
| | """ |
| | Center cropping implementation from ADM. |
| | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 |
| | """ |
| | while min(*pil_image.size) >= 2 * image_size: |
| | pil_image = pil_image.resize( |
| | tuple(x // 2 for x in pil_image.size), resample=Image.BOX |
| | ) |
| |
|
| | scale = image_size / min(*pil_image.size) |
| | pil_image = pil_image.resize( |
| | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC |
| | ) |
| |
|
| | arr = np.array(pil_image) |
| | crop_y = (arr.shape[0] - image_size) // 2 |
| | crop_x = (arr.shape[1] - image_size) // 2 |
| | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) |
| |
|
| | |
| |
|
| |
|
| | class LayerScale(nn.Module): |
| | def __init__(self, dim, init_values=1e-5, inplace=False): |
| | super().__init__() |
| | self.inplace = inplace |
| | self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
| |
|
| | def forward(self, x): |
| | return x.mul_(self.gamma) if self.inplace else x * self.gamma |
| | |
| | class Attention(nn.Module): |
| | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): |
| | super().__init__() |
| | assert dim % num_heads == 0, 'dim should be divisible by num_heads' |
| | self.num_heads = num_heads |
| | head_dim = dim // num_heads |
| | self.scale = head_dim ** -0.5 |
| |
|
| | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| | self.attn_drop = nn.Dropout(attn_drop) |
| | self.proj = nn.Linear(dim, dim) |
| | self.proj_drop = nn.Dropout(proj_drop) |
| |
|
| | def forward(self, x, return_attn_map = False): |
| | B, N, C = x.shape |
| | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv.unbind(0) |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| | if return_attn_map: |
| | qk_attn = attn.clone().detach() |
| | attn = attn.softmax(dim=-1) |
| | attn = self.attn_drop(attn) |
| |
|
| | x_ctxed = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | x = self.proj(x_ctxed) |
| | x = self.proj_drop(x) |
| | |
| | if return_attn_map: |
| | return x, [qk_attn, x_ctxed] |
| | return x |
| | |
| | class Block(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | dim, |
| | num_heads, |
| | mlp_ratio=4., |
| | qkv_bias=False, |
| | drop=0., |
| | attn_drop=0., |
| | init_values=None, |
| | drop_path=0., |
| | act_layer=nn.GELU, |
| | norm_layer=nn.LayerNorm |
| | ): |
| | super().__init__() |
| | self.norm1 = norm_layer(dim) |
| | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) |
| | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
| | |
| | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| |
|
| | self.norm2 = norm_layer(dim) |
| | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) |
| | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
| | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| |
|
| | def forward(self, x, return_attn_map = False): |
| | if return_attn_map: |
| | x_tmp, qk_and_x = self.attn(self.norm1(x), return_attn_map = True) |
| | |
| | else: |
| | x_tmp = self.attn(self.norm1(x)) |
| | x = x + self.drop_path1(self.ls1(x_tmp)) |
| | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) |
| | |
| | if return_attn_map: |
| | return x, qk_and_x |
| | return x |
| |
|
| | class Downsample(nn.Module): |
| | def __init__(self, ch_in, ch_out): |
| | super().__init__() |
| | self.conv = nn.Conv2d(ch_in, ch_out, 3, stride=2) |
| | |
| | def forward(self, x): |
| | B, N, C = x.shape |
| | H = int(N ** 0.5) |
| | assert H * H == N, 'Size mismatch.' |
| | x = x.reshape(B, H, H, C).permute(0,3,1,2) |
| | |
| | pad = (0, 1, 0, 1) |
| | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
| | x = self.conv(x) |
| | |
| | x = x.reshape(B, C, -1).permute(0,2,1) |
| | return x |
| | |
| | class Upsample(nn.Module): |
| | def __init__(self, ch_in, ch_out): |
| | super().__init__() |
| | self.conv = nn.Conv2d(ch_in, ch_out, 3, 1, 1) |
| | |
| | def forward(self, x): |
| | B, N, C = x.shape |
| | H = int(N ** 0.5) |
| | assert H * H == N, 'Size mismatch.' |
| | x = x.reshape(B, H, H, C).permute(0,3,1,2) |
| | |
| | if x.shape[0] >= 64: |
| | x = x.contiguous() |
| | |
| | scale_factor = 2 |
| | if x.numel() * scale_factor > pow(2, 31): |
| | x = x.contiguous() |
| |
|
| | x = F.interpolate(x, scale_factor=2.0, mode="nearest") |
| | |
| | x = self.conv(x) |
| | |
| | x = x.reshape(B, C, -1).permute(0,2,1) |
| | return x |
| | |
| | class MLP_dim_resize(nn.Module): |
| | def __init__(self, input_dim, hidden_dim, output_dim): |
| | super(MLP_dim_resize, self).__init__() |
| | self.layers = nn.Sequential( |
| | nn.Linear(input_dim, hidden_dim), |
| | nn.GELU(), |
| | nn.Linear(hidden_dim, output_dim) |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.layers(x) |
| | |
| | class conv_decoder_pred(nn.Module): |
| | def __init__(self, decoder_embed_dim, patch_size, in_chans, pred_with_conv=True): |
| | super(conv_decoder_pred, self).__init__() |
| | self.p = patch_size |
| | self.pred_with_conv = pred_with_conv |
| | if self.pred_with_conv: |
| | print('pred only with conv instead of previous linear') |
| | self.conv_smoother = nn.Conv2d(decoder_embed_dim, patch_size**2 * in_chans, 2, stride=1, padding=0) |
| | else: |
| | print('conv on rgb') |
| | self.linear_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) |
| | self.conv_smoother = nn.Conv2d(in_chans, in_chans, 3, 1, 1) |
| | |
| | def forward(self, x): |
| | h = w = int(x.shape[1]**.5) |
| | assert h * w == x.shape[1] |
| | |
| | if self.pred_with_conv: |
| | B = x.shape[0] |
| | x = x.reshape(B, h, w, -1).permute(0,3,1,2) |
| | padding = (0, 1, 0, 1) |
| | |
| | x = F.pad(x, padding, mode='constant', value=0) |
| | x = self.conv_smoother(x) |
| | x = x.reshape(B, -1, h*w).permute(0,2,1) |
| | |
| | else: |
| | x = self.linear_pred(x) |
| | x = x.reshape(shape=(x.shape[0], h, w, self.p, self.p, 3)) |
| | x = torch.einsum('nhwpqc->nchpwq', x) |
| | x = x.reshape(shape=(x.shape[0], 3, h * self.p, w * self.p)) |
| | |
| | x = self.conv_smoother(x) |
| | x = x.reshape(x.shape[0], 3, h, self.p, w, self.p) |
| | x = torch.einsum('nchpwq->nhwpqc', x) |
| | x = x.reshape(shape=(x.shape[0], h*w, self.p*self.p*3)) |
| | |
| | return x |
| |
|
| | class MaskedAutoencoderViT(nn.Module): |
| | """ Masked Autoencoder with VisionTransformer backbone |
| | """ |
| | def __init__(self, img_size=224, patch_size=16, in_chans=3, |
| | embed_dim=1024, depth=24, num_heads=16, |
| | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
| | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, |
| | latent_dim=32, ldmae_mode=False, scaling_factor=0.9654248952865601, no_cls=True, |
| | gradual_resol=False, finetune_downsample_layer=None, down_nonlinear=False, |
| | kl_loss_weight=None, smooth_output=False, pred_with_conv=False, perceptual_loss=None, perceptual_loss_ratio=1.0, |
| | fixed_std=None): |
| | super().__init__() |
| | |
| | |
| | |
| | self.fixed_std = fixed_std |
| | self.perceptual_loss = perceptual_loss |
| | self.perceptual_loss_ratio = perceptual_loss_ratio |
| | self.smooth_output = smooth_output |
| | self.gradual_resol = gradual_resol |
| | self.kl_loss_weight = kl_loss_weight |
| | encoder_latent_dim = latent_dim |
| | decoder_latent_dim = latent_dim |
| | if self.kl_loss_weight is not None: |
| | assert no_cls, 'There should be no class token to use KL loss.' |
| | encoder_latent_dim = 2 * latent_dim |
| | print(f'Use KL loss, encoder latent dim is {encoder_latent_dim} to predict mean & logvar') |
| | if self.gradual_resol: |
| | patch_size = patch_size // 2 |
| | print(f'patch size: {patch_size}') |
| | |
| | if down_nonlinear: |
| | print('Use MLP for latent embedding') |
| | self.to_latent = MLP_dim_resize(embed_dim, latent_dim*4, encoder_latent_dim) |
| | self.from_latent = MLP_dim_resize(decoder_latent_dim, latent_dim*4, embed_dim) |
| | else: |
| | self.to_latent = nn.Linear(embed_dim, encoder_latent_dim) |
| | self.from_latent = nn.Linear(decoder_latent_dim, embed_dim) |
| | self.config = Config(scaling_factor=scaling_factor) |
| | self.ldmae_mode = ldmae_mode |
| | self.img_size = img_size |
| | self.patch_size = patch_size |
| | self.latent_resolution = img_size // patch_size |
| | self.tile_latent_min_size = self.latent_resolution |
| | self.latent_dim = latent_dim |
| | self.no_cls = no_cls |
| | self.num_extra_tokens = 0 |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if not self.no_cls: |
| | self.num_extra_tokens += 1 |
| | |
| | |
| | |
| | |
| | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) |
| | num_patches = self.patch_embed.num_patches |
| |
|
| | if not self.no_cls: |
| | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_extra_tokens, embed_dim), requires_grad=False) |
| | |
| | if self.gradual_resol: |
| | blocks = [] |
| | downsize_time = depth // 2 if finetune_downsample_layer is None else finetune_downsample_layer |
| | for i in range(depth): |
| | blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)) |
| | if i == downsize_time-1: |
| | print(f"Add downsizing block in {i}th layer in encoder.") |
| | blocks.append(Downsample(embed_dim, embed_dim)) |
| | self.blocks = nn.ModuleList(blocks) |
| | else: |
| | self.blocks = nn.ModuleList([ |
| | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
| | for i in range(depth)]) |
| | |
| | self.norm = norm_layer(embed_dim) |
| | |
| |
|
| | |
| | |
| | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) |
| |
|
| | if not self.ldmae_mode: |
| | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
| |
|
| | if self.gradual_resol: |
| | decoder_num_patches = num_patches // 4 |
| | assert decoder_num_patches * 4 == num_patches |
| | else: |
| | decoder_num_patches = num_patches |
| | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, decoder_num_patches + self.num_extra_tokens, decoder_embed_dim), requires_grad=False) |
| |
|
| | if self.gradual_resol: |
| | decoder_blocks = [] |
| | upsize_time = decoder_depth - downsize_time |
| | for i in range(decoder_depth): |
| | decoder_blocks.append(Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)) |
| | if i == upsize_time-1: |
| | print(f"Add upsizing block in {i}th layer in decoder.") |
| | decoder_blocks.append(Upsample(decoder_embed_dim, decoder_embed_dim)) |
| | self.decoder_blocks = nn.ModuleList(decoder_blocks) |
| | else: |
| | self.decoder_blocks = nn.ModuleList([ |
| | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
| | for i in range(decoder_depth)]) |
| | |
| | self.decoder_norm = norm_layer(decoder_embed_dim) |
| | if smooth_output: |
| | assert no_cls, 'Should be no CLS token for smooth_output.' |
| | print('Use conv in decoder pred.') |
| | self.decoder_pred = conv_decoder_pred(decoder_embed_dim, patch_size, in_chans, pred_with_conv) |
| | else: |
| | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) |
| | |
| |
|
| | self.norm_pix_loss = norm_pix_loss |
| |
|
| | self.initialize_weights() |
| | |
| | |
| | def initialize_weights(self): |
| | |
| | |
| | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token= not self.no_cls) |
| | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
| |
|
| | if self.gradual_resol: |
| | decoder_num_patches = self.patch_embed.num_patches // 4 |
| | else: |
| | decoder_num_patches = self.patch_embed.num_patches |
| | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(decoder_num_patches**.5), cls_token= not self.no_cls) |
| | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) |
| |
|
| | |
| | w = self.patch_embed.proj.weight.data |
| | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
| |
|
| | |
| | if not self.no_cls: |
| | torch.nn.init.normal_(self.cls_token, std=.02) |
| | if not self.ldmae_mode: |
| | torch.nn.init.normal_(self.mask_token, std=.02) |
| |
|
| | |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | |
| | torch.nn.init.xavier_uniform_(m.weight) |
| | if isinstance(m, nn.Linear) and m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.LayerNorm): |
| | nn.init.constant_(m.bias, 0) |
| | nn.init.constant_(m.weight, 1.0) |
| |
|
| | def patchify(self, imgs): |
| | """ |
| | imgs: (N, 3, H, W) |
| | x: (N, L, patch_size**2 *3) |
| | """ |
| | p = self.patch_embed.patch_size[0] |
| | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 |
| |
|
| | h = w = imgs.shape[2] // p |
| | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) |
| | x = torch.einsum('nchpwq->nhwpqc', x) |
| | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) |
| | return x |
| |
|
| | def unpatchify(self, x): |
| | """ |
| | x: (N, L, patch_size**2 *3) |
| | imgs: (N, 3, H, W) |
| | """ |
| | p = self.patch_embed.patch_size[0] |
| | h = w = int(x.shape[1]**.5) |
| | assert h * w == x.shape[1] |
| | |
| | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) |
| | x = torch.einsum('nhwpqc->nchpwq', x) |
| | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) |
| | return imgs |
| |
|
| | def random_masking(self, x, mask_ratio): |
| | """ |
| | Perform per-sample random masking by per-sample shuffling. |
| | Per-sample shuffling is done by argsort random noise. |
| | x: [N, L, D], sequence |
| | """ |
| | N, L, D = x.shape |
| | len_keep = int(L * (1 - mask_ratio)) |
| | |
| | noise = torch.rand(N, L, device=x.device) |
| | |
| | |
| | ids_shuffle = torch.argsort(noise, dim=1) |
| | ids_restore = torch.argsort(ids_shuffle, dim=1) |
| |
|
| | |
| | ids_keep = ids_shuffle[:, :len_keep] |
| | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
| |
|
| | |
| | mask = torch.ones([N, L], device=x.device) |
| | mask[:, :len_keep] = 0 |
| | |
| | mask = torch.gather(mask, dim=1, index=ids_restore) |
| |
|
| | return x_masked, mask, ids_restore |
| |
|
| | def forward_encoder(self, x, mask_ratio): |
| | |
| | x = self.patch_embed(x) |
| | |
| | |
| | if self.no_cls: |
| | x = x + self.pos_embed |
| | else: |
| | x = x + self.pos_embed[:, 1:, :] |
| |
|
| | |
| | x, mask, ids_restore = self.random_masking(x, mask_ratio) |
| |
|
| | |
| | if not self.no_cls: |
| | cls_token = self.cls_token + self.pos_embed[:, :1, :] |
| | cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| |
|
| | |
| | for blk in self.blocks: |
| | x = blk(x) |
| | x = self.norm(x) |
| |
|
| | return x, mask, ids_restore |
| |
|
| | def forward_decoder(self, x, ids_restore): |
| | |
| | x = self.decoder_embed(x) |
| |
|
| | |
| | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) |
| | if not self.no_cls: |
| | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
| | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
| | x = torch.cat([x[:, :1, :], x_], dim=1) |
| | else: |
| | x_ = torch.cat([x, mask_tokens], dim=1) |
| | x = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
| |
|
| | |
| | x = x + self.decoder_pos_embed |
| |
|
| | |
| | for blk in self.decoder_blocks: |
| | x = blk(x) |
| | x = self.decoder_norm(x) |
| |
|
| | |
| | x = self.decoder_pred(x) |
| |
|
| | |
| | if not self.no_cls: |
| | x = x[:, 1:, :] |
| |
|
| | return x |
| | |
| | def forward_encoder_with_mask(self, x, mask_ratio): |
| | |
| | x = self.patch_embed(x) |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | x, mask, ids_restore = self.random_masking(x, mask_ratio) |
| | |
| | |
| | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) |
| | |
| | |
| | |
| | |
| | |
| | x_ = torch.cat([x, mask_tokens], dim=1) |
| | x = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
| |
|
| | |
| | if self.no_cls: |
| | x = x + self.pos_embed |
| | else: |
| | x = x + self.pos_embed[:, 1:, :] |
| | cls_token = self.cls_token + self.pos_embed[:, :1, :] |
| | cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| |
|
| | |
| | for blk in self.blocks: |
| | x = blk(x) |
| | x = self.norm(x) |
| |
|
| | return x, mask, ids_restore |
| |
|
| | def forward_decoder_without_mask(self, x, ids_restore): |
| | |
| | x = self.decoder_embed(x) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | x = x + self.decoder_pos_embed |
| |
|
| | |
| | for blk in self.decoder_blocks: |
| | x = blk(x) |
| | x = self.decoder_norm(x) |
| |
|
| | |
| | x = self.decoder_pred(x) |
| |
|
| | |
| | if not self.no_cls: |
| | x = x[:, 1:, :] |
| |
|
| | return x |
| | |
| | |
| | |
| | def ldmae_encoding(self, imgs, use_mode=False, return_kl=False): |
| | |
| | x = self.patch_embed(imgs) |
| | if not self.no_cls: |
| | x = x + self.pos_embed[:, 1:, :] |
| | else: |
| | x = x + self.pos_embed |
| | if not self.no_cls: |
| | cls_token = self.cls_token + self.pos_embed[:, :1, :] |
| | cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| | for blk in self.blocks: |
| | x = blk(x) |
| | x = self.norm(x) |
| | |
| | latent = self.to_latent(x) |
| |
|
| | if self.kl_loss_weight is not None: |
| | latent = latent.permute(0,2,1) |
| | posterior = DiagonalGaussianDistribution(latent) |
| | if use_mode: |
| | print('mode') |
| | latent = posterior.mode() |
| | else: |
| | latent = posterior.sample() |
| | latent = latent.permute(0,2,1) |
| | if return_kl: |
| | return latent, posterior.kl() |
| | return latent |
| | |
| | def ldmae_decoding(self, x): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | x = self.from_latent(x) |
| | |
| | x = self.decoder_embed(x) |
| | |
| | if not self.no_cls: |
| | decoder_pos_embed = self.decoder_pos_embed[:,1:,:] |
| | else: |
| | decoder_pos_embed = self.decoder_pos_embed |
| | x = x + decoder_pos_embed |
| | for blk in self.decoder_blocks: |
| | x = blk(x) |
| | x = self.decoder_norm(x) |
| | |
| | x = self.decoder_pred(x) |
| | if not self.no_cls: |
| | x = x[:, 1:, :] |
| | return x |
| | |
| | def reconstruct(self, imgs, use_mode=True, return_kl=False, mask_ratio=0.75): |
| | |
| | if mask_ratio == 0.0: |
| | x = self.ldmae_encoding(imgs, use_mode=True, return_kl=False) |
| | else: |
| | with torch.no_grad(): |
| | if return_kl: |
| | x, kl_val = self.ldmae_encoding(imgs, use_mode=use_mode, return_kl=return_kl) |
| | else: |
| | x = self.ldmae_encoding(imgs, use_mode=use_mode, return_kl=return_kl) |
| | x = self.ldmae_decoding(x) |
| | if return_kl: |
| | return x, kl_val |
| | return x |
| | |
| | def linear_probe_seg(self, images): |
| | with torch.no_grad(): |
| | x = self.ldmae_encoding(images) |
| |
|
| | if not self.no_cls: |
| | x = x[:, 1:, :] |
| | |
| | B, N, D = x.shape |
| | x = x.reshape(-1, D) |
| | |
| | for layer in self.head: |
| | x = layer(x) |
| | |
| | return x |
| | |
| | def linear_probe(self, images): |
| | with torch.no_grad(): |
| | x = self.ldmae_encoding(images) |
| | |
| | |
| | if self.no_cls: |
| | x = x.mean(dim=1) |
| | else: |
| | x = x[:, 1:, :].mean(dim=1) |
| | |
| | for layer in self.head: |
| | x = layer(x) |
| | return x |
| | |
| | def forward_loss(self, imgs, pred, mask, visible_loss_ratio=0.5): |
| | """ |
| | imgs: [N, 3, H, W] |
| | pred: [N, L, p*p*3] |
| | mask: [N, L], 0 is keep, 1 is remove, |
| | """ |
| | target = self.patchify(imgs) |
| | if self.norm_pix_loss: |
| | mean = target.mean(dim=-1, keepdim=True) |
| | var = target.var(dim=-1, keepdim=True) |
| | target = (target - mean) / (var + 1.e-6)**.5 |
| |
|
| | loss = (pred - target) ** 2 |
| | loss = loss.mean(dim=-1) |
| | |
| | visible_loss = (loss * (1-mask)).sum() / (1-mask).sum() |
| | mask_loss = (loss * mask).sum() / mask.sum() |
| | |
| | loss = (1-visible_loss_ratio) * mask_loss + visible_loss_ratio * visible_loss |
| | |
| | if self.perceptual_loss is not None: |
| | p_loss = self.perceptual_loss( |
| | imgs.contiguous(), |
| | self.unpatchify(pred).contiguous() |
| | ) |
| | p_loss = p_loss.mean() |
| | loss = loss + self.perceptual_loss_ratio * p_loss |
| | else: |
| | p_loss = torch.zeros_like(loss) |
| | |
| | return loss, visible_loss, mask_loss, p_loss |
| | |
| | def forward_vanilla(self, imgs, mask_ratio=0.75, visible_loss_ratio=0.5): |
| | if self.gradual_resol: |
| | latent, mask, ids_restore = self.forward_encoder_with_mask(imgs, mask_ratio) |
| | else: |
| | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) |
| | |
| | |
| | |
| | latent = self.to_latent(latent) |
| |
|
| | if self.kl_loss_weight is not None: |
| | B, N, D = latent.shape |
| | latent = latent.permute(0,2,1) |
| | posterior = DiagonalGaussianDistribution(latent, fixed_std=self.fixed_std) |
| | |
| | kl_loss = posterior.kl() |
| | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] / N |
| | |
| | latent = posterior.sample() |
| | latent = latent.permute(0,2,1) |
| | |
| | latent = self.from_latent(latent) |
| | |
| | |
| | if self.gradual_resol: |
| | pred = self.forward_decoder_without_mask(latent, ids_restore) |
| | else: |
| | pred = self.forward_decoder(latent, ids_restore) |
| | |
| | loss, vis_loss, mask_loss, p_loss = self.forward_loss(imgs, pred, mask, visible_loss_ratio) |
| | if self.kl_loss_weight is not None: |
| | loss = loss + self.kl_loss_weight * kl_loss |
| | else: |
| | kl_loss = None |
| | return loss, pred, mask, vis_loss, mask_loss, kl_loss, p_loss |
| | |
| | def forward_ldmae(self, imgs, mask_ratio=0.75): |
| | pred = self.reconstruct(imgs, use_mode=False, mask_ratio=mask_ratio) |
| | |
| | |
| | vis_loss = (self.unpatchify(pred) - imgs) ** 2 |
| | |
| | if self.perceptual_loss is not None: |
| | p_loss = self.perceptual_loss( |
| | imgs.contiguous(), |
| | self.unpatchify(pred).contiguous() |
| | ) |
| | loss = vis_loss + self.perceptual_loss_ratio * p_loss |
| | else: |
| | loss = vis_loss |
| | p_loss = torch.zeros_like(vis_loss) |
| | |
| | loss = loss.mean() |
| | return loss, pred, None, vis_loss.mean(), p_loss.mean(), None |
| | |
| | def forward(self, imgs, mask_ratio=0.75, visible_loss_ratio=0.5): |
| | if self.ldmae_mode: |
| | return self.forward_ldmae(imgs, mask_ratio=mask_ratio) |
| | else: |
| | return self.forward_vanilla(imgs, mask_ratio=mask_ratio, visible_loss_ratio=visible_loss_ratio) |
| | |
| | |
| | |
| | def _encode(self, x): |
| | |
| | x = self.patch_embed(x) |
| | if not self.no_cls: |
| | x = x + self.pos_embed[:, 1:, :] |
| | else: |
| | x = x + self.pos_embed |
| | if not self.no_cls: |
| | cls_token = self.cls_token + self.pos_embed[:, :1, :] |
| | cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| | for blk in self.blocks: |
| | x = blk(x) |
| | x = self.norm(x) |
| | |
| | x = self.to_latent(x) |
| | x = rearrange(x, 'b (h w) c -> b c h w', h=self.latent_resolution, w=self.latent_resolution) |
| | return x |
| | |
| | def encode_gflops(self, x, return_dict=True): |
| | |
| | x = self.patch_embed(x) |
| | if not self.no_cls: |
| | x = x + self.pos_embed[:, 1:, :] |
| | else: |
| | x = x + self.pos_embed |
| | if not self.no_cls: |
| | cls_token = self.cls_token + self.pos_embed[:, :1, :] |
| | cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| | for blk in self.blocks: |
| | x = blk(x) |
| | x = self.norm(x) |
| | |
| | x = self.to_latent(x) |
| | x = rearrange(x, 'b (h w) c -> b c h w', h=self.latent_resolution, w=self.latent_resolution) |
| | |
| | return x |
| | |
| | def encode(self, x, return_dict=True): |
| | |
| | x = self.patch_embed(x) |
| | if not self.no_cls: |
| | x = x + self.pos_embed[:, 1:, :] |
| | else: |
| | x = x + self.pos_embed |
| | if not self.no_cls: |
| | cls_token = self.cls_token + self.pos_embed[:, :1, :] |
| | cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| | for blk in self.blocks: |
| | x = blk(x) |
| | x = self.norm(x) |
| | |
| | x = self.to_latent(x) |
| | x = rearrange(x, 'b (h w) c -> b c h w', h=self.latent_resolution, w=self.latent_resolution) |
| |
|
| | if self.kl_loss_weight is not None: |
| | p = DiagonalGaussianDistribution(x) |
| | else: |
| | p = EncoderOutput(x) |
| | if not return_dict: |
| | return (p,) |
| | |
| | return MAEOutput(latent_dist=p) |
| |
|
| | def decode(self, z, return_dict=True, generator=None): |
| | |
| | |
| | z = rearrange(z, 'b c h w -> b (h w) c') |
| | x = self.from_latent(z) |
| | |
| | x = self.decoder_embed(x) |
| | |
| | if not self.no_cls: |
| | decoder_pos_embed = self.decoder_pos_embed[:,1:,:] |
| | else: |
| | decoder_pos_embed = self.decoder_pos_embed |
| | x = x + decoder_pos_embed |
| | for blk in self.decoder_blocks: |
| | x = blk(x) |
| | x = self.decoder_norm(x) |
| | x = self.decoder_pred(x) |
| |
|
| | img = self.unpatchify(x) |
| | if not return_dict: |
| | return (img,) |
| | |
| | return DecoderOutput(sample=img) |
| |
|
| | @property |
| | def device(self) -> torch.device: |
| | """ |
| | Returns: |
| | torch.device: The torch device on which the model's parameters are located. |
| | """ |
| | |
| | for param in self.parameters(): |
| | return param.device |
| |
|
| | |
| | for buffer in self.buffers(): |
| | return buffer.device |
| |
|
| | |
| | return torch.device("cpu") |
| | |
| | @property |
| | def _execution_device(self): |
| | r""" |
| | Returns the device on which the pipeline's models will be executed. After calling |
| | [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from |
| | Accelerate's module hooks. |
| | """ |
| | return self.device |
| | |
| | @property |
| | def dtype(self) -> torch.dtype: |
| | """ |
| | Returns: |
| | torch.dtype: The torch data type on which the model's parameters are stored. |
| | """ |
| | |
| | for param in self.parameters(): |
| | return param.dtype |
| |
|
| | |
| | for buffer in self.buffers(): |
| | return buffer.dtype |
| |
|
| | |
| | return torch.float32 |
| | |
| | |
| | |
| | |
| | def img_transform(self, p_hflip=0, img_size=None): |
| | """Image preprocessing transforms |
| | Args: |
| | p_hflip: Probability of horizontal flip |
| | img_size: Target image size, use default if None |
| | Returns: |
| | transforms.Compose: Image transform pipeline |
| | """ |
| | img_size = img_size if img_size is not None else self.img_size |
| | img_transforms = [ |
| | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, img_size)), |
| | transforms.RandomHorizontalFlip(p=p_hflip), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) |
| | ] |
| | return transforms.Compose(img_transforms) |
| | |
| | def encode_images(self, images): |
| | """Encode images to latent representations |
| | Args: |
| | images: Input image tensor |
| | Returns: |
| | torch.Tensor: Encoded latent representation |
| | """ |
| | with torch.no_grad(): |
| | posterior = self.encode(images.cuda(), return_dict=False)[0] |
| | return posterior.sample() |
| |
|
| | def decode_to_images(self, z): |
| | """Decode latent representations to images |
| | Args: |
| | z: Latent representation tensor |
| | Returns: |
| | np.ndarray: Decoded image array |
| | """ |
| | with torch.no_grad(): |
| | images = self.decode(z.cuda(), return_dict=False)[0] |
| | images = torch.clamp(127.5 * images + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() |
| | return images |
| | |
| | |
| |
|
| | def mae_for_ldmae(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | img_size=128, patch_size=8, embed_dim=192, depth=12, num_heads=12, |
| | decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=32, **kwargs) |
| | return model |
| |
|
| | |
| | def mae_for_ldmae_f8d32(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | img_size=128, patch_size=8, embed_dim=192, depth=12, num_heads=12, |
| | decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=32, **kwargs) |
| | return model |
| |
|
| | def mae_for_ldmae_f8d16_prev(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | patch_size=8, embed_dim=192, depth=12, num_heads=12, |
| | decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, **kwargs) |
| | return model |
| |
|
| | def mae_for_ldmae_f8d16_small(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | patch_size=8, embed_dim=96, depth=12, num_heads=8, |
| | decoder_embed_dim=96, decoder_depth=12, decoder_num_heads=8, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, **kwargs) |
| | return model |
| |
|
| | def mae_for_ldmae_f8d16_asym_small(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | patch_size=8, embed_dim=96, depth=12, num_heads=8, |
| | decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, **kwargs) |
| | return model |
| |
|
| | def mae_for_ldmae_f8d16_prev_large(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | patch_size=8, embed_dim=384, depth=12, num_heads=16, |
| | decoder_embed_dim=384, decoder_depth=12, decoder_num_heads=16, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, **kwargs) |
| | return model |
| |
|
| | def mae_for_ldmae_f8d16(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | patch_size=8, embed_dim=192, depth=12, num_heads=12, |
| | decoder_embed_dim=384, decoder_depth=12, decoder_num_heads=24, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, down_nonlinear=True,**kwargs) |
| | return model |
| |
|
| | def mae_for_ldmae_f8d16_flexible(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | patch_size=8, embed_dim=192, depth=12, num_heads=12, |
| | decoder_embed_dim=384, decoder_depth=12, decoder_num_heads=24, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, down_nonlinear=True,**kwargs) |
| | return model |
| |
|
| | def mae_for_ldmae_f16d32(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | img_size=128, patch_size=16, embed_dim=192, depth=12, num_heads=12, |
| | decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=32, **kwargs) |
| | return model |
| |
|
| | def mae_for_ldmae_f16d32_large(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | img_size=128, patch_size=16, embed_dim=384, depth=12, num_heads=12, |
| | decoder_embed_dim=384, decoder_depth=12, decoder_num_heads=12, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=32, finetune_downsample_layer=4, **kwargs) |
| | return model |
| |
|
| | def mae_for_ldmae_f8d32_flexible(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | patch_size=8, embed_dim=192, depth=12, num_heads=12, |
| | decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=32, **kwargs) |
| | return model |
| |
|
| | def mae_for_ldmae_16d(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | img_size=128, patch_size=8, embed_dim=192, depth=12, num_heads=12, |
| | decoder_embed_dim=192, decoder_depth=12, decoder_num_heads=12, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), latent_dim=16, **kwargs) |
| | return model |
| |
|
| | def mae_vit_base_patch16_dec512d8b(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | patch_size=16, embed_dim=768, depth=12, num_heads=12, |
| | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| | return model |
| |
|
| | def mae_vit_base_patch16_dec128d8b(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | patch_size=16, embed_dim=768, depth=12, num_heads=12, |
| | decoder_embed_dim=128, decoder_depth=8, decoder_num_heads=16, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| | return model |
| |
|
| |
|
| | def mae_vit_large_patch16_dec512d8b(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | patch_size=16, embed_dim=1024, depth=24, num_heads=16, |
| | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| | return model |
| |
|
| |
|
| | def mae_vit_huge_patch14_dec512d8b(**kwargs): |
| | model = MaskedAutoencoderViT( |
| | patch_size=14, embed_dim=1280, depth=32, num_heads=16, |
| | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
| | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| | return model |
| |
|
| |
|
| | |
| | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b |
| | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b |
| | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b |
| | mae_vit_base_patch16_128 = mae_vit_base_patch16_dec128d8b |
| |
|