| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| from typing import Callable, Optional, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.generation import GenerationMixin |
| from transformers.integrations import use_kernel_forward_from_hub |
| from transformers.integrations.flex_attention import compile_friendly_flex_attention |
| from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask |
| from transformers.modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer |
| from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.modeling_utils import AttentionInterface, PreTrainedModel |
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available |
| from transformers.utils.generic import OutputRecorder, check_model_inputs |
| from .configuration_doge import DogeConfig |
|
|
|
|
| if is_torch_flex_attn_available(): |
| from torch.nn.attention.flex_attention import BlockMask |
|
|
|
|
| @use_kernel_forward_from_hub("RMSNorm") |
| class DogeRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| """ |
| DogeRMSNorm is equivalent to T5LayerNorm |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
| class DogeResidual(nn.Module): |
| def __init__(self, hidden_size): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
| def forward(self, residual_states, hidden_states): |
| return self.weight * residual_states + hidden_states |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}" |
|
|
|
|
| class DogeRotaryEmbedding(nn.Module): |
| def __init__(self, config: DogeConfig, device=None): |
| super().__init__() |
| |
| if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): |
| self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
| else: |
| self.rope_type = "default" |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.original_inv_freq = self.inv_freq |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| position_ids_expanded = position_ids[:, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (`torch.Tensor`): The query tensor. |
| k (`torch.Tensor`): The key tensor. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| position_ids (`torch.Tensor`, *optional*): |
| Deprecated and unused. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attn_weights |
|
|
|
|
| def flex_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Union[torch.Tensor, "BlockMask"], |
| scaling: Optional[float] = None, |
| softcap: Optional[float] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| block_mask = None |
| causal_mask = None |
| if isinstance(attention_mask, BlockMask): |
| block_mask = attention_mask |
| else: |
| causal_mask = attention_mask |
|
|
| if causal_mask is not None: |
| causal_mask = causal_mask[:, :, :, : key.shape[-2]] |
|
|
| def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): |
| if softcap is not None: |
| score = softcap * torch.tanh(score / softcap) |
| if causal_mask is not None: |
| score = score + causal_mask[batch_idx][head_idx][q_idx][kv_idx] |
| if head_mask is not None: |
| score = score + head_mask[batch_idx][head_idx][0][0] |
| return score |
|
|
| attn_output, attention_weights = compile_friendly_flex_attention( |
| query, |
| key, |
| value, |
| score_mod=score_mod, |
| block_mask=block_mask, |
| enable_gqa=True, |
| scale=scaling, |
| |
| |
| return_lse=True, |
| ) |
| |
| attention_weights = attention_weights.to(value.dtype) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attention_weights |
|
|
|
|
| ALL_ATTENTION_FUNCTIONS = AttentionInterface() |
| ALL_ATTENTION_FUNCTIONS["doge_flex_attention"] = flex_attention_forward |
|
|
|
|
| class DogeAttention(nn.Module): |
| def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
| self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
| self.scaling = self.head_dim**-0.5 |
| self.attention_dropout = config.attention_dropout |
| self.keep_window_size = config.keep_window_size |
|
|
| self.q_proj = nn.Linear( |
| config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.k_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| self.v_proj = nn.Linear( |
| config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias |
| ) |
| |
| self.A = nn.Parameter(torch.zeros(config.num_attention_heads)) |
| self.dt_proj = nn.Linear( |
| config.num_key_value_heads * self.head_dim, config.num_attention_heads, bias=config.attention_bias |
| ) |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
| query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
| if past_key_value is not None: |
| |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
| |
| dt_states = self.dt_proj( |
| value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1) |
| ) |
| dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) |
| attn_mask = self.prepare_dynamic_mask( |
| hidden_states=hidden_states, |
| dt_states=dt_states, |
| keep_window_size=self.keep_window_size, |
| attention_mask=attention_mask, |
| ) |
|
|
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask=attn_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
| def prepare_dynamic_mask( |
| self, |
| hidden_states: torch.Tensor, |
| dt_states: torch.Tensor, |
| keep_window_size: int = 2048, |
| attention_mask: Optional[torch.Tensor] = None, |
| ): |
| """ |
| The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention. |
| |
| Combine `dt_states` with `attention_mask` to generate the final `attn_mask`. |
| |
| Args: |
| hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision. |
| dt_states (`torch.Tensor`): dt_states of shape `(batch_size, num_heads, key_sequence_length)`. |
| keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value. |
| attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`. |
| """ |
| min_dtype = torch.finfo(hidden_states.dtype).min |
| dtype = hidden_states.dtype |
| attn_mask = dt_states[:, :, None, :].expand( |
| -1, -1, hidden_states.shape[1], -1 |
| ) |
| if attention_mask is not None and not isinstance(attention_mask, BlockMask): |
| if attention_mask.dtype == torch.bool: |
| dtype = hidden_states.dtype |
| attention_mask = torch.where( |
| attention_mask, torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype |
| ) |
| attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : attn_mask.shape[-1]] != 0, min_dtype) |
| if attn_mask.shape[-1] > keep_window_size: |
| active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device) |
| topk_indices = torch.topk(attn_mask, keep_window_size, dim=-1, largest=True, sorted=False).indices |
| active_mask = active_mask.scatter(-1, topk_indices, 1.0) |
| attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype) |
| return attn_mask |
|
|
|
|
| class DogeMLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, x): |
| down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
| return down_proj |
|
|
|
|
| class DogeCDMoE(nn.Module): |
| def __init__(self, config: DogeConfig): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| self.num_experts = config.num_experts |
| self.num_keys = math.floor(math.sqrt(self.num_experts)) |
| self.top_k = config.num_experts_per_tok |
| self.norm_topk_prob = config.norm_topk_prob |
|
|
| |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) |
|
|
| |
| self.router_gate = nn.Linear(self.hidden_size, self.num_keys * 2, bias=False) |
|
|
| |
| self.down_embed = nn.Embedding(self.num_experts, self.hidden_size) |
| self.up_embed = nn.Embedding(self.num_experts, self.hidden_size) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| **kwargs, |
| ) -> torch.Tensor: |
| bsz, seq_len, _ = hidden_states.shape |
|
|
| |
| router_logits = self.router_gate(hidden_states).view(2, bsz * seq_len, -1) |
|
|
| |
| (scores_x, scores_y), (indices_x, indices_y) = router_logits.topk(self.num_keys, dim=-1) |
| all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) |
| all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2) |
| all_scores = all_scores.view(*all_scores.shape[:-2], -1) |
| all_indices = all_indices.view(*all_indices.shape[:-2], -1) |
| scores, position_indices = all_scores.topk(self.top_k, dim=-1) |
| indices = all_indices.gather(-1, position_indices) |
| routing_weights = F.softmax(scores, dim=-1) |
| if self.norm_topk_prob: |
| routing_weights /= routing_weights.sum(dim=-1, keepdim=True) |
|
|
| |
| down_embed = self.down_embed(indices) |
| up_embed = self.up_embed(indices) |
| experts_weights = torch.matmul(down_embed, hidden_states.view(bsz * seq_len, -1, 1)).view(bsz * seq_len, -1) |
| experts_weights = self.act_fn(experts_weights) * routing_weights |
| experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1) |
| hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) |
| hidden_states = hidden_states + experts_states |
| return hidden_states, router_logits |
|
|
|
|
| class DogeDecoderLayer(GradientCheckpointingLayer): |
| def __init__(self, config: DogeConfig, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.hidden_dropout = config.hidden_dropout |
|
|
| self.input_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.self_attn = DogeAttention(config=config, layer_idx=layer_idx) |
| self.input_residual = nn.Parameter(torch.ones(config.hidden_size)) |
|
|
| self.post_attention_layernorm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.mlp = DogeMLP(config) if not config.is_moe else DogeCDMoE(config) |
| self.post_attention_residual = nn.Parameter(torch.ones(config.hidden_size)) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[tuple[torch.Tensor]] = None, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states, self_attn_weights = self.self_attn( |
| hidden_states=hidden_states, |
| position_embeddings=position_embeddings, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training) |
| hidden_states = self.input_residual * residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training) |
| hidden_states = self.post_attention_residual * residual + hidden_states |
|
|
| return hidden_states |
|
|
|
|
| @auto_docstring |
| class DogePreTrainedModel(PreTrainedModel): |
| config: DogeConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["DogeDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn = False |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| _can_compile_fullgraph = False |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "router_logits": OutputRecorder(DogeCDMoE, index=1), |
| "hidden_states": DogeDecoderLayer, |
| "attentions": DogeAttention, |
| } |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| super()._init_weights(module) |
| if isinstance(module, DogeAttention): |
| if hasattr(module, "A"): |
| module.A.data.zero_() |
| elif isinstance(module, DogeDecoderLayer): |
| if hasattr(module, "input_residual"): |
| module.input_residual.data.fill_(1.0) |
| if hasattr(module, "post_attention_residual"): |
| module.post_attention_residual.data.fill_(1.0) |
|
|
|
|
| @auto_docstring |
| class DogeModel(DogePreTrainedModel): |
| def __init__(self, config: DogeConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList( |
| [DogeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self.norm = DogeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = DogeRotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
|
|
| |
| self.post_init() |
|
|
| @check_model_inputs |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> MoeModelOutputWithPast: |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache() |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask |
| causal_mask = mask_function( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
|
|
| |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
| for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
| hidden_states = decoder_layer( |
| hidden_states, |
| position_embeddings=position_embeddings, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| return MoeModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| ) |
|
|
|
|
| def load_balancing_loss_func( |
| gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None], |
| num_experts: Optional[int] = None, |
| num_keys: Optional[int] = None, |
| top_k: int = 2, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> Union[torch.Tensor, int]: |
| r""" |
| Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. |
| |
| See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss |
| function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between |
| experts is too unbalanced. |
| |
| Args: |
| gate_logits: |
| Logits from the `router_gate`, should be a tuple of model.config.num_hidden_layers tensors of |
| shape [2, batch_size * sequence_length, num_keys]. |
| num_experts: |
| Number of experts |
| num_keys: |
| Number of keys |
| top_k: |
| The number of experts to route per-token, can be also interpreted as the `top-k` routing |
| parameter. |
| attention_mask (`torch.Tensor`, *optional*): |
| The attention_mask used in forward function |
| shape [batch_size X sequence_length] if not None. |
| |
| Returns: |
| The auxiliary loss. |
| """ |
| if gate_logits is None or not isinstance(gate_logits, tuple): |
| return 0 |
|
|
| compute_dtype = gate_logits[0].dtype |
| compute_device = gate_logits[0].device |
| all_expert_indices = [] |
| all_routing_weights = [] |
|
|
| for layer_gate_logits in gate_logits: |
| layer_gate_logits = layer_gate_logits.to(compute_device) |
|
|
| (scores_x, scores_y), (indices_x, indices_y) = layer_gate_logits.topk(num_keys, dim=-1) |
|
|
| all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2) |
| all_indices = indices_x.unsqueeze(-1) * num_keys + indices_y.unsqueeze(-2) |
| all_scores = all_scores.view(*all_scores.shape[:-2], -1) |
| all_indices = all_indices.view(*all_indices.shape[:-2], -1) |
|
|
| _, position_indices = all_scores.topk(top_k, dim=-1) |
| expert_indices = all_indices.gather(-1, position_indices) |
|
|
| routing_weights = F.softmax(all_scores, dim=-1) |
|
|
| all_expert_indices.append(expert_indices) |
| all_routing_weights.append(routing_weights) |
| all_expert_indices = torch.cat(all_expert_indices, dim=0) |
| all_routing_weights = torch.cat(all_routing_weights, dim=0) |
|
|
| if attention_mask is None: |
| |
| all_expert_indices = all_expert_indices.view(-1) |
| tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device) |
| pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device) |
| tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / all_expert_indices.shape[0] |
|
|
| |
| router_prob_per_expert = torch.mean(all_routing_weights, dim=0) |
| else: |
| batch_size, sequence_length = attention_mask.shape |
| num_hidden_layers = len(gate_logits) |
|
|
| |
| expert_attention_mask = ( |
| attention_mask[None, :, :, None] |
| .expand((num_hidden_layers, batch_size, sequence_length, top_k)) |
| .reshape(-1) |
| .to(compute_device) |
| ) |
| all_expert_indices = all_expert_indices.view(-1)[expert_attention_mask.bool()] |
|
|
| |
| tokens_per_expert = torch.zeros(num_experts, dtype=compute_dtype, device=compute_device) |
| pad = torch.ones_like(all_expert_indices, dtype=compute_dtype, device=compute_device) |
| tokens_per_expert = tokens_per_expert.scatter_add_(0, all_expert_indices, pad) / torch.sum( |
| expert_attention_mask |
| ) |
|
|
| |
| router_per_expert_attention_mask = ( |
| attention_mask[None, :, :, None] |
| .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) |
| .reshape(-1, num_experts) |
| .to(compute_device) |
| ) |
|
|
| |
| router_prob_per_expert = torch.sum(all_routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( |
| router_per_expert_attention_mask, dim=0 |
| ) |
|
|
| overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert) |
| return overall_loss * num_experts |
|
|
|
|
| @auto_docstring |
| class DogeForCausalLM(DogePreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
| _tp_plan = {"lm_head": "colwise_rep"} |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = DogeModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.router_aux_loss_coef = config.router_aux_loss_coef |
| self.num_experts = config.num_experts |
| self.num_experts_per_tok = config.num_experts_per_tok |
|
|
| |
| self.post_init() |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[list[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| output_router_logits: Optional[bool] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> MoeCausalLMOutputWithPast: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, DogeForCausalLM |
| |
| >>> model = DogeForCausalLM.from_pretrained("SmallDoge/Doge-320M") |
| >>> tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
| output_router_logits = ( |
| output_router_logits if output_router_logits is not None else self.config.output_router_logits |
| ) |
|
|
| |
| outputs: MoeModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) |
|
|
| aux_loss = None |
| if output_router_logits: |
| aux_loss = load_balancing_loss_func( |
| outputs.router_logits, |
| self.num_experts, |
| math.floor(math.sqrt(self.num_experts)), |
| self.num_experts_per_tok, |
| attention_mask, |
| ) |
| if labels is not None: |
| loss += self.router_aux_loss_coef * aux_loss.to(loss.device) |
|
|
| return MoeCausalLMOutputWithPast( |
| loss=loss, |
| aux_loss=aux_loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| router_logits=outputs.router_logits, |
| ) |
|
|
|
|
| class DogeForSequenceClassification(GenericForSequenceClassification, DogePreTrainedModel): |
| pass |
|
|
|
|
| __all__ = ["DogeForCausalLM", "DogeModel", "DogePreTrainedModel", "DogeForSequenceClassification"] |
|
|