monsoon-nlp commited on
Commit
e6ae131
·
1 Parent(s): 0d23b7d

test hosting of finetuned model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "gLM2ForMaskedLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_glm2.gLM2Config",
7
+ "AutoModel": "modeling_glm2.gLM2Model",
8
+ "AutoModelForMaskedLM": "modeling_glm2.gLM2ForMaskedLM"
9
+ },
10
+ "depth": 30,
11
+ "dim": 640,
12
+ "dtype": "bfloat16",
13
+ "ffn_dim_multiplier": null,
14
+ "heads": 10,
15
+ "model_type": "gLM2",
16
+ "norm_eps": 1e-05,
17
+ "swiglu_multiple_of": 256,
18
+ "transformers_version": "4.56.1",
19
+ "vocab_size": 37
20
+ }
configuration_glm2.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """gLM2 model configuration"""
2
+
3
+ from typing import Optional
4
+ from transformers import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class gLM2Config(PretrainedConfig):
11
+ model_type = "gLM2"
12
+
13
+ def __init__(
14
+ self,
15
+ dim: int = 640,
16
+ depth: int = 30,
17
+ heads: int = 10,
18
+ vocab_size: int = 37,
19
+ swiglu_multiple_of: int = 256,
20
+ ffn_dim_multiplier: Optional[float] = None,
21
+ norm_eps: float = 1e-5,
22
+ **kwargs
23
+ ):
24
+ super().__init__(**kwargs)
25
+ self.dim = dim
26
+ self.depth = depth
27
+ self.heads = heads
28
+ self.vocab_size = vocab_size
29
+ self.swiglu_multiple_of = swiglu_multiple_of
30
+ self.ffn_dim_multiplier = ffn_dim_multiplier
31
+ self.norm_eps = norm_eps
32
+
33
+ self.auto_map = {
34
+ "AutoConfig": "configuration_glm2.gLM2Config",
35
+ "AutoModel": "modeling_glm2.gLM2Model",
36
+ "AutoModelForMaskedLM": "modeling_glm2.gLM2ForMaskedLM"
37
+ }
glm_tokenizer.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer
2
+ from tokenizers.models import BPE
3
+ from transformers import PreTrainedTokenizerFast
4
+
5
+
6
+ class gLM2Tokenizer(PreTrainedTokenizerFast):
7
+
8
+ VOCAB = [
9
+ "<cls>", "<pad>", "<eos>", "<unk>",
10
+ "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
11
+ "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
12
+ "O", "a", "t", "c", "g", "<+>", "<->", "<mask>", "<sep>",
13
+ ]
14
+
15
+ def __init__(
16
+ self,
17
+ unk_token="<unk>",
18
+ cls_token="<cls>",
19
+ pad_token="<pad>",
20
+ mask_token="<mask>",
21
+ eos_token="<eos>",
22
+ sep_token="<sep>",
23
+ pos_token="<+>",
24
+ neg_token="<->",
25
+ **kwargs,
26
+ ):
27
+ all_tokens = self.VOCAB
28
+ token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
29
+
30
+ bpe = BPE(token_to_id, merges=[], unk_token=str(unk_token))
31
+ tokenizer = Tokenizer(bpe)
32
+ special_tokens = [cls_token, pad_token,
33
+ mask_token, eos_token, sep_token, pos_token, neg_token]
34
+
35
+ tokenizer.add_special_tokens(
36
+ special_tokens,
37
+ )
38
+
39
+ super().__init__(
40
+ tokenizer_object=tokenizer,
41
+ unk_token=unk_token,
42
+ cls_token=cls_token,
43
+ pad_token=pad_token,
44
+ mask_token=mask_token,
45
+ eos_token=eos_token,
46
+ sep_token=sep_token,
47
+ **kwargs,
48
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00886a87fc91c831a592917192d9a68b33cb91d92d005c45d5a815b5c90b2d45
3
+ size 304940024
modeling_glm2.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch gLM2 model.
2
+
3
+ Some modules adapted from:
4
+ https://github.com/meta-llama/llama/blob/main/llama/model.py
5
+ """
6
+
7
+ import torch
8
+ from einops import rearrange, repeat
9
+ from typing import Optional, Tuple, Union
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+ from transformers.modeling_outputs import (
13
+ BaseModelOutput,
14
+ MaskedLMOutput,
15
+ )
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+ from .configuration_glm2 import gLM2Config
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ def rotate_half(x, interleaved=False):
24
+ if not interleaved:
25
+ x1, x2 = x.chunk(2, dim=-1)
26
+ return torch.cat((-x2, x1), dim=-1)
27
+ else:
28
+ x1, x2 = x[..., ::2], x[..., 1::2]
29
+ return rearrange(
30
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
31
+ )
32
+
33
+
34
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
35
+ """
36
+ x: (batch_size, seqlen, nheads, headdim)
37
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
38
+ """
39
+ ro_dim = cos.shape[-1] * 2
40
+ assert ro_dim <= x.shape[-1]
41
+ seqlen = x.shape[1]
42
+ cos, sin = cos[:seqlen], sin[:seqlen]
43
+ cos = repeat(
44
+ cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
45
+ )
46
+ sin = repeat(
47
+ sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
48
+ )
49
+ return torch.cat(
50
+ [
51
+ x[..., :ro_dim] * cos +
52
+ rotate_half(x[..., :ro_dim], interleaved) * sin,
53
+ x[..., ro_dim:],
54
+ ],
55
+ dim=-1,
56
+ )
57
+
58
+
59
+ class RotaryEmbedding(torch.nn.Module):
60
+ """
61
+ Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
62
+ Changed to use the torch version of apply_rotary_emb_func.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ dim: int,
68
+ base=10000.0,
69
+ interleaved=False,
70
+ scale_base=None,
71
+ pos_idx_in_fp32=True,
72
+ device=None,
73
+ ):
74
+ super().__init__()
75
+ self.dim = dim
76
+ self.base = float(base)
77
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
78
+ # Generate and save the inverse frequency buffer (non trainable)
79
+ inv_freq = self._compute_inv_freq(device)
80
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
81
+ self.interleaved = interleaved
82
+ self.scale_base = scale_base
83
+ scale = (
84
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
85
+ / (1.4 * dim)
86
+ if scale_base is not None
87
+ else None
88
+ )
89
+ self.register_buffer("scale", scale, persistent=False)
90
+
91
+ self._seq_len_cached = 0
92
+ self._cos_cached = None
93
+ self._sin_cached = None
94
+ self._cos_k_cached = None
95
+ self._sin_k_cached = None
96
+
97
+ def _compute_inv_freq(self, device=None):
98
+ return 1.0 / (
99
+ self.base
100
+ ** (
101
+ torch.arange(0, self.dim, 2, device=device,
102
+ dtype=torch.float32)
103
+ / self.dim
104
+ )
105
+ )
106
+
107
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
108
+ # Reset the tables if the sequence length has changed,
109
+ # if we're on a new device (possibly due to tracing for instance),
110
+ # or if we're switching from inference mode to training
111
+ if (
112
+ seqlen > self._seq_len_cached
113
+ or self._cos_cached is None
114
+ or self._cos_cached.device != device
115
+ or self._cos_cached.dtype != dtype
116
+ or (self.training and self._cos_cached.is_inference())
117
+ ):
118
+ self._seq_len_cached = seqlen
119
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
120
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
121
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
122
+ if self.pos_idx_in_fp32:
123
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
124
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
125
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
126
+ # cos & sin output to change significantly.
127
+ # We want to recompute self.inv_freq if it was not loaded in fp32
128
+ if self.inv_freq.dtype != torch.float32:
129
+ inv_freq = self._compute_inv_freq(device=device)
130
+ else:
131
+ inv_freq = self.inv_freq
132
+ else:
133
+ t = torch.arange(seqlen, device=device,
134
+ dtype=self.inv_freq.dtype)
135
+ inv_freq = self.inv_freq
136
+ # Don't do einsum, it converts fp32 to fp16 under AMP
137
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
138
+ freqs = torch.outer(t, inv_freq)
139
+ if self.scale is None:
140
+ self._cos_cached = torch.cos(freqs).to(dtype)
141
+ self._sin_cached = torch.sin(freqs).to(dtype)
142
+ else:
143
+ power = (
144
+ torch.arange(
145
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
146
+ )
147
+ - seqlen // 2
148
+ ) / self.scale_base
149
+ scale = self.scale.to(device=power.device) ** rearrange(
150
+ power, "s -> s 1"
151
+ )
152
+ # We want the multiplication by scale to happen in fp32
153
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
154
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
155
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
156
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
157
+
158
+ def forward(
159
+ self,
160
+ qkv: torch.Tensor,
161
+ max_seqlen: Optional[int] = None,
162
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
163
+ """
164
+ qkv: (batch, seqlen, 3, nheads, headdim)
165
+ """
166
+ seqlen = qkv.shape[1]
167
+ if seqlen > self._seq_len_cached:
168
+ self._update_cos_sin_cache(
169
+ seqlen, device=qkv.device, dtype=qkv.dtype)
170
+ elif max_seqlen is not None:
171
+ self._update_cos_sin_cache(
172
+ max_seqlen, device=qkv.device, dtype=qkv.dtype)
173
+ q_rot = apply_rotary_emb_torch(
174
+ qkv[:, :, 0], self._cos_cached, self._sin_cached, self.interleaved
175
+ )
176
+ k_rot = apply_rotary_emb_torch(
177
+ qkv[:, :, 1], self._cos_cached, self._sin_cached, self.interleaved
178
+ )
179
+ return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
180
+
181
+
182
+ # @torch.jit.script
183
+ def rmsnorm_func(hidden_states, weight, variance_epsilon):
184
+ """Apply the root mean square normalization."""
185
+ input_dtype = hidden_states.dtype
186
+ hidden_states = hidden_states.to(torch.float32)
187
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
188
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
189
+ return (weight * hidden_states).to(input_dtype)
190
+
191
+
192
+ class RMSNorm(nn.Module):
193
+ """Root mean square normalization."""
194
+
195
+ def __init__(self, dim, eps=1e-6):
196
+ super().__init__()
197
+ self.weight = nn.Parameter(torch.ones(dim))
198
+ self.register_buffer(
199
+ "variance_epsilon",
200
+ torch.tensor(eps),
201
+ persistent=False,
202
+ )
203
+
204
+ def forward(self, hidden_states):
205
+ return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
206
+
207
+
208
+ class Attention(nn.Module):
209
+ """Multi-head attention module."""
210
+
211
+ def __init__(self, config: gLM2Config):
212
+ super().__init__()
213
+ self.n_heads = config.heads
214
+ self.head_dim = config.dim // config.heads
215
+
216
+ self.wqkv = nn.Linear(config.dim, self.n_heads *
217
+ self.head_dim * 3, bias=False)
218
+ self.wo = nn.Linear(config.heads * self.head_dim,
219
+ config.dim, bias=False)
220
+
221
+ self.rotary_emb = RotaryEmbedding(self.head_dim)
222
+
223
+ def forward(
224
+ self,
225
+ x: torch.Tensor,
226
+ attention_mask: Optional[torch.Tensor] = None,
227
+ ) -> torch.Tensor:
228
+ bsz, seqlen, h_size = x.shape
229
+ qkv = self.wqkv(x)
230
+
231
+ qkv = qkv.view(bsz, seqlen, 3, self.n_heads, self.head_dim)
232
+ qkv = self.rotary_emb(qkv)
233
+
234
+ # (batch, nheads, 3, seqlen, headdim)
235
+ qkv = torch.transpose(qkv, 3, 1)
236
+ q = qkv[:, :, 0]
237
+ k = qkv[:, :, 1]
238
+ v = qkv[:, :, 2]
239
+ if attention_mask is not None:
240
+ attention_mask = attention_mask[:, None, None, :]
241
+ attention_mask = attention_mask.expand(
242
+ bsz, self.n_heads, seqlen, seqlen
243
+ ).bool()
244
+ # [B, heads, seq, D]
245
+ output = torch.nn.functional.scaled_dot_product_attention(
246
+ q, k, v, attn_mask=attention_mask
247
+ )
248
+ output = output.permute(0, 2, 1, 3).contiguous()
249
+
250
+ output = output.view(bsz, seqlen, h_size)
251
+ return self.wo(output)
252
+
253
+
254
+ class FeedForward(nn.Module):
255
+ def __init__(
256
+ self,
257
+ dim: int,
258
+ hidden_dim: int,
259
+ multiple_of: int,
260
+ ffn_dim_multiplier: Optional[float],
261
+ ):
262
+ """
263
+ SwiGLU FeedForward module.
264
+
265
+ Args:
266
+ dim (int): Input dimension.
267
+ hidden_dim (int): Hidden dimension of the feedforward layer.
268
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
269
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
270
+ """
271
+ super().__init__()
272
+ hidden_dim = int(2 * hidden_dim / 3)
273
+ # custom dim factor multiplier
274
+ if ffn_dim_multiplier is not None:
275
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
276
+ hidden_dim = multiple_of * \
277
+ ((hidden_dim + multiple_of - 1) // multiple_of)
278
+
279
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
280
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
281
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
282
+
283
+ def forward(self, x):
284
+ return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
285
+
286
+
287
+ class TransformerBlock(nn.Module):
288
+ def __init__(self, config: gLM2Config):
289
+ super().__init__()
290
+ self.n_heads = config.heads
291
+ self.dim = config.dim
292
+ self.head_dim = config.dim // config.heads
293
+ self.attention = Attention(config)
294
+ self.feed_forward = FeedForward(
295
+ dim=config.dim,
296
+ hidden_dim=4 * config.dim,
297
+ multiple_of=config.swiglu_multiple_of,
298
+ ffn_dim_multiplier=config.ffn_dim_multiplier,
299
+ )
300
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
301
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
302
+
303
+ def forward(
304
+ self,
305
+ x: torch.Tensor,
306
+ attention_mask: Optional[torch.Tensor] = None,
307
+ ) -> torch.Tensor:
308
+ r = self.attention(self.attention_norm(
309
+ x), attention_mask=attention_mask)
310
+ h = x + r
311
+ r = self.feed_forward(self.ffn_norm(h))
312
+ out = h + r
313
+ return out
314
+
315
+
316
+ class TransformerLayers(nn.Module):
317
+ def __init__(self, config: gLM2Config):
318
+ super().__init__()
319
+ self.config = config
320
+ self.layers = torch.nn.ModuleList(
321
+ [TransformerBlock(config=config) for _ in range(config.depth)]
322
+ )
323
+
324
+ def forward(
325
+ self,
326
+ x: torch.FloatTensor,
327
+ attention_mask: Optional[torch.BoolTensor] = None,
328
+ return_all_hiddens: bool = False,
329
+ ):
330
+ if x.shape[-1] != self.config.dim:
331
+ raise ValueError(
332
+ f"Input feature dim should be {self.config.dim}, but input has shape {x.shape}"
333
+ )
334
+ hiddens = []
335
+ for layer in self.layers:
336
+ x = layer(x, attention_mask=attention_mask)
337
+ if return_all_hiddens:
338
+ hiddens.append(x)
339
+
340
+ if return_all_hiddens:
341
+ return x, hiddens
342
+ return x
343
+
344
+
345
+ class gLM2PreTrainedModel(PreTrainedModel):
346
+ """
347
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
348
+ models.
349
+ """
350
+ config_class = gLM2Config
351
+ base_model_prefix = "glm2"
352
+ supports_gradient_checkpointing = False
353
+
354
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
355
+ def _init_weights(module, initializer_range=0.02):
356
+ if isinstance(module, nn.Linear):
357
+ nn.init.normal_(module.weight, std=initializer_range)
358
+ if module.bias is not None:
359
+ nn.init.zeros_(module.bias)
360
+ elif isinstance(module, nn.Embedding):
361
+ nn.init.normal_(module.weight, std=initializer_range)
362
+ if module.padding_idx is not None:
363
+ nn.init.zeros_(module.weight[module.padding_idx])
364
+
365
+
366
+ class gLM2Model(gLM2PreTrainedModel):
367
+ """gLM2 Model."""
368
+
369
+ def __init__(self, config: gLM2Config):
370
+ super().__init__(config)
371
+ self.config = config
372
+
373
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
374
+ self.encoder = TransformerLayers(config)
375
+ # Initialize weights and apply final processing
376
+ self.post_init()
377
+
378
+ def forward(
379
+ self,
380
+ input_ids: torch.Tensor,
381
+ attention_mask: Optional[torch.Tensor] = None,
382
+ output_hidden_states: Optional[bool] = None,
383
+ return_dict: Optional[bool] = None,
384
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
385
+ output_hidden_states = (
386
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
387
+ )
388
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
389
+
390
+ h = self.tok_embeddings(input_ids)
391
+ if output_hidden_states:
392
+ sequence_output, all_hidden_states = self.encoder(
393
+ h, attention_mask, return_all_hiddens=True)
394
+ else:
395
+ sequence_output = self.encoder(h, attention_mask)
396
+ all_hidden_states = None
397
+
398
+ if not return_dict:
399
+ return (sequence_output, all_hidden_states)
400
+
401
+ return BaseModelOutput(
402
+ last_hidden_state=sequence_output,
403
+ hidden_states=all_hidden_states,
404
+
405
+ )
406
+
407
+
408
+ class gLM2ForMaskedLM(gLM2PreTrainedModel):
409
+
410
+ def __init__(self, config: gLM2Config):
411
+ super().__init__(config)
412
+
413
+ self.glm2 = gLM2Model(config)
414
+ self.lm_head = gLM2LMHead(config)
415
+ self.init_weights()
416
+
417
+ def forward(
418
+ self,
419
+ input_ids: torch.Tensor,
420
+ attention_mask: Optional[torch.Tensor] = None,
421
+ labels: Optional[torch.LongTensor] = None,
422
+ output_hidden_states: Optional[bool] = None,
423
+ return_dict: Optional[bool] = None,
424
+ ) -> Union[Tuple, MaskedLMOutput]:
425
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
426
+
427
+ outputs = self.glm2(
428
+ input_ids,
429
+ attention_mask=attention_mask,
430
+ output_hidden_states=output_hidden_states,
431
+ return_dict=return_dict,
432
+ )
433
+ sequence_output = outputs[0]
434
+ prediction_scores = self.lm_head(sequence_output)
435
+
436
+ masked_lm_loss = None
437
+ if labels is not None:
438
+ loss_fct = CrossEntropyLoss()
439
+
440
+ labels = labels.to(prediction_scores.device)
441
+ masked_lm_loss = loss_fct(
442
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
443
+
444
+ if not return_dict:
445
+ output = (prediction_scores,) + outputs[2:]
446
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
447
+
448
+ return MaskedLMOutput(
449
+ loss=masked_lm_loss,
450
+ logits=prediction_scores,
451
+ hidden_states=outputs.hidden_states,
452
+ attentions=outputs.attentions,
453
+ )
454
+
455
+
456
+ class gLM2LMHead(nn.Module):
457
+ """gLM2 head for masked language modeling."""
458
+
459
+ def __init__(self, config):
460
+ super().__init__()
461
+
462
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
463
+ self.proj_output = nn.Linear(
464
+ config.dim, config.vocab_size, bias=False)
465
+
466
+ def forward(self, features):
467
+ return self.proj_output(self.norm(features))
special_tokens_map.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "<cls>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<eos>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "mask_token": {
17
+ "content": "<mask>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "pad_token": {
24
+ "content": "<pad>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "sep_token": {
31
+ "content": "<sep>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "unk_token": {
38
+ "content": "<unk>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ }
44
+ }
tokenizer.json ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "<cls>",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 1,
17
+ "content": "<pad>",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 2,
26
+ "content": "<eos>",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ },
33
+ {
34
+ "id": 3,
35
+ "content": "<unk>",
36
+ "single_word": false,
37
+ "lstrip": false,
38
+ "rstrip": false,
39
+ "normalized": false,
40
+ "special": true
41
+ },
42
+ {
43
+ "id": 33,
44
+ "content": "<+>",
45
+ "single_word": false,
46
+ "lstrip": false,
47
+ "rstrip": false,
48
+ "normalized": false,
49
+ "special": true
50
+ },
51
+ {
52
+ "id": 34,
53
+ "content": "<->",
54
+ "single_word": false,
55
+ "lstrip": false,
56
+ "rstrip": false,
57
+ "normalized": false,
58
+ "special": true
59
+ },
60
+ {
61
+ "id": 35,
62
+ "content": "<mask>",
63
+ "single_word": false,
64
+ "lstrip": false,
65
+ "rstrip": false,
66
+ "normalized": false,
67
+ "special": true
68
+ },
69
+ {
70
+ "id": 36,
71
+ "content": "<sep>",
72
+ "single_word": false,
73
+ "lstrip": false,
74
+ "rstrip": false,
75
+ "normalized": false,
76
+ "special": true
77
+ }
78
+ ],
79
+ "normalizer": null,
80
+ "pre_tokenizer": null,
81
+ "post_processor": null,
82
+ "decoder": null,
83
+ "model": {
84
+ "type": "BPE",
85
+ "dropout": null,
86
+ "unk_token": "<unk>",
87
+ "continuing_subword_prefix": null,
88
+ "end_of_word_suffix": null,
89
+ "fuse_unk": false,
90
+ "byte_fallback": false,
91
+ "ignore_merges": false,
92
+ "vocab": {
93
+ "<cls>": 0,
94
+ "<pad>": 1,
95
+ "<eos>": 2,
96
+ "<unk>": 3,
97
+ "L": 4,
98
+ "A": 5,
99
+ "G": 6,
100
+ "V": 7,
101
+ "S": 8,
102
+ "E": 9,
103
+ "R": 10,
104
+ "T": 11,
105
+ "I": 12,
106
+ "D": 13,
107
+ "P": 14,
108
+ "K": 15,
109
+ "Q": 16,
110
+ "N": 17,
111
+ "F": 18,
112
+ "Y": 19,
113
+ "M": 20,
114
+ "H": 21,
115
+ "W": 22,
116
+ "C": 23,
117
+ "X": 24,
118
+ "B": 25,
119
+ "U": 26,
120
+ "Z": 27,
121
+ "O": 28,
122
+ "a": 29,
123
+ "t": 30,
124
+ "c": 31,
125
+ "g": 32,
126
+ "<+>": 33,
127
+ "<->": 34,
128
+ "<mask>": 35,
129
+ "<sep>": 36
130
+ },
131
+ "merges": []
132
+ }
133
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<cls>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "<eos>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "33": {
36
+ "content": "<+>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "34": {
44
+ "content": "<->",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "35": {
52
+ "content": "<mask>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ },
59
+ "36": {
60
+ "content": "<sep>",
61
+ "lstrip": false,
62
+ "normalized": false,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": true
66
+ }
67
+ },
68
+ "auto_map": {
69
+ "AutoTokenizer": [
70
+ "glm_tokenizer.gLM2Tokenizer",
71
+ null
72
+ ]
73
+ },
74
+ "clean_up_tokenization_spaces": true,
75
+ "cls_token": "<cls>",
76
+ "eos_token": "<eos>",
77
+ "mask_token": "<mask>",
78
+ "model_max_length": 1000000000000000019884624838656,
79
+ "pad_token": "<pad>",
80
+ "sep_token": "<sep>",
81
+ "tokenizer_class": "gLM2Tokenizer",
82
+ "unk_token": "<unk>"
83
+ }