ogma-base / ogma_model.py
Antreas's picture
Enable AutoModel loading
9019da5 verified
"""OgmaModel — top-level model wrapping any architecture variant."""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from .config import OgmaConfig, TaskToken, VariantType
from .embeddings import TokenEmbedding
from .pooling import create_pooling
from .transformer import TransformerVariant
__all__ = ["OgmaModel"]
MAX_PARAMS = 10_000_000
def _build_variant(config: OgmaConfig) -> nn.Module:
"""Instantiate the released Ogma architecture variant."""
if config.variant != VariantType.TRANSFORMER:
raise ValueError(f"This HF release supports transformer checkpoints, got {config.variant}")
return TransformerVariant(config)
class OgmaModel(PreTrainedModel):
"""Ogma embedding model.
Wraps any architecture variant with shared embedding, pooling, and
normalization. Produces L2-normalized embeddings at d_output dimensions,
Matryoshka-compatible at configured sub-dimensions.
"""
config_class = OgmaConfig
base_model_prefix = "ogma"
supports_gradient_checkpointing = False
_tied_weights_keys: list[str] = []
all_tied_weights_keys: dict[str, str] = {}
def __init__(self, config: OgmaConfig) -> None:
super().__init__(config)
self.config = config
self.embedding = TokenEmbedding(config)
self.variant = _build_variant(config)
self.pooling = create_pooling(config)
# Output projection if variant output != d_output
needs_proj = (
config.variant == VariantType.DEEP_NARROW
and config.d_model != config.d_output
)
# DeepNarrowVariant already has output_proj, so no extra needed here
if not needs_proj and config.d_model != config.d_output:
self.output_proj: nn.Module = nn.Linear(
config.d_model, config.d_output
)
else:
self.output_proj = nn.Identity()
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
task_token_ids: torch.Tensor | None = None,
token_ids: torch.Tensor | None = None,
**_: object,
) -> torch.Tensor:
"""Forward pass producing L2-normalized embeddings.
Args:
input_ids: (B, S) token IDs, Hugging Face style.
attention_mask: (B, S) attention mask (1=valid, 0=pad).
task_token_ids: (B,) task token IDs (4=QRY, 5=DOC, 6=SYM).
token_ids: Backward-compatible alias for input_ids.
Returns:
(B, d_output) L2-normalized embeddings.
"""
if input_ids is None:
input_ids = token_ids
if input_ids is None:
raise ValueError("input_ids or token_ids must be provided")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if task_token_ids is None:
task_token_ids = torch.full(
(input_ids.shape[0],),
self.config.sym_id,
device=input_ids.device,
dtype=torch.long,
)
token_ids = input_ids
# Embed tokens with task token prepended -> (B, S+1, d_model)
x = self.embedding(token_ids, task_token_ids)
# Extend attention mask for prepended task token
task_mask = torch.ones(
attention_mask.shape[0], 1,
device=attention_mask.device,
dtype=attention_mask.dtype,
)
extended_mask = torch.cat([task_mask, attention_mask], dim=1)
# Run through variant
x = self.variant(x, extended_mask)
# Pool
x = self.pooling(x, extended_mask)
# Project if needed
x = self.output_proj(x)
# L2 normalize
return F.normalize(x, p=2, dim=-1)
def encode(
self,
token_ids: torch.Tensor,
attention_mask: torch.Tensor,
task: TaskToken = TaskToken.SYM,
) -> torch.Tensor:
"""Encode tokens with a specified task mode.
Args:
token_ids: (B, S) token IDs.
attention_mask: (B, S) attention mask.
task: Task token to use.
Returns:
(B, d_output) L2-normalized embeddings.
"""
task_ids = torch.full(
(token_ids.shape[0],),
self.config.task_token_id(task),
device=token_ids.device,
dtype=torch.long,
)
return self.forward(input_ids=token_ids, attention_mask=attention_mask, task_token_ids=task_ids)
def param_count(self) -> int:
"""Count total trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def assert_param_budget(self) -> None:
"""Assert model is under the 10M parameter budget."""
count = self.param_count()
assert count < MAX_PARAMS, (
f"Model has {count:,} params, exceeds {MAX_PARAMS:,} budget"
)
@classmethod
def from_config(cls, config: OgmaConfig) -> OgmaModel:
"""Factory method to build a model from config."""
model = cls(config)
if model.param_count() < MAX_PARAMS:
model.assert_param_budget()
return model
@classmethod
def from_checkpoint(
cls,
path: str,
device: str = "cpu",
) -> OgmaModel:
"""Load model from a checkpoint directory.
Args:
path: Path to checkpoint directory containing config.yaml
and model.pt.
device: Device to load model to.
Returns:
Loaded OgmaModel.
"""
from pathlib import Path
import yaml
ckpt_path = Path(path)
with open(ckpt_path / "config.yaml") as f:
config_dict = yaml.safe_load(f)
config = OgmaConfig.from_dict(config_dict)
model = cls(config)
state_dict = torch.load(
ckpt_path / "model.pt",
map_location=device,
weights_only=True,
)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
def save_checkpoint(self, path: str) -> None:
"""Save model checkpoint.
Args:
path: Directory to save config.yaml and model.pt.
"""
from pathlib import Path
import yaml
ckpt_path = Path(path)
ckpt_path.mkdir(parents=True, exist_ok=True)
with open(ckpt_path / "config.yaml", "w") as f:
yaml.dump(self.config.to_dict(), f, default_flow_style=False)
torch.save(self.state_dict(), ckpt_path / "model.pt")