Upload folder using huggingface_hub
Browse files- LICENSE.MD +54 -0
- NOTICE.MD +13 -0
- README.MD +46 -0
- __init__.py +11 -0
- config.json +44 -0
- config.py +222 -0
- generation_config.json +5 -0
- gpt.py +900 -0
- hook_utils.py +182 -0
- model.safetensors +3 -0
- modeling_circuitgpt.py +127 -0
- special_tokens_map.json +3 -0
- tokenizer.json +0 -0
- tokenizer.model +3 -0
- tokenizer_config.json +6 -0
LICENSE.MD
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 2 |
+
|
| 3 |
+
1. Definitions.
|
| 4 |
+
“License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
| 5 |
+
|
| 6 |
+
“Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
| 7 |
+
|
| 8 |
+
“Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
| 9 |
+
|
| 10 |
+
“You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
|
| 11 |
+
|
| 12 |
+
“Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
| 13 |
+
|
| 14 |
+
“Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
| 15 |
+
|
| 16 |
+
“Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
| 17 |
+
|
| 18 |
+
“Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
| 19 |
+
|
| 20 |
+
“Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
|
| 21 |
+
|
| 22 |
+
“Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
| 23 |
+
|
| 24 |
+
2. Grant of Copyright License.
|
| 25 |
+
Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
| 26 |
+
|
| 27 |
+
3. Grant of Patent License.
|
| 28 |
+
Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
| 29 |
+
|
| 30 |
+
4. Redistribution.
|
| 31 |
+
You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
| 32 |
+
|
| 33 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
| 34 |
+
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
| 35 |
+
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
| 36 |
+
If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
| 37 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
| 38 |
+
|
| 39 |
+
5. Submission of Contributions.
|
| 40 |
+
Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
| 41 |
+
|
| 42 |
+
6. Trademarks.
|
| 43 |
+
This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
| 44 |
+
|
| 45 |
+
7. Disclaimer of Warranty.
|
| 46 |
+
Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
| 47 |
+
|
| 48 |
+
8. Limitation of Liability.
|
| 49 |
+
In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
| 50 |
+
|
| 51 |
+
9. Accepting Warranty or Additional Liability.
|
| 52 |
+
While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
| 53 |
+
|
| 54 |
+
END OF TERMS AND CONDITIONS
|
NOTICE.MD
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright 2025 OpenAI
|
| 2 |
+
|
| 3 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
you may not use this file except in compliance with the License.
|
| 5 |
+
You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
Distributed under the License is distributed on an “AS IS” BASIS,
|
| 11 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
See the License for the specific language governing permissions and
|
| 13 |
+
Limitations under the License.
|
README.MD
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
## Sparse Model from Gao et al. 2025
|
| 3 |
+
|
| 4 |
+
Weights for a sparse model from Gao et al. 2025, used for the qualitative results from the paper (related to bracket counting and variable binding). All weights for the other models used in the paper, as well as lightweight inference code, are present in https://github.com/openai/circuit_sparsity. In the context of that repo, this model is csp_yolo2.
|
| 5 |
+
|
| 6 |
+
This is a runnable standalone huggingface implementation for one of the models.
|
| 7 |
+
|
| 8 |
+
Some trivial code to load the locally converted HF model + tokenizer and
|
| 9 |
+
run a tiny generation.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 14 |
+
|
| 15 |
+
if __name__ == "__main__":
|
| 16 |
+
PROMPT = "def square_sum(xs):\n return sum(x * x for x in xs)\n\nsquare_sum([1, 2, 3])\n"
|
| 17 |
+
tok = AutoTokenizer.from_pretrained("circuit-sparsity", trust_remote_code=True)
|
| 18 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 19 |
+
"circuit-sparsity",
|
| 20 |
+
trust_remote_code=True,
|
| 21 |
+
torch_dtype="auto",
|
| 22 |
+
)
|
| 23 |
+
model.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
inputs = tok(PROMPT, return_tensors="pt", add_special_tokens=False)["input_ids"].to(
|
| 25 |
+
model.device
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
with torch.no_grad():
|
| 29 |
+
out = model.generate(
|
| 30 |
+
inputs,
|
| 31 |
+
max_new_tokens=64,
|
| 32 |
+
do_sample=True,
|
| 33 |
+
temperature=0.8,
|
| 34 |
+
top_p=0.95,
|
| 35 |
+
return_dict_in_generate=False,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
print("=== Prompt ===")
|
| 39 |
+
print(PROMPT)
|
| 40 |
+
print("\n=== Generation ===")
|
| 41 |
+
print(tok.decode(out[0], skip_special_tokens=True))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
## License
|
| 46 |
+
This project is licensed under the [Apache License 2.0](LICENSE).
|
__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config import CircuitGPTConfig
|
| 2 |
+
from .modeling_circuitgpt import CircuitGPTForCausalLM
|
| 3 |
+
from .tokenizer_simple2k import Simple2KTokenizerFast, export_simple2k_tokenizer
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"CircuitGPTConfig",
|
| 7 |
+
"CircuitGPTForCausalLM",
|
| 8 |
+
"Simple2KTokenizerFast",
|
| 9 |
+
"export_simple2k_tokenizer",
|
| 10 |
+
]
|
| 11 |
+
|
config.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation_type": "gelu",
|
| 3 |
+
"afrac": 0.25,
|
| 4 |
+
"afrac_loctypes": "attn_in,attn_out,mlp_in,mlp_out,mlp_neuron,attn_v,attn_k,attn_q",
|
| 5 |
+
"architectures": [
|
| 6 |
+
"CircuitGPTForCausalLM"
|
| 7 |
+
],
|
| 8 |
+
"auto_map": {
|
| 9 |
+
"AutoConfig": "config.CircuitGPTConfig",
|
| 10 |
+
"AutoModelForCausalLM": "modeling_circuitgpt.CircuitGPTForCausalLM"
|
| 11 |
+
},
|
| 12 |
+
"bias": true,
|
| 13 |
+
"bigram_table_rank": null,
|
| 14 |
+
"block_size": 1024,
|
| 15 |
+
"bos_token_id": null,
|
| 16 |
+
"d_head": 16,
|
| 17 |
+
"d_mlp": 8192,
|
| 18 |
+
"d_model": 2048,
|
| 19 |
+
"d_pos_emb": 32,
|
| 20 |
+
"dropout": 0.0,
|
| 21 |
+
"dropout_cat_pos_emb": false,
|
| 22 |
+
"enable_bigram_table": true,
|
| 23 |
+
"eos_token_id": 2047,
|
| 24 |
+
"flash": true,
|
| 25 |
+
"is_decoder": true,
|
| 26 |
+
"learnable_bigram_table": true,
|
| 27 |
+
"ln_bias": true,
|
| 28 |
+
"max_position_embeddings": 1024,
|
| 29 |
+
"model_type": "circuitgpt",
|
| 30 |
+
"n_head": 128,
|
| 31 |
+
"n_layer": 8,
|
| 32 |
+
"pad_token_id": null,
|
| 33 |
+
"residual_activation_type": "identity",
|
| 34 |
+
"rms_norm": true,
|
| 35 |
+
"sink": true,
|
| 36 |
+
"sinusoidal_cat_pos_emb": false,
|
| 37 |
+
"tie_word_embeddings": false,
|
| 38 |
+
"tied_unembed": false,
|
| 39 |
+
"torch_dtype": "float32",
|
| 40 |
+
"transformers_version": "4.49.0",
|
| 41 |
+
"unembed_rank": null,
|
| 42 |
+
"use_position_embeddings": true,
|
| 43 |
+
"vocab_size": 2048
|
| 44 |
+
}
|
config.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from transformers import PretrainedConfig
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CircuitGPTConfig(PretrainedConfig):
|
| 9 |
+
"""
|
| 10 |
+
Minimal Hugging Face config wrapper around the circuit_sparsity GPTConfig.
|
| 11 |
+
Only the fields exercised by the Neuronpedia runs are exposed.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
model_type = "circuitgpt"
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
vocab_size: int = 2048,
|
| 19 |
+
block_size: int = 256,
|
| 20 |
+
n_layer: int = 8,
|
| 21 |
+
n_head: int = 8,
|
| 22 |
+
d_model: int = 1024,
|
| 23 |
+
d_mlp: int | None = None,
|
| 24 |
+
d_head: int | None = None,
|
| 25 |
+
dropout: float = 0.0,
|
| 26 |
+
bias: bool = True,
|
| 27 |
+
ln_bias: bool = True,
|
| 28 |
+
rms_norm: bool = True,
|
| 29 |
+
activation_type: str = "gelu",
|
| 30 |
+
residual_activation_type: str = "identity",
|
| 31 |
+
tied_unembed: bool = False,
|
| 32 |
+
unembed_rank: int | None = None,
|
| 33 |
+
afrac: float | None = None,
|
| 34 |
+
afrac_loctypes: str = "attn_in,attn_out,mlp_in,mlp_out",
|
| 35 |
+
flash: bool = True,
|
| 36 |
+
use_position_embeddings: bool = False,
|
| 37 |
+
sink: bool = False,
|
| 38 |
+
enable_bigram_table: bool = False,
|
| 39 |
+
learnable_bigram_table: bool = False,
|
| 40 |
+
bigram_table_rank: int | None = None,
|
| 41 |
+
dropout_cat_pos_emb: bool = False,
|
| 42 |
+
sinusoidal_cat_pos_emb: bool = False,
|
| 43 |
+
d_pos_emb: int | None = None,
|
| 44 |
+
auto_map: dict[str, str] | None = None,
|
| 45 |
+
**kwargs: Any,
|
| 46 |
+
) -> None:
|
| 47 |
+
# Drop unsupported/sensitive keys that may be present in a loaded config.
|
| 48 |
+
for key in [
|
| 49 |
+
"afrac_ste",
|
| 50 |
+
"afrac_ste_only_non_neurons",
|
| 51 |
+
"afrac_approx",
|
| 52 |
+
"rtopk",
|
| 53 |
+
"mup",
|
| 54 |
+
"mup_width_multiplier",
|
| 55 |
+
"grad_checkpointing",
|
| 56 |
+
"enable_fp8_linear",
|
| 57 |
+
"scale_invariance",
|
| 58 |
+
"cat_pos_emb",
|
| 59 |
+
]:
|
| 60 |
+
kwargs.pop(key, None)
|
| 61 |
+
d_mlp = d_mlp or 4 * d_model
|
| 62 |
+
d_head = d_head or d_model // n_head
|
| 63 |
+
|
| 64 |
+
# Avoid duplicate kwargs when loading from a config dict.
|
| 65 |
+
bos_token_id = kwargs.pop("bos_token_id", None)
|
| 66 |
+
eos_token_id = kwargs.pop("eos_token_id", vocab_size - 1)
|
| 67 |
+
pad_token_id = kwargs.pop("pad_token_id", None)
|
| 68 |
+
|
| 69 |
+
super().__init__(
|
| 70 |
+
bos_token_id=bos_token_id,
|
| 71 |
+
eos_token_id=eos_token_id,
|
| 72 |
+
pad_token_id=pad_token_id,
|
| 73 |
+
**kwargs,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.vocab_size = vocab_size
|
| 77 |
+
self.block_size = block_size
|
| 78 |
+
self.max_position_embeddings = block_size
|
| 79 |
+
self.n_layer = n_layer
|
| 80 |
+
self.n_head = n_head
|
| 81 |
+
self.d_model = d_model
|
| 82 |
+
self.d_mlp = d_mlp
|
| 83 |
+
self.d_head = d_head
|
| 84 |
+
self.dropout = dropout
|
| 85 |
+
self.bias = bias
|
| 86 |
+
self.ln_bias = ln_bias
|
| 87 |
+
self.rms_norm = rms_norm
|
| 88 |
+
self.activation_type = activation_type
|
| 89 |
+
self.residual_activation_type = residual_activation_type
|
| 90 |
+
self.tied_unembed = tied_unembed
|
| 91 |
+
self.unembed_rank = unembed_rank
|
| 92 |
+
self.afrac = afrac
|
| 93 |
+
self.afrac_loctypes = afrac_loctypes
|
| 94 |
+
self.flash = flash
|
| 95 |
+
self.use_position_embeddings = use_position_embeddings
|
| 96 |
+
self.d_pos_emb = d_pos_emb
|
| 97 |
+
self.sink = sink
|
| 98 |
+
self.enable_bigram_table = enable_bigram_table
|
| 99 |
+
self.learnable_bigram_table = learnable_bigram_table
|
| 100 |
+
self.bigram_table_rank = bigram_table_rank
|
| 101 |
+
self.dropout_cat_pos_emb = dropout_cat_pos_emb
|
| 102 |
+
self.sinusoidal_cat_pos_emb = sinusoidal_cat_pos_emb
|
| 103 |
+
self.is_decoder = True
|
| 104 |
+
# Provide explicit auto_map entries so AutoModel/AutoConfig can locate
|
| 105 |
+
# the custom classes when trust_remote_code=True on the Hub.
|
| 106 |
+
self.auto_map = auto_map or {
|
| 107 |
+
"AutoConfig": "config.CircuitGPTConfig",
|
| 108 |
+
"AutoModelForCausalLM": "modeling_circuitgpt.CircuitGPTForCausalLM",
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
# ---------------------------------------------------------------------
|
| 112 |
+
# Conversion helpers
|
| 113 |
+
# ---------------------------------------------------------------------
|
| 114 |
+
@classmethod
|
| 115 |
+
def from_circuit_config(cls, circuit_config: "GPTConfig") -> "CircuitGPTConfig": # type: ignore[name-defined]
|
| 116 |
+
config_dict: dict[str, Any] = {
|
| 117 |
+
"vocab_size": circuit_config.vocab_size,
|
| 118 |
+
"block_size": circuit_config.block_size,
|
| 119 |
+
"n_layer": circuit_config.n_layer,
|
| 120 |
+
"n_head": circuit_config.n_head,
|
| 121 |
+
"d_model": circuit_config.d_model,
|
| 122 |
+
"d_mlp": circuit_config.d_mlp,
|
| 123 |
+
"d_head": circuit_config.d_head,
|
| 124 |
+
"dropout": circuit_config.dropout,
|
| 125 |
+
"bias": circuit_config.bias,
|
| 126 |
+
"ln_bias": circuit_config.ln_bias,
|
| 127 |
+
"rms_norm": circuit_config.rms_norm,
|
| 128 |
+
"activation_type": circuit_config.activation_type,
|
| 129 |
+
"residual_activation_type": circuit_config.residual_activation_type,
|
| 130 |
+
"tied_unembed": circuit_config.tied_unembed,
|
| 131 |
+
"unembed_rank": circuit_config.unembed_rank,
|
| 132 |
+
"afrac": circuit_config.afrac,
|
| 133 |
+
"afrac_loctypes": circuit_config.afrac_loctypes,
|
| 134 |
+
"flash": circuit_config.flash,
|
| 135 |
+
"use_position_embeddings": circuit_config.d_pos_emb is not None,
|
| 136 |
+
"d_pos_emb": getattr(circuit_config, "d_pos_emb", None),
|
| 137 |
+
"sink": getattr(circuit_config, "sink", False),
|
| 138 |
+
"enable_bigram_table": getattr(circuit_config, "enable_bigram_table", False),
|
| 139 |
+
"learnable_bigram_table": getattr(circuit_config, "learnable_bigram_table", False),
|
| 140 |
+
"bigram_table_rank": getattr(circuit_config, "bigram_table_rank", None),
|
| 141 |
+
"dropout_cat_pos_emb": getattr(circuit_config, "dropout_cat_pos_emb", False),
|
| 142 |
+
"sinusoidal_cat_pos_emb": getattr(circuit_config, "sinusoidal_cat_pos_emb", False),
|
| 143 |
+
}
|
| 144 |
+
return cls(**config_dict)
|
| 145 |
+
|
| 146 |
+
def to_circuit_config(self) -> "GPTConfig": # type: ignore[name-defined]
|
| 147 |
+
from circuit_sparsity.gpt import GPTConfig as CircuitConfig
|
| 148 |
+
|
| 149 |
+
config_kwargs: dict[str, Any] = dict(
|
| 150 |
+
vocab_size=self.vocab_size,
|
| 151 |
+
block_size=self.block_size,
|
| 152 |
+
n_layer=self.n_layer,
|
| 153 |
+
n_head=self.n_head,
|
| 154 |
+
d_model=self.d_model,
|
| 155 |
+
dropout=self.dropout,
|
| 156 |
+
bias=self.bias,
|
| 157 |
+
ln_bias=self.ln_bias,
|
| 158 |
+
rms_norm=self.rms_norm,
|
| 159 |
+
activation_type=self.activation_type,
|
| 160 |
+
residual_activation_type=self.residual_activation_type,
|
| 161 |
+
tied_unembed=self.tied_unembed,
|
| 162 |
+
unembed_rank=self.unembed_rank,
|
| 163 |
+
afrac=self.afrac,
|
| 164 |
+
afrac_loctypes=self.afrac_loctypes,
|
| 165 |
+
flash=self.flash,
|
| 166 |
+
afrac_ste=False,
|
| 167 |
+
afrac_ste_only_non_neurons=False,
|
| 168 |
+
afrac_approx=False,
|
| 169 |
+
rtopk=False,
|
| 170 |
+
mup=False,
|
| 171 |
+
mup_width_multiplier=None,
|
| 172 |
+
grad_checkpointing=False,
|
| 173 |
+
enable_fp8_linear=False,
|
| 174 |
+
scale_invariance=False,
|
| 175 |
+
d_mlp=self.d_mlp,
|
| 176 |
+
d_head=self.d_head,
|
| 177 |
+
enable_sparse_kernels=False,
|
| 178 |
+
enable_bigram_table=self.enable_bigram_table,
|
| 179 |
+
learnable_bigram_table=self.learnable_bigram_table,
|
| 180 |
+
bigram_table_rank=self.bigram_table_rank,
|
| 181 |
+
d_pos_emb=self.d_pos_emb
|
| 182 |
+
if self.d_pos_emb is not None
|
| 183 |
+
else (self.d_model if self.use_position_embeddings else None),
|
| 184 |
+
sink=self.sink,
|
| 185 |
+
dropout_cat_pos_emb=self.dropout_cat_pos_emb,
|
| 186 |
+
sinusoidal_cat_pos_emb=self.sinusoidal_cat_pos_emb,
|
| 187 |
+
)
|
| 188 |
+
return CircuitConfig(**config_kwargs)
|
| 189 |
+
|
| 190 |
+
def to_dict(self) -> dict[str, Any]:
|
| 191 |
+
base = super().to_dict()
|
| 192 |
+
data = {
|
| 193 |
+
"vocab_size": self.vocab_size,
|
| 194 |
+
"block_size": self.block_size,
|
| 195 |
+
"n_layer": self.n_layer,
|
| 196 |
+
"n_head": self.n_head,
|
| 197 |
+
"d_model": self.d_model,
|
| 198 |
+
"d_mlp": self.d_mlp,
|
| 199 |
+
"d_head": self.d_head,
|
| 200 |
+
"dropout": self.dropout,
|
| 201 |
+
"bias": self.bias,
|
| 202 |
+
"ln_bias": self.ln_bias,
|
| 203 |
+
"rms_norm": self.rms_norm,
|
| 204 |
+
"activation_type": self.activation_type,
|
| 205 |
+
"residual_activation_type": self.residual_activation_type,
|
| 206 |
+
"tied_unembed": self.tied_unembed,
|
| 207 |
+
"unembed_rank": self.unembed_rank,
|
| 208 |
+
"flash": self.flash,
|
| 209 |
+
"afrac": self.afrac,
|
| 210 |
+
"afrac_loctypes": self.afrac_loctypes,
|
| 211 |
+
"use_position_embeddings": self.use_position_embeddings,
|
| 212 |
+
"d_pos_emb": self.d_pos_emb,
|
| 213 |
+
"sink": self.sink,
|
| 214 |
+
"enable_bigram_table": self.enable_bigram_table,
|
| 215 |
+
"learnable_bigram_table": self.learnable_bigram_table,
|
| 216 |
+
"bigram_table_rank": self.bigram_table_rank,
|
| 217 |
+
"dropout_cat_pos_emb": self.dropout_cat_pos_emb,
|
| 218 |
+
"sinusoidal_cat_pos_emb": self.sinusoidal_cat_pos_emb,
|
| 219 |
+
"auto_map": self.auto_map,
|
| 220 |
+
}
|
| 221 |
+
base.update(data)
|
| 222 |
+
return base
|
generation_config.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"eos_token_id": 2047,
|
| 4 |
+
"transformers_version": "4.49.0"
|
| 5 |
+
}
|
gpt.py
ADDED
|
@@ -0,0 +1,900 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Full definition of a GPT Language Model, all of it in this single file.
|
| 3 |
+
References:
|
| 4 |
+
1) the official GPT-2 TensorFlow implementation released by OpenAI:
|
| 5 |
+
https://github.com/openai/gpt-2/blob/master/src/model.py
|
| 6 |
+
2) huggingface/transformers PyTorch implementation:
|
| 7 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import Literal
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
# has to be down here to avoid loading cuda too early
|
| 19 |
+
from .hook_utils import (
|
| 20 |
+
hook_namespace,
|
| 21 |
+
hook_save,
|
| 22 |
+
torch_recompute_preserving_hook_context,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def sample_top_k(*, n: int, k: int, shape: tuple[int, ...]):
|
| 27 |
+
"""Fallback sampler used only when sparse kernels are enabled."""
|
| 28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
return torch.randn(shape, device=device, dtype=torch.float32)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AbsTopK(nn.Module):
|
| 33 |
+
def __init__(self, k):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.k = k
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
vals, inds = torch.topk(x.abs(), self.k, dim=-1, sorted=False)
|
| 39 |
+
ret = torch.zeros_like(x)
|
| 40 |
+
ret.scatter_(-1, inds, x.gather(-1, inds))
|
| 41 |
+
return ret
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def barrier():
|
| 45 |
+
# stub
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class LayerNorm(nn.Module):
|
| 50 |
+
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, ndim, bias):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
| 55 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
| 56 |
+
|
| 57 |
+
def forward(self, input):
|
| 58 |
+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class CausalSelfAttention(nn.Module):
|
| 62 |
+
def __init__(self, config):
|
| 63 |
+
super().__init__()
|
| 64 |
+
assert config.d_model % config.n_head == 0
|
| 65 |
+
# key, query, value projections for all heads, but in a batch
|
| 66 |
+
self.c_attn = config.Linear(
|
| 67 |
+
config.d_model, 3 * config.d_head * config.n_head, bias=config.bias
|
| 68 |
+
)
|
| 69 |
+
# output projection
|
| 70 |
+
self.c_proj = config.Linear(config.d_head * config.n_head, config.d_model, bias=config.bias)
|
| 71 |
+
# regularization
|
| 72 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 73 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 74 |
+
self.n_head = config.n_head
|
| 75 |
+
self.d_head = config.d_head
|
| 76 |
+
self.d_model = config.d_model
|
| 77 |
+
self.dropout = config.dropout
|
| 78 |
+
|
| 79 |
+
self.config = config
|
| 80 |
+
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
|
| 81 |
+
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and config.flash
|
| 82 |
+
|
| 83 |
+
if self.flash:
|
| 84 |
+
self.attn_imp = (
|
| 85 |
+
SDPAWithSink(config.n_head) if config.sink else F.scaled_dot_product_attention
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if not self.flash:
|
| 89 |
+
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
| 90 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
| 91 |
+
self.register_buffer(
|
| 92 |
+
"bias",
|
| 93 |
+
torch.tril(torch.ones(config.block_size, config.block_size)).view(
|
| 94 |
+
1, 1, config.block_size, config.block_size
|
| 95 |
+
),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (d_model)
|
| 100 |
+
|
| 101 |
+
x = self.config.maybe_activation_sparsity(x, "attn_in")
|
| 102 |
+
x = hook_save("act_in", x)
|
| 103 |
+
|
| 104 |
+
if self.config.debug_nans:
|
| 105 |
+
assert x.isfinite().all(), "nan in input"
|
| 106 |
+
|
| 107 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
| 108 |
+
q, k, v = self.c_attn(x).split(self.n_head * self.d_head, dim=2)
|
| 109 |
+
|
| 110 |
+
k = self.config.maybe_activation_sparsity(k, "attn_k")
|
| 111 |
+
q = self.config.maybe_activation_sparsity(q, "attn_q")
|
| 112 |
+
v = self.config.maybe_activation_sparsity(v, "attn_v")
|
| 113 |
+
|
| 114 |
+
k = hook_save("k", k) # (B, T, n_head * d_head)
|
| 115 |
+
q = hook_save("q", q) # (B, T, n_head * d_head)
|
| 116 |
+
v = hook_save("v", v) # (B, T, n_head * d_head)
|
| 117 |
+
|
| 118 |
+
k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2) # (B, nh, T, hs)
|
| 119 |
+
q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2) # (B, nh, T, hs)
|
| 120 |
+
v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2) # (B, nh, T, hs)
|
| 121 |
+
|
| 122 |
+
if self.config.debug_nans:
|
| 123 |
+
assert q.isfinite().all(), "nan in query"
|
| 124 |
+
assert k.isfinite().all(), "nan in key"
|
| 125 |
+
assert v.isfinite().all(), "nan in value"
|
| 126 |
+
|
| 127 |
+
attention_scale = 1.0 / math.sqrt(k.size(-1))
|
| 128 |
+
|
| 129 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
| 130 |
+
if self.flash:
|
| 131 |
+
# efficient attention using Flash Attention CUDA kernels
|
| 132 |
+
y = self.attn_imp(
|
| 133 |
+
q,
|
| 134 |
+
k,
|
| 135 |
+
v,
|
| 136 |
+
dropout_p=self.dropout if self.training else 0,
|
| 137 |
+
is_causal=True,
|
| 138 |
+
scale=attention_scale,
|
| 139 |
+
)
|
| 140 |
+
else:
|
| 141 |
+
# manual implementation of attention
|
| 142 |
+
att = (q @ k.transpose(-2, -1)) * attention_scale
|
| 143 |
+
att = att.masked_fill(
|
| 144 |
+
self.bias[:, :, :T, :T] == 0, torch.finfo(att.dtype).min
|
| 145 |
+
) # float("-inf"))
|
| 146 |
+
|
| 147 |
+
att = F.softmax(att, dim=-1)
|
| 148 |
+
att = self.attn_dropout(att)
|
| 149 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
| 150 |
+
|
| 151 |
+
if self.config.debug_nans:
|
| 152 |
+
assert y.isfinite().all(), "nan in attention output"
|
| 153 |
+
|
| 154 |
+
y = (
|
| 155 |
+
y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.d_head)
|
| 156 |
+
) # re-assemble all head outputs side by side
|
| 157 |
+
|
| 158 |
+
# y = self.config.maybe_activation_sparsity(y)
|
| 159 |
+
y = hook_save("y", y) # (B, T, n_head * d_head)
|
| 160 |
+
|
| 161 |
+
# output projection
|
| 162 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 163 |
+
|
| 164 |
+
if self.config.debug_nans:
|
| 165 |
+
assert y.isfinite().all(), "nan in attention output 2"
|
| 166 |
+
|
| 167 |
+
y = self.config.maybe_activation_sparsity(y, "attn_out")
|
| 168 |
+
return y
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class MLP(nn.Module):
|
| 172 |
+
def __init__(self, config):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.config = config
|
| 175 |
+
self.c_fc = config.Linear(config.d_model, config.d_mlp, bias=config.bias)
|
| 176 |
+
self.act_fn = {
|
| 177 |
+
"gelu": nn.GELU(),
|
| 178 |
+
"relu": nn.ReLU(),
|
| 179 |
+
}[config.activation_type]
|
| 180 |
+
self.c_proj = config.Linear(config.d_mlp, config.d_model, bias=config.bias)
|
| 181 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 182 |
+
|
| 183 |
+
def forward(self, x):
|
| 184 |
+
x = self.config.maybe_activation_sparsity(x, "mlp_in")
|
| 185 |
+
x = hook_save("act_in", x)
|
| 186 |
+
|
| 187 |
+
if self.config.debug_nans:
|
| 188 |
+
assert x.isfinite().all(), "nan in mlp input"
|
| 189 |
+
|
| 190 |
+
x = self.c_fc(x)
|
| 191 |
+
|
| 192 |
+
if self.config.debug_nans:
|
| 193 |
+
assert x.isfinite().all(), "nan in mlp after c_fc"
|
| 194 |
+
|
| 195 |
+
x = self.act_fn(x)
|
| 196 |
+
x = self.config.maybe_activation_sparsity(x, "mlp_neuron")
|
| 197 |
+
x = hook_save("post_act", x)
|
| 198 |
+
|
| 199 |
+
if self.config.debug_nans:
|
| 200 |
+
assert x.isfinite().all(), "nan in mlp after act"
|
| 201 |
+
|
| 202 |
+
x = self.c_proj(x)
|
| 203 |
+
|
| 204 |
+
if self.config.debug_nans:
|
| 205 |
+
assert x.isfinite().all(), "nan in mlp after c_proj"
|
| 206 |
+
x = self.dropout(x)
|
| 207 |
+
|
| 208 |
+
x = self.config.maybe_activation_sparsity(x, "mlp_out")
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class SDPAWithSink(nn.Module):
|
| 213 |
+
"""
|
| 214 |
+
Adds a learnable denominator-only term ("attention sink") to SDPA by
|
| 215 |
+
concatenating a dummy KV slot whose logit is b and whose V is zero.
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
def __init__(self, num_heads: int, init_logit: float = 0.0):
|
| 219 |
+
super().__init__()
|
| 220 |
+
shape = (num_heads,)
|
| 221 |
+
self.sink_logit = nn.Parameter(torch.full(shape, init_logit))
|
| 222 |
+
|
| 223 |
+
def forward(
|
| 224 |
+
self,
|
| 225 |
+
q: torch.Tensor, # (B, H, Lq, D)
|
| 226 |
+
k: torch.Tensor, # (B, H, Lk, D)
|
| 227 |
+
v: torch.Tensor, # (B, H, Lk, Dv)
|
| 228 |
+
*,
|
| 229 |
+
dropout_p: float = 0.0,
|
| 230 |
+
is_causal: bool = False,
|
| 231 |
+
scale: float | None = None,
|
| 232 |
+
) -> torch.Tensor:
|
| 233 |
+
B, H, Lq, D = q.shape
|
| 234 |
+
_, _, Lk, _ = k.shape
|
| 235 |
+
Dv = v.size(-1)
|
| 236 |
+
|
| 237 |
+
# 1) Prepend a dummy KV slot (always visible)
|
| 238 |
+
k_sink = torch.zeros((B, H, 1, D), dtype=q.dtype, device=q.device)
|
| 239 |
+
v_sink = torch.zeros((B, H, 1, Dv), dtype=v.dtype, device=v.device)
|
| 240 |
+
k_aug = torch.cat([k_sink, k], dim=2) # (B,H,Lk+1,D)
|
| 241 |
+
v_aug = torch.cat([v_sink, v], dim=2) # (B,H,Lk+1,Dv)
|
| 242 |
+
|
| 243 |
+
# 2) Build shifted causal allow-mask over keys (columns 1..), always allow col 0 (sink)
|
| 244 |
+
# allow: 1 where attending is allowed, 0 where disallowed
|
| 245 |
+
# For real keys: allow[i, j+1] = 1 if j <= i else 0 (lower-triangular)
|
| 246 |
+
allow = torch.zeros((Lq, Lk + 1), dtype=torch.bool, device=q.device)
|
| 247 |
+
allow[:, 0] = True # sink column always on
|
| 248 |
+
# lower-triangular for real keys shifted by +1
|
| 249 |
+
real = torch.ones((Lq, Lk), dtype=torch.bool, device=q.device).tril()
|
| 250 |
+
allow[:, 1:] = real
|
| 251 |
+
|
| 252 |
+
# Broadcast to (B,H,Lq,Lk+1)
|
| 253 |
+
allow = allow.view(1, 1, Lq, Lk + 1).expand(B, H, Lq, Lk + 1)
|
| 254 |
+
|
| 255 |
+
# 3) Turn it into an additive mask. 0 for allowed, -inf for disallowed
|
| 256 |
+
neg_inf = torch.finfo(q.dtype).min
|
| 257 |
+
base_mask = torch.where(
|
| 258 |
+
allow,
|
| 259 |
+
torch.zeros((), dtype=q.dtype, device=q.device),
|
| 260 |
+
torch.full((), neg_inf, dtype=q.dtype, device=q.device),
|
| 261 |
+
) # (B,H,Lq,Lk+1)
|
| 262 |
+
|
| 263 |
+
# 4) Add learnable sink bias b to column 0 (per head or shared)
|
| 264 |
+
if self.sink_logit.numel() == H:
|
| 265 |
+
b = self.sink_logit.to(dtype=q.dtype, device=q.device).view(1, H, 1, 1) # (1,H,1,1)
|
| 266 |
+
else:
|
| 267 |
+
b = self.sink_logit.to(dtype=q.dtype, device=q.device).view(1, 1, 1, 1) # (1,1,1,1)
|
| 268 |
+
|
| 269 |
+
sink_bias_mask = torch.zeros((1, 1, 1, Lk + 1), dtype=q.dtype, device=q.device)
|
| 270 |
+
sink_bias_mask[..., 0] = 1.0
|
| 271 |
+
attn_mask = base_mask + sink_bias_mask * b # (B,H,Lq,Lk+1)
|
| 272 |
+
|
| 273 |
+
# 5) SDPA with our custom mask; keep is_causal=False to avoid double-masking
|
| 274 |
+
out = F.scaled_dot_product_attention(
|
| 275 |
+
q,
|
| 276 |
+
k_aug,
|
| 277 |
+
v_aug,
|
| 278 |
+
attn_mask=attn_mask,
|
| 279 |
+
dropout_p=dropout_p,
|
| 280 |
+
is_causal=False, # important
|
| 281 |
+
scale=scale,
|
| 282 |
+
)
|
| 283 |
+
return out
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class Block(nn.Module):
|
| 287 |
+
# block exactly satisfies the invariant that forward = forward_mlp_block . forward_attn_block
|
| 288 |
+
def __init__(self, config):
|
| 289 |
+
super().__init__()
|
| 290 |
+
self.config = config
|
| 291 |
+
|
| 292 |
+
self.ln_1 = (
|
| 293 |
+
nn.RMSNorm(config.d_model)
|
| 294 |
+
if config.rms_norm
|
| 295 |
+
else LayerNorm(config.d_model, bias=config.ln_bias)
|
| 296 |
+
)
|
| 297 |
+
self.attn = CausalSelfAttention(config)
|
| 298 |
+
self.ln_2 = (
|
| 299 |
+
nn.RMSNorm(config.d_model)
|
| 300 |
+
if config.rms_norm
|
| 301 |
+
else LayerNorm(config.d_model, bias=config.ln_bias)
|
| 302 |
+
)
|
| 303 |
+
self.mlp = MLP(config)
|
| 304 |
+
|
| 305 |
+
def forward_attn_block(self, x):
|
| 306 |
+
x = hook_save("resid_in", x)
|
| 307 |
+
|
| 308 |
+
if self.config.debug_nans:
|
| 309 |
+
assert x.isfinite().all(), "nan in blk input"
|
| 310 |
+
|
| 311 |
+
with hook_namespace("attn"):
|
| 312 |
+
if self.config.grad_checkpointing:
|
| 313 |
+
x = x + hook_save(
|
| 314 |
+
"resid_delta",
|
| 315 |
+
torch_recompute_preserving_hook_context(
|
| 316 |
+
lambda x: self.attn(self.ln_1(x)), x, use_reentrant=False
|
| 317 |
+
),
|
| 318 |
+
)
|
| 319 |
+
else:
|
| 320 |
+
x = x + hook_save("resid_delta", self.attn(self.ln_1(x)))
|
| 321 |
+
|
| 322 |
+
if self.config.residual_activation_type == "relu":
|
| 323 |
+
x = torch.relu(x)
|
| 324 |
+
x = self.config.maybe_activation_sparsity(x, "resid_post_attn")
|
| 325 |
+
|
| 326 |
+
return x
|
| 327 |
+
|
| 328 |
+
def forward_mlp_block(self, x):
|
| 329 |
+
x = hook_save("resid_mid", x)
|
| 330 |
+
with hook_namespace("mlp"):
|
| 331 |
+
if self.config.grad_checkpointing:
|
| 332 |
+
x = x + hook_save(
|
| 333 |
+
"resid_delta",
|
| 334 |
+
torch_recompute_preserving_hook_context(
|
| 335 |
+
lambda x: self.mlp(self.ln_2(x)), x, use_reentrant=False
|
| 336 |
+
),
|
| 337 |
+
)
|
| 338 |
+
else:
|
| 339 |
+
x = x + hook_save("resid_delta", self.mlp(self.ln_2(x)))
|
| 340 |
+
|
| 341 |
+
if self.config.residual_activation_type == "relu":
|
| 342 |
+
x = torch.relu(x)
|
| 343 |
+
x = self.config.maybe_activation_sparsity(x, "resid_post_mlp")
|
| 344 |
+
return x
|
| 345 |
+
|
| 346 |
+
def forward(self, x):
|
| 347 |
+
x = self.forward_attn_block(x)
|
| 348 |
+
x = self.forward_mlp_block(x)
|
| 349 |
+
return x
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class CausalSelfAttentionCatPosEmb(CausalSelfAttention):
|
| 353 |
+
def __init__(self, config):
|
| 354 |
+
# initialize base attention with standard shapes, we'll override projections
|
| 355 |
+
super().__init__(config)
|
| 356 |
+
assert config.d_model % config.n_head == 0
|
| 357 |
+
# key, query, value projections for all heads, but in a batch
|
| 358 |
+
self.c_attn = config.Linear(
|
| 359 |
+
config.d_model_in, 3 * config.d_head * config.n_head, bias=config.bias
|
| 360 |
+
)
|
| 361 |
+
# output projection
|
| 362 |
+
self.c_proj = config.Linear(config.d_head * config.n_head, config.d_model, bias=config.bias)
|
| 363 |
+
# regularization
|
| 364 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 365 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 366 |
+
self.n_head = config.n_head
|
| 367 |
+
self.d_head = config.d_head
|
| 368 |
+
self.d_model_in = config.d_model_in
|
| 369 |
+
self.d_model = config.d_model
|
| 370 |
+
self.dropout = config.dropout
|
| 371 |
+
self.config = config
|
| 372 |
+
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
|
| 373 |
+
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and config.flash
|
| 374 |
+
|
| 375 |
+
if not self.flash:
|
| 376 |
+
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
| 377 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
| 378 |
+
self.register_buffer(
|
| 379 |
+
"bias",
|
| 380 |
+
torch.tril(torch.ones(config.block_size, config.block_size)).view(
|
| 381 |
+
1, 1, config.block_size, config.block_size
|
| 382 |
+
),
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
def forward(self, x, pos_emb_to_cat):
|
| 386 |
+
# Broadcast pos emb over batch if provided as shape [1, T, C]
|
| 387 |
+
if pos_emb_to_cat is not None and pos_emb_to_cat.size(0) == 1 and x.size(0) != 1:
|
| 388 |
+
pos_emb_to_cat = pos_emb_to_cat.expand(x.size(0), -1, -1)
|
| 389 |
+
x = torch.cat([x, pos_emb_to_cat], dim=-1)
|
| 390 |
+
return super().forward(x)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class MLPCatPosEmb(MLP):
|
| 394 |
+
def __init__(self, config):
|
| 395 |
+
# initialize base MLP, we'll override the projections to match cat shapes
|
| 396 |
+
super().__init__(config)
|
| 397 |
+
self.config = config
|
| 398 |
+
self.c_fc = config.Linear(config.d_model_in, config.d_mlp, bias=config.bias)
|
| 399 |
+
self.act_fn = {
|
| 400 |
+
"gelu": nn.GELU(),
|
| 401 |
+
"relu": nn.ReLU(),
|
| 402 |
+
}[config.activation_type]
|
| 403 |
+
self.c_proj = config.Linear(config.d_mlp, config.d_model, bias=config.bias)
|
| 404 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 405 |
+
|
| 406 |
+
def forward(self, x, pos_emb_to_cat):
|
| 407 |
+
# Broadcast pos emb over batch if provided as shape [1, T, C]
|
| 408 |
+
if pos_emb_to_cat is not None and pos_emb_to_cat.size(0) == 1 and x.size(0) != 1:
|
| 409 |
+
pos_emb_to_cat = pos_emb_to_cat.expand(x.size(0), -1, -1)
|
| 410 |
+
x = torch.cat([x, pos_emb_to_cat], dim=-1)
|
| 411 |
+
x = super().forward(x)
|
| 412 |
+
return x
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class BlockCatPosEmb(Block):
|
| 416 |
+
# block exactly satisfies the invariant that forward = forward_mlp_block . forward_attn_block
|
| 417 |
+
def __init__(self, config):
|
| 418 |
+
# initialize base Block to get ln_1/ln_2 and other invariants
|
| 419 |
+
super().__init__(config)
|
| 420 |
+
self.ln_p1 = (
|
| 421 |
+
nn.RMSNorm(config.d_pos_emb)
|
| 422 |
+
if config.rms_norm
|
| 423 |
+
else LayerNorm(config.d_pos_emb, bias=config.ln_bias)
|
| 424 |
+
)
|
| 425 |
+
self.ln_p2 = (
|
| 426 |
+
nn.RMSNorm(config.d_pos_emb)
|
| 427 |
+
if config.rms_norm
|
| 428 |
+
else LayerNorm(config.d_pos_emb, bias=config.ln_bias)
|
| 429 |
+
)
|
| 430 |
+
self.attn = CausalSelfAttentionCatPosEmb(config)
|
| 431 |
+
self.mlp = MLPCatPosEmb(config)
|
| 432 |
+
|
| 433 |
+
def forward_attn_block(self, x, p):
|
| 434 |
+
x = hook_save("resid_in", x)
|
| 435 |
+
|
| 436 |
+
if self.config.debug_nans:
|
| 437 |
+
assert x.isfinite().all(), "nan in blk input"
|
| 438 |
+
|
| 439 |
+
with hook_namespace("attn"):
|
| 440 |
+
if self.config.grad_checkpointing:
|
| 441 |
+
x = x + hook_save(
|
| 442 |
+
"resid_delta",
|
| 443 |
+
torch_recompute_preserving_hook_context(
|
| 444 |
+
lambda x, p: self.attn(self.ln_1(x), self.ln_p1(p)),
|
| 445 |
+
x,
|
| 446 |
+
p,
|
| 447 |
+
use_reentrant=False,
|
| 448 |
+
),
|
| 449 |
+
)
|
| 450 |
+
else:
|
| 451 |
+
x = x + hook_save("resid_delta", self.attn(self.ln_1(x), self.ln_p1(p)))
|
| 452 |
+
|
| 453 |
+
if self.config.residual_activation_type == "relu":
|
| 454 |
+
x = torch.relu(x)
|
| 455 |
+
x = self.config.maybe_activation_sparsity(x, "resid_post_attn")
|
| 456 |
+
|
| 457 |
+
return x
|
| 458 |
+
|
| 459 |
+
def forward_mlp_block(self, x, p):
|
| 460 |
+
x = hook_save("resid_mid", x)
|
| 461 |
+
with hook_namespace("mlp"):
|
| 462 |
+
if self.config.grad_checkpointing:
|
| 463 |
+
x = x + hook_save(
|
| 464 |
+
"resid_delta",
|
| 465 |
+
torch_recompute_preserving_hook_context(
|
| 466 |
+
lambda x, p: self.mlp(self.ln_2(x), self.ln_p2(p)),
|
| 467 |
+
x,
|
| 468 |
+
p,
|
| 469 |
+
use_reentrant=False,
|
| 470 |
+
),
|
| 471 |
+
)
|
| 472 |
+
else:
|
| 473 |
+
x = x + hook_save("resid_delta", self.mlp(self.ln_2(x), self.ln_p2(p)))
|
| 474 |
+
|
| 475 |
+
if self.config.residual_activation_type == "relu":
|
| 476 |
+
x = torch.relu(x)
|
| 477 |
+
x = self.config.maybe_activation_sparsity(x, "resid_post_mlp")
|
| 478 |
+
return x
|
| 479 |
+
|
| 480 |
+
def forward(self, x, pos_emb_to_cat):
|
| 481 |
+
x = self.forward_attn_block(x, pos_emb_to_cat)
|
| 482 |
+
x = self.forward_mlp_block(x, pos_emb_to_cat)
|
| 483 |
+
return x
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
@dataclass
|
| 487 |
+
class GPTConfig:
|
| 488 |
+
block_size: int = 1024
|
| 489 |
+
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency # TODO: FLAG FOR ACHY
|
| 490 |
+
n_layer: int = 12
|
| 491 |
+
n_head: int = 12
|
| 492 |
+
d_head: int | None = None # defaults to d_model // n_head
|
| 493 |
+
d_model: int = 768
|
| 494 |
+
dropout: float = 0.0
|
| 495 |
+
bias: bool = (
|
| 496 |
+
True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
| 497 |
+
)
|
| 498 |
+
ln_bias: bool = (
|
| 499 |
+
True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
| 500 |
+
)
|
| 501 |
+
rms_norm: bool = False # use RMSNorm instead of LayerNorm
|
| 502 |
+
residual_activation_type: Literal["identity", "relu"] = "identity"
|
| 503 |
+
activation_type: Literal["gelu", "relu"] = "gelu"
|
| 504 |
+
afrac: float | None = None # fraction of activations to keep
|
| 505 |
+
afrac_loctypes: str = "attn_in,attn_out,mlp_in,mlp_out"
|
| 506 |
+
debug_nans: bool = False
|
| 507 |
+
tied_unembed: bool = True
|
| 508 |
+
|
| 509 |
+
tokenizer_name: str = "tinypython_2k"
|
| 510 |
+
|
| 511 |
+
grad_checkpointing: bool = True
|
| 512 |
+
d_mlp: int | None = None # multiplier for MLP hidden layer size
|
| 513 |
+
|
| 514 |
+
enable_bigram_table: bool = False
|
| 515 |
+
learnable_bigram_table: bool = False
|
| 516 |
+
d_pos_emb: int | None = None
|
| 517 |
+
dropout_cat_pos_emb: bool = False
|
| 518 |
+
sinusoidal_cat_pos_emb: bool = False
|
| 519 |
+
enable_sparse_kernels: bool = False
|
| 520 |
+
|
| 521 |
+
flash: bool = True
|
| 522 |
+
sink: bool = False
|
| 523 |
+
|
| 524 |
+
@property
|
| 525 |
+
def cat_pos_emb(self):
|
| 526 |
+
return self.d_pos_emb is not None
|
| 527 |
+
|
| 528 |
+
@property
|
| 529 |
+
def d_model_in(self):
|
| 530 |
+
return self.d_model + self.d_pos_emb if self.cat_pos_emb else self.d_model
|
| 531 |
+
|
| 532 |
+
def __post_init__(self):
|
| 533 |
+
assert self.d_model % self.n_head == 0
|
| 534 |
+
assert self.residual_activation_type in ["identity", "relu"]
|
| 535 |
+
assert self.activation_type in ["gelu", "relu"]
|
| 536 |
+
|
| 537 |
+
if self.d_mlp is None:
|
| 538 |
+
self.d_mlp = 4 * self.d_model
|
| 539 |
+
if self.d_head is None:
|
| 540 |
+
self.d_head = self.d_model // self.n_head
|
| 541 |
+
|
| 542 |
+
@property
|
| 543 |
+
def Linear(self):
|
| 544 |
+
return nn.Linear
|
| 545 |
+
|
| 546 |
+
def maybe_activation_sparsity(self, x, loctype):
|
| 547 |
+
if self.afrac is not None and loctype in self.afrac_loctypes.split(","):
|
| 548 |
+
|
| 549 |
+
def keep_abstopk(x, k):
|
| 550 |
+
ret = torch.zeros_like(x)
|
| 551 |
+
_, topk_inds = torch.topk(x.abs(), k, dim=-1, sorted=False)
|
| 552 |
+
ret.scatter_(-1, topk_inds, x.gather(-1, topk_inds))
|
| 553 |
+
return ret
|
| 554 |
+
|
| 555 |
+
x = keep_abstopk(
|
| 556 |
+
x,
|
| 557 |
+
k=int(self.afrac * x.shape[-1]),
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
return x
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class GPT(nn.Module):
|
| 564 |
+
def __init__(self, config):
|
| 565 |
+
super().__init__()
|
| 566 |
+
assert config.vocab_size is not None
|
| 567 |
+
assert config.block_size is not None
|
| 568 |
+
self.config = config
|
| 569 |
+
|
| 570 |
+
if config.cat_pos_emb:
|
| 571 |
+
block_cls = BlockCatPosEmb
|
| 572 |
+
else:
|
| 573 |
+
block_cls = Block
|
| 574 |
+
|
| 575 |
+
self.transformer = nn.ModuleDict(
|
| 576 |
+
dict(
|
| 577 |
+
wte=nn.Embedding(config.vocab_size, config.d_model),
|
| 578 |
+
wpe=nn.Embedding(config.block_size, config.d_pos_emb or config.d_model),
|
| 579 |
+
drop=nn.Dropout(config.dropout),
|
| 580 |
+
h=nn.ModuleList([(block_cls(config)) for _ in range(config.n_layer)]),
|
| 581 |
+
ln_f=nn.RMSNorm(config.d_model)
|
| 582 |
+
if config.rms_norm
|
| 583 |
+
else LayerNorm(config.d_model, bias=config.ln_bias),
|
| 584 |
+
)
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 588 |
+
|
| 589 |
+
self.register_buffer(
|
| 590 |
+
"final_logits_bias", torch.zeros(config.vocab_size, dtype=torch.float32)
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
if self.config.enable_bigram_table:
|
| 594 |
+
if self.config.learnable_bigram_table:
|
| 595 |
+
# HACK: low rank to fit in mem
|
| 596 |
+
self.bigram_table = nn.Parameter(
|
| 597 |
+
torch.zeros(config.vocab_size, config.vocab_size, dtype=torch.float32)
|
| 598 |
+
)
|
| 599 |
+
else:
|
| 600 |
+
self.register_buffer(
|
| 601 |
+
"bigram_table",
|
| 602 |
+
torch.zeros(config.vocab_size, config.vocab_size, dtype=torch.float32),
|
| 603 |
+
)
|
| 604 |
+
else:
|
| 605 |
+
self.bigram_table = None
|
| 606 |
+
|
| 607 |
+
# Never tie embeddings/unembed to avoid accidental aliasing in exports.
|
| 608 |
+
config.tied_unembed = False
|
| 609 |
+
|
| 610 |
+
# init all weights
|
| 611 |
+
self.apply(self._init_weights)
|
| 612 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
| 613 |
+
for pn, p in self.named_parameters():
|
| 614 |
+
if pn.endswith("c_proj.weight"):
|
| 615 |
+
if p.is_sparse:
|
| 616 |
+
num_nonzero = p._nnz()
|
| 617 |
+
p._values().data = (
|
| 618 |
+
sample_top_k(n=p.numel(), k=num_nonzero, shape=(num_nonzero,))
|
| 619 |
+
* 0.02
|
| 620 |
+
/ math.sqrt(2 * config.n_layer)
|
| 621 |
+
)
|
| 622 |
+
else:
|
| 623 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
|
| 624 |
+
|
| 625 |
+
# If requested, initialize positional embeddings with fixed sinusoids and freeze
|
| 626 |
+
if config.cat_pos_emb and config.sinusoidal_cat_pos_emb:
|
| 627 |
+
assert config.d_pos_emb is not None, (
|
| 628 |
+
"sinusoidal_cat_pos_emb requires cat_pos_emb (d_pos_emb must be set)"
|
| 629 |
+
)
|
| 630 |
+
with torch.no_grad():
|
| 631 |
+
T = config.block_size
|
| 632 |
+
D = config.d_pos_emb
|
| 633 |
+
device = self.transformer.wpe.weight.device
|
| 634 |
+
dtype = self.transformer.wpe.weight.dtype
|
| 635 |
+
positions = torch.arange(T, device=device, dtype=dtype).unsqueeze(1) # [T,1]
|
| 636 |
+
d_half = max(1, D // 2)
|
| 637 |
+
# periods from 4 tokens up to block_size tokens (log-spaced)
|
| 638 |
+
T_float = float(T)
|
| 639 |
+
p_min = 4.0
|
| 640 |
+
p_max = max(p_min, T_float)
|
| 641 |
+
periods = torch.logspace(
|
| 642 |
+
math.log10(p_min), math.log10(p_max), steps=d_half, device=device, dtype=dtype
|
| 643 |
+
)
|
| 644 |
+
freqs = 2 * math.pi / periods # [d_half]
|
| 645 |
+
angles = positions * freqs # [T, d_half]
|
| 646 |
+
sinv = torch.sin(angles)
|
| 647 |
+
cosv = torch.cos(angles)
|
| 648 |
+
enc = torch.cat([sinv, cosv], dim=1) # [T, 2*d_half]
|
| 649 |
+
if enc.shape[1] < D:
|
| 650 |
+
pad = torch.zeros(T, D - enc.shape[1], device=device, dtype=dtype)
|
| 651 |
+
enc = torch.cat([enc, pad], dim=1)
|
| 652 |
+
elif enc.shape[1] > D:
|
| 653 |
+
enc = enc[:, :D]
|
| 654 |
+
self.transformer.wpe.weight.copy_(enc)
|
| 655 |
+
self.transformer.wpe.weight.requires_grad_(False)
|
| 656 |
+
|
| 657 |
+
# report number of parameters
|
| 658 |
+
print("number of parameters: %.2fM" % (self.get_num_params() / 1e6,))
|
| 659 |
+
|
| 660 |
+
def get_num_params(self, non_embedding=True):
|
| 661 |
+
"""
|
| 662 |
+
Return the number of parameters in the model.
|
| 663 |
+
For non-embedding count (default), the position embeddings get subtracted.
|
| 664 |
+
The token embeddings would too, except due to the parameter sharing these
|
| 665 |
+
params are actually used as weights in the final layer, so we include them.
|
| 666 |
+
"""
|
| 667 |
+
n_params = sum(p.numel() for p in self.parameters())
|
| 668 |
+
if non_embedding:
|
| 669 |
+
n_params -= self.transformer.wpe.weight.numel()
|
| 670 |
+
return n_params
|
| 671 |
+
|
| 672 |
+
def _init_weights(self, module):
|
| 673 |
+
if isinstance(module, nn.Linear):
|
| 674 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 675 |
+
if module.bias is not None:
|
| 676 |
+
torch.nn.init.zeros_(module.bias)
|
| 677 |
+
elif isinstance(module, nn.Embedding):
|
| 678 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 679 |
+
|
| 680 |
+
def forward(self, idx, targets=None, include_resid_mid=False):
|
| 681 |
+
device = idx.device
|
| 682 |
+
b, t = idx.size()
|
| 683 |
+
|
| 684 |
+
assert t <= self.config.block_size, (
|
| 685 |
+
f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
| 686 |
+
)
|
| 687 |
+
# pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
|
| 688 |
+
|
| 689 |
+
# forward the GPT model itself
|
| 690 |
+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, d_model)
|
| 691 |
+
# pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, d_model)
|
| 692 |
+
pos_emb = self.transformer.wpe.weight[:t].unsqueeze(0)
|
| 693 |
+
if self.config.cat_pos_emb:
|
| 694 |
+
x = self.transformer.drop(tok_emb)
|
| 695 |
+
else:
|
| 696 |
+
x = self.transformer.drop(tok_emb + pos_emb)
|
| 697 |
+
|
| 698 |
+
if self.config.debug_nans:
|
| 699 |
+
assert x.isfinite().all(), "nan in initial post-embedding"
|
| 700 |
+
|
| 701 |
+
if self.config.enable_bigram_table:
|
| 702 |
+
# add bigram table to the logits bias
|
| 703 |
+
additional_logits_bias = F.embedding(idx, self.bigram_table, padding_idx=-1)
|
| 704 |
+
additional_logits_bias = additional_logits_bias.to(x.dtype)
|
| 705 |
+
else:
|
| 706 |
+
additional_logits_bias = None
|
| 707 |
+
|
| 708 |
+
if self.config.cat_pos_emb:
|
| 709 |
+
pos_emb_to_cat = pos_emb
|
| 710 |
+
if self.config.dropout_cat_pos_emb:
|
| 711 |
+
pos_emb_to_cat = self.transformer.drop(pos_emb)
|
| 712 |
+
else:
|
| 713 |
+
pos_emb_to_cat = None
|
| 714 |
+
|
| 715 |
+
return self.forward_tail(
|
| 716 |
+
x,
|
| 717 |
+
n=0,
|
| 718 |
+
targets=targets,
|
| 719 |
+
additional_logits_bias=additional_logits_bias,
|
| 720 |
+
include_resid_mid=include_resid_mid, # this is hacky we should just switch to using hooks
|
| 721 |
+
pos_emb_to_cat=pos_emb_to_cat,
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
def forward_tail(
|
| 725 |
+
self,
|
| 726 |
+
x,
|
| 727 |
+
n,
|
| 728 |
+
targets=None,
|
| 729 |
+
additional_logits_bias=None,
|
| 730 |
+
include_resid_mid=False,
|
| 731 |
+
pos_emb_to_cat=None,
|
| 732 |
+
):
|
| 733 |
+
# print(x.shape)
|
| 734 |
+
hs = []
|
| 735 |
+
blks = list(self.transformer.h)
|
| 736 |
+
|
| 737 |
+
if include_resid_mid:
|
| 738 |
+
blks = list_join(
|
| 739 |
+
[
|
| 740 |
+
[
|
| 741 |
+
blk.forward_attn_block,
|
| 742 |
+
blk.forward_mlp_block,
|
| 743 |
+
]
|
| 744 |
+
for blk in blks
|
| 745 |
+
]
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
assert n <= len(blks)
|
| 749 |
+
|
| 750 |
+
for i, block_fn in enumerate(blks[n:]):
|
| 751 |
+
global curlayer
|
| 752 |
+
curlayer = i
|
| 753 |
+
with hook_namespace(f"{i // 2}") if include_resid_mid else hook_namespace(f"{i}"):
|
| 754 |
+
hs.append(x)
|
| 755 |
+
if self.config.cat_pos_emb:
|
| 756 |
+
x = block_fn(x, pos_emb_to_cat)
|
| 757 |
+
else:
|
| 758 |
+
x = block_fn(x)
|
| 759 |
+
|
| 760 |
+
x = hook_save("final_resid", x)
|
| 761 |
+
x = self.transformer.ln_f(x)
|
| 762 |
+
|
| 763 |
+
logits = (
|
| 764 |
+
self.lm_head(x)
|
| 765 |
+
+ self.final_logits_bias
|
| 766 |
+
+ (additional_logits_bias if additional_logits_bias is not None else 0)
|
| 767 |
+
)
|
| 768 |
+
if targets is not None:
|
| 769 |
+
loss = F.cross_entropy(
|
| 770 |
+
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1
|
| 771 |
+
)
|
| 772 |
+
else:
|
| 773 |
+
loss = torch.zeros(1, device=x.device)
|
| 774 |
+
|
| 775 |
+
return logits, loss, hs # hs is deprecated in favor of hook stuff
|
| 776 |
+
|
| 777 |
+
def crop_block_size(self, block_size):
|
| 778 |
+
# model surgery to decrease the block size if necessary
|
| 779 |
+
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
|
| 780 |
+
# but want to use a smaller block size for some smaller, simpler model
|
| 781 |
+
assert block_size <= self.config.block_size
|
| 782 |
+
self.config.block_size = block_size
|
| 783 |
+
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
|
| 784 |
+
for block in self.transformer.h:
|
| 785 |
+
if hasattr(block.attn, "bias"):
|
| 786 |
+
block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
|
| 787 |
+
|
| 788 |
+
@torch.no_grad()
|
| 789 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
| 790 |
+
"""
|
| 791 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
| 792 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
| 793 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
| 794 |
+
"""
|
| 795 |
+
for _ in range(max_new_tokens):
|
| 796 |
+
# if the sequence context is growing too long we must crop it at block_size
|
| 797 |
+
idx_cond = (
|
| 798 |
+
idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :]
|
| 799 |
+
)
|
| 800 |
+
# forward the model to get the logits for the index in the sequence
|
| 801 |
+
logits, _, _ = self(idx_cond)
|
| 802 |
+
# pluck the logits at the final step and scale by desired temperature
|
| 803 |
+
logits = logits[:, -1, :] / temperature
|
| 804 |
+
# optionally crop the logits to only the top k options
|
| 805 |
+
if top_k is not None:
|
| 806 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 807 |
+
logits[logits < v[:, -1:]] = -float("Inf")
|
| 808 |
+
# apply softmax to convert logits to (normalized) probabilities
|
| 809 |
+
probs = F.softmax(logits, dim=-1)
|
| 810 |
+
# sample from the distribution
|
| 811 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 812 |
+
# append sampled index to the running sequence and continue
|
| 813 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
| 814 |
+
|
| 815 |
+
return idx
|
| 816 |
+
|
| 817 |
+
def is_mlp_param(self, p):
|
| 818 |
+
return id(p) in list_join(
|
| 819 |
+
[
|
| 820 |
+
[
|
| 821 |
+
id(self.transformer.h[i].mlp.c_fc.weight),
|
| 822 |
+
id(self.transformer.h[i].mlp.c_proj.weight),
|
| 823 |
+
]
|
| 824 |
+
for i in range(self.config.n_layer)
|
| 825 |
+
]
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
def is_param_embed(self, p):
|
| 829 |
+
return p is self.transformer.wte.weight or p is self.transformer.wpe.weight
|
| 830 |
+
|
| 831 |
+
def is_attn_param(self, p):
|
| 832 |
+
return id(p) in list_join(
|
| 833 |
+
[
|
| 834 |
+
[
|
| 835 |
+
id(self.transformer.h[i].attn.c_attn.weight),
|
| 836 |
+
id(self.transformer.h[i].attn.c_proj.weight),
|
| 837 |
+
]
|
| 838 |
+
for i in range(self.config.n_layer)
|
| 839 |
+
]
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
def is_bias(self, p):
|
| 843 |
+
return id(p) in list_join(
|
| 844 |
+
[
|
| 845 |
+
[
|
| 846 |
+
id(self.transformer.h[i].attn.c_attn.bias),
|
| 847 |
+
id(self.transformer.h[i].attn.c_proj.bias),
|
| 848 |
+
id(self.transformer.h[i].mlp.c_fc.bias),
|
| 849 |
+
id(self.transformer.h[i].mlp.c_proj.bias),
|
| 850 |
+
]
|
| 851 |
+
for i in range(self.config.n_layer)
|
| 852 |
+
]
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
def is_ln_param(self, p):
|
| 856 |
+
return id(p) in list_join(
|
| 857 |
+
[
|
| 858 |
+
[
|
| 859 |
+
id(self.transformer.h[i].ln_1.weight),
|
| 860 |
+
id(self.transformer.h[i].ln_2.weight),
|
| 861 |
+
]
|
| 862 |
+
for i in range(self.config.n_layer)
|
| 863 |
+
]
|
| 864 |
+
) + [
|
| 865 |
+
id(self.transformer.ln_f.weight),
|
| 866 |
+
]
|
| 867 |
+
|
| 868 |
+
def is_sparse_param(self, p, dense_embeddings=None, dense_unembed=None, dense_biases=None):
|
| 869 |
+
# if these params aren't specified, then still give answers, but only for uncontroversial params
|
| 870 |
+
|
| 871 |
+
if dense_embeddings is None:
|
| 872 |
+
assert p is not self.transformer.wte.weight and p is not self.transformer.wpe.weight
|
| 873 |
+
if dense_unembed is None:
|
| 874 |
+
assert p is not self.lm_head.weight
|
| 875 |
+
if dense_biases is None:
|
| 876 |
+
assert not self.is_bias(p)
|
| 877 |
+
|
| 878 |
+
if p is self.transformer.wte.weight or p is self.transformer.wpe.weight:
|
| 879 |
+
return not dense_embeddings
|
| 880 |
+
if p is self.lm_head.weight:
|
| 881 |
+
return not dense_unembed
|
| 882 |
+
if self.is_bias(p):
|
| 883 |
+
return not dense_biases
|
| 884 |
+
|
| 885 |
+
return id(p) in list_join(
|
| 886 |
+
[
|
| 887 |
+
[
|
| 888 |
+
id(self.transformer.h[i].attn.c_attn.weight),
|
| 889 |
+
id(self.transformer.h[i].attn.c_proj.weight),
|
| 890 |
+
id(self.transformer.h[i].mlp.c_fc.weight),
|
| 891 |
+
id(self.transformer.h[i].mlp.c_proj.weight),
|
| 892 |
+
]
|
| 893 |
+
for i in range(self.config.n_layer)
|
| 894 |
+
]
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
def list_join(xss: list[list]) -> list:
|
| 899 |
+
"""monadic join for lists"""
|
| 900 |
+
return [x for xs in xss for x in xs]
|
hook_utils.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Self-contained subset of :mod:`circuit_sparsity.hook_utils` for inference builds.
|
| 2 |
+
|
| 3 |
+
The full module has no exotic dependencies, but mirroring the definitions here
|
| 4 |
+
keeps the trimmed :mod:`circuit_sparsity.inference.gpt` module hermetic and easy to vendor. The
|
| 5 |
+
implementations below are copied with minor tweaks for readability so that code
|
| 6 |
+
written against :func:`hook_recorder`, :func:`hook_namespace`, and
|
| 7 |
+
:func:`torch_recompute_preserving_hook_context` behaves identically in both the
|
| 8 |
+
training and inference configurations.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import re
|
| 14 |
+
from contextlib import contextmanager
|
| 15 |
+
from functools import partial
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.utils.checkpoint
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class HookContext:
|
| 22 |
+
"""State container used by the hook helpers."""
|
| 23 |
+
|
| 24 |
+
def __init__(self) -> None:
|
| 25 |
+
self._reset()
|
| 26 |
+
self.curintervtransformer = lambda x: x
|
| 27 |
+
|
| 28 |
+
def _reset(self) -> None:
|
| 29 |
+
self.curcontext = None
|
| 30 |
+
self.curname = ""
|
| 31 |
+
self.curregex = None
|
| 32 |
+
self.curinterventions = None
|
| 33 |
+
self.save_grads = None
|
| 34 |
+
|
| 35 |
+
def _get_interventions(self):
|
| 36 |
+
return self.curintervtransformer(
|
| 37 |
+
self.curinterventions if self.curinterventions is not None else {}
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
@contextmanager
|
| 41 |
+
def hook_recorder(self, regex: str = ".*", interventions=None, save_grads: bool = False):
|
| 42 |
+
"""Record tensors that pass through hooks matching ``regex``."""
|
| 43 |
+
|
| 44 |
+
assert self.curcontext is None, "reentrancy not allowed!"
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
self.curcontext = {}
|
| 48 |
+
self.curregex = re.compile(regex)
|
| 49 |
+
self.curname = ""
|
| 50 |
+
self.curinterventions = interventions
|
| 51 |
+
self.save_grads = save_grads
|
| 52 |
+
|
| 53 |
+
yield self.curcontext
|
| 54 |
+
finally:
|
| 55 |
+
self._reset()
|
| 56 |
+
get_context()._reset()
|
| 57 |
+
|
| 58 |
+
@contextmanager
|
| 59 |
+
def hook_intervention_transform(self, intervention_transformer):
|
| 60 |
+
oldintervention_transformer = self.curintervtransformer
|
| 61 |
+
|
| 62 |
+
def compose(f, g):
|
| 63 |
+
return lambda x: f(g(x))
|
| 64 |
+
|
| 65 |
+
self.curintervtransformer = compose(
|
| 66 |
+
intervention_transformer,
|
| 67 |
+
self.curintervtransformer,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
yield
|
| 72 |
+
finally:
|
| 73 |
+
self.curintervtransformer = oldintervention_transformer
|
| 74 |
+
|
| 75 |
+
@contextmanager
|
| 76 |
+
def hook_namespace(self, name: str):
|
| 77 |
+
"""Temporarily push ``name`` onto the hook namespace stack."""
|
| 78 |
+
|
| 79 |
+
oldname = self.curname
|
| 80 |
+
self.curname = self.curname + name + "."
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
yield
|
| 84 |
+
finally:
|
| 85 |
+
self.curname = oldname
|
| 86 |
+
|
| 87 |
+
def hook_save(self, name: str, tensor: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
"""Optionally record ``tensor`` using the current namespace."""
|
| 89 |
+
|
| 90 |
+
curinterventions = self._get_interventions()
|
| 91 |
+
if curinterventions is not None:
|
| 92 |
+
key = self.curname + name
|
| 93 |
+
if key in curinterventions:
|
| 94 |
+
tensor = curinterventions[key](tensor)
|
| 95 |
+
|
| 96 |
+
if self.curcontext is not None and self.curregex.match(self.curname + name):
|
| 97 |
+
self.curcontext[self.curname + name] = tensor
|
| 98 |
+
|
| 99 |
+
if self.curcontext is not None and self.save_grads and tensor.requires_grad:
|
| 100 |
+
|
| 101 |
+
class _Grad(torch.autograd.Function):
|
| 102 |
+
@staticmethod
|
| 103 |
+
def forward(ctx, input_tensor):
|
| 104 |
+
return input_tensor
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def backward(ctx, grad_output):
|
| 108 |
+
self.curcontext[self.curname + name + ".grad"] = grad_output
|
| 109 |
+
return grad_output
|
| 110 |
+
|
| 111 |
+
if self.curregex.match(self.curname + name + ".grad"):
|
| 112 |
+
tensor = _Grad.apply(tensor)
|
| 113 |
+
|
| 114 |
+
return tensor
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def set_context(new_context: HookContext) -> None:
|
| 118 |
+
global context
|
| 119 |
+
context = new_context
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_context() -> HookContext:
|
| 123 |
+
global context
|
| 124 |
+
return context
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def torch_recompute_preserving_hook_context(f, *xs, use_reentrant=None):
|
| 128 |
+
"""Wrapper around :func:`torch.utils.checkpoint` that propagates hooks."""
|
| 129 |
+
|
| 130 |
+
oldcontext = get_context()
|
| 131 |
+
curcontext = HookContext()
|
| 132 |
+
curcontext.curcontext = (
|
| 133 |
+
dict(oldcontext.curcontext) if oldcontext.curcontext is not None else None
|
| 134 |
+
)
|
| 135 |
+
curcontext.curregex = oldcontext.curregex
|
| 136 |
+
curcontext.curname = oldcontext.curname
|
| 137 |
+
curcontext.curinterventions = (
|
| 138 |
+
dict(oldcontext.curinterventions) if oldcontext.curinterventions is not None else None
|
| 139 |
+
)
|
| 140 |
+
curcontext.save_grads = oldcontext.save_grads
|
| 141 |
+
|
| 142 |
+
is_recompute = False
|
| 143 |
+
|
| 144 |
+
def _f(curcontext: HookContext, *xs):
|
| 145 |
+
initcontext = get_context()
|
| 146 |
+
nonlocal is_recompute
|
| 147 |
+
|
| 148 |
+
set_context(curcontext)
|
| 149 |
+
try:
|
| 150 |
+
res = f(*xs)
|
| 151 |
+
|
| 152 |
+
if not is_recompute and oldcontext.curcontext is not None:
|
| 153 |
+
oldcontext.curcontext |= curcontext.curcontext
|
| 154 |
+
finally:
|
| 155 |
+
set_context(initcontext)
|
| 156 |
+
is_recompute = True
|
| 157 |
+
return res
|
| 158 |
+
|
| 159 |
+
res = torch.utils.checkpoint.checkpoint(
|
| 160 |
+
partial(_f, curcontext), *xs, use_reentrant=use_reentrant
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
return res
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
context = HookContext()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def hook_recorder(*a, **k):
|
| 170 |
+
return get_context().hook_recorder(*a, **k)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def hook_namespace(*a, **k):
|
| 174 |
+
return get_context().hook_namespace(*a, **k)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def hook_save(*a, **k):
|
| 178 |
+
return get_context().hook_save(*a, **k)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def hook_intervention_transform(*a, **k):
|
| 182 |
+
return get_context().hook_intervention_transform(*a, **k)
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c7a63c67d78a1e8da5fc3ead1338fe0f18994d48deedeb4374a6b3ad1d350724
|
| 3 |
+
size 1676511896
|
modeling_circuitgpt.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Sequence
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from transformers.generation.utils import GenerationMixin
|
| 10 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 11 |
+
from transformers.utils.generic import ModelOutput
|
| 12 |
+
|
| 13 |
+
from .config import CircuitGPTConfig
|
| 14 |
+
from .gpt import GPT
|
| 15 |
+
from .hook_utils import hook_recorder
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class CircuitGPTCausalLMOutput(ModelOutput):
|
| 20 |
+
loss: torch.Tensor | None = None
|
| 21 |
+
logits: torch.Tensor | None = None
|
| 22 |
+
activations: dict[str, torch.Tensor] | None = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _activations_regex(keys: Sequence[str]) -> str:
|
| 26 |
+
escaped = (re.escape(k) for k in keys)
|
| 27 |
+
return "^(" + "|".join(escaped) + ")$"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CircuitGPTPreTrainedModel(PreTrainedModel):
|
| 31 |
+
config_class = CircuitGPTConfig
|
| 32 |
+
base_model_prefix = "circuit_model"
|
| 33 |
+
circuit_model: GPT
|
| 34 |
+
|
| 35 |
+
def __init__(self, config: CircuitGPTConfig, *inputs, **kwargs) -> None:
|
| 36 |
+
super().__init__(config, *inputs, **kwargs)
|
| 37 |
+
|
| 38 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 39 |
+
return self.circuit_model.transformer.wte # type: ignore[return-value]
|
| 40 |
+
|
| 41 |
+
def set_input_embeddings(self, value: nn.Module) -> None:
|
| 42 |
+
self.circuit_model.transformer.wte = value # type: ignore[assignment]
|
| 43 |
+
|
| 44 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 45 |
+
return self.circuit_model.lm_head # type: ignore[return-value]
|
| 46 |
+
|
| 47 |
+
def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
|
| 48 |
+
self.circuit_model.lm_head = new_embeddings # type: ignore[assignment]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class CircuitGPTForCausalLM(CircuitGPTPreTrainedModel, GenerationMixin):
|
| 52 |
+
"""
|
| 53 |
+
Hugging Face-compatible wrapper around `circuit_sparsity.gpt.GPT`.
|
| 54 |
+
All math happens inside the original module so parity is guaranteed.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, config: CircuitGPTConfig, circuit_model: GPT | None = None) -> None:
|
| 58 |
+
super().__init__(config)
|
| 59 |
+
|
| 60 |
+
if circuit_model is None:
|
| 61 |
+
self.circuit_model = GPT(config.to_circuit_config())
|
| 62 |
+
self.post_init()
|
| 63 |
+
else:
|
| 64 |
+
self.circuit_model = circuit_model
|
| 65 |
+
|
| 66 |
+
# ------------------------------------------------------------------
|
| 67 |
+
# Constructors
|
| 68 |
+
# ------------------------------------------------------------------
|
| 69 |
+
@classmethod
|
| 70 |
+
def from_circuit_model(cls, circuit_model: GPT) -> "CircuitGPTForCausalLM":
|
| 71 |
+
config = CircuitGPTConfig.from_circuit_config(circuit_model.config)
|
| 72 |
+
return cls(config, circuit_model=circuit_model)
|
| 73 |
+
|
| 74 |
+
# ------------------------------------------------------------------
|
| 75 |
+
# Forward
|
| 76 |
+
# ------------------------------------------------------------------
|
| 77 |
+
def forward(
|
| 78 |
+
self,
|
| 79 |
+
input_ids: torch.Tensor,
|
| 80 |
+
labels: torch.LongTensor | None = None,
|
| 81 |
+
output_activations: Sequence[str] | None = None,
|
| 82 |
+
return_dict: bool | None = None,
|
| 83 |
+
use_cache: bool | None = None,
|
| 84 |
+
output_attentions: bool | None = None,
|
| 85 |
+
output_hidden_states: bool | None = None,
|
| 86 |
+
**kwargs,
|
| 87 |
+
) -> CircuitGPTCausalLMOutput:
|
| 88 |
+
# Ignore HF generation kwargs we don't use; surface any unknowns.
|
| 89 |
+
remaining_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
| 90 |
+
if remaining_kwargs:
|
| 91 |
+
unsupported = ", ".join(remaining_kwargs.keys())
|
| 92 |
+
raise ValueError(f"Unsupported arguments for CircuitGPTForCausalLM: {unsupported}")
|
| 93 |
+
|
| 94 |
+
if input_ids.size(-1) > self.config.block_size:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"Sequence length {input_ids.size(-1)} exceeds block size {self.config.block_size}"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if output_activations:
|
| 100 |
+
regex = _activations_regex(output_activations)
|
| 101 |
+
with hook_recorder(regex=regex) as recorded:
|
| 102 |
+
logits, loss, _ = self.circuit_model(input_ids, targets=labels)
|
| 103 |
+
activations = {key: recorded[key] for key in output_activations if key in recorded}
|
| 104 |
+
else:
|
| 105 |
+
activations = None
|
| 106 |
+
logits, loss, _ = self.circuit_model(input_ids, targets=labels)
|
| 107 |
+
|
| 108 |
+
if labels is None:
|
| 109 |
+
loss = None
|
| 110 |
+
|
| 111 |
+
return CircuitGPTCausalLMOutput(
|
| 112 |
+
loss=loss,
|
| 113 |
+
logits=logits,
|
| 114 |
+
activations=activations,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# ------------------------------------------------------------------
|
| 118 |
+
# Generation helpers
|
| 119 |
+
# ------------------------------------------------------------------
|
| 120 |
+
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, **kwargs):
|
| 121 |
+
if input_ids.size(-1) > self.config.block_size:
|
| 122 |
+
input_ids = input_ids[:, -self.config.block_size :]
|
| 123 |
+
return {"input_ids": input_ids}
|
| 124 |
+
|
| 125 |
+
def _reorder_cache(self, past, beam_idx):
|
| 126 |
+
# No KV cache implemented; method exists for interface completeness.
|
| 127 |
+
return past
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"eos_token": "<|endoftext|>"
|
| 3 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b42d53ab8a66da8fc53288442b1239e373697d630ad61fa66dc35350274617c7
|
| 3 |
+
size 24366
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"eos_token": "<|endoftext|>",
|
| 3 |
+
"model_max_length": 256,
|
| 4 |
+
"padding_side": "left",
|
| 5 |
+
"tokenizer_class": "PreTrainedTokenizerFast"
|
| 6 |
+
}
|