achyutarajaram commited on
Commit
179cd55
·
verified ·
1 Parent(s): e7b93d7

Upload folder using huggingface_hub

Browse files
LICENSE.MD ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
2
+
3
+ 1. Definitions.
4
+ “License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
5
+
6
+ “Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
7
+
8
+ “Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
9
+
10
+ “You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
11
+
12
+ “Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
13
+
14
+ “Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
15
+
16
+ “Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
17
+
18
+ “Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
19
+
20
+ “Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
21
+
22
+ “Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
23
+
24
+ 2. Grant of Copyright License.
25
+ Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
26
+
27
+ 3. Grant of Patent License.
28
+ Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
29
+
30
+ 4. Redistribution.
31
+ You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
32
+
33
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
34
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
35
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
36
+ If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
37
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
38
+
39
+ 5. Submission of Contributions.
40
+ Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
41
+
42
+ 6. Trademarks.
43
+ This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
44
+
45
+ 7. Disclaimer of Warranty.
46
+ Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
47
+
48
+ 8. Limitation of Liability.
49
+ In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
50
+
51
+ 9. Accepting Warranty or Additional Liability.
52
+ While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
53
+
54
+ END OF TERMS AND CONDITIONS
NOTICE.MD ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2025 OpenAI
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ Distributed under the License is distributed on an “AS IS” BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ Limitations under the License.
README.MD ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Sparse Model from Gao et al. 2025
3
+
4
+ Weights for a sparse model from Gao et al. 2025, used for the qualitative results from the paper (related to bracket counting and variable binding). All weights for the other models used in the paper, as well as lightweight inference code, are present in https://github.com/openai/circuit_sparsity. In the context of that repo, this model is csp_yolo2.
5
+
6
+ This is a runnable standalone huggingface implementation for one of the models.
7
+
8
+ Some trivial code to load the locally converted HF model + tokenizer and
9
+ run a tiny generation.
10
+
11
+
12
+ import torch
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+
15
+ if __name__ == "__main__":
16
+ PROMPT = "def square_sum(xs):\n return sum(x * x for x in xs)\n\nsquare_sum([1, 2, 3])\n"
17
+ tok = AutoTokenizer.from_pretrained("circuit-sparsity", trust_remote_code=True)
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ "circuit-sparsity",
20
+ trust_remote_code=True,
21
+ torch_dtype="auto",
22
+ )
23
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
24
+ inputs = tok(PROMPT, return_tensors="pt", add_special_tokens=False)["input_ids"].to(
25
+ model.device
26
+ )
27
+
28
+ with torch.no_grad():
29
+ out = model.generate(
30
+ inputs,
31
+ max_new_tokens=64,
32
+ do_sample=True,
33
+ temperature=0.8,
34
+ top_p=0.95,
35
+ return_dict_in_generate=False,
36
+ )
37
+
38
+ print("=== Prompt ===")
39
+ print(PROMPT)
40
+ print("\n=== Generation ===")
41
+ print(tok.decode(out[0], skip_special_tokens=True))
42
+
43
+
44
+
45
+ ## License
46
+ This project is licensed under the [Apache License 2.0](LICENSE).
__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .config import CircuitGPTConfig
2
+ from .modeling_circuitgpt import CircuitGPTForCausalLM
3
+ from .tokenizer_simple2k import Simple2KTokenizerFast, export_simple2k_tokenizer
4
+
5
+ __all__ = [
6
+ "CircuitGPTConfig",
7
+ "CircuitGPTForCausalLM",
8
+ "Simple2KTokenizerFast",
9
+ "export_simple2k_tokenizer",
10
+ ]
11
+
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_type": "gelu",
3
+ "afrac": 0.25,
4
+ "afrac_loctypes": "attn_in,attn_out,mlp_in,mlp_out,mlp_neuron,attn_v,attn_k,attn_q",
5
+ "architectures": [
6
+ "CircuitGPTForCausalLM"
7
+ ],
8
+ "auto_map": {
9
+ "AutoConfig": "config.CircuitGPTConfig",
10
+ "AutoModelForCausalLM": "modeling_circuitgpt.CircuitGPTForCausalLM"
11
+ },
12
+ "bias": true,
13
+ "bigram_table_rank": null,
14
+ "block_size": 1024,
15
+ "bos_token_id": null,
16
+ "d_head": 16,
17
+ "d_mlp": 8192,
18
+ "d_model": 2048,
19
+ "d_pos_emb": 32,
20
+ "dropout": 0.0,
21
+ "dropout_cat_pos_emb": false,
22
+ "enable_bigram_table": true,
23
+ "eos_token_id": 2047,
24
+ "flash": true,
25
+ "is_decoder": true,
26
+ "learnable_bigram_table": true,
27
+ "ln_bias": true,
28
+ "max_position_embeddings": 1024,
29
+ "model_type": "circuitgpt",
30
+ "n_head": 128,
31
+ "n_layer": 8,
32
+ "pad_token_id": null,
33
+ "residual_activation_type": "identity",
34
+ "rms_norm": true,
35
+ "sink": true,
36
+ "sinusoidal_cat_pos_emb": false,
37
+ "tie_word_embeddings": false,
38
+ "tied_unembed": false,
39
+ "torch_dtype": "float32",
40
+ "transformers_version": "4.49.0",
41
+ "unembed_rank": null,
42
+ "use_position_embeddings": true,
43
+ "vocab_size": 2048
44
+ }
config.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ class CircuitGPTConfig(PretrainedConfig):
9
+ """
10
+ Minimal Hugging Face config wrapper around the circuit_sparsity GPTConfig.
11
+ Only the fields exercised by the Neuronpedia runs are exposed.
12
+ """
13
+
14
+ model_type = "circuitgpt"
15
+
16
+ def __init__(
17
+ self,
18
+ vocab_size: int = 2048,
19
+ block_size: int = 256,
20
+ n_layer: int = 8,
21
+ n_head: int = 8,
22
+ d_model: int = 1024,
23
+ d_mlp: int | None = None,
24
+ d_head: int | None = None,
25
+ dropout: float = 0.0,
26
+ bias: bool = True,
27
+ ln_bias: bool = True,
28
+ rms_norm: bool = True,
29
+ activation_type: str = "gelu",
30
+ residual_activation_type: str = "identity",
31
+ tied_unembed: bool = False,
32
+ unembed_rank: int | None = None,
33
+ afrac: float | None = None,
34
+ afrac_loctypes: str = "attn_in,attn_out,mlp_in,mlp_out",
35
+ flash: bool = True,
36
+ use_position_embeddings: bool = False,
37
+ sink: bool = False,
38
+ enable_bigram_table: bool = False,
39
+ learnable_bigram_table: bool = False,
40
+ bigram_table_rank: int | None = None,
41
+ dropout_cat_pos_emb: bool = False,
42
+ sinusoidal_cat_pos_emb: bool = False,
43
+ d_pos_emb: int | None = None,
44
+ auto_map: dict[str, str] | None = None,
45
+ **kwargs: Any,
46
+ ) -> None:
47
+ # Drop unsupported/sensitive keys that may be present in a loaded config.
48
+ for key in [
49
+ "afrac_ste",
50
+ "afrac_ste_only_non_neurons",
51
+ "afrac_approx",
52
+ "rtopk",
53
+ "mup",
54
+ "mup_width_multiplier",
55
+ "grad_checkpointing",
56
+ "enable_fp8_linear",
57
+ "scale_invariance",
58
+ "cat_pos_emb",
59
+ ]:
60
+ kwargs.pop(key, None)
61
+ d_mlp = d_mlp or 4 * d_model
62
+ d_head = d_head or d_model // n_head
63
+
64
+ # Avoid duplicate kwargs when loading from a config dict.
65
+ bos_token_id = kwargs.pop("bos_token_id", None)
66
+ eos_token_id = kwargs.pop("eos_token_id", vocab_size - 1)
67
+ pad_token_id = kwargs.pop("pad_token_id", None)
68
+
69
+ super().__init__(
70
+ bos_token_id=bos_token_id,
71
+ eos_token_id=eos_token_id,
72
+ pad_token_id=pad_token_id,
73
+ **kwargs,
74
+ )
75
+
76
+ self.vocab_size = vocab_size
77
+ self.block_size = block_size
78
+ self.max_position_embeddings = block_size
79
+ self.n_layer = n_layer
80
+ self.n_head = n_head
81
+ self.d_model = d_model
82
+ self.d_mlp = d_mlp
83
+ self.d_head = d_head
84
+ self.dropout = dropout
85
+ self.bias = bias
86
+ self.ln_bias = ln_bias
87
+ self.rms_norm = rms_norm
88
+ self.activation_type = activation_type
89
+ self.residual_activation_type = residual_activation_type
90
+ self.tied_unembed = tied_unembed
91
+ self.unembed_rank = unembed_rank
92
+ self.afrac = afrac
93
+ self.afrac_loctypes = afrac_loctypes
94
+ self.flash = flash
95
+ self.use_position_embeddings = use_position_embeddings
96
+ self.d_pos_emb = d_pos_emb
97
+ self.sink = sink
98
+ self.enable_bigram_table = enable_bigram_table
99
+ self.learnable_bigram_table = learnable_bigram_table
100
+ self.bigram_table_rank = bigram_table_rank
101
+ self.dropout_cat_pos_emb = dropout_cat_pos_emb
102
+ self.sinusoidal_cat_pos_emb = sinusoidal_cat_pos_emb
103
+ self.is_decoder = True
104
+ # Provide explicit auto_map entries so AutoModel/AutoConfig can locate
105
+ # the custom classes when trust_remote_code=True on the Hub.
106
+ self.auto_map = auto_map or {
107
+ "AutoConfig": "config.CircuitGPTConfig",
108
+ "AutoModelForCausalLM": "modeling_circuitgpt.CircuitGPTForCausalLM",
109
+ }
110
+
111
+ # ---------------------------------------------------------------------
112
+ # Conversion helpers
113
+ # ---------------------------------------------------------------------
114
+ @classmethod
115
+ def from_circuit_config(cls, circuit_config: "GPTConfig") -> "CircuitGPTConfig": # type: ignore[name-defined]
116
+ config_dict: dict[str, Any] = {
117
+ "vocab_size": circuit_config.vocab_size,
118
+ "block_size": circuit_config.block_size,
119
+ "n_layer": circuit_config.n_layer,
120
+ "n_head": circuit_config.n_head,
121
+ "d_model": circuit_config.d_model,
122
+ "d_mlp": circuit_config.d_mlp,
123
+ "d_head": circuit_config.d_head,
124
+ "dropout": circuit_config.dropout,
125
+ "bias": circuit_config.bias,
126
+ "ln_bias": circuit_config.ln_bias,
127
+ "rms_norm": circuit_config.rms_norm,
128
+ "activation_type": circuit_config.activation_type,
129
+ "residual_activation_type": circuit_config.residual_activation_type,
130
+ "tied_unembed": circuit_config.tied_unembed,
131
+ "unembed_rank": circuit_config.unembed_rank,
132
+ "afrac": circuit_config.afrac,
133
+ "afrac_loctypes": circuit_config.afrac_loctypes,
134
+ "flash": circuit_config.flash,
135
+ "use_position_embeddings": circuit_config.d_pos_emb is not None,
136
+ "d_pos_emb": getattr(circuit_config, "d_pos_emb", None),
137
+ "sink": getattr(circuit_config, "sink", False),
138
+ "enable_bigram_table": getattr(circuit_config, "enable_bigram_table", False),
139
+ "learnable_bigram_table": getattr(circuit_config, "learnable_bigram_table", False),
140
+ "bigram_table_rank": getattr(circuit_config, "bigram_table_rank", None),
141
+ "dropout_cat_pos_emb": getattr(circuit_config, "dropout_cat_pos_emb", False),
142
+ "sinusoidal_cat_pos_emb": getattr(circuit_config, "sinusoidal_cat_pos_emb", False),
143
+ }
144
+ return cls(**config_dict)
145
+
146
+ def to_circuit_config(self) -> "GPTConfig": # type: ignore[name-defined]
147
+ from circuit_sparsity.gpt import GPTConfig as CircuitConfig
148
+
149
+ config_kwargs: dict[str, Any] = dict(
150
+ vocab_size=self.vocab_size,
151
+ block_size=self.block_size,
152
+ n_layer=self.n_layer,
153
+ n_head=self.n_head,
154
+ d_model=self.d_model,
155
+ dropout=self.dropout,
156
+ bias=self.bias,
157
+ ln_bias=self.ln_bias,
158
+ rms_norm=self.rms_norm,
159
+ activation_type=self.activation_type,
160
+ residual_activation_type=self.residual_activation_type,
161
+ tied_unembed=self.tied_unembed,
162
+ unembed_rank=self.unembed_rank,
163
+ afrac=self.afrac,
164
+ afrac_loctypes=self.afrac_loctypes,
165
+ flash=self.flash,
166
+ afrac_ste=False,
167
+ afrac_ste_only_non_neurons=False,
168
+ afrac_approx=False,
169
+ rtopk=False,
170
+ mup=False,
171
+ mup_width_multiplier=None,
172
+ grad_checkpointing=False,
173
+ enable_fp8_linear=False,
174
+ scale_invariance=False,
175
+ d_mlp=self.d_mlp,
176
+ d_head=self.d_head,
177
+ enable_sparse_kernels=False,
178
+ enable_bigram_table=self.enable_bigram_table,
179
+ learnable_bigram_table=self.learnable_bigram_table,
180
+ bigram_table_rank=self.bigram_table_rank,
181
+ d_pos_emb=self.d_pos_emb
182
+ if self.d_pos_emb is not None
183
+ else (self.d_model if self.use_position_embeddings else None),
184
+ sink=self.sink,
185
+ dropout_cat_pos_emb=self.dropout_cat_pos_emb,
186
+ sinusoidal_cat_pos_emb=self.sinusoidal_cat_pos_emb,
187
+ )
188
+ return CircuitConfig(**config_kwargs)
189
+
190
+ def to_dict(self) -> dict[str, Any]:
191
+ base = super().to_dict()
192
+ data = {
193
+ "vocab_size": self.vocab_size,
194
+ "block_size": self.block_size,
195
+ "n_layer": self.n_layer,
196
+ "n_head": self.n_head,
197
+ "d_model": self.d_model,
198
+ "d_mlp": self.d_mlp,
199
+ "d_head": self.d_head,
200
+ "dropout": self.dropout,
201
+ "bias": self.bias,
202
+ "ln_bias": self.ln_bias,
203
+ "rms_norm": self.rms_norm,
204
+ "activation_type": self.activation_type,
205
+ "residual_activation_type": self.residual_activation_type,
206
+ "tied_unembed": self.tied_unembed,
207
+ "unembed_rank": self.unembed_rank,
208
+ "flash": self.flash,
209
+ "afrac": self.afrac,
210
+ "afrac_loctypes": self.afrac_loctypes,
211
+ "use_position_embeddings": self.use_position_embeddings,
212
+ "d_pos_emb": self.d_pos_emb,
213
+ "sink": self.sink,
214
+ "enable_bigram_table": self.enable_bigram_table,
215
+ "learnable_bigram_table": self.learnable_bigram_table,
216
+ "bigram_table_rank": self.bigram_table_rank,
217
+ "dropout_cat_pos_emb": self.dropout_cat_pos_emb,
218
+ "sinusoidal_cat_pos_emb": self.sinusoidal_cat_pos_emb,
219
+ "auto_map": self.auto_map,
220
+ }
221
+ base.update(data)
222
+ return base
generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 2047,
4
+ "transformers_version": "4.49.0"
5
+ }
gpt.py ADDED
@@ -0,0 +1,900 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full definition of a GPT Language Model, all of it in this single file.
3
+ References:
4
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5
+ https://github.com/openai/gpt-2/blob/master/src/model.py
6
+ 2) huggingface/transformers PyTorch implementation:
7
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8
+ """
9
+
10
+ import math
11
+ from dataclasses import dataclass
12
+ from typing import Literal
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ # has to be down here to avoid loading cuda too early
19
+ from .hook_utils import (
20
+ hook_namespace,
21
+ hook_save,
22
+ torch_recompute_preserving_hook_context,
23
+ )
24
+
25
+
26
+ def sample_top_k(*, n: int, k: int, shape: tuple[int, ...]):
27
+ """Fallback sampler used only when sparse kernels are enabled."""
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ return torch.randn(shape, device=device, dtype=torch.float32)
30
+
31
+
32
+ class AbsTopK(nn.Module):
33
+ def __init__(self, k):
34
+ super().__init__()
35
+ self.k = k
36
+
37
+ def forward(self, x):
38
+ vals, inds = torch.topk(x.abs(), self.k, dim=-1, sorted=False)
39
+ ret = torch.zeros_like(x)
40
+ ret.scatter_(-1, inds, x.gather(-1, inds))
41
+ return ret
42
+
43
+
44
+ def barrier():
45
+ # stub
46
+ pass
47
+
48
+
49
+ class LayerNorm(nn.Module):
50
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
51
+
52
+ def __init__(self, ndim, bias):
53
+ super().__init__()
54
+ self.weight = nn.Parameter(torch.ones(ndim))
55
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
56
+
57
+ def forward(self, input):
58
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
59
+
60
+
61
+ class CausalSelfAttention(nn.Module):
62
+ def __init__(self, config):
63
+ super().__init__()
64
+ assert config.d_model % config.n_head == 0
65
+ # key, query, value projections for all heads, but in a batch
66
+ self.c_attn = config.Linear(
67
+ config.d_model, 3 * config.d_head * config.n_head, bias=config.bias
68
+ )
69
+ # output projection
70
+ self.c_proj = config.Linear(config.d_head * config.n_head, config.d_model, bias=config.bias)
71
+ # regularization
72
+ self.attn_dropout = nn.Dropout(config.dropout)
73
+ self.resid_dropout = nn.Dropout(config.dropout)
74
+ self.n_head = config.n_head
75
+ self.d_head = config.d_head
76
+ self.d_model = config.d_model
77
+ self.dropout = config.dropout
78
+
79
+ self.config = config
80
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
81
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and config.flash
82
+
83
+ if self.flash:
84
+ self.attn_imp = (
85
+ SDPAWithSink(config.n_head) if config.sink else F.scaled_dot_product_attention
86
+ )
87
+
88
+ if not self.flash:
89
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
90
+ # causal mask to ensure that attention is only applied to the left in the input sequence
91
+ self.register_buffer(
92
+ "bias",
93
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
94
+ 1, 1, config.block_size, config.block_size
95
+ ),
96
+ )
97
+
98
+ def forward(self, x):
99
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (d_model)
100
+
101
+ x = self.config.maybe_activation_sparsity(x, "attn_in")
102
+ x = hook_save("act_in", x)
103
+
104
+ if self.config.debug_nans:
105
+ assert x.isfinite().all(), "nan in input"
106
+
107
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
108
+ q, k, v = self.c_attn(x).split(self.n_head * self.d_head, dim=2)
109
+
110
+ k = self.config.maybe_activation_sparsity(k, "attn_k")
111
+ q = self.config.maybe_activation_sparsity(q, "attn_q")
112
+ v = self.config.maybe_activation_sparsity(v, "attn_v")
113
+
114
+ k = hook_save("k", k) # (B, T, n_head * d_head)
115
+ q = hook_save("q", q) # (B, T, n_head * d_head)
116
+ v = hook_save("v", v) # (B, T, n_head * d_head)
117
+
118
+ k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2) # (B, nh, T, hs)
119
+ q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2) # (B, nh, T, hs)
120
+ v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2) # (B, nh, T, hs)
121
+
122
+ if self.config.debug_nans:
123
+ assert q.isfinite().all(), "nan in query"
124
+ assert k.isfinite().all(), "nan in key"
125
+ assert v.isfinite().all(), "nan in value"
126
+
127
+ attention_scale = 1.0 / math.sqrt(k.size(-1))
128
+
129
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
130
+ if self.flash:
131
+ # efficient attention using Flash Attention CUDA kernels
132
+ y = self.attn_imp(
133
+ q,
134
+ k,
135
+ v,
136
+ dropout_p=self.dropout if self.training else 0,
137
+ is_causal=True,
138
+ scale=attention_scale,
139
+ )
140
+ else:
141
+ # manual implementation of attention
142
+ att = (q @ k.transpose(-2, -1)) * attention_scale
143
+ att = att.masked_fill(
144
+ self.bias[:, :, :T, :T] == 0, torch.finfo(att.dtype).min
145
+ ) # float("-inf"))
146
+
147
+ att = F.softmax(att, dim=-1)
148
+ att = self.attn_dropout(att)
149
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
150
+
151
+ if self.config.debug_nans:
152
+ assert y.isfinite().all(), "nan in attention output"
153
+
154
+ y = (
155
+ y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.d_head)
156
+ ) # re-assemble all head outputs side by side
157
+
158
+ # y = self.config.maybe_activation_sparsity(y)
159
+ y = hook_save("y", y) # (B, T, n_head * d_head)
160
+
161
+ # output projection
162
+ y = self.resid_dropout(self.c_proj(y))
163
+
164
+ if self.config.debug_nans:
165
+ assert y.isfinite().all(), "nan in attention output 2"
166
+
167
+ y = self.config.maybe_activation_sparsity(y, "attn_out")
168
+ return y
169
+
170
+
171
+ class MLP(nn.Module):
172
+ def __init__(self, config):
173
+ super().__init__()
174
+ self.config = config
175
+ self.c_fc = config.Linear(config.d_model, config.d_mlp, bias=config.bias)
176
+ self.act_fn = {
177
+ "gelu": nn.GELU(),
178
+ "relu": nn.ReLU(),
179
+ }[config.activation_type]
180
+ self.c_proj = config.Linear(config.d_mlp, config.d_model, bias=config.bias)
181
+ self.dropout = nn.Dropout(config.dropout)
182
+
183
+ def forward(self, x):
184
+ x = self.config.maybe_activation_sparsity(x, "mlp_in")
185
+ x = hook_save("act_in", x)
186
+
187
+ if self.config.debug_nans:
188
+ assert x.isfinite().all(), "nan in mlp input"
189
+
190
+ x = self.c_fc(x)
191
+
192
+ if self.config.debug_nans:
193
+ assert x.isfinite().all(), "nan in mlp after c_fc"
194
+
195
+ x = self.act_fn(x)
196
+ x = self.config.maybe_activation_sparsity(x, "mlp_neuron")
197
+ x = hook_save("post_act", x)
198
+
199
+ if self.config.debug_nans:
200
+ assert x.isfinite().all(), "nan in mlp after act"
201
+
202
+ x = self.c_proj(x)
203
+
204
+ if self.config.debug_nans:
205
+ assert x.isfinite().all(), "nan in mlp after c_proj"
206
+ x = self.dropout(x)
207
+
208
+ x = self.config.maybe_activation_sparsity(x, "mlp_out")
209
+ return x
210
+
211
+
212
+ class SDPAWithSink(nn.Module):
213
+ """
214
+ Adds a learnable denominator-only term ("attention sink") to SDPA by
215
+ concatenating a dummy KV slot whose logit is b and whose V is zero.
216
+ """
217
+
218
+ def __init__(self, num_heads: int, init_logit: float = 0.0):
219
+ super().__init__()
220
+ shape = (num_heads,)
221
+ self.sink_logit = nn.Parameter(torch.full(shape, init_logit))
222
+
223
+ def forward(
224
+ self,
225
+ q: torch.Tensor, # (B, H, Lq, D)
226
+ k: torch.Tensor, # (B, H, Lk, D)
227
+ v: torch.Tensor, # (B, H, Lk, Dv)
228
+ *,
229
+ dropout_p: float = 0.0,
230
+ is_causal: bool = False,
231
+ scale: float | None = None,
232
+ ) -> torch.Tensor:
233
+ B, H, Lq, D = q.shape
234
+ _, _, Lk, _ = k.shape
235
+ Dv = v.size(-1)
236
+
237
+ # 1) Prepend a dummy KV slot (always visible)
238
+ k_sink = torch.zeros((B, H, 1, D), dtype=q.dtype, device=q.device)
239
+ v_sink = torch.zeros((B, H, 1, Dv), dtype=v.dtype, device=v.device)
240
+ k_aug = torch.cat([k_sink, k], dim=2) # (B,H,Lk+1,D)
241
+ v_aug = torch.cat([v_sink, v], dim=2) # (B,H,Lk+1,Dv)
242
+
243
+ # 2) Build shifted causal allow-mask over keys (columns 1..), always allow col 0 (sink)
244
+ # allow: 1 where attending is allowed, 0 where disallowed
245
+ # For real keys: allow[i, j+1] = 1 if j <= i else 0 (lower-triangular)
246
+ allow = torch.zeros((Lq, Lk + 1), dtype=torch.bool, device=q.device)
247
+ allow[:, 0] = True # sink column always on
248
+ # lower-triangular for real keys shifted by +1
249
+ real = torch.ones((Lq, Lk), dtype=torch.bool, device=q.device).tril()
250
+ allow[:, 1:] = real
251
+
252
+ # Broadcast to (B,H,Lq,Lk+1)
253
+ allow = allow.view(1, 1, Lq, Lk + 1).expand(B, H, Lq, Lk + 1)
254
+
255
+ # 3) Turn it into an additive mask. 0 for allowed, -inf for disallowed
256
+ neg_inf = torch.finfo(q.dtype).min
257
+ base_mask = torch.where(
258
+ allow,
259
+ torch.zeros((), dtype=q.dtype, device=q.device),
260
+ torch.full((), neg_inf, dtype=q.dtype, device=q.device),
261
+ ) # (B,H,Lq,Lk+1)
262
+
263
+ # 4) Add learnable sink bias b to column 0 (per head or shared)
264
+ if self.sink_logit.numel() == H:
265
+ b = self.sink_logit.to(dtype=q.dtype, device=q.device).view(1, H, 1, 1) # (1,H,1,1)
266
+ else:
267
+ b = self.sink_logit.to(dtype=q.dtype, device=q.device).view(1, 1, 1, 1) # (1,1,1,1)
268
+
269
+ sink_bias_mask = torch.zeros((1, 1, 1, Lk + 1), dtype=q.dtype, device=q.device)
270
+ sink_bias_mask[..., 0] = 1.0
271
+ attn_mask = base_mask + sink_bias_mask * b # (B,H,Lq,Lk+1)
272
+
273
+ # 5) SDPA with our custom mask; keep is_causal=False to avoid double-masking
274
+ out = F.scaled_dot_product_attention(
275
+ q,
276
+ k_aug,
277
+ v_aug,
278
+ attn_mask=attn_mask,
279
+ dropout_p=dropout_p,
280
+ is_causal=False, # important
281
+ scale=scale,
282
+ )
283
+ return out
284
+
285
+
286
+ class Block(nn.Module):
287
+ # block exactly satisfies the invariant that forward = forward_mlp_block . forward_attn_block
288
+ def __init__(self, config):
289
+ super().__init__()
290
+ self.config = config
291
+
292
+ self.ln_1 = (
293
+ nn.RMSNorm(config.d_model)
294
+ if config.rms_norm
295
+ else LayerNorm(config.d_model, bias=config.ln_bias)
296
+ )
297
+ self.attn = CausalSelfAttention(config)
298
+ self.ln_2 = (
299
+ nn.RMSNorm(config.d_model)
300
+ if config.rms_norm
301
+ else LayerNorm(config.d_model, bias=config.ln_bias)
302
+ )
303
+ self.mlp = MLP(config)
304
+
305
+ def forward_attn_block(self, x):
306
+ x = hook_save("resid_in", x)
307
+
308
+ if self.config.debug_nans:
309
+ assert x.isfinite().all(), "nan in blk input"
310
+
311
+ with hook_namespace("attn"):
312
+ if self.config.grad_checkpointing:
313
+ x = x + hook_save(
314
+ "resid_delta",
315
+ torch_recompute_preserving_hook_context(
316
+ lambda x: self.attn(self.ln_1(x)), x, use_reentrant=False
317
+ ),
318
+ )
319
+ else:
320
+ x = x + hook_save("resid_delta", self.attn(self.ln_1(x)))
321
+
322
+ if self.config.residual_activation_type == "relu":
323
+ x = torch.relu(x)
324
+ x = self.config.maybe_activation_sparsity(x, "resid_post_attn")
325
+
326
+ return x
327
+
328
+ def forward_mlp_block(self, x):
329
+ x = hook_save("resid_mid", x)
330
+ with hook_namespace("mlp"):
331
+ if self.config.grad_checkpointing:
332
+ x = x + hook_save(
333
+ "resid_delta",
334
+ torch_recompute_preserving_hook_context(
335
+ lambda x: self.mlp(self.ln_2(x)), x, use_reentrant=False
336
+ ),
337
+ )
338
+ else:
339
+ x = x + hook_save("resid_delta", self.mlp(self.ln_2(x)))
340
+
341
+ if self.config.residual_activation_type == "relu":
342
+ x = torch.relu(x)
343
+ x = self.config.maybe_activation_sparsity(x, "resid_post_mlp")
344
+ return x
345
+
346
+ def forward(self, x):
347
+ x = self.forward_attn_block(x)
348
+ x = self.forward_mlp_block(x)
349
+ return x
350
+
351
+
352
+ class CausalSelfAttentionCatPosEmb(CausalSelfAttention):
353
+ def __init__(self, config):
354
+ # initialize base attention with standard shapes, we'll override projections
355
+ super().__init__(config)
356
+ assert config.d_model % config.n_head == 0
357
+ # key, query, value projections for all heads, but in a batch
358
+ self.c_attn = config.Linear(
359
+ config.d_model_in, 3 * config.d_head * config.n_head, bias=config.bias
360
+ )
361
+ # output projection
362
+ self.c_proj = config.Linear(config.d_head * config.n_head, config.d_model, bias=config.bias)
363
+ # regularization
364
+ self.attn_dropout = nn.Dropout(config.dropout)
365
+ self.resid_dropout = nn.Dropout(config.dropout)
366
+ self.n_head = config.n_head
367
+ self.d_head = config.d_head
368
+ self.d_model_in = config.d_model_in
369
+ self.d_model = config.d_model
370
+ self.dropout = config.dropout
371
+ self.config = config
372
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
373
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and config.flash
374
+
375
+ if not self.flash:
376
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
377
+ # causal mask to ensure that attention is only applied to the left in the input sequence
378
+ self.register_buffer(
379
+ "bias",
380
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
381
+ 1, 1, config.block_size, config.block_size
382
+ ),
383
+ )
384
+
385
+ def forward(self, x, pos_emb_to_cat):
386
+ # Broadcast pos emb over batch if provided as shape [1, T, C]
387
+ if pos_emb_to_cat is not None and pos_emb_to_cat.size(0) == 1 and x.size(0) != 1:
388
+ pos_emb_to_cat = pos_emb_to_cat.expand(x.size(0), -1, -1)
389
+ x = torch.cat([x, pos_emb_to_cat], dim=-1)
390
+ return super().forward(x)
391
+
392
+
393
+ class MLPCatPosEmb(MLP):
394
+ def __init__(self, config):
395
+ # initialize base MLP, we'll override the projections to match cat shapes
396
+ super().__init__(config)
397
+ self.config = config
398
+ self.c_fc = config.Linear(config.d_model_in, config.d_mlp, bias=config.bias)
399
+ self.act_fn = {
400
+ "gelu": nn.GELU(),
401
+ "relu": nn.ReLU(),
402
+ }[config.activation_type]
403
+ self.c_proj = config.Linear(config.d_mlp, config.d_model, bias=config.bias)
404
+ self.dropout = nn.Dropout(config.dropout)
405
+
406
+ def forward(self, x, pos_emb_to_cat):
407
+ # Broadcast pos emb over batch if provided as shape [1, T, C]
408
+ if pos_emb_to_cat is not None and pos_emb_to_cat.size(0) == 1 and x.size(0) != 1:
409
+ pos_emb_to_cat = pos_emb_to_cat.expand(x.size(0), -1, -1)
410
+ x = torch.cat([x, pos_emb_to_cat], dim=-1)
411
+ x = super().forward(x)
412
+ return x
413
+
414
+
415
+ class BlockCatPosEmb(Block):
416
+ # block exactly satisfies the invariant that forward = forward_mlp_block . forward_attn_block
417
+ def __init__(self, config):
418
+ # initialize base Block to get ln_1/ln_2 and other invariants
419
+ super().__init__(config)
420
+ self.ln_p1 = (
421
+ nn.RMSNorm(config.d_pos_emb)
422
+ if config.rms_norm
423
+ else LayerNorm(config.d_pos_emb, bias=config.ln_bias)
424
+ )
425
+ self.ln_p2 = (
426
+ nn.RMSNorm(config.d_pos_emb)
427
+ if config.rms_norm
428
+ else LayerNorm(config.d_pos_emb, bias=config.ln_bias)
429
+ )
430
+ self.attn = CausalSelfAttentionCatPosEmb(config)
431
+ self.mlp = MLPCatPosEmb(config)
432
+
433
+ def forward_attn_block(self, x, p):
434
+ x = hook_save("resid_in", x)
435
+
436
+ if self.config.debug_nans:
437
+ assert x.isfinite().all(), "nan in blk input"
438
+
439
+ with hook_namespace("attn"):
440
+ if self.config.grad_checkpointing:
441
+ x = x + hook_save(
442
+ "resid_delta",
443
+ torch_recompute_preserving_hook_context(
444
+ lambda x, p: self.attn(self.ln_1(x), self.ln_p1(p)),
445
+ x,
446
+ p,
447
+ use_reentrant=False,
448
+ ),
449
+ )
450
+ else:
451
+ x = x + hook_save("resid_delta", self.attn(self.ln_1(x), self.ln_p1(p)))
452
+
453
+ if self.config.residual_activation_type == "relu":
454
+ x = torch.relu(x)
455
+ x = self.config.maybe_activation_sparsity(x, "resid_post_attn")
456
+
457
+ return x
458
+
459
+ def forward_mlp_block(self, x, p):
460
+ x = hook_save("resid_mid", x)
461
+ with hook_namespace("mlp"):
462
+ if self.config.grad_checkpointing:
463
+ x = x + hook_save(
464
+ "resid_delta",
465
+ torch_recompute_preserving_hook_context(
466
+ lambda x, p: self.mlp(self.ln_2(x), self.ln_p2(p)),
467
+ x,
468
+ p,
469
+ use_reentrant=False,
470
+ ),
471
+ )
472
+ else:
473
+ x = x + hook_save("resid_delta", self.mlp(self.ln_2(x), self.ln_p2(p)))
474
+
475
+ if self.config.residual_activation_type == "relu":
476
+ x = torch.relu(x)
477
+ x = self.config.maybe_activation_sparsity(x, "resid_post_mlp")
478
+ return x
479
+
480
+ def forward(self, x, pos_emb_to_cat):
481
+ x = self.forward_attn_block(x, pos_emb_to_cat)
482
+ x = self.forward_mlp_block(x, pos_emb_to_cat)
483
+ return x
484
+
485
+
486
+ @dataclass
487
+ class GPTConfig:
488
+ block_size: int = 1024
489
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency # TODO: FLAG FOR ACHY
490
+ n_layer: int = 12
491
+ n_head: int = 12
492
+ d_head: int | None = None # defaults to d_model // n_head
493
+ d_model: int = 768
494
+ dropout: float = 0.0
495
+ bias: bool = (
496
+ True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
497
+ )
498
+ ln_bias: bool = (
499
+ True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
500
+ )
501
+ rms_norm: bool = False # use RMSNorm instead of LayerNorm
502
+ residual_activation_type: Literal["identity", "relu"] = "identity"
503
+ activation_type: Literal["gelu", "relu"] = "gelu"
504
+ afrac: float | None = None # fraction of activations to keep
505
+ afrac_loctypes: str = "attn_in,attn_out,mlp_in,mlp_out"
506
+ debug_nans: bool = False
507
+ tied_unembed: bool = True
508
+
509
+ tokenizer_name: str = "tinypython_2k"
510
+
511
+ grad_checkpointing: bool = True
512
+ d_mlp: int | None = None # multiplier for MLP hidden layer size
513
+
514
+ enable_bigram_table: bool = False
515
+ learnable_bigram_table: bool = False
516
+ d_pos_emb: int | None = None
517
+ dropout_cat_pos_emb: bool = False
518
+ sinusoidal_cat_pos_emb: bool = False
519
+ enable_sparse_kernels: bool = False
520
+
521
+ flash: bool = True
522
+ sink: bool = False
523
+
524
+ @property
525
+ def cat_pos_emb(self):
526
+ return self.d_pos_emb is not None
527
+
528
+ @property
529
+ def d_model_in(self):
530
+ return self.d_model + self.d_pos_emb if self.cat_pos_emb else self.d_model
531
+
532
+ def __post_init__(self):
533
+ assert self.d_model % self.n_head == 0
534
+ assert self.residual_activation_type in ["identity", "relu"]
535
+ assert self.activation_type in ["gelu", "relu"]
536
+
537
+ if self.d_mlp is None:
538
+ self.d_mlp = 4 * self.d_model
539
+ if self.d_head is None:
540
+ self.d_head = self.d_model // self.n_head
541
+
542
+ @property
543
+ def Linear(self):
544
+ return nn.Linear
545
+
546
+ def maybe_activation_sparsity(self, x, loctype):
547
+ if self.afrac is not None and loctype in self.afrac_loctypes.split(","):
548
+
549
+ def keep_abstopk(x, k):
550
+ ret = torch.zeros_like(x)
551
+ _, topk_inds = torch.topk(x.abs(), k, dim=-1, sorted=False)
552
+ ret.scatter_(-1, topk_inds, x.gather(-1, topk_inds))
553
+ return ret
554
+
555
+ x = keep_abstopk(
556
+ x,
557
+ k=int(self.afrac * x.shape[-1]),
558
+ )
559
+
560
+ return x
561
+
562
+
563
+ class GPT(nn.Module):
564
+ def __init__(self, config):
565
+ super().__init__()
566
+ assert config.vocab_size is not None
567
+ assert config.block_size is not None
568
+ self.config = config
569
+
570
+ if config.cat_pos_emb:
571
+ block_cls = BlockCatPosEmb
572
+ else:
573
+ block_cls = Block
574
+
575
+ self.transformer = nn.ModuleDict(
576
+ dict(
577
+ wte=nn.Embedding(config.vocab_size, config.d_model),
578
+ wpe=nn.Embedding(config.block_size, config.d_pos_emb or config.d_model),
579
+ drop=nn.Dropout(config.dropout),
580
+ h=nn.ModuleList([(block_cls(config)) for _ in range(config.n_layer)]),
581
+ ln_f=nn.RMSNorm(config.d_model)
582
+ if config.rms_norm
583
+ else LayerNorm(config.d_model, bias=config.ln_bias),
584
+ )
585
+ )
586
+
587
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
588
+
589
+ self.register_buffer(
590
+ "final_logits_bias", torch.zeros(config.vocab_size, dtype=torch.float32)
591
+ )
592
+
593
+ if self.config.enable_bigram_table:
594
+ if self.config.learnable_bigram_table:
595
+ # HACK: low rank to fit in mem
596
+ self.bigram_table = nn.Parameter(
597
+ torch.zeros(config.vocab_size, config.vocab_size, dtype=torch.float32)
598
+ )
599
+ else:
600
+ self.register_buffer(
601
+ "bigram_table",
602
+ torch.zeros(config.vocab_size, config.vocab_size, dtype=torch.float32),
603
+ )
604
+ else:
605
+ self.bigram_table = None
606
+
607
+ # Never tie embeddings/unembed to avoid accidental aliasing in exports.
608
+ config.tied_unembed = False
609
+
610
+ # init all weights
611
+ self.apply(self._init_weights)
612
+ # apply special scaled init to the residual projections, per GPT-2 paper
613
+ for pn, p in self.named_parameters():
614
+ if pn.endswith("c_proj.weight"):
615
+ if p.is_sparse:
616
+ num_nonzero = p._nnz()
617
+ p._values().data = (
618
+ sample_top_k(n=p.numel(), k=num_nonzero, shape=(num_nonzero,))
619
+ * 0.02
620
+ / math.sqrt(2 * config.n_layer)
621
+ )
622
+ else:
623
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
624
+
625
+ # If requested, initialize positional embeddings with fixed sinusoids and freeze
626
+ if config.cat_pos_emb and config.sinusoidal_cat_pos_emb:
627
+ assert config.d_pos_emb is not None, (
628
+ "sinusoidal_cat_pos_emb requires cat_pos_emb (d_pos_emb must be set)"
629
+ )
630
+ with torch.no_grad():
631
+ T = config.block_size
632
+ D = config.d_pos_emb
633
+ device = self.transformer.wpe.weight.device
634
+ dtype = self.transformer.wpe.weight.dtype
635
+ positions = torch.arange(T, device=device, dtype=dtype).unsqueeze(1) # [T,1]
636
+ d_half = max(1, D // 2)
637
+ # periods from 4 tokens up to block_size tokens (log-spaced)
638
+ T_float = float(T)
639
+ p_min = 4.0
640
+ p_max = max(p_min, T_float)
641
+ periods = torch.logspace(
642
+ math.log10(p_min), math.log10(p_max), steps=d_half, device=device, dtype=dtype
643
+ )
644
+ freqs = 2 * math.pi / periods # [d_half]
645
+ angles = positions * freqs # [T, d_half]
646
+ sinv = torch.sin(angles)
647
+ cosv = torch.cos(angles)
648
+ enc = torch.cat([sinv, cosv], dim=1) # [T, 2*d_half]
649
+ if enc.shape[1] < D:
650
+ pad = torch.zeros(T, D - enc.shape[1], device=device, dtype=dtype)
651
+ enc = torch.cat([enc, pad], dim=1)
652
+ elif enc.shape[1] > D:
653
+ enc = enc[:, :D]
654
+ self.transformer.wpe.weight.copy_(enc)
655
+ self.transformer.wpe.weight.requires_grad_(False)
656
+
657
+ # report number of parameters
658
+ print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
659
+
660
+ def get_num_params(self, non_embedding=True):
661
+ """
662
+ Return the number of parameters in the model.
663
+ For non-embedding count (default), the position embeddings get subtracted.
664
+ The token embeddings would too, except due to the parameter sharing these
665
+ params are actually used as weights in the final layer, so we include them.
666
+ """
667
+ n_params = sum(p.numel() for p in self.parameters())
668
+ if non_embedding:
669
+ n_params -= self.transformer.wpe.weight.numel()
670
+ return n_params
671
+
672
+ def _init_weights(self, module):
673
+ if isinstance(module, nn.Linear):
674
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
675
+ if module.bias is not None:
676
+ torch.nn.init.zeros_(module.bias)
677
+ elif isinstance(module, nn.Embedding):
678
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
679
+
680
+ def forward(self, idx, targets=None, include_resid_mid=False):
681
+ device = idx.device
682
+ b, t = idx.size()
683
+
684
+ assert t <= self.config.block_size, (
685
+ f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
686
+ )
687
+ # pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
688
+
689
+ # forward the GPT model itself
690
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, d_model)
691
+ # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, d_model)
692
+ pos_emb = self.transformer.wpe.weight[:t].unsqueeze(0)
693
+ if self.config.cat_pos_emb:
694
+ x = self.transformer.drop(tok_emb)
695
+ else:
696
+ x = self.transformer.drop(tok_emb + pos_emb)
697
+
698
+ if self.config.debug_nans:
699
+ assert x.isfinite().all(), "nan in initial post-embedding"
700
+
701
+ if self.config.enable_bigram_table:
702
+ # add bigram table to the logits bias
703
+ additional_logits_bias = F.embedding(idx, self.bigram_table, padding_idx=-1)
704
+ additional_logits_bias = additional_logits_bias.to(x.dtype)
705
+ else:
706
+ additional_logits_bias = None
707
+
708
+ if self.config.cat_pos_emb:
709
+ pos_emb_to_cat = pos_emb
710
+ if self.config.dropout_cat_pos_emb:
711
+ pos_emb_to_cat = self.transformer.drop(pos_emb)
712
+ else:
713
+ pos_emb_to_cat = None
714
+
715
+ return self.forward_tail(
716
+ x,
717
+ n=0,
718
+ targets=targets,
719
+ additional_logits_bias=additional_logits_bias,
720
+ include_resid_mid=include_resid_mid, # this is hacky we should just switch to using hooks
721
+ pos_emb_to_cat=pos_emb_to_cat,
722
+ )
723
+
724
+ def forward_tail(
725
+ self,
726
+ x,
727
+ n,
728
+ targets=None,
729
+ additional_logits_bias=None,
730
+ include_resid_mid=False,
731
+ pos_emb_to_cat=None,
732
+ ):
733
+ # print(x.shape)
734
+ hs = []
735
+ blks = list(self.transformer.h)
736
+
737
+ if include_resid_mid:
738
+ blks = list_join(
739
+ [
740
+ [
741
+ blk.forward_attn_block,
742
+ blk.forward_mlp_block,
743
+ ]
744
+ for blk in blks
745
+ ]
746
+ )
747
+
748
+ assert n <= len(blks)
749
+
750
+ for i, block_fn in enumerate(blks[n:]):
751
+ global curlayer
752
+ curlayer = i
753
+ with hook_namespace(f"{i // 2}") if include_resid_mid else hook_namespace(f"{i}"):
754
+ hs.append(x)
755
+ if self.config.cat_pos_emb:
756
+ x = block_fn(x, pos_emb_to_cat)
757
+ else:
758
+ x = block_fn(x)
759
+
760
+ x = hook_save("final_resid", x)
761
+ x = self.transformer.ln_f(x)
762
+
763
+ logits = (
764
+ self.lm_head(x)
765
+ + self.final_logits_bias
766
+ + (additional_logits_bias if additional_logits_bias is not None else 0)
767
+ )
768
+ if targets is not None:
769
+ loss = F.cross_entropy(
770
+ logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
771
+ )
772
+ else:
773
+ loss = torch.zeros(1, device=x.device)
774
+
775
+ return logits, loss, hs # hs is deprecated in favor of hook stuff
776
+
777
+ def crop_block_size(self, block_size):
778
+ # model surgery to decrease the block size if necessary
779
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
780
+ # but want to use a smaller block size for some smaller, simpler model
781
+ assert block_size <= self.config.block_size
782
+ self.config.block_size = block_size
783
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
784
+ for block in self.transformer.h:
785
+ if hasattr(block.attn, "bias"):
786
+ block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
787
+
788
+ @torch.no_grad()
789
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
790
+ """
791
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
792
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
793
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
794
+ """
795
+ for _ in range(max_new_tokens):
796
+ # if the sequence context is growing too long we must crop it at block_size
797
+ idx_cond = (
798
+ idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :]
799
+ )
800
+ # forward the model to get the logits for the index in the sequence
801
+ logits, _, _ = self(idx_cond)
802
+ # pluck the logits at the final step and scale by desired temperature
803
+ logits = logits[:, -1, :] / temperature
804
+ # optionally crop the logits to only the top k options
805
+ if top_k is not None:
806
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
807
+ logits[logits < v[:, -1:]] = -float("Inf")
808
+ # apply softmax to convert logits to (normalized) probabilities
809
+ probs = F.softmax(logits, dim=-1)
810
+ # sample from the distribution
811
+ idx_next = torch.multinomial(probs, num_samples=1)
812
+ # append sampled index to the running sequence and continue
813
+ idx = torch.cat((idx, idx_next), dim=1)
814
+
815
+ return idx
816
+
817
+ def is_mlp_param(self, p):
818
+ return id(p) in list_join(
819
+ [
820
+ [
821
+ id(self.transformer.h[i].mlp.c_fc.weight),
822
+ id(self.transformer.h[i].mlp.c_proj.weight),
823
+ ]
824
+ for i in range(self.config.n_layer)
825
+ ]
826
+ )
827
+
828
+ def is_param_embed(self, p):
829
+ return p is self.transformer.wte.weight or p is self.transformer.wpe.weight
830
+
831
+ def is_attn_param(self, p):
832
+ return id(p) in list_join(
833
+ [
834
+ [
835
+ id(self.transformer.h[i].attn.c_attn.weight),
836
+ id(self.transformer.h[i].attn.c_proj.weight),
837
+ ]
838
+ for i in range(self.config.n_layer)
839
+ ]
840
+ )
841
+
842
+ def is_bias(self, p):
843
+ return id(p) in list_join(
844
+ [
845
+ [
846
+ id(self.transformer.h[i].attn.c_attn.bias),
847
+ id(self.transformer.h[i].attn.c_proj.bias),
848
+ id(self.transformer.h[i].mlp.c_fc.bias),
849
+ id(self.transformer.h[i].mlp.c_proj.bias),
850
+ ]
851
+ for i in range(self.config.n_layer)
852
+ ]
853
+ )
854
+
855
+ def is_ln_param(self, p):
856
+ return id(p) in list_join(
857
+ [
858
+ [
859
+ id(self.transformer.h[i].ln_1.weight),
860
+ id(self.transformer.h[i].ln_2.weight),
861
+ ]
862
+ for i in range(self.config.n_layer)
863
+ ]
864
+ ) + [
865
+ id(self.transformer.ln_f.weight),
866
+ ]
867
+
868
+ def is_sparse_param(self, p, dense_embeddings=None, dense_unembed=None, dense_biases=None):
869
+ # if these params aren't specified, then still give answers, but only for uncontroversial params
870
+
871
+ if dense_embeddings is None:
872
+ assert p is not self.transformer.wte.weight and p is not self.transformer.wpe.weight
873
+ if dense_unembed is None:
874
+ assert p is not self.lm_head.weight
875
+ if dense_biases is None:
876
+ assert not self.is_bias(p)
877
+
878
+ if p is self.transformer.wte.weight or p is self.transformer.wpe.weight:
879
+ return not dense_embeddings
880
+ if p is self.lm_head.weight:
881
+ return not dense_unembed
882
+ if self.is_bias(p):
883
+ return not dense_biases
884
+
885
+ return id(p) in list_join(
886
+ [
887
+ [
888
+ id(self.transformer.h[i].attn.c_attn.weight),
889
+ id(self.transformer.h[i].attn.c_proj.weight),
890
+ id(self.transformer.h[i].mlp.c_fc.weight),
891
+ id(self.transformer.h[i].mlp.c_proj.weight),
892
+ ]
893
+ for i in range(self.config.n_layer)
894
+ ]
895
+ )
896
+
897
+
898
+ def list_join(xss: list[list]) -> list:
899
+ """monadic join for lists"""
900
+ return [x for xs in xss for x in xs]
hook_utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Self-contained subset of :mod:`circuit_sparsity.hook_utils` for inference builds.
2
+
3
+ The full module has no exotic dependencies, but mirroring the definitions here
4
+ keeps the trimmed :mod:`circuit_sparsity.inference.gpt` module hermetic and easy to vendor. The
5
+ implementations below are copied with minor tweaks for readability so that code
6
+ written against :func:`hook_recorder`, :func:`hook_namespace`, and
7
+ :func:`torch_recompute_preserving_hook_context` behaves identically in both the
8
+ training and inference configurations.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import re
14
+ from contextlib import contextmanager
15
+ from functools import partial
16
+
17
+ import torch
18
+ import torch.utils.checkpoint
19
+
20
+
21
+ class HookContext:
22
+ """State container used by the hook helpers."""
23
+
24
+ def __init__(self) -> None:
25
+ self._reset()
26
+ self.curintervtransformer = lambda x: x
27
+
28
+ def _reset(self) -> None:
29
+ self.curcontext = None
30
+ self.curname = ""
31
+ self.curregex = None
32
+ self.curinterventions = None
33
+ self.save_grads = None
34
+
35
+ def _get_interventions(self):
36
+ return self.curintervtransformer(
37
+ self.curinterventions if self.curinterventions is not None else {}
38
+ )
39
+
40
+ @contextmanager
41
+ def hook_recorder(self, regex: str = ".*", interventions=None, save_grads: bool = False):
42
+ """Record tensors that pass through hooks matching ``regex``."""
43
+
44
+ assert self.curcontext is None, "reentrancy not allowed!"
45
+
46
+ try:
47
+ self.curcontext = {}
48
+ self.curregex = re.compile(regex)
49
+ self.curname = ""
50
+ self.curinterventions = interventions
51
+ self.save_grads = save_grads
52
+
53
+ yield self.curcontext
54
+ finally:
55
+ self._reset()
56
+ get_context()._reset()
57
+
58
+ @contextmanager
59
+ def hook_intervention_transform(self, intervention_transformer):
60
+ oldintervention_transformer = self.curintervtransformer
61
+
62
+ def compose(f, g):
63
+ return lambda x: f(g(x))
64
+
65
+ self.curintervtransformer = compose(
66
+ intervention_transformer,
67
+ self.curintervtransformer,
68
+ )
69
+
70
+ try:
71
+ yield
72
+ finally:
73
+ self.curintervtransformer = oldintervention_transformer
74
+
75
+ @contextmanager
76
+ def hook_namespace(self, name: str):
77
+ """Temporarily push ``name`` onto the hook namespace stack."""
78
+
79
+ oldname = self.curname
80
+ self.curname = self.curname + name + "."
81
+
82
+ try:
83
+ yield
84
+ finally:
85
+ self.curname = oldname
86
+
87
+ def hook_save(self, name: str, tensor: torch.Tensor) -> torch.Tensor:
88
+ """Optionally record ``tensor`` using the current namespace."""
89
+
90
+ curinterventions = self._get_interventions()
91
+ if curinterventions is not None:
92
+ key = self.curname + name
93
+ if key in curinterventions:
94
+ tensor = curinterventions[key](tensor)
95
+
96
+ if self.curcontext is not None and self.curregex.match(self.curname + name):
97
+ self.curcontext[self.curname + name] = tensor
98
+
99
+ if self.curcontext is not None and self.save_grads and tensor.requires_grad:
100
+
101
+ class _Grad(torch.autograd.Function):
102
+ @staticmethod
103
+ def forward(ctx, input_tensor):
104
+ return input_tensor
105
+
106
+ @staticmethod
107
+ def backward(ctx, grad_output):
108
+ self.curcontext[self.curname + name + ".grad"] = grad_output
109
+ return grad_output
110
+
111
+ if self.curregex.match(self.curname + name + ".grad"):
112
+ tensor = _Grad.apply(tensor)
113
+
114
+ return tensor
115
+
116
+
117
+ def set_context(new_context: HookContext) -> None:
118
+ global context
119
+ context = new_context
120
+
121
+
122
+ def get_context() -> HookContext:
123
+ global context
124
+ return context
125
+
126
+
127
+ def torch_recompute_preserving_hook_context(f, *xs, use_reentrant=None):
128
+ """Wrapper around :func:`torch.utils.checkpoint` that propagates hooks."""
129
+
130
+ oldcontext = get_context()
131
+ curcontext = HookContext()
132
+ curcontext.curcontext = (
133
+ dict(oldcontext.curcontext) if oldcontext.curcontext is not None else None
134
+ )
135
+ curcontext.curregex = oldcontext.curregex
136
+ curcontext.curname = oldcontext.curname
137
+ curcontext.curinterventions = (
138
+ dict(oldcontext.curinterventions) if oldcontext.curinterventions is not None else None
139
+ )
140
+ curcontext.save_grads = oldcontext.save_grads
141
+
142
+ is_recompute = False
143
+
144
+ def _f(curcontext: HookContext, *xs):
145
+ initcontext = get_context()
146
+ nonlocal is_recompute
147
+
148
+ set_context(curcontext)
149
+ try:
150
+ res = f(*xs)
151
+
152
+ if not is_recompute and oldcontext.curcontext is not None:
153
+ oldcontext.curcontext |= curcontext.curcontext
154
+ finally:
155
+ set_context(initcontext)
156
+ is_recompute = True
157
+ return res
158
+
159
+ res = torch.utils.checkpoint.checkpoint(
160
+ partial(_f, curcontext), *xs, use_reentrant=use_reentrant
161
+ )
162
+
163
+ return res
164
+
165
+
166
+ context = HookContext()
167
+
168
+
169
+ def hook_recorder(*a, **k):
170
+ return get_context().hook_recorder(*a, **k)
171
+
172
+
173
+ def hook_namespace(*a, **k):
174
+ return get_context().hook_namespace(*a, **k)
175
+
176
+
177
+ def hook_save(*a, **k):
178
+ return get_context().hook_save(*a, **k)
179
+
180
+
181
+ def hook_intervention_transform(*a, **k):
182
+ return get_context().hook_intervention_transform(*a, **k)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7a63c67d78a1e8da5fc3ead1338fe0f18994d48deedeb4374a6b3ad1d350724
3
+ size 1676511896
modeling_circuitgpt.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from dataclasses import dataclass
5
+ from typing import Sequence
6
+
7
+ import torch
8
+ from torch import nn
9
+ from transformers.generation.utils import GenerationMixin
10
+ from transformers.modeling_utils import PreTrainedModel
11
+ from transformers.utils.generic import ModelOutput
12
+
13
+ from .config import CircuitGPTConfig
14
+ from .gpt import GPT
15
+ from .hook_utils import hook_recorder
16
+
17
+
18
+ @dataclass
19
+ class CircuitGPTCausalLMOutput(ModelOutput):
20
+ loss: torch.Tensor | None = None
21
+ logits: torch.Tensor | None = None
22
+ activations: dict[str, torch.Tensor] | None = None
23
+
24
+
25
+ def _activations_regex(keys: Sequence[str]) -> str:
26
+ escaped = (re.escape(k) for k in keys)
27
+ return "^(" + "|".join(escaped) + ")$"
28
+
29
+
30
+ class CircuitGPTPreTrainedModel(PreTrainedModel):
31
+ config_class = CircuitGPTConfig
32
+ base_model_prefix = "circuit_model"
33
+ circuit_model: GPT
34
+
35
+ def __init__(self, config: CircuitGPTConfig, *inputs, **kwargs) -> None:
36
+ super().__init__(config, *inputs, **kwargs)
37
+
38
+ def get_input_embeddings(self) -> nn.Module:
39
+ return self.circuit_model.transformer.wte # type: ignore[return-value]
40
+
41
+ def set_input_embeddings(self, value: nn.Module) -> None:
42
+ self.circuit_model.transformer.wte = value # type: ignore[assignment]
43
+
44
+ def get_output_embeddings(self) -> nn.Module:
45
+ return self.circuit_model.lm_head # type: ignore[return-value]
46
+
47
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
48
+ self.circuit_model.lm_head = new_embeddings # type: ignore[assignment]
49
+
50
+
51
+ class CircuitGPTForCausalLM(CircuitGPTPreTrainedModel, GenerationMixin):
52
+ """
53
+ Hugging Face-compatible wrapper around `circuit_sparsity.gpt.GPT`.
54
+ All math happens inside the original module so parity is guaranteed.
55
+ """
56
+
57
+ def __init__(self, config: CircuitGPTConfig, circuit_model: GPT | None = None) -> None:
58
+ super().__init__(config)
59
+
60
+ if circuit_model is None:
61
+ self.circuit_model = GPT(config.to_circuit_config())
62
+ self.post_init()
63
+ else:
64
+ self.circuit_model = circuit_model
65
+
66
+ # ------------------------------------------------------------------
67
+ # Constructors
68
+ # ------------------------------------------------------------------
69
+ @classmethod
70
+ def from_circuit_model(cls, circuit_model: GPT) -> "CircuitGPTForCausalLM":
71
+ config = CircuitGPTConfig.from_circuit_config(circuit_model.config)
72
+ return cls(config, circuit_model=circuit_model)
73
+
74
+ # ------------------------------------------------------------------
75
+ # Forward
76
+ # ------------------------------------------------------------------
77
+ def forward(
78
+ self,
79
+ input_ids: torch.Tensor,
80
+ labels: torch.LongTensor | None = None,
81
+ output_activations: Sequence[str] | None = None,
82
+ return_dict: bool | None = None,
83
+ use_cache: bool | None = None,
84
+ output_attentions: bool | None = None,
85
+ output_hidden_states: bool | None = None,
86
+ **kwargs,
87
+ ) -> CircuitGPTCausalLMOutput:
88
+ # Ignore HF generation kwargs we don't use; surface any unknowns.
89
+ remaining_kwargs = {k: v for k, v in kwargs.items() if v is not None}
90
+ if remaining_kwargs:
91
+ unsupported = ", ".join(remaining_kwargs.keys())
92
+ raise ValueError(f"Unsupported arguments for CircuitGPTForCausalLM: {unsupported}")
93
+
94
+ if input_ids.size(-1) > self.config.block_size:
95
+ raise ValueError(
96
+ f"Sequence length {input_ids.size(-1)} exceeds block size {self.config.block_size}"
97
+ )
98
+
99
+ if output_activations:
100
+ regex = _activations_regex(output_activations)
101
+ with hook_recorder(regex=regex) as recorded:
102
+ logits, loss, _ = self.circuit_model(input_ids, targets=labels)
103
+ activations = {key: recorded[key] for key in output_activations if key in recorded}
104
+ else:
105
+ activations = None
106
+ logits, loss, _ = self.circuit_model(input_ids, targets=labels)
107
+
108
+ if labels is None:
109
+ loss = None
110
+
111
+ return CircuitGPTCausalLMOutput(
112
+ loss=loss,
113
+ logits=logits,
114
+ activations=activations,
115
+ )
116
+
117
+ # ------------------------------------------------------------------
118
+ # Generation helpers
119
+ # ------------------------------------------------------------------
120
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, **kwargs):
121
+ if input_ids.size(-1) > self.config.block_size:
122
+ input_ids = input_ids[:, -self.config.block_size :]
123
+ return {"input_ids": input_ids}
124
+
125
+ def _reorder_cache(self, past, beam_idx):
126
+ # No KV cache implemented; method exists for interface completeness.
127
+ return past
special_tokens_map.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "eos_token": "<|endoftext|>"
3
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b42d53ab8a66da8fc53288442b1239e373697d630ad61fa66dc35350274617c7
3
+ size 24366
tokenizer_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": "<|endoftext|>",
3
+ "model_max_length": 256,
4
+ "padding_side": "left",
5
+ "tokenizer_class": "PreTrainedTokenizerFast"
6
+ }