Instructions to use togethercomputer/m2-bert-80M-2k with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use togethercomputer/m2-bert-80M-2k with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("fill-mask", model="togethercomputer/m2-bert-80M-2k", trust_remote_code=True)# Load model directly from transformers import AutoModelForMaskedLM model = AutoModelForMaskedLM.from_pretrained("togethercomputer/m2-bert-80M-2k", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2023, Dan Fu and Simran Arora. | |
| # Adapted from https://github.com/HazyResearch/safari/blob/main/src/models/sequence/hyena.py | |
| import torch.nn as nn | |
| from einops import rearrange | |
| import opt_einsum as oe | |
| contract = oe.contract | |
| from .hyena_utils import HyenaFilter | |
| class MonarchMixerSequenceMixing(nn.Module): | |
| def __init__( | |
| self, | |
| d_model, | |
| l_max=128, | |
| dropout=0.0, | |
| hyena_kernel_lr=None, | |
| bidirectional=False, | |
| hyena_lr_pos_emb=1e-5, | |
| hyena_w=10, | |
| hyena_w_mod=1, | |
| hyena_wd=0.1, | |
| hyena_emb_dim=3, | |
| hyena_filter_dropout=0.0, | |
| hyena_filter_order=16, | |
| residual_long_conv=False, | |
| hyena_training_additions=False, | |
| ): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.l_max = l_max | |
| self.kernel_lr = hyena_kernel_lr | |
| self.channels = 1 | |
| self.bidirectional = bidirectional | |
| self.residual_long_conv = residual_long_conv | |
| self.NUM_PROJECTIONS = 3 | |
| print('-- Bidirectional:', self.bidirectional) | |
| print("-- Using Long Conv Residual:", self.residual_long_conv) | |
| print('-- Hyena w:', hyena_w) | |
| print('-- Hyena w mod:', hyena_w_mod) | |
| print(f"-- Hyena filter order: {hyena_filter_order}") | |
| print(f"-- Hyena filter dropout: {hyena_filter_dropout}") | |
| print(f"-- Hyena filter wd: {hyena_wd}") | |
| print(f"-- Hyena filter emb dim: {hyena_emb_dim}") | |
| print(f"-- Hyena filter lr: {hyena_kernel_lr}") | |
| print(f"-- Hyena filter lr pos emb: {hyena_lr_pos_emb}") | |
| self.filter_fn = HyenaFilter( | |
| self.d_model, | |
| order=hyena_filter_order, | |
| seq_len=self.l_max, | |
| dropout=hyena_filter_dropout, | |
| bidirectional=self.bidirectional, | |
| lr=hyena_kernel_lr, | |
| lr_pos_emb=hyena_lr_pos_emb, | |
| w=hyena_w, # frequency of periodic activations | |
| w_mod=hyena_w_mod, | |
| wd=hyena_wd, # weight decay of kernel parameters | |
| emb_dim=hyena_emb_dim, | |
| ) | |
| if self.residual_long_conv: | |
| self.filter_fn2 = HyenaFilter( | |
| self.d_model, | |
| order=hyena_filter_order, | |
| seq_len=self.l_max, | |
| dropout=hyena_filter_dropout, | |
| bidirectional=self.bidirectional, | |
| lr=hyena_kernel_lr, | |
| lr_pos_emb=hyena_lr_pos_emb, | |
| w=hyena_w, # frequency of periodic activations | |
| w_mod=hyena_w_mod, | |
| wd=hyena_wd, # weight decay of kernel parameters | |
| emb_dim=hyena_emb_dim, | |
| ) | |
| # setup projections | |
| self.in_linear = nn.Linear(d_model, 3 * d_model) | |
| self.out_linear = nn.Linear(d_model, d_model) | |
| self.hyena_training_additions = hyena_training_additions | |
| if self.hyena_training_additions: | |
| self.act = nn.Identity() | |
| self.drop = nn.Dropout(dropout) | |
| self.layernorm = nn.LayerNorm(d_model) | |
| # setup short conv | |
| total_width = self.d_model * self.NUM_PROJECTIONS | |
| self.short_filter = nn.Conv1d( | |
| in_channels=total_width, | |
| out_channels=total_width, | |
| kernel_size=3, | |
| groups=total_width, | |
| padding=2, | |
| ) | |
| def forward(self, u, **kwargs): | |
| # u is B L H | |
| if self.hyena_training_additions: | |
| u = self.layernorm(u) | |
| L = u.size(-2) | |
| # in projection | |
| u_orig = u | |
| u = self.in_linear(u) | |
| u = rearrange(u, "b l d -> b d l") | |
| # short filter | |
| uc = self.short_filter(u)[..., :L] | |
| x1, x2, v = uc.split(self.d_model, dim=1) | |
| v = v * x1 | |
| if self.hyena_training_additions: | |
| v = self.drop(v) | |
| k = self.filter_fn.filter(L, device=u.device) | |
| k = rearrange(k, "c l d -> c d l")[0] # `c` is always 1 by default | |
| if self.bidirectional: | |
| k_rev = self.filter_fn.filter_rev(L, device=u.device) | |
| k_rev = rearrange(k_rev, "c l d -> c d l")[0] # `c` is always 1 by default | |
| else: | |
| k_rev = None | |
| y = self.filter_fn(v, L, k_fwd=k, k_rev=k_rev, bias= self.filter_fn.bias[None, :, None]) | |
| if self.residual_long_conv: | |
| k2 = self.filter_fn2.filter(L, device=u.device) | |
| k2 = rearrange(k2, "c l d -> c d l")[0] | |
| if self.bidirectional: | |
| k2_rev = self.filter_fn2.filter_rev(L, device=u.device) | |
| k2_rev = rearrange(k2_rev, "c l d -> c d l")[0] # `c` is always 1 by default | |
| else: | |
| k2_rev = None | |
| yu = self.filter_fn2(u_orig.transpose(-1, -2), L, k_fwd=k2, k_rev=k2_rev, bias= self.filter_fn2.bias[None, :, None]) | |
| # post gating | |
| y = y * x2 | |
| if self.residual_long_conv: | |
| y = y + yu | |
| y = y.transpose(-1, -2) | |
| if self.hyena_training_additions: | |
| y = self.drop(self.act(y)) | |
| y = self.out_linear(y) | |
| return y, None | |