| | """ |
| | Implementation of "Attention is All You Need" and of |
| | subsequent transformer based architectures |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from onmt.decoders.decoder import DecoderBase |
| | from onmt.modules import MultiHeadedAttention, AverageAttention |
| | from onmt.modules.position_ffn import PositionwiseFeedForward |
| | from onmt.modules.position_ffn import ActivationFunction |
| | from onmt.utils.misc import sequence_mask |
| |
|
| |
|
| | class TransformerDecoderLayerBase(nn.Module): |
| | def __init__( |
| | self, |
| | d_model, |
| | heads, |
| | d_ff, |
| | dropout, |
| | attention_dropout, |
| | self_attn_type="scaled-dot", |
| | max_relative_positions=0, |
| | aan_useffn=False, |
| | full_context_alignment=False, |
| | alignment_heads=0, |
| | pos_ffn_activation_fn=ActivationFunction.relu, |
| | ): |
| | """ |
| | Args: |
| | d_model (int): the dimension of keys/values/queries in |
| | :class:`MultiHeadedAttention`, also the input size of |
| | the first-layer of the :class:`PositionwiseFeedForward`. |
| | heads (int): the number of heads for MultiHeadedAttention. |
| | d_ff (int): the second-layer of the |
| | :class:`PositionwiseFeedForward`. |
| | dropout (float): dropout in residual, self-attn(dot) and |
| | feed-forward |
| | attention_dropout (float): dropout in context_attn (and |
| | self-attn(avg)) |
| | self_attn_type (string): type of self-attention scaled-dot, |
| | average |
| | max_relative_positions (int): |
| | Max distance between inputs in relative positions |
| | representations |
| | aan_useffn (bool): Turn on the FFN layer in the AAN decoder |
| | full_context_alignment (bool): |
| | whether enable an extra full context decoder forward for |
| | alignment |
| | alignment_heads (int): |
| | N. of cross attention heads to use for alignment guiding |
| | pos_ffn_activation_fn (ActivationFunction): |
| | activation function choice for PositionwiseFeedForward layer |
| | |
| | """ |
| | super(TransformerDecoderLayerBase, self).__init__() |
| |
|
| | if self_attn_type == "scaled-dot": |
| | self.self_attn = MultiHeadedAttention( |
| | heads, |
| | d_model, |
| | dropout=attention_dropout, |
| | max_relative_positions=max_relative_positions, |
| | ) |
| | elif self_attn_type == "average": |
| | self.self_attn = AverageAttention( |
| | d_model, dropout=attention_dropout, aan_useffn=aan_useffn |
| | ) |
| |
|
| | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, |
| | pos_ffn_activation_fn |
| | ) |
| | self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) |
| | self.drop = nn.Dropout(dropout) |
| | self.full_context_alignment = full_context_alignment |
| | self.alignment_heads = alignment_heads |
| |
|
| | def forward(self, *args, **kwargs): |
| | """Extend `_forward` for (possibly) multiple decoder pass: |
| | Always a default (future masked) decoder forward pass, |
| | Possibly a second future aware decoder pass for joint learn |
| | full context alignement, :cite:`garg2019jointly`. |
| | |
| | Args: |
| | * All arguments of _forward. |
| | with_align (bool): whether return alignment attention. |
| | |
| | Returns: |
| | (FloatTensor, FloatTensor, FloatTensor or None): |
| | |
| | * output ``(batch_size, T, model_dim)`` |
| | * top_attn ``(batch_size, T, src_len)`` |
| | * attn_align ``(batch_size, T, src_len)`` or None |
| | """ |
| | with_align = kwargs.pop("with_align", False) |
| | output, attns = self._forward(*args, **kwargs) |
| | top_attn = attns[:, 0, :, :].contiguous() |
| | attn_align = None |
| | if with_align: |
| | if self.full_context_alignment: |
| | |
| | _, attns = self._forward(*args, **kwargs, future=True) |
| |
|
| | if self.alignment_heads > 0: |
| | attns = attns[:, : self.alignment_heads, :, :].contiguous() |
| | |
| | |
| | |
| | |
| | attn_align = attns.mean(dim=1) |
| | return output, top_attn, attn_align |
| |
|
| | def update_dropout(self, dropout, attention_dropout): |
| | self.self_attn.update_dropout(attention_dropout) |
| | self.feed_forward.update_dropout(dropout) |
| | self.drop.p = dropout |
| |
|
| | def _forward(self, *args, **kwargs): |
| | raise NotImplementedError |
| |
|
| | def _compute_dec_mask(self, tgt_pad_mask, future): |
| | tgt_len = tgt_pad_mask.size(-1) |
| | if not future: |
| | future_mask = torch.ones( |
| | [tgt_len, tgt_len], |
| | device=tgt_pad_mask.device, |
| | dtype=torch.uint8, |
| | ) |
| | future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) |
| | |
| | try: |
| | future_mask = future_mask.bool() |
| | except AttributeError: |
| | pass |
| | dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) |
| | else: |
| | dec_mask = tgt_pad_mask |
| | return dec_mask |
| |
|
| | def _forward_self_attn(self, inputs_norm, dec_mask, layer_cache, step): |
| | if isinstance(self.self_attn, MultiHeadedAttention): |
| | return self.self_attn( |
| | inputs_norm, |
| | inputs_norm, |
| | inputs_norm, |
| | mask=dec_mask, |
| | layer_cache=layer_cache, |
| | attn_type="self", |
| | ) |
| | elif isinstance(self.self_attn, AverageAttention): |
| | return self.self_attn( |
| | inputs_norm, mask=dec_mask, layer_cache=layer_cache, step=step |
| | ) |
| | else: |
| | raise ValueError( |
| | f"self attention {type(self.self_attn)} not supported" |
| | ) |
| |
|
| |
|
| | class TransformerDecoderLayer(TransformerDecoderLayerBase): |
| | """Transformer Decoder layer block in Pre-Norm style. |
| | Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style, |
| | providing better converge speed and performance. This is also the actual |
| | implementation in tensor2tensor and also avalable in fairseq. |
| | See https://tunz.kr/post/4 and :cite:`DeeperTransformer`. |
| | |
| | .. mermaid:: |
| | |
| | graph LR |
| | %% "*SubLayer" can be self-attn, src-attn or feed forward block |
| | A(input) --> B[Norm] |
| | B --> C["*SubLayer"] |
| | C --> D[Drop] |
| | D --> E((+)) |
| | A --> E |
| | E --> F(out) |
| | |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | d_model, |
| | heads, |
| | d_ff, |
| | dropout, |
| | attention_dropout, |
| | self_attn_type="scaled-dot", |
| | max_relative_positions=0, |
| | aan_useffn=False, |
| | full_context_alignment=False, |
| | alignment_heads=0, |
| | pos_ffn_activation_fn=ActivationFunction.relu, |
| | ): |
| | """ |
| | Args: |
| | See TransformerDecoderLayerBase |
| | """ |
| | super(TransformerDecoderLayer, self).__init__( |
| | d_model, |
| | heads, |
| | d_ff, |
| | dropout, |
| | attention_dropout, |
| | self_attn_type, |
| | max_relative_positions, |
| | aan_useffn, |
| | full_context_alignment, |
| | alignment_heads, |
| | pos_ffn_activation_fn=pos_ffn_activation_fn, |
| | ) |
| | self.context_attn = MultiHeadedAttention( |
| | heads, d_model, dropout=attention_dropout |
| | ) |
| | self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) |
| |
|
| | def update_dropout(self, dropout, attention_dropout): |
| | super(TransformerDecoderLayer, self).update_dropout( |
| | dropout, attention_dropout |
| | ) |
| | self.context_attn.update_dropout(attention_dropout) |
| |
|
| | def _forward( |
| | self, |
| | inputs, |
| | memory_bank, |
| | src_pad_mask, |
| | tgt_pad_mask, |
| | layer_cache=None, |
| | step=None, |
| | future=False, |
| | ): |
| | """A naive forward pass for transformer decoder. |
| | |
| | # T: could be 1 in the case of stepwise decoding or tgt_len |
| | |
| | Args: |
| | inputs (FloatTensor): ``(batch_size, T, model_dim)`` |
| | memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)`` |
| | src_pad_mask (bool): ``(batch_size, 1, src_len)`` |
| | tgt_pad_mask (bool): ``(batch_size, 1, T)`` |
| | layer_cache (dict or None): cached layer info when stepwise decode |
| | step (int or None): stepwise decoding counter |
| | future (bool): If set True, do not apply future_mask. |
| | |
| | Returns: |
| | (FloatTensor, FloatTensor): |
| | |
| | * output ``(batch_size, T, model_dim)`` |
| | * attns ``(batch_size, head, T, src_len)`` |
| | |
| | """ |
| | dec_mask = None |
| |
|
| | if inputs.size(1) > 1: |
| | |
| | dec_mask = self._compute_dec_mask(tgt_pad_mask, future) |
| |
|
| | inputs_norm = self.layer_norm_1(inputs) |
| |
|
| | query, _ = self._forward_self_attn( |
| | inputs_norm, dec_mask, layer_cache, step |
| | ) |
| |
|
| | query = self.drop(query) + inputs |
| |
|
| | query_norm = self.layer_norm_2(query) |
| | mid, attns = self.context_attn( |
| | memory_bank, |
| | memory_bank, |
| | query_norm, |
| | mask=src_pad_mask, |
| | layer_cache=layer_cache, |
| | attn_type="context", |
| | ) |
| | output = self.feed_forward(self.drop(mid) + query) |
| |
|
| | return output, attns |
| |
|
| |
|
| | class TransformerDecoderBase(DecoderBase): |
| | def __init__(self, d_model, copy_attn, alignment_layer): |
| | super(TransformerDecoderBase, self).__init__() |
| |
|
| | |
| | self.state = {} |
| |
|
| | |
| | |
| | |
| | self._copy = copy_attn |
| | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) |
| |
|
| | self.alignment_layer = alignment_layer |
| |
|
| | @classmethod |
| | def from_opt(cls, opt, embeddings): |
| | """Alternate constructor.""" |
| | return cls( |
| | opt.dec_layers, |
| | opt.dec_rnn_size, |
| | opt.heads, |
| | opt.transformer_ff, |
| | opt.copy_attn, |
| | opt.self_attn_type, |
| | opt.dropout[0] if type(opt.dropout) is list else opt.dropout, |
| | opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout, |
| | embeddings, |
| | opt.max_relative_positions, |
| | opt.aan_useffn, |
| | opt.full_context_alignment, |
| | opt.alignment_layer, |
| | alignment_heads=opt.alignment_heads, |
| | pos_ffn_activation_fn=opt.pos_ffn_activation_fn, |
| | ) |
| |
|
| | def init_state(self, src, memory_bank, enc_hidden): |
| | """Initialize decoder state.""" |
| | self.state["src"] = src |
| | self.state["cache"] = None |
| |
|
| | def map_state(self, fn): |
| | def _recursive_map(struct, batch_dim=0): |
| | for k, v in struct.items(): |
| | if v is not None: |
| | if isinstance(v, dict): |
| | _recursive_map(v) |
| | else: |
| | struct[k] = fn(v, batch_dim) |
| |
|
| | if self.state["src"] is not None: |
| | self.state["src"] = fn(self.state["src"], 1) |
| | if self.state["cache"] is not None: |
| | _recursive_map(self.state["cache"]) |
| |
|
| | def detach_state(self): |
| | raise NotImplementedError |
| |
|
| | def forward(self, *args, **kwargs): |
| | raise NotImplementedError |
| |
|
| | def update_dropout(self, dropout, attention_dropout): |
| | self.embeddings.update_dropout(dropout) |
| | for layer in self.transformer_layers: |
| | layer.update_dropout(dropout, attention_dropout) |
| |
|
| |
|
| | class TransformerDecoder(TransformerDecoderBase): |
| | """The Transformer decoder from "Attention is All You Need". |
| | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` |
| | |
| | .. mermaid:: |
| | |
| | graph BT |
| | A[input] |
| | B[multi-head self-attn] |
| | BB[multi-head src-attn] |
| | C[feed forward] |
| | O[output] |
| | A --> B |
| | B --> BB |
| | BB --> C |
| | C --> O |
| | |
| | |
| | Args: |
| | num_layers (int): number of decoder layers. |
| | d_model (int): size of the model |
| | heads (int): number of heads |
| | d_ff (int): size of the inner FF layer |
| | copy_attn (bool): if using a separate copy attention |
| | self_attn_type (str): type of self-attention scaled-dot, average |
| | dropout (float): dropout in residual, self-attn(dot) and feed-forward |
| | attention_dropout (float): dropout in context_attn (and self-attn(avg)) |
| | embeddings (onmt.modules.Embeddings): |
| | embeddings to use, should have positional encodings |
| | max_relative_positions (int): |
| | Max distance between inputs in relative positions representations |
| | aan_useffn (bool): Turn on the FFN layer in the AAN decoder |
| | full_context_alignment (bool): |
| | whether enable an extra full context decoder forward for alignment |
| | alignment_layer (int): N° Layer to supervise with for alignment guiding |
| | alignment_heads (int): |
| | N. of cross attention heads to use for alignment guiding |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | num_layers, |
| | d_model, |
| | heads, |
| | d_ff, |
| | copy_attn, |
| | self_attn_type, |
| | dropout, |
| | attention_dropout, |
| | max_relative_positions, |
| | aan_useffn, |
| | full_context_alignment, |
| | alignment_layer, |
| | alignment_heads, |
| | pos_ffn_activation_fn=ActivationFunction.relu, |
| | ): |
| | super(TransformerDecoder, self).__init__( |
| | d_model, copy_attn, alignment_layer |
| | ) |
| |
|
| | self.transformer_layers = nn.ModuleList( |
| | [ |
| | TransformerDecoderLayer( |
| | d_model, |
| | heads, |
| | d_ff, |
| | dropout, |
| | attention_dropout, |
| | self_attn_type=self_attn_type, |
| | max_relative_positions=max_relative_positions, |
| | aan_useffn=aan_useffn, |
| | full_context_alignment=full_context_alignment, |
| | alignment_heads=alignment_heads, |
| | pos_ffn_activation_fn=pos_ffn_activation_fn, |
| | ) |
| | for i in range(num_layers) |
| | ] |
| | ) |
| |
|
| | def detach_state(self): |
| | self.state["src"] = self.state["src"].detach() |
| |
|
| | def forward(self, tgt_emb, memory_bank, src_pad_mask=None, tgt_pad_mask=None, step=None, **kwargs): |
| | """Decode, possibly stepwise.""" |
| | if step == 0: |
| | self._init_cache(memory_bank) |
| |
|
| | batch_size, src_len, src_dim = memory_bank.size() |
| | device = memory_bank.device |
| | if src_pad_mask is None: |
| | src_pad_mask = torch.zeros((batch_size, 1, src_len), dtype=torch.bool, device=device) |
| | output = tgt_emb |
| | batch_size, tgt_len, tgt_dim = tgt_emb.size() |
| | if tgt_pad_mask is None: |
| | tgt_pad_mask = torch.zeros((batch_size, 1, tgt_len), dtype=torch.bool, device=device) |
| |
|
| | future = kwargs.pop("future", False) |
| | with_align = kwargs.pop("with_align", False) |
| | attn_aligns = [] |
| | hiddens = [] |
| |
|
| | for i, layer in enumerate(self.transformer_layers): |
| | layer_cache = ( |
| | self.state["cache"]["layer_{}".format(i)] |
| | if step is not None |
| | else None |
| | ) |
| | output, attn, attn_align = layer( |
| | output, |
| | memory_bank, |
| | src_pad_mask, |
| | tgt_pad_mask, |
| | layer_cache=layer_cache, |
| | step=step, |
| | with_align=with_align, |
| | future=future |
| | ) |
| | hiddens.append(output) |
| | if attn_align is not None: |
| | attn_aligns.append(attn_align) |
| |
|
| | output = self.layer_norm(output) |
| |
|
| | attns = {"std": attn} |
| | if self._copy: |
| | attns["copy"] = attn |
| | if with_align: |
| | attns["align"] = attn_aligns[self.alignment_layer] |
| | |
| |
|
| | |
| | return output, attns, hiddens |
| |
|
| | def _init_cache(self, memory_bank): |
| | self.state["cache"] = {} |
| | for i, layer in enumerate(self.transformer_layers): |
| | layer_cache = {"memory_keys": None, "memory_values": None, "self_keys": None, "self_values": None} |
| | self.state["cache"]["layer_{}".format(i)] = layer_cache |
| |
|
| |
|