| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutput |
| from .configuration_duchifat_v2 import DuchifatConfig |
|
|
| class DuchifatBlock(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(config.hidden_size) |
| self.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size) |
| self.wo = nn.Linear(config.hidden_size, config.hidden_size) |
| self.ln2 = nn.LayerNorm(config.hidden_size) |
| self.mlp = nn.Sequential( |
| nn.Linear(config.hidden_size, 4 * config.hidden_size), |
| nn.GELU(approximate='tanh'), |
| nn.Linear(4 * config.hidden_size, config.hidden_size) |
| ) |
| self.n_head = config.nhead |
| self.head_dim = config.hidden_size // config.nhead |
|
|
| def forward(self, x): |
| norm_x = self.ln1(x) |
| B, T, C = norm_x.size() |
| qkv = self.qkv(norm_x).view(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| |
| attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
| attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, C) |
|
|
| x = x + self.wo(attn_out) |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
| class DuchifatPreTrainedModel(PreTrainedModel): |
| config_class = DuchifatConfig |
| base_model_prefix = "model" |
| _no_split_modules = ["DuchifatBlock"] |
|
|
| class DuchifatCore(DuchifatPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.wte = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.wpe = nn.Embedding(config.max_seq, config.hidden_size) |
| self.blocks = nn.ModuleList([DuchifatBlock(config) for _ in range(config.num_layers)]) |
| self.ln_f = nn.LayerNorm(config.hidden_size) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.wte |
|
|
| def set_input_embeddings(self, value): |
| self.wte = value |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
| |
| if input_ids is None: |
| raise ValueError("You must specify input_ids") |
| |
| B, T = input_ids.size() |
| device = input_ids.device |
| |
| |
| pos = torch.arange(0, T, dtype=torch.long, device=device) |
| |
| x = self.wte(input_ids) + self.wpe(pos) |
|
|
| for block in self.blocks: |
| x = block(x) |
|
|
| logits = self.lm_head(self.ln_f(x)) |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
|
|
| return CausalLMOutput( |
| loss=loss, |
| logits=logits |
| ) |
|
|
| |
| def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs): |
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask |
| } |
|
|
| |
| def _reorder_cache(self, past_key_values, beam_idx): |
| return past_key_values |