from ast import Module from bdb import effective from cProfile import label from functools import partial from black import Mode from matplotlib.pyplot import grid from dataclasses import dataclass import warnings import math from copy import deepcopy from typing import Union, Tuple, Sequence, Optional, List, Callable, Set from einops import rearrange import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.init import trunc_normal_ from torch import Tensor from torch.nn.init import _calculate_fan_in_and_fan_out from torchvision.utils import save_image from transformers.activations import PytorchGELUTanh, ACT2FN from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.integrations import use_kernel_forward_from_hub from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, BaseModelOutputWithPooling, BaseModelOutput from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from transformers.utils import ( ModelOutput, can_return_tuple, logging, replace_return_docstrings, torch_int, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, ) if is_flash_attn_2_available(): from transformers.modeling_flash_attention_utils import _flash_attention_forward ViT_Attention_type = "flash_attention_2" else: flash_attn_varlen_func = None ViT_Attention_type = "sdpa" from transformers.activations import ACT2FN from transformers.processing_utils import Unpack from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from transformers.generation import GenerationMixin from .configuration_umm import Qwen2Config, Siglip2VisionConfig, UMMConfig logger = logging.get_logger(__name__) from torch.nn.attention.flex_attention import flex_attention, create_block_mask from torch.nn.attention.flex_attention import BlockMask IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 IMAGE_GEN_TOKEN_INDEX = -300 IM_START_ID = 151667 IM_END_ID = 151668 NO_MEAN_ID = 151669 EOS_TOKEN_ID = 151645 @dataclass class Siglip2VisionOutput(ModelOutput): """ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. Args: image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): The image embeddings obtained by applying the projection layer to the pooler_output. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ image_embeds: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None class Siglip2VisionEmbeddings(nn.Module): def __init__(self, config: Siglip2VisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.patch_embedding = nn.Conv2d( in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing and no class embeddings. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] num_positions = self.position_embedding.weight.shape[0] # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embedding(self.position_ids) patch_pos_embed = self.position_embedding.weight.unsqueeze(0) dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bicubic", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: _, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] embeddings = patch_embeds.flatten(2).transpose(1, 2) if interpolate_pos_encoding: embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) else: embeddings = embeddings + self.position_embedding(self.position_ids) return embeddings def eager_attention_forward_siglip2( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Siglip2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Union[Siglip2VisionConfig]): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout self.is_causal = False # self.is_causal = True self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" batch_size, seq_length, embed_dim = hidden_states.shape queries = self.q_proj(hidden_states) keys = self.k_proj(hidden_states) values = self.v_proj(hidden_states) queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) attention_interface: Callable = eager_attention_forward_siglip2 if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and output_attentions: logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_interface = ALL_ATTENTION_FUNCTIONS[ViT_Attention_type] attn_output, attn_weights = attention_interface( self, queries, keys, values, attention_mask, is_causal=self.is_causal, scaling=self.scale, dropout=0.0 if not self.training else self.dropout, ) attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights class Siglip2MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class Siglip2EncoderLayer(nn.Module): def __init__(self, config: Union[Siglip2VisionConfig]): super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = Siglip2Attention(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Siglip2MLP(config) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): Input to the layer of shape `(batch, seq_len, embed_dim)`. attention_mask (`torch.FloatTensor`): Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class Siglip2Encoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`Siglip2EncoderLayer`]. Args: config: Siglip2Config """ def __init__(self, config): super().__init__() self.config = config self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False # Ignore copy @can_return_tuple def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = inputs_embeds for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, ) class Siglip2VisionTransformer(nn.Module): def __init__(self, config: Siglip2VisionConfig): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = Siglip2VisionEmbeddings(config) self.encoder = Siglip2Encoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head if self.use_head: self.head = Siglip2MultiheadAttentionPoolingHead(config) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" @can_return_tuple @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2VisionConfig) def forward( self, pixel_values: torch.FloatTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = False, ) -> BaseModelOutputWithPooling: r""" Returns: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) pooler_output = self.head(last_hidden_state) if self.use_head else None return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooler_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) def trunc_normal_tf_( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 ) -> torch.Tensor: """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \\leq \text{mean} \\leq b`. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsequently scaled and shifted by the mean and std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value """ with torch.no_grad(): _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in elif mode == "fan_out": denom = fan_out elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": with torch.no_grad(): tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) with torch.no_grad(): tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") class Siglip2MultiheadAttentionPoolingHead(nn.Module): """Multihead Attention Pooling.""" def __init__(self, config: Siglip2VisionConfig): super().__init__() self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = Siglip2MLP(config) self.num_heads = config.num_attention_heads def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) if attention_mask is not None: target_len, source_len = probe.shape[1], hidden_state.shape[1] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len) attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1) attention_mask = attention_mask.reshape(-1, target_len, source_len) hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0] residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) return hidden_state[:, 0] def swish(x: Tensor) -> Tensor: return x * torch.sigmoid(x) class AttnBlock(nn.Module): def __init__(self, in_channels: int): super().__init__() self.in_channels = in_channels self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) def attention(self, h_: Tensor) -> Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() h_ = nn.functional.scaled_dot_product_attention(q, k, v) return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) def forward(self, x: Tensor) -> Tensor: return x + self.proj_out(self.attention(x)) class ResnetBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h = x h = self.norm1(h) h = swish(h) h = self.conv1(h) h = self.norm2(h) h = swish(h) h = self.conv2(h) if self.in_channels != self.out_channels: x = self.nin_shortcut(x) return x + h class Downsample(nn.Module): def __init__(self, in_channels: int): super().__init__() # no asymmetric padding in torch conv, must do it ourselves self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x: Tensor): pad = (0, 1, 0, 1) x = nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) return x class Upsample(nn.Module): def __init__(self, in_channels: int): super().__init__() self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: Tensor): x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") x = self.conv(x) return x class Encoder(nn.Module): def __init__( self, config, ): super().__init__() self.quant_conv = torch.nn.Conv2d(2 * config.z_channels, 2 * config.z_channels, 1) self.ch = config.ch self.num_resolutions = len(config.ch_mult) self.num_res_blocks = config.num_res_blocks self.resolution = config.resolution self.in_channels = config.in_channels # downsampling self.conv_in = nn.Conv2d(config.in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = config.resolution in_ch_mult = (1,) + tuple(config.ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() block_in = self.ch for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = config.ch * in_ch_mult[i_level] block_out = config.ch * config.ch_mult[i_level] for _ in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, 2 * config.z_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: Tensor) -> Tensor: # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1]) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) h = self.quant_conv(h) return h class Decoder(nn.Module): def __init__( self, config ): super().__init__() self.post_quant_conv = torch.nn.Conv2d(config.z_channels, config.z_channels, 1) self.ch = config.ch self.num_resolutions = len(config.ch_mult) self.num_res_blocks = config.num_res_blocks self.resolution = config.resolution self.in_channels = config.in_channels self.ffactor = 2 ** (self.num_resolutions - 1) # compute in_ch_mult, block_in and curr_res at lowest res block_in = config.ch * config.ch_mult[self.num_resolutions - 1] curr_res = config.resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, config.z_channels, curr_res, curr_res) # z to block_in self.conv_in = nn.Conv2d(config.z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = config.ch * config.ch_mult[i_level] for _ in range(self.num_res_blocks + 1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, config.out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z: Tensor) -> Tensor: z = self.post_quant_conv(z) # get dtype for proper tracing upscale_dtype = next(self.up.parameters()).dtype # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # cast to proper dtype h = h.to(upscale_dtype) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h class Qwen2Attention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer and "Generating Long Sequences with Sparse Transformers". """ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True self.attention_dropout = config.attention_dropout if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class Qwen2SdpaAttention(Qwen2Attention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ # Adapted from Qwen2Attention.forward def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask # if attention_mask is not None: # no matter the length, we just slice it # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal = True if causal_mask is None and q_len > 1 else False query_states = query_states.to(value_states.dtype) key_states = key_states.to(value_states.dtype) if type(attention_mask) == BlockMask: attn_output = flex_attention(query_states, key_states, value_states, block_mask=attention_mask) else: causal_mask = causal_mask.to(torch.bool).to(value_states.device) attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value QWEN2_ATTENTION_CLASSES = { "eager": Qwen2Attention, "sdpa": Qwen2SdpaAttention, } class Qwen2MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Qwen2RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Qwen2DecoderLayer(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class Qwen2RotaryEmbedding(nn.Module): def __init__(self, config, device=None): super().__init__() if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @dataclass class UMMCausalLMOutput(CausalLMOutputWithPast): image_latent: Optional[torch.FloatTensor] = None image_latent_label: Optional[torch.FloatTensor] = None image_pixcel: Optional[torch.FloatTensor] = None image_pixcel_label: Optional[torch.FloatTensor] = None text_loss: Optional[torch.FloatTensor] = None image_loss: Optional[torch.FloatTensor] = None time_embeds_1: Optional[torch.FloatTensor] = None time_embeds_2: Optional[torch.FloatTensor] = None class UMMPretrainedModel(PreTrainedModel): config: UMMConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2DecoderLayer", "Siglip2EncoderLayer", "Encoder", "Decoder"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _can_compile_fullgraph = True _supports_attention_backend = True def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) class VAEModel(UMMPretrainedModel): def __init__(self, config): super().__init__(config) self.encoder = Encoder(config.vae_encoder_config) self.decoder = Decoder(config.vae_decoder_config) self.bn_eps = 1e-4 self.bn_momentum = 0.1 self.ps = [2, 2] self.bn = torch.nn.BatchNorm2d( math.prod(self.ps) * config.vae_encoder_config.z_channels, eps=self.bn_eps, momentum=self.bn_momentum, affine=False, track_running_stats=True, ) self.post_init() def normalize(self, z): self.bn.eval() return self.bn(z) def inv_normalize(self, z): self.bn.eval() s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps) m = self.bn.running_mean.view(1, -1, 1, 1) return z * s + m def encode(self, x: Tensor) -> Tensor: moments = self.encoder(x) mean = torch.chunk(moments, 2, dim=1)[0] z = rearrange( mean, "... c (i pi) (j pj) -> ... (c pi pj) i j", pi=self.ps[0], pj=self.ps[1], ) z = self.normalize(z) return z def decode(self, z: Tensor) -> Tensor: z = self.inv_normalize(z) z = rearrange( z, "... (c pi pj) i j -> ... c (i pi) (j pj)", pi=self.ps[0], pj=self.ps[1], ) dec = self.decoder(z) return dec class VAEDecoderProjector(nn.Module): def __init__(self, config): super().__init__() self.decoder = Decoder(config.vae_decoder_config) self.bn_eps = 1e-4 self.bn_momentum = 0.1 self.ps = [2, 2] self.bn = torch.nn.BatchNorm2d( math.prod(self.ps) * config.vae_encoder_config.z_channels, eps=self.bn_eps, momentum=self.bn_momentum, affine=False, track_running_stats=True, ) def inv_normalize(self, z): self.bn.eval() s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps) m = self.bn.running_mean.view(1, -1, 1, 1) return z * s + m def forward(self, z: Tensor) -> Tensor: z = self.inv_normalize(z) z = rearrange( z, "... (c pi pj) i j -> ... c (i pi) (j pj)", pi=self.ps[0], pj=self.ps[1], ) dec = self.decoder(z) return dec class UMMUndProjector(nn.Module): def __init__( self, embed_dim, image_embed_dim, compression_factor=(2,2), ): super().__init__() self.embed_dim = embed_dim self.image_embed_dim = image_embed_dim self.layernorm = nn.LayerNorm(image_embed_dim) self.hidden_size = image_embed_dim * (compression_factor[0]*compression_factor[1]) self.mlp = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, embed_dim), ) self.compression_factor = compression_factor def forward(self, x): x = x.to(torch.bfloat16) x = self.layernorm(x) height, width = int(x.size(1)**0.5), int(x.size(1)**0.5) x = x.permute(0, 2, 1).unflatten(-1, (height, width)) # b, dim, h, w batch_size, dim, height, width = x.shape unfolded = x.unfold(2, self.compression_factor[0], self.compression_factor[0]).unfold(3, self.compression_factor[1], self.compression_factor[1]) unfolded = unfolded.contiguous().view(batch_size, dim, -1, self.compression_factor[0] * self.compression_factor[1]) unfolded = unfolded.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, dim*self.compression_factor[0] * self.compression_factor[1]) compressed_x = self.mlp(unfolded) return compressed_x class DiTAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, hidden_size, layer_idx, num_attention_heads=32, num_key_value_heads=8, attention_dropout=0.0, attn_implementation='sdpa' ): super().__init__() self.layer_idx = layer_idx self.head_dim = hidden_size // num_attention_heads self.num_key_value_groups = num_attention_heads // num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = attention_dropout self.attn_implementation = attn_implementation self.is_causal = False self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=False) def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # attention_interface = ALL_ATTENTION_FUNCTIONS[self.attn_implementation] attention_interface = ALL_ATTENTION_FUNCTIONS[ViT_Attention_type] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, is_causal=self.is_causal, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def modulate(x, shift, scale): input_dtype = x.dtype x = x.to(torch.float32) shift = shift.to(torch.float32) scale = scale.to(torch.float32) if len(x.shape) != len(shift.shape): return (x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)).to(input_dtype) else: return (x * (1 + scale) + shift).to(input_dtype) class MLP(nn.Module): def __init__(self, hidden_size, intermediate_size): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = nn.SiLU() def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class ModulatedAttentionBlock(nn.Module): def __init__( self, hidden_size, layer_idx, num_attention_heads=32, num_key_value_heads=8, attention_dropout=0.0, ): super().__init__() self.hidden_size = hidden_size self.attn = DiTAttention( hidden_size=hidden_size, layer_idx=layer_idx, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, attention_dropout=attention_dropout, ) self.mlp = MLP(self.hidden_size, self.hidden_size*4) self.input_layernorm = RMSNorm(self.hidden_size, eps=1e-6) self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=1e-6) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear( self.hidden_size, 6*self.hidden_size, bias=True, ), ) nn.init.zeros_(self.adaLN_modulation[1].weight) nn.init.zeros_(self.adaLN_modulation[1].bias) def forward( self, hidden_states, adaln_input, attention_mask=None, past_key_value=None, output_attentions=False, cache_position=None, position_embeddings=None, **kwargs, ): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(6, dim=-1) residual = hidden_states hidden_states = modulate(self.input_layernorm(hidden_states), shift_msa, scale_msa) hidden_states, self_attn_weights = self.attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, past_key_value=past_key_value, cache_position=cache_position, **kwargs, ) hidden_states = residual + gate_msa * hidden_states residual = hidden_states hidden_states = modulate(self.post_attention_layernorm(hidden_states), shift_mlp, scale_mlp) hidden_states = self.mlp(hidden_states) hidden_states = residual + gate_mlp * hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs class DiTRotaryEmbedding(nn.Module): def __init__( self, dim=None, max_position_embeddings=2048, base=10000, scaling_factor=1.0,): super().__init__() self.rope_kwargs={ "rope_type": "default", "factor": scaling_factor, "dim": dim, "base": base, "max_position_embeddings": max_position_embeddings, } self.rope_type = "default" self.max_seq_len_cached = max_position_embeddings self.original_max_seq_len = max_position_embeddings self.config = None self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, None, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class FinalLayer(nn.Module): def __init__(self, hidden_size, patch_size, out_channels): super().__init__() self.norm_final = RMSNorm(hidden_size) self.linear = nn.Linear(hidden_size, patch_size*patch_size*out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2*hidden_size, bias=True) ) def forward(self, x, adaln_input): shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=-1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class UMMGenProjector(nn.Module): def __init__( self, embed_dim, num_attention_heads, num_key_value_heads, patch_size, output_dim, layers_num, ): super().__init__() self.diffusion_head_a = nn.ModuleList( [ModulatedAttentionBlock(hidden_size=embed_dim, layer_idx=layer_idx, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads) for layer_idx in range(layers_num)] ) self.diffusion_head_b = FinalLayer(hidden_size=embed_dim, patch_size=patch_size, out_channels=output_dim) self.rotary_emb = DiTRotaryEmbedding( dim=embed_dim // num_attention_heads, max_position_embeddings=2048, base=10000, scaling_factor=1.0, ) def forward(self, x, time_embeds, position_ids=None): if position_ids is None: position_ids = torch.arange(x.shape[1], device=x.device). unsqueeze(0) position_embeddings = self.rotary_emb(x, position_ids) hidden_states = x for layer in self.diffusion_head_a: hidden_states = layer( hidden_states=hidden_states, position_embeddings=position_embeddings, adaln_input=time_embeds, position_ids=position_ids, )[0] v_pred = self.diffusion_head_b(hidden_states, time_embeds) return v_pred class UMMGenHiProjector(nn.Module): def __init__( self, embed_dim, num_attention_heads, num_key_value_heads, patch_size, output_dim, layers_num, ): super().__init__() self.diffusion_head_a = nn.ModuleList( [ModulatedAttentionBlock(hidden_size=embed_dim, layer_idx=layer_idx, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads) for layer_idx in range(layers_num)] ) self.diffusion_head_b = FinalLayer(hidden_size=embed_dim, patch_size=patch_size, out_channels=output_dim) self.rotary_emb = DiTRotaryEmbedding( dim=embed_dim // num_attention_heads, max_position_embeddings=2048, base=10000, scaling_factor=1.0, ) def forward(self, x, time_embeds): position_ids = torch.arange(x.shape[1] // 2, device=x.device). unsqueeze(0) cos, sin = self.rotary_emb(x, position_ids) cos = cos.repeat(1, 2, 1) sin = sin.repeat(1, 2, 1) position_embeddings = (cos, sin) hidden_states = x for layer in self.diffusion_head_a: hidden_states = layer( hidden_states=hidden_states, position_embeddings=position_embeddings, adaln_input=time_embeds, position_ids=position_ids, )[0] v_pred = self.diffusion_head_b(hidden_states, time_embeds) return v_pred class TimestepEmbedder(nn.Module): def __init__(self, hidden_size_1, hidden_size_2, frequency_embedding_size=256): super().__init__() self.mlp_1 = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size_1, bias=True), nn.SiLU(), nn.Linear(hidden_size_1, hidden_size_1, bias=True), ) self.mlp_2 = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size_2, bias=True), nn.SiLU(), nn.Linear(hidden_size_2, hidden_size_2, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t, dtype): t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) t_emb_1 = self.mlp_1(t_freq) t_emb_2 = self.mlp_2(t_freq) return t_emb_1, t_emb_2 class HiGate(nn.Module): def __init__(self, embed_dim): super().__init__() self.gate = nn.Sequential( nn.Linear(embed_dim, embed_dim), nn.SiLU(), nn.Linear(embed_dim,1), ) self.layer_norm = nn.LayerNorm(embed_dim) nn.init.zeros_(self.gate[-1].weight) nn.init.zeros_(self.gate[-1].bias) def forward(self, low_info, high_info, heat_map=False): hi_gate = self.gate(low_info) output = low_info + hi_gate * high_info output = self.layer_norm(output) if heat_map: hi_gate_map = hi_gate[0, :, 0] seq_len = hi_gate_map.shape[0] H=W=int(math.sqrt(seq_len)) heat_map = hi_gate_map.view(H, W).detach().float().cpu().numpy() return output, heat_map return output class UMMTextModel(UMMPretrainedModel): config: Qwen2Config _no_split_modules = ["Qwen2DecoderLayer"] def __init__(self, config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def get_input_embeddings(self): return self.embed_tokens def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> BaseModelOutputWithPast: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = attention_mask hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( partial(decoder_layer.__call__, **kwargs), hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) class UMMModel(UMMPretrainedModel): config_class = UMMConfig _no_split_modules = [ "Encoder", "Decoder", "Siglip2EncoderLayer", "UMMUndProjector", "UMMGenProjector", "UMMGenHiProjector", "Qwen2DecoderLayer" ] def __init__(self, config): super().__init__(config) self.config = config self.vision_representation = Siglip2VisionTransformer(config.vision_representation_config) self.vae_model = VAEModel(config) self.vae_decoder_projector = VAEDecoderProjector(config) text_embed_dim = config.text_config.hidden_size image_embed_dim = config.vision_representation_config.hidden_size t_num_attention_heads = config.text_config.num_attention_heads t_num_key_value_heads = config.text_config.num_key_value_heads i_num_attention_heads = config.vision_representation_config.num_attention_heads i_num_key_value_heads = config.text_config.num_key_value_heads self.time_embed = TimestepEmbedder(hidden_size_1=text_embed_dim, hidden_size_2=image_embed_dim) self.und_projector = UMMUndProjector(embed_dim=text_embed_dim, image_embed_dim=image_embed_dim) self.gen_projector = UMMGenProjector( embed_dim=text_embed_dim, num_attention_heads=t_num_attention_heads, num_key_value_heads=t_num_key_value_heads, patch_size=2, output_dim=image_embed_dim, layers_num=7, ) self.hi_gate = HiGate(embed_dim=image_embed_dim) self.hi_projector = UMMGenHiProjector( embed_dim=image_embed_dim, num_attention_heads=i_num_attention_heads, num_key_value_heads=i_num_key_value_heads, patch_size=1, output_dim=config.vae_decoder_config.ch, layers_num=3, ) self.language_model = UMMTextModel._from_config(config.text_config) self.post_init() def path_sample(self, t, x0, x1): dims = [1] * len(x1[0].size()) t = t.view(t.size(0), *dims) alpha_t, d_alpha_t = t, 1 sigma_t, d_sigma_t = 1-t, -1 xt = alpha_t * x1 + sigma_t * x0 ut = d_alpha_t * x1 + d_sigma_t * x0 mask = (t < 1).float().to(device=ut.device, dtype=ut.dtype) ut = ut * mask return xt, ut def get_image_features(self, pixel_values, t=None): pixel_values = pixel_values.to(self.vae_model.encoder.quant_conv.weight.dtype) with torch.no_grad(): image_latent = self.vae_model.encode(pixel_values) ### b, 128, h/16, w/16 x0 = torch.randn_like(image_latent) if t is not None: xt, ut = self.path_sample(t, x0, image_latent) else: xt, ut = image_latent, torch.zeros_like(image_latent, device=image_latent.device, dtype=image_latent.dtype) xt = xt.to(torch.bfloat16) image_pixel_hat = self.vae_decoder_projector(xt) if image_pixel_hat.size(-1) > 512: interpolate_pos_encoding = True else: interpolate_pos_encoding = False image_features = self.vision_representation(image_pixel_hat, interpolate_pos_encoding=interpolate_pos_encoding).last_hidden_state image_target = ut.reshape(ut.size(0), ut.size(1), -1).permute(0, 2, 1) projected_image_features = self.und_projector(image_features) output = { "image_target": image_target, "projected_image_features": projected_image_features, "pixel_values": pixel_values, "image_pixel_hat": image_pixel_hat, } return output def prepare_inputs_labels_for_multimodal( self, input_ids, t, position_ids=None, attention_mask=None, past_key_values=None, labels=None, pixel_values=None, grid_hws=None ): if pixel_values is None or input_ids.shape[1] == 1: return input_ids, position_ids, attention_mask, past_key_values, None, None, None, None, None image_features = self.get_image_features(pixel_values, t) _labels = labels _position_ids = position_ids _attention_mask = attention_mask if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() if position_ids is None: position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) if labels is None: labels = torch.full_like(input_ids, IGNORE_INDEX) input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] image_gen_pixcel_hat = [] new_input_embeds = [] new_labels = [] image_gen_labels = [] image_gen_pixcel_labels = [] image_mask = [] cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() num_images_gen = (cur_input_ids == IMAGE_GEN_TOKEN_INDEX).sum() if num_images == 0 and num_images_gen == 0: cur_new_labels = [] cur_image_features = image_features["projected_image_features"][cur_image_idx] cur_input_embeds_1 = self.language_model.embed_tokens(cur_input_ids) cur_input_embeds = torch.cat([cur_image_features, cur_input_embeds_1], dim=0) cur_image_targets = torch.full((cur_image_features.shape[0],), IMAGE_TOKEN_INDEX, device=labels[batch_idx].device, dtype=labels[batch_idx].dtype) new_input_embeds.append(cur_input_embeds) image_gen_labels.append(image_features["image_target"][cur_image_idx]) image_gen_pixcel_labels.append(image_features["pixel_values"][cur_image_idx]) image_gen_pixcel_hat.append(image_features["image_pixel_hat"][cur_image_idx]) image_mask.append(torch.zeros(cur_input_embeds.size(0), dtype=torch.bool)) cur_new_labels.append(cur_image_targets) cur_new_labels.append(labels[batch_idx]) cur_new_labels = torch.cat(cur_new_labels) new_labels.append(cur_new_labels) cur_image_idx += 1 continue image_token_positions = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() image_gen_token_positions = torch.where(cur_input_ids == IMAGE_GEN_TOKEN_INDEX)[0].tolist() all_insert_positions = sorted([ (pos, "image") for pos in image_token_positions ] + [ (pos, "image_gen") for pos in image_gen_token_positions ], key=lambda x: x[0]) all_token_indices = [-1] + [p[0] for p in all_insert_positions] + [cur_input_ids.shape[0]] cur_input_ids_noim = [] cur_labels = labels[batch_idx] cur_labels_noim = [] for i in range(len(all_token_indices) - 1): start = all_token_indices[i] + 1 end = all_token_indices[i + 1] cur_input_ids_noim.append(cur_input_ids[start: end]) cur_labels_noim.append(cur_labels[start: end]) split_sizes = [x.shape[0] for x in cur_labels_noim] cur_input_embeds = self.language_model.embed_tokens(torch.cat(cur_input_ids_noim)) cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) cur_new_input_embeds = [] cur_new_labels = [] cur_image_mask = [] for i in range(len(cur_input_embeds_no_im)): cur_new_input_embeds.append(cur_input_embeds_no_im[i]) cur_new_labels.append(cur_labels_noim[i]) cur_image_mask.append(torch.zeros(cur_input_embeds_no_im[i].size(0), dtype=torch.bool)) if i < len(all_insert_positions): token_type = all_insert_positions[i][1] if token_type == "image": cur_image_features = image_features["projected_image_features"][cur_image_idx] cur_image_targets = torch.full((cur_image_features.shape[0],), IMAGE_TOKEN_INDEX, device=cur_labels.device, dtype=cur_labels.dtype) image_gen_labels.append(image_features["image_target"][cur_image_idx]) image_gen_pixcel_labels.append(image_features["pixel_values"][cur_image_idx]) image_gen_pixcel_hat.append(image_features["image_pixel_hat"][cur_image_idx]) cur_image_idx += 1 elif token_type == "image_gen": cur_image_features = image_features["projected_image_features"][cur_image_idx] cur_image_targets = torch.full((cur_image_features.shape[0],), IMAGE_GEN_TOKEN_INDEX, device=cur_labels.device, dtype=cur_labels.dtype) image_gen_labels.append(image_features["image_target"][cur_image_idx]) image_gen_pixcel_labels.append(image_features["pixel_values"][cur_image_idx]) image_gen_pixcel_hat.append(image_features["image_pixel_hat"][cur_image_idx]) cur_image_idx += 1 else: raise ValueError(f"Unexpected token type: {token_type}") cur_new_input_embeds.append(cur_image_features) cur_new_labels.append(cur_image_targets) cur_image_mask.append(torch.ones(cur_image_features.size(0), dtype=torch.bool)) cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds) cur_new_labels = torch.cat(cur_new_labels) cur_image_mask = torch.cat(cur_image_mask) new_input_embeds.append(cur_new_input_embeds) new_labels.append(cur_new_labels) image_mask.append(cur_image_mask) tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", 4096) new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] new_labels = [x[:tokenizer_model_max_length] for x in new_labels] image_mask = [x[:tokenizer_model_max_length] for x in image_mask] max_len = max(x.shape[0] for x in new_input_embeds) batch_size = len(new_input_embeds) new_input_embeds_padded = [] new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) new_image_mask = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) for i, (cur_new_embed, cur_new_labels, cur_image_mask) in enumerate(zip(new_input_embeds, new_labels, image_mask)): cur_len = cur_new_embed.shape[0] if getattr(self.config, "tokenizer_padding_side", "right") == "left": new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0)) if cur_len > 0: new_labels_padded[i, -cur_len:] = cur_new_labels attention_mask[i, -cur_len:] = True position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) new_image_mask[i, -cur_len:] = cur_image_mask else: new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) new_image_mask[i, :cur_len] = cur_image_mask new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) if _labels is None: new_labels = None else: new_labels = new_labels_padded if _attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=_attention_mask.dtype) if _position_ids is None: position_ids = None return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, image_gen_labels, image_gen_pixcel_labels, image_gen_pixcel_hat, new_image_mask def forward( self, input_ids = None, t = None, position_ids = None, attention_mask = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, pixel_values = None, grid_hws = None, return_dict = None, **kwargs, ): if t is not None: if isinstance(t, torch.Tensor): t = t.reshape(-1).unsqueeze(1).to(dtype=pixel_values.dtype, device=pixel_values.device) elif isinstance(t, list): t = torch.cat(t).unsqueeze(1).to(dtype=pixel_values.dtype, device=pixel_values.device) else: t = torch.ones((len(pixel_values), 1), dtype=pixel_values.dtype, device=pixel_values.device) if inputs_embeds is None or pixel_values is not None: input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, image_gen_labels, image_gen_pixcel_labels, image_gen_pixcel_hat, new_image_mask = self.prepare_inputs_labels_for_multimodal( input_ids, t, position_ids, attention_mask, past_key_values, labels, pixel_values, grid_hws ) device = inputs_embeds.device t_embeds_1, t_embeds_2 = self.time_embed(t, inputs_embeds.dtype) batch_size, seq_len = new_image_mask.shape head_num = self.config.text_config.num_attention_heads omni_attention_mask = torch.tril(torch.ones((batch_size,head_num,seq_len,seq_len), dtype=torch.long)).to(device) img = new_image_mask.bool() seg = (img & ~torch.nn.functional.pad(img[:, :-1], (1, 0), value=False)).cumsum(1) * img attention_mask_img = ((seg.unsqueeze(2) == seg.unsqueeze(1)) & img.unsqueeze(2) & img.unsqueeze(1)) attention_mask_img = attention_mask_img.to(dtype=torch.long).to(device) txt = attention_mask.bool() seg = (txt & ~torch.nn.functional.pad(txt[:, :-1], (1, 0), value=False)).cumsum(1) * txt attention_mask_txt = ((seg.unsqueeze(2) == seg.unsqueeze(1)) & txt.unsqueeze(2) & txt.unsqueeze(1)) attention_mask_txt = attention_mask_txt.to(dtype=torch.long).to(device) pad = ~attention_mask.bool() seg = (pad & ~torch.nn.functional.pad(pad[:, :-1], (1, 0), value=False)).cumsum(1) * pad attention_mask_pad = ((seg.unsqueeze(2) == seg.unsqueeze(1)) & pad.unsqueeze(2) & pad.unsqueeze(1)) attention_mask_pad = attention_mask_pad.to(dtype=torch.long).to(device) for i in range(omni_attention_mask.size(1)): omni_attention_mask[:,i,:,:] = torch.bitwise_or(omni_attention_mask[:,i,:,:], attention_mask_img) * attention_mask_txt omni_attention_mask[:,i,:,:] = torch.bitwise_or(omni_attention_mask[:,i,:,:], attention_mask_pad) output = self.language_model( input_ids=None, attention_mask=omni_attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) if labels is not None: return output, t_embeds_1, t_embeds_2, labels, image_gen_labels, image_gen_pixcel_labels, image_gen_pixcel_hat return output, t_embeds_1, t_embeds_2, new_image_mask class Cheers(UMMPretrainedModel, GenerationMixin): config_class = UMMConfig _no_split_modules = [ "Encoder", "Decoder", "Siglip2EncoderLayer", "UMMUndProjector", "UMMGenProjector", "Qwen2DecoderLayer" ] _tied_weights_keys = ["lm_head.weight", "model.language_model.embed_tokens.weight"] def __init__(self, config): super().__init__(config) self.config.text_config._attn_implementation = self.config._attn_implementation self.config.vision_representation_config._attn_implementation = self.config._attn_implementation self.model = UMMModel(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.text_loss_fc = nn.CrossEntropyLoss(reduction="none", ignore_index=-100) self.post_init() @property def language_model(self): return self.model.language_model @property def vae_model(self): return self.model.vae_model @property def vision_representation(self): return self.model.vision_representation def get_input_embeddings(self): return self.language_model.embed_tokens def get_output_embeddings(self): return self.lm_head def forward(self, input_ids, t=None, labels=None, attention_mask=None, pixel_values=None, grid_hws=None, **kwargs): if grid_hws is not None: image_h, image_w = grid_hws[0] per_image_token = int(image_h * image_w) // 4 else: image_h, image_w = 32, 32 per_image_token = 256 if labels is not None: outputs, t_embeds_1, t_embeds_2, labels, image_gen_labels, image_gen_pixcel_labels, image_gen_pixcel_hat = self.model(input_ids, t=t, labels=labels, attention_mask=attention_mask, pixel_values=pixel_values, grid_hws=grid_hws, **kwargs) hidden_states = outputs[0][:, :-1, :].contiguous() labels = labels[:, 1:].contiguous() bsz, seq_len, dim = hidden_states.shape device = hidden_states.device image_num = len(image_gen_labels) mask_img_de = torch.zeros_like(labels, dtype=torch.bool) image_latent = None for b in range(bsz): cur_labels = labels[b] new_cur_labels = [] for idx, cur_id in enumerate(cur_labels): if cur_id == IM_START_ID: new_cur_labels.append(IM_START_ID) new_cur_labels.append(-100) elif cur_id == IM_END_ID: continue elif cur_id == IMAGE_TOKEN_INDEX or cur_id == IMAGE_GEN_TOKEN_INDEX: new_cur_labels.append(-100) else: new_cur_labels.append(cur_id) new_cur_labels = torch.tensor(new_cur_labels, dtype=cur_labels.dtype, device=cur_labels.device) labels[b] = new_cur_labels start_indices = (new_cur_labels == IM_START_ID).nonzero(as_tuple=True)[0] for start in start_indices: mask_img_de[b, start + 2:start + per_image_token + 2] = True if not mask_img_de[b].any(): mask_img_de[b, :per_image_token] = True mask_text = (labels != -100) mask_text[mask_img_de] = False anchor_text = (self.lm_head(hidden_states[:, :1, :]).sum()) * 0.0 anchor_img = hidden_states[:, :1, :].sum() * 0.0 zero_anchor = anchor_text + anchor_img if mask_text.any(): out_text = self.lm_head(hidden_states) if mask_img_de.any(): h_img_de = hidden_states[mask_img_de] try: h_img_de = h_img_de.reshape(image_num, per_image_token, h_img_de.size(-1)) except: print(h_img_de.shape, image_num, per_image_token, mask_img_de.shape) image_latent = self.model.gen_projector(h_img_de, t_embeds_1) image_latent = image_latent.reshape(image_num, int(image_h//2), int(image_w//2), image_latent.size(-1)) B, H, W, C = image_latent.shape P = 2 D = C//(P*P) image_latent = image_latent.view(B, H, W, P, P, D) image_latent = image_latent.permute(0, 1, 3, 2, 4, 5).contiguous() image_latent = image_latent.view(B, H*P, W*P, D) image_latent = image_latent.view(B, H*P*W*P, D) image_gen_pixcel_hat = torch.stack(image_gen_pixcel_hat, dim=0) patch_embedding_res = self.model.vision_representation.embeddings(image_gen_pixcel_hat) hi_input = self.model.hi_gate(image_latent, patch_embedding_res) image_latent_pre = self.model.hi_projector(hi_input, t_embeds_2) image_latent = image_latent_pre.view(B, image_h, image_w, image_latent_pre.size(-1)) image_latent = image_latent.permute(0, 3, 1, 2) text_loss_denominator = (labels != -100).sum().clamp(min=1) image_loss_denominator = mask_img_de.sum() if text_loss_denominator.item() > 0: text_logits = out_text.view(-1, out_text.shape[-1]).contiguous() text_labels = labels.view(-1).type(torch.long).contiguous() text_loss = self.text_loss_fc(text_logits, text_labels) valid_mask = (text_labels != -100).type_as(text_loss) text_loss = text_loss * valid_mask text_loss = text_loss.sum() / text_loss_denominator else: text_loss = zero_anchor if image_loss_denominator.item() > 0: #### latent space loss #### image_gen_labels = torch.stack(image_gen_labels, dim=0) image_loss = F.mse_loss(image_latent_pre, image_gen_labels) else: image_loss = zero_anchor alpha = 1.0 total_loss = text_loss + alpha * image_loss image_latent_label = image_gen_labels if image_latent is not None and image_latent.numel(): with torch.no_grad(): image_pixcel = self.model.vae_model.decode(image_latent) else: image_pixcel = None return UMMCausalLMOutput( loss=total_loss, logits=text_logits, image_latent=image_latent, image_latent_label=image_latent_label, image_pixcel=image_pixcel, image_pixcel_label=image_gen_pixcel_labels, text_loss=text_loss, image_loss=image_loss, past_key_values=outputs.past_key_values, hidden_states=outputs.last_hidden_state, attentions=outputs.attentions ) else: outputs, t_embeds_1, t_embeds_2, omni_attention_mask = self.model(input_ids, t=t, labels=labels, attention_mask=attention_mask, pixel_values=pixel_values, grid_hws=grid_hws, **kwargs) hidden_states = outputs.last_hidden_state slice_indices = slice(0, None) logits = self.lm_head(hidden_states[:,slice_indices,:]) loss = None return UMMCausalLMOutput( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=hidden_states, attentions=omni_attention_mask, time_embeds_1=t_embeds_1, time_embeds_2=t_embeds_2, ) def generate( self, input_ids, max_length=2048, temperature=0.7, top_p=1.0, t=None, cfg_scale=None, attention_mask=None, pixel_values=None, grid_hws=None, num_inference_steps=50, num_beams=1, length_penalty=1.0, repetition_penalty=1.1, edit_image=False, alpha=1.0, **kwargs ): if temperature != 0 and num_beams > 1: warnings.warn( "Both temperature != 0 and num_beams > 1 are set. " "Beam search will be used and sampling paramters are ignored.", UserWarning ) max_input_tokens = 1024 max_output_tokens = max_length orig_len = input_ids.size(1) if grid_hws is not None: image_h, image_w = grid_hws[0] per_image_token = int(image_h * image_w) // 4 else: image_h, image_w = 32, 32 per_image_token = 1024 // 4 if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) ones = torch.ones((input_ids.size(0), 1), dtype=attention_mask.dtype, device=attention_mask.device) outputs = self( input_ids, t=t, attention_mask=attention_mask, pixel_values=pixel_values, grid_hws=grid_hws, use_cache=True, output_hidden_states=True, ) past_key_values = outputs.past_key_values last_input_ids = input_ids[0][-1] last_token_hidden_states = outputs.hidden_states[:,-1,:].unsqueeze(1) if cfg_scale is not None and cfg_scale != 1.0: use_cfg = True uncond_input_ids = self._build_uncond_full_ids(input_ids) uncond_outputs = self( uncond_input_ids, t=t, attention_mask=attention_mask, pixel_values=pixel_values, grid_hws=grid_hws, use_cache=True, output_hidden_states=True, ) uncond_last_input_ids = uncond_input_ids[0][-1] uncond_past_key_values = uncond_outputs.past_key_values uncond_last_token_hidden_states = uncond_outputs.hidden_states[:,-1,:].unsqueeze(1) else: use_cfg = False def _clone_dynamic_cache(cache): legacy = cache.to_legacy_cache() legacy = tuple(tuple(x.clone() for x in layer) for layer in legacy) return DynamicCache.from_legacy_cache(legacy) image_latent_pre_list = [] image_gen_time = False cur_time = 0 attention_mask = outputs.attentions while True: if input_ids.size(1) > max_input_tokens: break if cur_time > max_output_tokens: break if last_input_ids == IM_START_ID: image_gen_time = True if image_gen_time: if pixel_values is not None and len(pixel_values) == input_ids.size(0) and edit_image == True: start_t = 0.1 num_inference_steps = int(num_inference_steps * (1-start_t)) last_step_size = 1 / num_inference_steps t_list = torch.linspace(start_t, 1, num_inference_steps) t_list = self.time_shift(t_list, alpha=alpha) z_image = self.model.vae_model.encode(pixel_values.to(dtype=last_token_hidden_states.dtype, device=last_token_hidden_states.device)) noise_z = torch.randn((input_ids.size(0), 128, image_h, image_w), dtype=last_token_hidden_states.dtype, device=last_token_hidden_states.device) z = start_t * z_image + (1 - start_t) * noise_z else: last_step_size = 1 / num_inference_steps t_list = torch.linspace(0, 1, num_inference_steps) t_list = self.time_shift(t_list, alpha=alpha) z = torch.randn((input_ids.size(0), 128, image_h, image_w), dtype=last_token_hidden_states.dtype, device=last_token_hidden_states.device) z_mask = torch.ones((z.size(0), (image_h//2) *(image_w//2)), dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([attention_mask, z_mask], dim=1) x_t = z for n in range(num_inference_steps): ti = t_list[n] with torch.no_grad(): drift, step_output, heat_map = self._drift_fn(x_t, ti, attention_mask, past_key_values, out_heat_map=True) if use_cfg: uncond_drift, uncond_step_output = self._drift_fn(x_t, ti, attention_mask, uncond_past_key_values) drift = uncond_drift + cfg_scale * (drift - uncond_drift) if ti != 1 and n < num_inference_steps-1: dt = t_list[n+1] - t_list[n] x_t = self.euler_maruyama_step(x_t, drift, dt) past_key_values.crop(-per_image_token) if use_cfg: uncond_past_key_values.crop(-per_image_token) else: image_latent_pre = x_t + drift * last_step_size past_key_values = step_output.past_key_values if use_cfg: uncond_past_key_values = uncond_step_output.past_key_values image_latent_pre_list.append(image_latent_pre) attention_mask = torch.cat([attention_mask, ones], dim=1) image_gen_time = False last_input_ids = IM_END_ID last_input_ids_tensor = torch.full((input_ids.size(0),1), last_input_ids, dtype=input_ids.dtype, device=input_ids.device) input_ids = torch.cat([input_ids, last_input_ids_tensor], dim=1) inputs_embeds = self.language_model.embed_tokens(last_input_ids_tensor) text_attention_mask = torch.ones((1, 1, 1, attention_mask.size(1)), dtype=attention_mask.dtype, device=attention_mask.device) step_output = self.language_model( inputs_embeds=inputs_embeds, attention_mask=text_attention_mask, past_key_values=past_key_values, use_cache=True, output_hidden_states=True, ) past_key_values = step_output.past_key_values last_token_hidden_states = step_output.last_hidden_state[:,-1,:].unsqueeze(1) if use_cfg: uncond_last_input_ids = IM_END_ID uncond_last_input_ids_tensor = torch.full((input_ids.size(0),1), uncond_last_input_ids, dtype=input_ids.dtype, device=input_ids.device) uncond_input_ids = torch.cat([uncond_input_ids, uncond_last_input_ids_tensor], dim=1) uncond_inputs_embeds = self.language_model.embed_tokens(uncond_last_input_ids_tensor) text_attention_mask = torch.ones((1, 1, 1, attention_mask.size(1)), dtype=attention_mask.dtype, device=attention_mask.device) step_output = self.language_model( inputs_embeds=uncond_inputs_embeds, attention_mask=text_attention_mask, past_key_values=uncond_past_key_values, use_cache=True, output_hidden_states=True, ) uncond_past_key_values = step_output.past_key_values uncond_last_token_hidden_states = step_output.last_hidden_state[:,-1,:].unsqueeze(1) cur_time += per_image_token else: if last_input_ids == EOS_TOKEN_ID: break if num_beams is None or num_beams <=1: last_logits = self.lm_head(last_token_hidden_states) last_input_ids = self._sample_from_logits(last_logits[:,-1,:], temp=temperature, top_p=top_p) last_input_ids_tensor = last_input_ids.unsqueeze(1) last_input_ids = last_input_ids[0] input_ids = torch.cat([input_ids, last_input_ids_tensor], dim=1) if use_cfg: uncond_last_logits = self.lm_head(uncond_last_token_hidden_states) uncond_last_input_ids = self._sample_from_logits(uncond_last_logits[:,-1,:], temp=temperature, top_p=top_p) uncond_last_input_ids_tensor = uncond_last_input_ids.unsqueeze(1) uncond_last_input_ids = uncond_last_input_ids[0] uncond_input_ids = torch.cat([uncond_input_ids, uncond_last_input_ids_tensor], dim=1) attention_mask = torch.cat([attention_mask, ones], dim=1) inputs_embeds = self.language_model.embed_tokens(last_input_ids_tensor) text_attention_mask = torch.ones((1, 1, 1, attention_mask.size(1)), dtype=attention_mask.dtype, device=attention_mask.device) step_output = self.language_model( inputs_embeds=inputs_embeds, attention_mask=text_attention_mask, past_key_values=past_key_values, use_cache=True, output_hidden_states=True, ) past_key_values = step_output.past_key_values last_token_hidden_states = step_output.last_hidden_state[:,-1,:].unsqueeze(1) if use_cfg: uncond_inputs_embeds = self.language_model.embed_tokens(uncond_last_input_ids_tensor) text_attention_mask = torch.ones((1, 1, 1, attention_mask.size(1)), dtype=attention_mask.dtype, device=attention_mask.device) step_output = self.language_model( inputs_embeds=uncond_inputs_embeds, attention_mask=text_attention_mask, past_key_values=uncond_past_key_values, use_cache=True, output_hidden_states=True, ) uncond_past_key_values = step_output.past_key_values uncond_last_token_hidden_states = step_output.last_hidden_state[:,-1,:].unsqueeze(1) else: assert input_ids.size(0) == 1, "only support beam=1" if 'beam_state' not in locals(): beam_state = { "seqs":[input_ids.clone() for _ in range(num_beams)], "scores": torch.zeros(num_beams, device=input_ids.device), "pkv": [_clone_dynamic_cache(past_key_values) for _ in range(num_beams)], "h": [last_token_hidden_states.clone() for _ in range(num_beams)], "finished": torch.zeros(num_beams, dtype=torch.bool, device=input_ids.device), } k_expand = num_beams all_cand = [] # (new_score, beam_idx, token_id) for bi in range(num_beams): if beam_state["finished"][bi]: all_cand.append((beam_state["scores"][bi], bi, EOS_TOKEN_ID)) continue h = beam_state["h"][bi] # (1,1,D) logits = self.lm_head(h)[:, -1, :].float() # (1,V) logprobs = torch.log_softmax(logits, dim=-1) # (1,V) rep = repetition_penalty if rep is not None and rep != 1.0: prev = beam_state["seqs"][bi][0] prev_ids = prev.unique() lp_vals = logprobs[0, prev_ids] logprobs[0, prev_ids] = torch.where(lp_vals > 0, lp_vals / rep, lp_vals * rep) topk_lp, topk_id = torch.topk(logprobs, k_expand, dim=-1) for j in range(k_expand): tok = int(topk_id[0, j].item()) sc = beam_state["scores"][bi] + topk_lp[0, j] all_cand.append((sc, bi, tok)) all_cand.sort(key=lambda x: x[0].item(), reverse=True) new_cand = all_cand[:num_beams] new_seqs, new_pkv, new_h = [], [], [] new_scores = torch.empty(num_beams, device=input_ids.device) new_finished = torch.empty(num_beams, dtype=torch.bool, device=input_ids.device) for i, (sc, parent_bi, tok) in enumerate(new_cand): new_scores[i] = sc parent_seq = beam_state["seqs"][parent_bi] tok_tensor = torch.tensor([[tok]], dtype=parent_seq.dtype, device=parent_seq.device) seq_i = torch.cat([parent_seq, tok_tensor], dim=1) new_seqs.append(seq_i) if tok == EOS_TOKEN_ID: new_pkv.append(beam_state["pkv"][parent_bi]) new_h.append(beam_state["h"][parent_bi]) new_finished[i] = True continue inputs_embeds = self.language_model.embed_tokens(tok_tensor) past_len = beam_state["pkv"][parent_bi].get_seq_length() cache_position = torch.tensor([past_len], device=inputs_embeds.device, dtype=torch.long) text_attention_mask = torch.ones((1, 1, 1, past_len+1), dtype=attention_mask.dtype, device=attention_mask.device) step_output = self.language_model( inputs_embeds=inputs_embeds, attention_mask=text_attention_mask, past_key_values=beam_state["pkv"][parent_bi], cache_position=cache_position, use_cache=True, output_hidden_states=True, ) new_pkv.append(step_output.past_key_values) new_h.append(step_output.last_hidden_state[:, -1, :].unsqueeze(1)) new_finished[i] = False lengths = torch.tensor([s.size(1) for s in new_seqs], device=input_ids.device, dtype=torch.float32) if length_penalty is not None and length_penalty != 1.0: norm_scores = new_scores / (lengths ** length_penalty) else: norm_scores = new_scores best_idx = int(torch.argmax(norm_scores).item()) beam_state["seqs"] = new_seqs beam_state["pkv"] = new_pkv beam_state["h"] = new_h beam_state["scores"] = new_scores beam_state["finished"] = new_finished input_ids = beam_state["seqs"][best_idx] past_key_values = beam_state["pkv"][best_idx] last_token_hidden_states = beam_state["h"][best_idx] last_input_ids = int(input_ids[0, -1].item()) attention_mask = torch.cat([attention_mask, ones], dim=1) cur_time += 1 all_images = [] for image_latent in image_latent_pre_list: image_pixcel = self.model.vae_model.decode(image_latent) all_images.append(image_pixcel) generated_ids = input_ids[:, orig_len:] return { "input_ids": generated_ids, "images": all_images, } def time_shift(self, ts, alpha=1.0): return (alpha * ts) / (1.0 + (alpha-1.0) * ts) def _drift_fn(self, x_t, t, attention_mask, past_key_values, out_heat_map=False): t = torch.full((x_t.size(0),1), t, device=x_t.device, dtype=x_t.dtype) t_embeds_1, t_embeds_2 = self.model.time_embed(t, t.dtype) x_t = self.model.vae_decoder_projector(x_t) if x_t.size(-1) > 512: interpolate_pos_encoding = True else: interpolate_pos_encoding = False image_feature = self.model.vision_representation(x_t, interpolate_pos_encoding=interpolate_pos_encoding).last_hidden_state projected_image_feature = self.model.und_projector(image_feature) h_w = int(projected_image_feature.size(1) ** 0.5) attention_mask = torch.ones((1, 1, projected_image_feature.size(1), attention_mask.size(1)), dtype=attention_mask.dtype, device=attention_mask.device) step_output = self.model.language_model( inputs_embeds=projected_image_feature, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, output_hidden_states=True, ) image_feature_pre = step_output.last_hidden_state[:, -projected_image_feature.size(1):, :] drift = self.model.gen_projector(image_feature_pre, t_embeds_1) ### batch_size, 256, dim ### drift = drift.reshape(drift.size(0), h_w, h_w, drift.size(-1)) B, H, W, C = drift.shape P = 2 D = C//(P*P) drift = drift.view(B, H, W, P, P, D) drift = drift.permute(0, 1, 3, 2, 4, 5).contiguous() drift = drift.view(B, H*P, W*P, D) drift = drift.view(B, H*P*W*P, D) patch_embedding_res = self.model.vision_representation.embeddings(x_t, interpolate_pos_encoding=interpolate_pos_encoding) if out_heat_map: hi_input, heat_map = self.model.hi_gate(drift, patch_embedding_res, heat_map=out_heat_map) else: hi_input = self.model.hi_gate(drift, patch_embedding_res, heat_map=out_heat_map) drift = self.model.hi_projector(hi_input, t_embeds_2) drift = drift.view(B, h_w*2, h_w*2, drift.size(-1)) drift = drift.permute(0, 3, 1, 2) if out_heat_map: return drift, step_output, heat_map return drift, step_output def euler_maruyama_step(self, x, drift, dt): mean_x = x + drift * dt x = mean_x return x # def _build_uncond_full_ids(self, seq_ids): # bsz, seqlen = seq_ids.shape # new_ids = seq_ids.clone() # for b in range(bsz): # seq = seq_ids[b] # in_img_block = False # for t in range(seqlen): # tok = seq[t].item() # if tok == IM_START_ID: # in_img_block = True # new_ids[b, t] = IM_START_ID # elif tok == IM_END_ID and in_img_block: # new_ids[b, t] = IM_END_ID # in_img_block = False # elif tok == IMAGE_TOKEN_INDEX: # new_ids[b, t] = IMAGE_TOKEN_INDEX # else: # if in_img_block: # new_ids[b, t] = seq[t] # else: # new_ids[b, t] = NO_MEAN_ID # return new_ids def _build_uncond_full_ids(self, seq_ids): bsz, seqlen = seq_ids.shape new_ids = seq_ids.clone() for b in range(bsz): seq = seq_ids[b] for t in range(seqlen): tok = seq[t].item() if tok in [IM_START_ID, IM_END_ID, IMAGE_TOKEN_INDEX]: new_ids[b, t] = seq[t] else: new_ids[b, t] = NO_MEAN_ID return new_ids def _sample_from_logits( self, logits: torch.Tensor, # (B, V) temp: float = 0.7, top_p: float = 0.9, top_k: int = 0, repetition_penalty: float = 1.0, prev_tokens: torch.Tensor | None = None, # (B, T) ): if temp is None or temp <= 1e-6: return logits.argmax(dim=-1) logits = logits.float() if repetition_penalty is not None and repetition_penalty != 1.0 and prev_tokens is not None: ids = prev_tokens.long() # (B, T) vals = logits.gather(1, ids) # (B, T) new_vals = torch.where(vals > 0, vals / repetition_penalty, vals * repetition_penalty) logits = logits.scatter(1, ids, new_vals) logits = logits / temp if top_k is not None and top_k > 0 and top_k < logits.size(-1): kth = torch.topk(logits, top_k, dim=-1).values[:, -1].unsqueeze(-1) # (B,1) logits = logits.masked_fill(logits < kth, float("-inf")) if top_p is not None and top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) sorted_probs = F.softmax(sorted_logits, dim=-1) cum_probs = torch.cumsum(sorted_probs, dim=-1) remove = cum_probs > top_p remove[:, 1:] = remove[:, :-1].clone() remove[:, 0] = False sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) probs = F.softmax(sorted_logits, dim=-1) idx_in_sorted = torch.multinomial(probs, 1).squeeze(-1) # (B,) return sorted_idx.gather(1, idx_in_sorted.unsqueeze(-1)).squeeze(-1) probs = F.softmax(logits, dim=-1) return torch.multinomial(probs, 1).squeeze(-1) ModelClass = Cheers __all__ = ["Cheers", "UMMModel", "VAEModel", "UMMPretrainedModel", "UMMTextModel"]