"""Variant A: Tiny Transformer — 1-2 layer standard transformer.""" from __future__ import annotations import math import torch import torch.nn as nn import torch.nn.functional as F from .configuration_ogma import OgmaConfig from .embeddings import RotaryPositionalEncoding, apply_rope __all__ = ["TransformerVariant"] class SwiGLU(nn.Module): """SwiGLU feedforward network.""" def __init__(self, d_model: int, d_hidden: int, dropout: float = 0.0) -> None: super().__init__() self.w1 = nn.Linear(d_model, d_hidden, bias=False) self.w2 = nn.Linear(d_model, d_hidden, bias=False) self.w3 = nn.Linear(d_hidden, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: out: torch.Tensor = self.dropout( self.w3(F.silu(self.w1(x)) * self.w2(x)) ) return out class TransformerLayer(nn.Module): """Single transformer encoder layer with RoPE and SwiGLU.""" def __init__(self, config: OgmaConfig, rope: RotaryPositionalEncoding) -> None: super().__init__() self.n_heads = config.n_heads self.d_head = config.d_head self.rope = rope self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.norm1 = nn.LayerNorm(config.d_model) self.norm2 = nn.LayerNorm(config.d_model) self.ffn = SwiGLU(config.d_model, config.ffn_hidden, config.dropout) self.attn_dropout = nn.Dropout(config.dropout) def forward( self, x: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: B, S, D = x.shape # Pre-norm attention h = self.norm1(x) q = self.q_proj(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2) k = self.k_proj(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2) v = self.v_proj(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2) cos, sin = self.rope(h) q, k = apply_rope(q, k, cos, sin) scale = 1.0 / math.sqrt(self.d_head) attn = torch.matmul(q, k.transpose(-2, -1)) * scale if attention_mask is not None: # attention_mask: (B, S) -> (B, 1, 1, S) for broadcasting mask = attention_mask.unsqueeze(1).unsqueeze(2) attn = attn.masked_fill(mask == 0, float("-inf")) attn = self.attn_dropout(F.softmax(attn, dim=-1)) out = torch.matmul(attn, v) out = out.transpose(1, 2).contiguous().view(B, S, D) x = x + self.o_proj(out) # Pre-norm FFN x = x + self.ffn(self.norm2(x)) return x class TransformerVariant(nn.Module): """Variant A: 1-2 layer transformer encoder with RoPE and SwiGLU.""" def __init__(self, config: OgmaConfig) -> None: super().__init__() rope = RotaryPositionalEncoding(config.d_head, config.max_seq_len + 1) self.layers = nn.ModuleList( [TransformerLayer(config, rope) for _ in range(config.n_layers)] ) def forward( self, x: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: for layer in self.layers: x = layer(x, attention_mask) return x