ogma-base / transformer.py
Antreas's picture
Enable AutoModel loading
0be3e32 verified
"""Variant A: Tiny Transformer — 1-2 layer standard transformer."""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .configuration_ogma import OgmaConfig
from .embeddings import RotaryPositionalEncoding, apply_rope
__all__ = ["TransformerVariant"]
class SwiGLU(nn.Module):
"""SwiGLU feedforward network."""
def __init__(self, d_model: int, d_hidden: int, dropout: float = 0.0) -> None:
super().__init__()
self.w1 = nn.Linear(d_model, d_hidden, bias=False)
self.w2 = nn.Linear(d_model, d_hidden, bias=False)
self.w3 = nn.Linear(d_hidden, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out: torch.Tensor = self.dropout(
self.w3(F.silu(self.w1(x)) * self.w2(x))
)
return out
class TransformerLayer(nn.Module):
"""Single transformer encoder layer with RoPE and SwiGLU."""
def __init__(self, config: OgmaConfig, rope: RotaryPositionalEncoding) -> None:
super().__init__()
self.n_heads = config.n_heads
self.d_head = config.d_head
self.rope = rope
self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.norm1 = nn.LayerNorm(config.d_model)
self.norm2 = nn.LayerNorm(config.d_model)
self.ffn = SwiGLU(config.d_model, config.ffn_hidden, config.dropout)
self.attn_dropout = nn.Dropout(config.dropout)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
B, S, D = x.shape
# Pre-norm attention
h = self.norm1(x)
q = self.q_proj(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
k = self.k_proj(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
v = self.v_proj(h).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
cos, sin = self.rope(h)
q, k = apply_rope(q, k, cos, sin)
scale = 1.0 / math.sqrt(self.d_head)
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
if attention_mask is not None:
# attention_mask: (B, S) -> (B, 1, 1, S) for broadcasting
mask = attention_mask.unsqueeze(1).unsqueeze(2)
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = self.attn_dropout(F.softmax(attn, dim=-1))
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, S, D)
x = x + self.o_proj(out)
# Pre-norm FFN
x = x + self.ffn(self.norm2(x))
return x
class TransformerVariant(nn.Module):
"""Variant A: 1-2 layer transformer encoder with RoPE and SwiGLU."""
def __init__(self, config: OgmaConfig) -> None:
super().__init__()
rope = RotaryPositionalEncoding(config.d_head, config.max_seq_len + 1)
self.layers = nn.ModuleList(
[TransformerLayer(config, rope) for _ in range(config.n_layers)]
)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, attention_mask)
return x