Sentence Similarity
ONNX
Safetensors
English
ogma
embeddings
dense-retrieval
matryoshka
rag
agents
mteb
semantic-search
text-embeddings
text-embedding
vector-search
document-retrieval
similarity-search
classification
clustering
edge-ai
on-device
local-inference
efficient-ai
rag-retrieval
custom_code
Eval Results (legacy)
| """OgmaModel — top-level model wrapping any architecture variant.""" | |
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel | |
| from .config import OgmaConfig, TaskToken, VariantType | |
| from .embeddings import TokenEmbedding | |
| from .pooling import create_pooling | |
| from .transformer import TransformerVariant | |
| __all__ = ["OgmaModel"] | |
| MAX_PARAMS = 10_000_000 | |
| def _build_variant(config: OgmaConfig) -> nn.Module: | |
| """Instantiate the released Ogma architecture variant.""" | |
| if config.variant != VariantType.TRANSFORMER: | |
| raise ValueError(f"This HF release supports transformer checkpoints, got {config.variant}") | |
| return TransformerVariant(config) | |
| class OgmaModel(PreTrainedModel): | |
| """Ogma embedding model. | |
| Wraps any architecture variant with shared embedding, pooling, and | |
| normalization. Produces L2-normalized embeddings at d_output dimensions, | |
| Matryoshka-compatible at configured sub-dimensions. | |
| """ | |
| config_class = OgmaConfig | |
| base_model_prefix = "ogma" | |
| supports_gradient_checkpointing = False | |
| _tied_weights_keys: list[str] = [] | |
| all_tied_weights_keys: dict[str, str] = {} | |
| def __init__(self, config: OgmaConfig) -> None: | |
| super().__init__(config) | |
| self.config = config | |
| self.embedding = TokenEmbedding(config) | |
| self.variant = _build_variant(config) | |
| self.pooling = create_pooling(config) | |
| # Output projection if variant output != d_output | |
| needs_proj = ( | |
| config.variant == VariantType.DEEP_NARROW | |
| and config.d_model != config.d_output | |
| ) | |
| # DeepNarrowVariant already has output_proj, so no extra needed here | |
| if not needs_proj and config.d_model != config.d_output: | |
| self.output_proj: nn.Module = nn.Linear( | |
| config.d_model, config.d_output | |
| ) | |
| else: | |
| self.output_proj = nn.Identity() | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| task_token_ids: torch.Tensor | None = None, | |
| token_ids: torch.Tensor | None = None, | |
| **_: object, | |
| ) -> torch.Tensor: | |
| """Forward pass producing L2-normalized embeddings. | |
| Args: | |
| input_ids: (B, S) token IDs, Hugging Face style. | |
| attention_mask: (B, S) attention mask (1=valid, 0=pad). | |
| task_token_ids: (B,) task token IDs (4=QRY, 5=DOC, 6=SYM). | |
| token_ids: Backward-compatible alias for input_ids. | |
| Returns: | |
| (B, d_output) L2-normalized embeddings. | |
| """ | |
| if input_ids is None: | |
| input_ids = token_ids | |
| if input_ids is None: | |
| raise ValueError("input_ids or token_ids must be provided") | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids) | |
| if task_token_ids is None: | |
| task_token_ids = torch.full( | |
| (input_ids.shape[0],), | |
| self.config.sym_id, | |
| device=input_ids.device, | |
| dtype=torch.long, | |
| ) | |
| token_ids = input_ids | |
| # Embed tokens with task token prepended -> (B, S+1, d_model) | |
| x = self.embedding(token_ids, task_token_ids) | |
| # Extend attention mask for prepended task token | |
| task_mask = torch.ones( | |
| attention_mask.shape[0], 1, | |
| device=attention_mask.device, | |
| dtype=attention_mask.dtype, | |
| ) | |
| extended_mask = torch.cat([task_mask, attention_mask], dim=1) | |
| # Run through variant | |
| x = self.variant(x, extended_mask) | |
| # Pool | |
| x = self.pooling(x, extended_mask) | |
| # Project if needed | |
| x = self.output_proj(x) | |
| # L2 normalize | |
| return F.normalize(x, p=2, dim=-1) | |
| def encode( | |
| self, | |
| token_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| task: TaskToken = TaskToken.SYM, | |
| ) -> torch.Tensor: | |
| """Encode tokens with a specified task mode. | |
| Args: | |
| token_ids: (B, S) token IDs. | |
| attention_mask: (B, S) attention mask. | |
| task: Task token to use. | |
| Returns: | |
| (B, d_output) L2-normalized embeddings. | |
| """ | |
| task_ids = torch.full( | |
| (token_ids.shape[0],), | |
| self.config.task_token_id(task), | |
| device=token_ids.device, | |
| dtype=torch.long, | |
| ) | |
| return self.forward(input_ids=token_ids, attention_mask=attention_mask, task_token_ids=task_ids) | |
| def param_count(self) -> int: | |
| """Count total trainable parameters.""" | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| def assert_param_budget(self) -> None: | |
| """Assert model is under the 10M parameter budget.""" | |
| count = self.param_count() | |
| assert count < MAX_PARAMS, ( | |
| f"Model has {count:,} params, exceeds {MAX_PARAMS:,} budget" | |
| ) | |
| def from_config(cls, config: OgmaConfig) -> OgmaModel: | |
| """Factory method to build a model from config.""" | |
| model = cls(config) | |
| if model.param_count() < MAX_PARAMS: | |
| model.assert_param_budget() | |
| return model | |
| def from_checkpoint( | |
| cls, | |
| path: str, | |
| device: str = "cpu", | |
| ) -> OgmaModel: | |
| """Load model from a checkpoint directory. | |
| Args: | |
| path: Path to checkpoint directory containing config.yaml | |
| and model.pt. | |
| device: Device to load model to. | |
| Returns: | |
| Loaded OgmaModel. | |
| """ | |
| from pathlib import Path | |
| import yaml | |
| ckpt_path = Path(path) | |
| with open(ckpt_path / "config.yaml") as f: | |
| config_dict = yaml.safe_load(f) | |
| config = OgmaConfig.from_dict(config_dict) | |
| model = cls(config) | |
| state_dict = torch.load( | |
| ckpt_path / "model.pt", | |
| map_location=device, | |
| weights_only=True, | |
| ) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def save_checkpoint(self, path: str) -> None: | |
| """Save model checkpoint. | |
| Args: | |
| path: Directory to save config.yaml and model.pt. | |
| """ | |
| from pathlib import Path | |
| import yaml | |
| ckpt_path = Path(path) | |
| ckpt_path.mkdir(parents=True, exist_ok=True) | |
| with open(ckpt_path / "config.yaml", "w") as f: | |
| yaml.dump(self.config.to_dict(), f, default_flow_style=False) | |
| torch.save(self.state_dict(), ckpt_path / "model.pt") | |