# source https://github.com/TheDenk/wan2.1-dilated-controlnet/blob/main/wan_controlnet.py from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.models.transformers.transformer_wan import ( WanTimeTextImageEmbedding, WanRotaryPosEmbed, WanTransformerBlock ) def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module logger = logging.get_logger(__name__) # pylint: disable=invalid-name class WanControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A Controlnet Transformer model for video-like data used in the Wan model. Args: patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). num_attention_heads (`int`, defaults to `40`): Fixed length for text embeddings. attention_head_dim (`int`, defaults to `128`): The number of channels in each head. vae_channels (`int`, defaults to `16`): The number of channels in the vae input. in_channels (`int`, defaults to `16`): The number of channels in the controlnet input. text_dim (`int`, defaults to `512`): Input dimension for text embeddings. freq_dim (`int`, defaults to `256`): Dimension for sinusoidal time embeddings. ffn_dim (`int`, defaults to `13824`): Intermediate dimension in feed-forward network. num_layers (`int`, defaults to `40`): The number of layers of transformer blocks to use. window_size (`Tuple[int]`, defaults to `(-1, -1)`): Window size for local attention (-1 indicates global attention). cross_attn_norm (`bool`, defaults to `True`): Enable cross-attention normalization. qk_norm (`bool`, defaults to `True`): Enable query/key normalization. eps (`float`, defaults to `1e-6`): Epsilon value for normalization layers. add_img_emb (`bool`, defaults to `False`): Whether to use img_emb. added_kv_proj_dim (`int`, *optional*, defaults to `None`): The number of channels to use for the added key and value projections. If `None`, no projection is used. downscale_coef (`int`, *optional*, defaults to `8`): Coeficient for downscale controlnet input video. out_proj_dim (`int`, *optional*, defaults to `128 * 12`): Output projection dimention for last linear layers. """ _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["WanTransformerBlock"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @register_to_config def __init__( self, patch_size: Tuple[int] = (1, 2, 2), num_attention_heads: int = 40, attention_head_dim: int = 128, in_channels: int = 3, vae_channels: int = 16, text_dim: int = 4096, freq_dim: int = 256, ffn_dim: int = 13824, num_layers: int = 20, cross_attn_norm: bool = True, qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, image_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None, rope_max_seq_len: int = 1024, downscale_coef: int = 8, out_proj_dim: int = 128 * 12, ) -> None: super().__init__() start_channels = in_channels * (downscale_coef ** 2) input_channels = [start_channels, start_channels // 2, start_channels // 4] self.control_encoder = nn.ModuleList([ ## Spatial compression with time awareness nn.Sequential( nn.Conv3d( in_channels, input_channels[0], kernel_size=(3, downscale_coef + 1, downscale_coef + 1), stride=(1, downscale_coef, downscale_coef), padding=(1, downscale_coef // 2, downscale_coef // 2) ), nn.GELU(approximate="tanh"), nn.GroupNorm(2, input_channels[0]), ), ## Temporal compression with spatial awareness nn.Sequential( nn.Conv3d(input_channels[0], input_channels[1], kernel_size=3, stride=(2, 1, 1), padding=1), nn.GELU(approximate="tanh"), nn.GroupNorm(2, input_channels[1]), ), ## Temporal compression with spatial awareness nn.Sequential( nn.Conv3d(input_channels[1], input_channels[2], kernel_size=3, stride=(2, 1, 1), padding=1), nn.GELU(approximate="tanh"), nn.GroupNorm(2, input_channels[2]), ) ]) inner_dim = num_attention_heads * attention_head_dim # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) self.patch_embedding = nn.Conv3d(vae_channels + input_channels[2], inner_dim, kernel_size=patch_size, stride=patch_size) # 2. Condition embeddings # image_embedding_dim=1280 for I2V model self.condition_embedder = WanTimeTextImageEmbedding( dim=inner_dim, time_freq_dim=freq_dim, time_proj_dim=inner_dim * 6, text_embed_dim=text_dim, image_embed_dim=image_dim, ) # 3. Transformer blocks self.blocks = nn.ModuleList( [ WanTransformerBlock( inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim ) for _ in range(num_layers) ] ) # 4 Controlnet modules self.controlnet_blocks = nn.ModuleList([]) for _ in range(len(self.blocks)): controlnet_block = nn.Linear(inner_dim, out_proj_dim) controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, controlnet_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) rotary_emb = self.rope(hidden_states) # 0. Controlnet encoder for control_encoder_block in self.control_encoder: controlnet_states = control_encoder_block(controlnet_states) hidden_states = torch.cat([hidden_states, controlnet_states], dim=1) ## 1. Patch embedding and stack hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) timestep_proj = timestep_proj.unflatten(1, (6, -1)) if encoder_hidden_states_image is not None: encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) # 2. Transformer blocks controlnet_hidden_states = () if torch.is_grad_enabled() and self.gradient_checkpointing: for block, controlnet_block in zip(self.blocks, self.controlnet_blocks): hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb ) controlnet_hidden_states += (controlnet_block(hidden_states),) else: for block, controlnet_block in zip(self.blocks, self.controlnet_blocks): hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) controlnet_hidden_states += (controlnet_block(hidden_states),) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (controlnet_hidden_states,) return Transformer2DModelOutput(sample=controlnet_hidden_states) if __name__ == "__main__": parameters = { "added_kv_proj_dim": None, "attention_head_dim": 128, "cross_attn_norm": True, "eps": 1e-06, "ffn_dim": 8960, "freq_dim": 256, "image_dim": None, "in_channels": 3, "num_attention_heads": 12, "num_layers": 2, "patch_size": [1, 2, 2], "qk_norm": "rms_norm_across_heads", "rope_max_seq_len": 1024, "text_dim": 4096, "downscale_coef": 8, "out_proj_dim": 12 * 128, } controlnet = WanControlnet(**parameters) hidden_states = torch.rand(1, 16, 21, 60, 90) timestep = torch.randint(low=0, high=1000, size=(1,), dtype=torch.long) encoder_hidden_states = torch.rand(1, 512, 4096) controlnet_states = torch.rand(1, 3, 81, 480, 720) controlnet_hidden_states = controlnet( hidden_states=hidden_states, timestep=timestep, encoder_hidden_states=encoder_hidden_states, controlnet_states=controlnet_states, return_dict=False ) print("Output states count", len(controlnet_hidden_states[0])) for out_hidden_states in controlnet_hidden_states[0]: print(out_hidden_states.shape)