"""OgmaModel — top-level model wrapping any architecture variant.""" from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from .config import OgmaConfig, TaskToken, VariantType from .embeddings import TokenEmbedding from .pooling import create_pooling from .transformer import TransformerVariant __all__ = ["OgmaModel"] MAX_PARAMS = 10_000_000 def _build_variant(config: OgmaConfig) -> nn.Module: """Instantiate the released Ogma architecture variant.""" if config.variant != VariantType.TRANSFORMER: raise ValueError(f"This HF release supports transformer checkpoints, got {config.variant}") return TransformerVariant(config) class OgmaModel(PreTrainedModel): """Ogma embedding model. Wraps any architecture variant with shared embedding, pooling, and normalization. Produces L2-normalized embeddings at d_output dimensions, Matryoshka-compatible at configured sub-dimensions. """ config_class = OgmaConfig base_model_prefix = "ogma" supports_gradient_checkpointing = False _tied_weights_keys: list[str] = [] all_tied_weights_keys: dict[str, str] = {} def __init__(self, config: OgmaConfig) -> None: super().__init__(config) self.config = config self.embedding = TokenEmbedding(config) self.variant = _build_variant(config) self.pooling = create_pooling(config) # Output projection if variant output != d_output needs_proj = ( config.variant == VariantType.DEEP_NARROW and config.d_model != config.d_output ) # DeepNarrowVariant already has output_proj, so no extra needed here if not needs_proj and config.d_model != config.d_output: self.output_proj: nn.Module = nn.Linear( config.d_model, config.d_output ) else: self.output_proj = nn.Identity() def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, task_token_ids: torch.Tensor | None = None, token_ids: torch.Tensor | None = None, **_: object, ) -> torch.Tensor: """Forward pass producing L2-normalized embeddings. Args: input_ids: (B, S) token IDs, Hugging Face style. attention_mask: (B, S) attention mask (1=valid, 0=pad). task_token_ids: (B,) task token IDs (4=QRY, 5=DOC, 6=SYM). token_ids: Backward-compatible alias for input_ids. Returns: (B, d_output) L2-normalized embeddings. """ if input_ids is None: input_ids = token_ids if input_ids is None: raise ValueError("input_ids or token_ids must be provided") if attention_mask is None: attention_mask = torch.ones_like(input_ids) if task_token_ids is None: task_token_ids = torch.full( (input_ids.shape[0],), self.config.sym_id, device=input_ids.device, dtype=torch.long, ) token_ids = input_ids # Embed tokens with task token prepended -> (B, S+1, d_model) x = self.embedding(token_ids, task_token_ids) # Extend attention mask for prepended task token task_mask = torch.ones( attention_mask.shape[0], 1, device=attention_mask.device, dtype=attention_mask.dtype, ) extended_mask = torch.cat([task_mask, attention_mask], dim=1) # Run through variant x = self.variant(x, extended_mask) # Pool x = self.pooling(x, extended_mask) # Project if needed x = self.output_proj(x) # L2 normalize return F.normalize(x, p=2, dim=-1) def encode( self, token_ids: torch.Tensor, attention_mask: torch.Tensor, task: TaskToken = TaskToken.SYM, ) -> torch.Tensor: """Encode tokens with a specified task mode. Args: token_ids: (B, S) token IDs. attention_mask: (B, S) attention mask. task: Task token to use. Returns: (B, d_output) L2-normalized embeddings. """ task_ids = torch.full( (token_ids.shape[0],), self.config.task_token_id(task), device=token_ids.device, dtype=torch.long, ) return self.forward(input_ids=token_ids, attention_mask=attention_mask, task_token_ids=task_ids) def param_count(self) -> int: """Count total trainable parameters.""" return sum(p.numel() for p in self.parameters() if p.requires_grad) def assert_param_budget(self) -> None: """Assert model is under the 10M parameter budget.""" count = self.param_count() assert count < MAX_PARAMS, ( f"Model has {count:,} params, exceeds {MAX_PARAMS:,} budget" ) @classmethod def from_config(cls, config: OgmaConfig) -> OgmaModel: """Factory method to build a model from config.""" model = cls(config) if model.param_count() < MAX_PARAMS: model.assert_param_budget() return model @classmethod def from_checkpoint( cls, path: str, device: str = "cpu", ) -> OgmaModel: """Load model from a checkpoint directory. Args: path: Path to checkpoint directory containing config.yaml and model.pt. device: Device to load model to. Returns: Loaded OgmaModel. """ from pathlib import Path import yaml ckpt_path = Path(path) with open(ckpt_path / "config.yaml") as f: config_dict = yaml.safe_load(f) config = OgmaConfig.from_dict(config_dict) model = cls(config) state_dict = torch.load( ckpt_path / "model.pt", map_location=device, weights_only=True, ) model.load_state_dict(state_dict) model.to(device) model.eval() return model def save_checkpoint(self, path: str) -> None: """Save model checkpoint. Args: path: Directory to save config.yaml and model.pt. """ from pathlib import Path import yaml ckpt_path = Path(path) ckpt_path.mkdir(parents=True, exist_ok=True) with open(ckpt_path / "config.yaml", "w") as f: yaml.dump(self.config.to_dict(), f, default_flow_style=False) torch.save(self.state_dict(), ckpt_path / "model.pt")