Spaces:
Running
on
Zero
Running
on
Zero
| # source https://github.com/TheDenk/wan2.1-dilated-controlnet/blob/main/wan_controlnet.py | |
| from typing import Any, Dict, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin | |
| from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers | |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.models.transformers.transformer_wan import ( | |
| WanTimeTextImageEmbedding, | |
| WanRotaryPosEmbed, | |
| WanTransformerBlock | |
| ) | |
| def zero_module(module): | |
| for p in module.parameters(): | |
| nn.init.zeros_(p) | |
| return module | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class WanControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): | |
| r""" | |
| A Controlnet Transformer model for video-like data used in the Wan model. | |
| Args: | |
| patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): | |
| 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). | |
| num_attention_heads (`int`, defaults to `40`): | |
| Fixed length for text embeddings. | |
| attention_head_dim (`int`, defaults to `128`): | |
| The number of channels in each head. | |
| vae_channels (`int`, defaults to `16`): | |
| The number of channels in the vae input. | |
| in_channels (`int`, defaults to `16`): | |
| The number of channels in the controlnet input. | |
| text_dim (`int`, defaults to `512`): | |
| Input dimension for text embeddings. | |
| freq_dim (`int`, defaults to `256`): | |
| Dimension for sinusoidal time embeddings. | |
| ffn_dim (`int`, defaults to `13824`): | |
| Intermediate dimension in feed-forward network. | |
| num_layers (`int`, defaults to `40`): | |
| The number of layers of transformer blocks to use. | |
| window_size (`Tuple[int]`, defaults to `(-1, -1)`): | |
| Window size for local attention (-1 indicates global attention). | |
| cross_attn_norm (`bool`, defaults to `True`): | |
| Enable cross-attention normalization. | |
| qk_norm (`bool`, defaults to `True`): | |
| Enable query/key normalization. | |
| eps (`float`, defaults to `1e-6`): | |
| Epsilon value for normalization layers. | |
| add_img_emb (`bool`, defaults to `False`): | |
| Whether to use img_emb. | |
| added_kv_proj_dim (`int`, *optional*, defaults to `None`): | |
| The number of channels to use for the added key and value projections. If `None`, no projection is used. | |
| downscale_coef (`int`, *optional*, defaults to `8`): | |
| Coeficient for downscale controlnet input video. | |
| out_proj_dim (`int`, *optional*, defaults to `128 * 12`): | |
| Output projection dimention for last linear layers. | |
| """ | |
| _supports_gradient_checkpointing = True | |
| _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] | |
| _no_split_modules = ["WanTransformerBlock"] | |
| _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] | |
| _keys_to_ignore_on_load_unexpected = ["norm_added_q"] | |
| def __init__( | |
| self, | |
| patch_size: Tuple[int] = (1, 2, 2), | |
| num_attention_heads: int = 40, | |
| attention_head_dim: int = 128, | |
| in_channels: int = 3, | |
| vae_channels: int = 16, | |
| text_dim: int = 4096, | |
| freq_dim: int = 256, | |
| ffn_dim: int = 13824, | |
| num_layers: int = 20, | |
| cross_attn_norm: bool = True, | |
| qk_norm: Optional[str] = "rms_norm_across_heads", | |
| eps: float = 1e-6, | |
| image_dim: Optional[int] = None, | |
| added_kv_proj_dim: Optional[int] = None, | |
| rope_max_seq_len: int = 1024, | |
| downscale_coef: int = 8, | |
| out_proj_dim: int = 128 * 12, | |
| ) -> None: | |
| super().__init__() | |
| start_channels = in_channels * (downscale_coef ** 2) | |
| input_channels = [start_channels, start_channels // 2, start_channels // 4] | |
| self.control_encoder = nn.ModuleList([ | |
| ## Spatial compression with time awareness | |
| nn.Sequential( | |
| nn.Conv3d( | |
| in_channels, | |
| input_channels[0], | |
| kernel_size=(3, downscale_coef + 1, downscale_coef + 1), | |
| stride=(1, downscale_coef, downscale_coef), | |
| padding=(1, downscale_coef // 2, downscale_coef // 2) | |
| ), | |
| nn.GELU(approximate="tanh"), | |
| nn.GroupNorm(2, input_channels[0]), | |
| ), | |
| ## Temporal compression with spatial awareness | |
| nn.Sequential( | |
| nn.Conv3d(input_channels[0], input_channels[1], kernel_size=3, stride=(2, 1, 1), padding=1), | |
| nn.GELU(approximate="tanh"), | |
| nn.GroupNorm(2, input_channels[1]), | |
| ), | |
| ## Temporal compression with spatial awareness | |
| nn.Sequential( | |
| nn.Conv3d(input_channels[1], input_channels[2], kernel_size=3, stride=(2, 1, 1), padding=1), | |
| nn.GELU(approximate="tanh"), | |
| nn.GroupNorm(2, input_channels[2]), | |
| ) | |
| ]) | |
| inner_dim = num_attention_heads * attention_head_dim | |
| # 1. Patch & position embedding | |
| self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) | |
| self.patch_embedding = nn.Conv3d(vae_channels + input_channels[2], inner_dim, kernel_size=patch_size, stride=patch_size) | |
| # 2. Condition embeddings | |
| # image_embedding_dim=1280 for I2V model | |
| self.condition_embedder = WanTimeTextImageEmbedding( | |
| dim=inner_dim, | |
| time_freq_dim=freq_dim, | |
| time_proj_dim=inner_dim * 6, | |
| text_embed_dim=text_dim, | |
| image_embed_dim=image_dim, | |
| ) | |
| # 3. Transformer blocks | |
| self.blocks = nn.ModuleList( | |
| [ | |
| WanTransformerBlock( | |
| inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| # 4 Controlnet modules | |
| self.controlnet_blocks = nn.ModuleList([]) | |
| for _ in range(len(self.blocks)): | |
| controlnet_block = nn.Linear(inner_dim, out_proj_dim) | |
| controlnet_block = zero_module(controlnet_block) | |
| self.controlnet_blocks.append(controlnet_block) | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| timestep: torch.LongTensor, | |
| encoder_hidden_states: torch.Tensor, | |
| controlnet_states: torch.Tensor, | |
| encoder_hidden_states_image: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| attention_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: | |
| if attention_kwargs is not None: | |
| attention_kwargs = attention_kwargs.copy() | |
| lora_scale = attention_kwargs.pop("scale", 1.0) | |
| else: | |
| lora_scale = 1.0 | |
| if USE_PEFT_BACKEND: | |
| # weight the lora layers by setting `lora_scale` for each PEFT layer | |
| scale_lora_layers(self, lora_scale) | |
| else: | |
| if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: | |
| logger.warning( | |
| "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." | |
| ) | |
| rotary_emb = self.rope(hidden_states) | |
| # 0. Controlnet encoder | |
| for control_encoder_block in self.control_encoder: | |
| controlnet_states = control_encoder_block(controlnet_states) | |
| hidden_states = torch.cat([hidden_states, controlnet_states], dim=1) | |
| ## 1. Patch embedding and stack | |
| hidden_states = self.patch_embedding(hidden_states) | |
| hidden_states = hidden_states.flatten(2).transpose(1, 2) | |
| temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( | |
| timestep, encoder_hidden_states, encoder_hidden_states_image | |
| ) | |
| timestep_proj = timestep_proj.unflatten(1, (6, -1)) | |
| if encoder_hidden_states_image is not None: | |
| encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) | |
| # 2. Transformer blocks | |
| controlnet_hidden_states = () | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| for block, controlnet_block in zip(self.blocks, self.controlnet_blocks): | |
| hidden_states = self._gradient_checkpointing_func( | |
| block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb | |
| ) | |
| controlnet_hidden_states += (controlnet_block(hidden_states),) | |
| else: | |
| for block, controlnet_block in zip(self.blocks, self.controlnet_blocks): | |
| hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) | |
| controlnet_hidden_states += (controlnet_block(hidden_states),) | |
| if USE_PEFT_BACKEND: | |
| # remove `lora_scale` from each PEFT layer | |
| unscale_lora_layers(self, lora_scale) | |
| if not return_dict: | |
| return (controlnet_hidden_states,) | |
| return Transformer2DModelOutput(sample=controlnet_hidden_states) | |
| if __name__ == "__main__": | |
| parameters = { | |
| "added_kv_proj_dim": None, | |
| "attention_head_dim": 128, | |
| "cross_attn_norm": True, | |
| "eps": 1e-06, | |
| "ffn_dim": 8960, | |
| "freq_dim": 256, | |
| "image_dim": None, | |
| "in_channels": 3, | |
| "num_attention_heads": 12, | |
| "num_layers": 2, | |
| "patch_size": [1, 2, 2], | |
| "qk_norm": "rms_norm_across_heads", | |
| "rope_max_seq_len": 1024, | |
| "text_dim": 4096, | |
| "downscale_coef": 8, | |
| "out_proj_dim": 12 * 128, | |
| } | |
| controlnet = WanControlnet(**parameters) | |
| hidden_states = torch.rand(1, 16, 21, 60, 90) | |
| timestep = torch.randint(low=0, high=1000, size=(1,), dtype=torch.long) | |
| encoder_hidden_states = torch.rand(1, 512, 4096) | |
| controlnet_states = torch.rand(1, 3, 81, 480, 720) | |
| controlnet_hidden_states = controlnet( | |
| hidden_states=hidden_states, | |
| timestep=timestep, | |
| encoder_hidden_states=encoder_hidden_states, | |
| controlnet_states=controlnet_states, | |
| return_dict=False | |
| ) | |
| print("Output states count", len(controlnet_hidden_states[0])) | |
| for out_hidden_states in controlnet_hidden_states[0]: | |
| print(out_hidden_states.shape) | |