# SPDX-License-Identifier: MIT # Remote code for Hugging Face Hub (QED / SLLM causal LM). # Single module so transformers dynamic import does not treat configuration_qed as a pip package. # Mirrors training-time sllm.model.SLLMForCausalLM weight names for load_state_dict compatibility. from __future__ import annotations from typing import Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from transformers import PreTrainedModel, PretrainedConfig from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast class QEDConfig(PretrainedConfig): """Configuration for QED (custom RoPE + SwiGLU decoder-only LM).""" model_type = "qed" def __init__( self, vocab_size: int = 49_152, max_seq_len: int = 8_192, d_model: int = 384, n_layers: int = 32, n_heads: int = 6, ffn_hidden_dim: int = 1_024, rope_theta: float = 10_000.0, rms_norm_eps: float = 1e-5, initializer_range: float = 0.02, dropout: float = 0.0, tie_word_embeddings: bool = True, bias: bool = False, pad_token_id: int = 0, bos_token_id: int = 1, eos_token_id: int = 2, **kwargs, ) -> None: self.vocab_size = vocab_size self.max_seq_len = max_seq_len self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.ffn_hidden_dim = ffn_hidden_dim self.rope_theta = rope_theta self.rms_norm_eps = rms_norm_eps self.initializer_range = initializer_range self.dropout = dropout self.tie_word_embeddings = tie_word_embeddings self.bias = bias super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float) -> None: super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: variance = hidden_states.pow(2).mean(dim=-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) return self.weight * hidden_states def rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) class RotaryEmbedding(nn.Module): def __init__(self, dim: int, max_seq_len: int, theta: float) -> None: super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) positions = torch.arange(max_seq_len, dtype=torch.float32) freqs = torch.outer(positions, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos(), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: cos = self.cos_cached[position_ids].unsqueeze(1).to(dtype=x.dtype, device=x.device) sin = self.sin_cached[position_ids].unsqueeze(1).to(dtype=x.dtype, device=x.device) return (x * cos) + (rotate_half(x) * sin) class CausalSelfAttention(nn.Module): def __init__(self, config: QEDConfig) -> None: super().__init__() if config.d_model % config.n_heads != 0: raise ValueError("d_model must be divisible by n_heads.") self.n_heads = config.n_heads self.head_dim = config.d_model // config.n_heads self.scale = self.head_dim**-0.5 self.q_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias) self.k_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias) self.v_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias) self.o_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias) self.rotary = RotaryEmbedding(self.head_dim, config.max_seq_len, config.rope_theta) self.dropout = config.dropout def _shape(self, x: torch.Tensor) -> torch.Tensor: batch_size, seq_len, _ = x.shape return x.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: query = self._shape(self.q_proj(hidden_states)) key = self._shape(self.k_proj(hidden_states)) value = self._shape(self.v_proj(hidden_states)) query = self.rotary(query, position_ids) key = self.rotary(key, position_ids) if past_key_value is not None: past_key, past_value = past_key_value key = torch.cat([past_key, key], dim=-2) value = torch.cat([past_value, value], dim=-2) next_past_key_value = (key, value) if use_cache else None attn_mask = None is_causal = past_key_value is None and attention_mask is None if attention_mask is not None: key_padding_mask = attention_mask[:, None, None, :].to(dtype=torch.bool, device=query.device) if not torch.all(key_padding_mask): kv_len = key.size(-2) key_padding_mask = key_padding_mask[..., :kv_len] query_positions = position_ids[:, None, :, None] key_positions = torch.arange(kv_len, device=query.device)[None, None, None, :] causal_mask = key_positions <= query_positions attn_mask = causal_mask & key_padding_mask is_causal = False elif past_key_value is not None: kv_len = key.size(-2) query_positions = position_ids[:, None, :, None] key_positions = torch.arange(kv_len, device=query.device)[None, None, None, :] attn_mask = key_positions <= query_positions is_causal = False attn_output = F.scaled_dot_product_attention( query, key, value, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, scale=self.scale, ) attn_output = attn_output.transpose(1, 2).contiguous().view(hidden_states.shape) return self.o_proj(attn_output), next_past_key_value class SwiGLU(nn.Module): def __init__(self, config: QEDConfig) -> None: super().__init__() self.gate_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=config.bias) self.up_proj = nn.Linear(config.d_model, config.ffn_hidden_dim, bias=config.bias) self.down_proj = nn.Linear(config.ffn_hidden_dim, config.d_model, bias=config.bias) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) class TransformerBlock(nn.Module): def __init__(self, config: QEDConfig) -> None: super().__init__() self.input_norm = RMSNorm(config.d_model, config.rms_norm_eps) self.attention = CausalSelfAttention(config) self.post_attn_norm = RMSNorm(config.d_model, config.rms_norm_eps) self.mlp = SwiGLU(config) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: attn_output, next_past_key_value = self.attention( self.input_norm(hidden_states), position_ids=position_ids, attention_mask=attention_mask, past_key_value=past_key_value, use_cache=use_cache, ) hidden_states = hidden_states + attn_output hidden_states = hidden_states + self.mlp(self.post_attn_norm(hidden_states)) return hidden_states, next_past_key_value class QEDForCausalLM(PreTrainedModel, GenerationMixin): config_class = QEDConfig supports_gradient_checkpointing = False _no_split_modules = ["TransformerBlock"] _tied_weights_keys = ["lm_head.weight"] @classmethod def _supports_default_dynamic_cache(cls) -> bool: """Use legacy tuple KV cache; DynamicCache expects standard HF config fields.""" return False @torch.no_grad() def _sample_next_token( self, next_token_logits: torch.Tensor, temperature: float, top_k: int | None, ) -> torch.Tensor: """ Sample next token from logits. Matches behavior of the training-time SLLM generator. """ if temperature <= 0: return torch.argmax(next_token_logits, dim=-1, keepdim=True) next_token_logits = next_token_logits / temperature if top_k is not None and top_k > 0: top_k = min(top_k, next_token_logits.size(-1)) values, _ = torch.topk(next_token_logits, top_k) cutoff = values[:, [-1]] next_token_logits = next_token_logits.masked_fill( next_token_logits < cutoff, float("-inf") ) probs = F.softmax(next_token_logits, dim=-1) return torch.multinomial(probs, num_samples=1) @torch.no_grad() def generate( self, input_ids: torch.LongTensor, max_new_tokens: int = 128, temperature: float = 0.8, top_k: int | None = 50, eos_token_id: Optional[int] = None, do_sample: bool = True, **kwargs, ) -> torch.LongTensor: """ Generate tokens using the same logic as `src/sllm/model.py::SLLMForCausalLM.generate`. We override HF's `GenerationMixin.generate()` because its cache/position semantics can differ from this model's legacy KV cache path. This makes HF inference match your local script output. """ _ = kwargs if eos_token_id is None: eos_token_id = getattr(self.config, "eos_token_id", None) # For compatibility: if caller doesn't want sampling, force greedy decoding. if not do_sample: temperature = 0.0 generated = input_ids[:, -self.config.max_seq_len :] outputs = self(generated, use_cache=True) past_key_values = outputs.past_key_values next_token_logits = outputs.logits[:, -1, :] for _ in range(max_new_tokens): next_token = self._sample_next_token( next_token_logits, temperature=temperature, top_k=top_k ) generated = torch.cat([generated, next_token], dim=1) if eos_token_id is not None and torch.all(next_token.squeeze(-1) == eos_token_id): break if generated.size(1) >= self.config.max_seq_len: # Sliding window when the context is full. context = generated[:, -self.config.max_seq_len :] outputs = self(context, use_cache=True) else: # One-step decode with cached KV. outputs = self( next_token, past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values next_token_logits = outputs.logits[:, -1, :] return generated def __init__(self, config: QEDConfig) -> None: super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) self.norm = RMSNorm(config.d_model, config.rms_norm_eps) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=True) if config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight self.post_init() def get_input_embeddings(self) -> nn.Module: return self.embed_tokens def set_input_embeddings(self, value: nn.Module) -> None: self.embed_tokens = value def get_output_embeddings(self) -> nn.Module: return self.lm_head def set_output_embeddings(self, new_embeddings: nn.Module) -> None: self.lm_head = new_embeddings def _tie_weights(self) -> None: if self.config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, past_key_values: Optional[list] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, **kwargs, ) -> dict: """Legacy tuple cache + one-token steps for HF generate().""" _ = kwargs if past_key_values is not None: input_ids = input_ids[:, -1:] model_inputs = { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True, } if attention_mask is not None: model_inputs["attention_mask"] = attention_mask if inputs_embeds is not None and past_key_values is None: model_inputs["inputs_embeds"] = inputs_embeds return model_inputs def _reorder_cache(self, past_key_values: list, beam_idx: torch.LongTensor) -> list: reordered: list = [] for layer_past in past_key_values: reordered_layer = tuple( past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past ) reordered.append(reordered_layer) return reordered def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[list] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_type_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: _ = token_type_ids, kwargs _ = output_attentions, output_hidden_states return_dict = return_dict if return_dict is not None else True use_cache = use_cache if use_cache is not None else False if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds") if input_ids is None and inputs_embeds is None: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: batch_size, seq_len = input_ids.shape hidden_states = self.embed_tokens(input_ids) else: hidden_states = inputs_embeds batch_size, seq_len = hidden_states.shape[:2] past_length = 0 if past_key_values is not None and len(past_key_values) > 0: past_length = past_key_values[0][0].size(-2) total_seq_len = past_length + seq_len if total_seq_len > self.config.max_seq_len: raise ValueError( f"Input length {total_seq_len} exceeds model context window {self.config.max_seq_len}." ) if position_ids is None: position_ids = torch.arange( past_length, total_seq_len, device=hidden_states.device, ).unsqueeze(0).expand(batch_size, -1) next_past_key_values: list = [] for layer_index, layer in enumerate(self.layers): layer_past = past_key_values[layer_index] if past_key_values is not None else None hidden_states, next_past_key_value = layer( hidden_states, position_ids=position_ids, attention_mask=attention_mask, past_key_value=layer_past, use_cache=use_cache, ) if use_cache and next_past_key_value is not None: next_past_key_values.append(next_past_key_value) hidden_states = self.norm(hidden_states) logits = self.lm_head(hidden_states) loss = None if labels is not None: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100, ) if not return_dict: out = (logits,) if past_key_values is not None or use_cache: out = out + (next_past_key_values if use_cache else None,) if loss is not None: out = (loss,) + out return out return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=next_past_key_values if use_cache else None, hidden_states=None, attentions=None, )