| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| |
|
| |
|
| | class NdLinear(nn.Module): |
| | def __init__(self, input_dims: tuple, hidden_size: tuple, transform_outer=True, bias: bool = True): |
| | """ |
| | NdLinear: A PyTorch layer for projecting tensors into multi-space representations. |
| | |
| | Unlike conventional embedding layers that map into a single vector space, NdLinear |
| | transforms tensors across a collection of vector spaces, capturing multivariate structure |
| | and topical information that standard deep learning architectures typically lose. |
| | |
| | Args: |
| | input_dims (tuple): Shape of input tensor (excluding batch dimension). |
| | hidden_size (tuple): Target hidden dimensions after transformation. |
| | """ |
| | super(NdLinear, self).__init__() |
| |
|
| | if len(input_dims) != len(hidden_size): |
| | raise Exception("Input shape and hidden shape do not match.") |
| |
|
| | self.input_dims = input_dims |
| | self.hidden_size = hidden_size |
| | self.num_layers = len(input_dims) |
| | |
| | self.transform_outer = transform_outer |
| | self.bias = bias |
| |
|
| | |
| | self.align_layers = nn.ModuleList([ |
| | nn.Linear(input_dims[i], hidden_size[i], bias=self.bias) for i in range(self.num_layers) |
| | ]) |
| |
|
| |
|
| | def forward(self, X): |
| | """ |
| | Forward pass to project input tensor into a new multi-space representation. |
| | - Incrementally transposes, flattens, applies linear layers, and restores shape. |
| | |
| | Expected Input Shape: [batch_size, *input_dims] |
| | Output Shape: [batch_size, *hidden_size] |
| | |
| | Args: |
| | X (torch.Tensor): Input tensor with shape [batch_size, *input_dims] |
| | |
| | Returns: |
| | torch.Tensor: Output tensor with shape [batch_size, *hidden_size] |
| | """ |
| | num_transforms = self.num_layers |
| | |
| | |
| | |
| |
|
| | for i in range(num_transforms): |
| | if self.transform_outer: |
| | layer = self.align_layers[i] |
| | transpose_dim = i + 1 |
| | else: |
| | layer = self.align_layers[num_transforms - (i+1)] |
| | transpose_dim = num_transforms - i |
| |
|
| | |
| | X = torch.transpose(X, transpose_dim, num_transforms).contiguous() |
| |
|
| | |
| | X_size = X.shape[:-1] |
| |
|
| | |
| | X = X.view(-1, X.shape[-1]) |
| |
|
| | |
| | X = layer(X) |
| | |
| | |
| | X = X.view(*X_size, X.shape[-1]) |
| |
|
| | |
| | X = torch.transpose(X, transpose_dim, num_transforms).contiguous() |
| |
|
| | return X |