mmrech
/

WHAM
English
microsoft
mmrech katja-hofmann commited on
Commit
aa16c0a
·
0 Parent(s):

Duplicate from microsoft/wham

Browse files

Co-authored-by: Katja Hofmann <katja-hofmann@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +38 -0
  2. CODE_OF_CONDUCT.md +10 -0
  3. CONTRIBUTING.md +14 -0
  4. LICENSE.md +96 -0
  5. README.md +243 -0
  6. SECURITY.md +37 -0
  7. WHAM_Demonstrator.zip +3 -0
  8. assets/Demonstrator/Fig_01.png +3 -0
  9. assets/Demonstrator/Fig_02.png +3 -0
  10. assets/Demonstrator/Fig_03.png +3 -0
  11. assets/Demonstrator/Fig_04.png +3 -0
  12. assets/Demonstrator/Fig_05.png +3 -0
  13. assets/Demonstrator/Fig_06.png +3 -0
  14. assets/Demonstrator/Fig_07.png +3 -0
  15. assets/Demonstrator/Fig_08.png +3 -0
  16. assets/Demonstrator/Fig_09.png +3 -0
  17. assets/Demonstrator/Fig_10.png +3 -0
  18. assets/Demonstrator/Fig_11.png +3 -0
  19. assets/Demonstrator/Fig_12.png +3 -0
  20. assets/Demonstrator/Fig_13.png +3 -0
  21. assets/Demonstrator/Fig_14.png +3 -0
  22. assets/Demonstrator/Fig_15.png +3 -0
  23. assets/Demonstrator/Fig_16.png +3 -0
  24. assets/Demonstrator/Fig_17.png +3 -0
  25. assets/Readme/model_capabilities.gif +3 -0
  26. assets/Readme/wham_gen_1.gif +3 -0
  27. assets/Readme/wham_gen_2.gif +3 -0
  28. assets/Readme/wham_gen_3.gif +3 -0
  29. assets/Readme/wham_gen_4.gif +3 -0
  30. assets/Readme/wham_gen_5.gif +3 -0
  31. assets/Readme/wham_gen_6.gif +3 -0
  32. assets/Readme/wham_gen_7.gif +3 -0
  33. assets/Readme/wham_gen_8.gif +3 -0
  34. assets/Readme/wham_gen_9.gif +3 -0
  35. configs/metadata_custom_tag.config +5 -0
  36. data_summary_card.md +145 -0
  37. models/WHAM_1.6B_v1.ckpt +3 -0
  38. models/WHAM_200M.ckpt +3 -0
  39. models/config.json +0 -0
  40. requirements.txt +48 -0
  41. run_dreaming.py +264 -0
  42. run_server.py +519 -0
  43. setup_local.sh +21 -0
  44. wham/models/nn/model_blocks.py +49 -0
  45. wham/models/nn/nanoGPT.py +665 -0
  46. wham/models/pl/__init__.py +0 -0
  47. wham/models/pl/pl_base_model.py +5 -0
  48. wham/models/vqgan/taming/LICENSE +24 -0
  49. wham/models/vqgan/taming/model.py +696 -0
  50. wham/models/vqgan/taming/quantize.py +146 -0
.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ fonts/arial.ttf filter=lfs diff=lfs merge=lfs -text
37
+ *.gif filter=lfs diff=lfs merge=lfs -text
38
+ *.png filter=lfs diff=lfs merge=lfs -text
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Microsoft Open Source Code of Conduct
2
+
3
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4
+
5
+ Resources:
6
+
7
+ - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8
+ - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9
+ - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10
+ - Employees can reach out at [aka.ms/opensource/moderation-support](https://aka.ms/opensource/moderation-support)
CONTRIBUTING.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing
2
+
3
+ This project welcomes contributions and suggestions. Most contributions require you to
4
+ agree to a Contributor License Agreement (CLA) declaring that you have the right to,
5
+ and actually do, grant us the rights to use your contribution. For details, visit
6
+ https://cla.microsoft.com.
7
+
8
+ When you submit a pull request, a CLA-bot will automatically determine whether you need
9
+ to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
10
+ instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
11
+
12
+ This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
13
+ For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
14
+ or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
LICENSE.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MICROSOFT RESEARCH LICENSE TERMS
2
+
3
+ **IF YOU LIVE IN THE UNITED STATES, PLEASE READ THE “BINDING ARBITRATION AND CLASS ACTION WAIVER” SECTION BELOW. IT AFFECTS HOW DISPUTES ARE RESOLVED.**
4
+
5
+ These license terms are an agreement between you and Microsoft Corporation (or one of its affiliates). They apply to the source code, object code, machine learning models, or data (collectively “Materials”) that accompany this license. IF YOU COMPLY WITH THESE LICENSE TERMS, YOU HAVE THE RIGHTS BELOW. BY USING THE MATERIALS, YOU ACCEPT THESE TERMS.
6
+
7
+ ## 1) INSTALLATION AND USE RIGHTS TO THE MATERIALS.
8
+
9
+ Subject to the terms of this agreement, you have the below rights, if applicable, to use the Materials solely for non-commercial, non-revenue generating, research purposes:
10
+
11
+ a) **Source Code.** If source code is included, you may use and modify the source code, but you may not distribute the source code.
12
+
13
+ b) **Object Code.** If object code is included, you may use the object code, but you may not distribute the object code.
14
+
15
+ c) **Models.** If machine learning model(s) are included, you may use the model(s), but you may not distribute the models.
16
+
17
+ d) **Data.** If data is included, you may use the data, but your use must be consistent with the consent under which the data was provided and/or gathered and you may not modify or distribute the data.
18
+
19
+ ## 2) SCOPE OF LICENSE.
20
+
21
+ The Materials are licensed, not sold. Microsoft reserves all other rights. Unless applicable law gives you more rights despite this limitation, you will not (and have no right to):
22
+
23
+ a) Work around any technical limitations in the Materials that only allow you to use it in certain ways;
24
+
25
+ b) Reverse engineer, decompile or disassemble the Materials;
26
+
27
+ c) Remove, minimize, block, or modify any notices of Microsoft or its suppliers in the Materials;
28
+
29
+ d) Use the Materials in any way that is against the law or to create or propagate malware; or
30
+
31
+ e) Share, publish, distribute or lend the Materials, provide the Materials as a stand-alone hosted solution for others to use, or transfer the Materials or this agreement to any third party.
32
+
33
+ ## 3) PERSONAL DATA.
34
+
35
+ If the data (set forth in Section 1(d) above) includes or is found to include any data that enables any ability to identify an individual ("Personal Data"), you will not use such Personal Data for any purpose other than was authorized and consented to by the data subject/research participant. You will not use Personal Data to contact any person. You will keep Personal Data in strict confidence. You will not share any Personal Data that is collected or in your possession with any third party for any reason and as required under the original consent agreement. Further, you will destroy the Personal Data and any backup or copies, **immediately upon the completion of your research.**
36
+
37
+ ## 4) LICENSE TO MICROSOFT.
38
+
39
+ Notwithstanding the limitations in Section 1, you may distribute your modifications back to Microsoft, and if you do provide Microsoft with modifications of the Materials, you hereby grant Microsoft, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, post, distribute, make and have made, sell and transfer such modifications and derivatives for any purpose.
40
+
41
+ ## 5) PUBLICATION.
42
+
43
+ You may publish (or present papers or articles) on your results from using the Materials provided that no material or substantial portion of the Materials is included in any such publication or presentation.
44
+
45
+ ## 6) FEEDBACK.
46
+
47
+ Any feedback about the Materials provided by you to us is voluntarily given, and Microsoft shall be free to use the feedback as it sees fit without obligation or restriction of any kind, even if the feedback is designated by you as confidential. **Additional** Such feedback shall be considered a contribution and licensed to Microsoft under the terms of Section 4 above.
48
+
49
+ ## 7) COMPLIANCE WITH TRADE LAWS.
50
+
51
+ You acknowledge that the Materials may be subject to applicable trade laws in one or more countries. You will comply with all relevant laws and regulations applicable to the import or export of the Materials, including but not limited to, trade laws such as the U.S. Export Administration Regulations or other end-user, end use, and destination restrictions by the U.S. and other governments, as well as sanctions regulations administered by the U.S. Office of Foreign Assets Control. Microsoft may suspend or terminate the agreement immediately to the extent that Microsoft reasonably concludes that continued performance would violate trade laws or put it at risk of becoming subject to sanctions or penalties under trade laws. For additional information, see www.microsoft.com/exporting.
52
+
53
+ ## 8) SUPPORT SERVICES.
54
+
55
+ Microsoft is not obligated under this agreement to provide any support services for the Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
56
+
57
+ ## 9) BINDING ARBITRATION AND CLASS ACTION WAIVER.
58
+
59
+ **This Section applies if you live in (or, if a business, your principal place of business is in) the United States.** If you and Microsoft have a dispute, you and Microsoft agree to try for 60 days to resolve it informally. If you and Microsoft can’t, you and Microsoft agree to **binding individual arbitration before the American Arbitration Association** under the Federal Arbitration Act ("FAA"), and not to **sue in court in front of a judge or jury.** Instead, a neutral arbitrator will decide. **Class action lawsuits, class-wide arbitrations, private attorney-general actions,** and any other proceeding where someone acts in a representative capacity **are not allowed;** nor is combining individual proceedings without the consent of all parties. The complete Arbitration Agreement contains more terms and is at aka.ms/arb-agreement-1. You and Microsoft agree to these terms.
60
+
61
+ ## 10) ENTIRE AGREEMENT.
62
+
63
+ This agreement, and any other terms Microsoft may provide for supplements, updates, or third-party applications, is the entire agreement for the Materials.
64
+
65
+ ## 11) APPLICABLE LAW AND PLACE TO RESOLVE DISPUTES.
66
+
67
+ If you acquired the Materials in the United States or Canada, the laws of the state or province where you live (or, if a business, where your principal place of business is located) govern the interpretation of this agreement, claims for its breach, and all other claims (including consumer protection, unfair competition, and tort claims), regardless of conflict of laws principles, except that the FAA governs everything related to arbitration. If you acquired the Materials in any other country, its laws apply, except that the FAA governs everything related to arbitration. If U.S. federal jurisdiction exists, you and Microsoft consent to exclusive jurisdiction and venue in the federal court in King County, Washington for all disputes heard in court (excluding arbitration). If not, you and Microsoft consent to exclusive jurisdiction and venue in the Superior Court of King County, Washington for all disputes heard in court (excluding arbitration).
68
+
69
+ ## 12) CONSUMER RIGHTS; REGIONAL VARIATIONS.
70
+
71
+ This agreement describes certain legal rights. You may have other rights, including consumer rights, under the laws of your state, province, or country. Separate and apart from your relationship with Microsoft, you may also have rights with respect to the party from which you acquired the Materials. This agreement does not change those other rights if the laws of your state, province, or country do not permit it to do so. For example, if you acquired the Materials in one of the below regions, or mandatory country law applies, then the following provisions apply to you:
72
+
73
+ a) **Australia.** You have statutory guarantees under the Australian Consumer Law and nothing in this agreement is intended to affect those rights.
74
+
75
+ b) **Canada.** If you acquired this software in Canada, you may stop receiving updates by turning off the automatic update feature, disconnecting your device from the Internet (if and when you re-connect to the Internet, however, the Materials will resume checking for and installing updates), or uninstalling the Materials. The product documentation, if any, may also specify how to turn off updates for your specific device or software.
76
+
77
+ c) **Germany and Austria.**
78
+ i. **Warranty.** The properly licensed software will perform substantially as described in any Microsoft materials that accompany the Materials. However, Microsoft gives no contractual guarantee in relation to the licensed software.
79
+ ii. **Limitation of Liability.** In case of intentional conduct, gross negligence, claims based on the Product Liability Act, as well as, in case of death or personal or physical injury, Microsoft is liable according to the statutory law.
80
+
81
+ Subject to the foregoing clause (ii), Microsoft will only be liable for slight negligence if Microsoft is in breach of such material contractual obligations, the fulfillment of which facilitate the due performance of this agreement, the breach of which would endanger the purpose of this agreement and the compliance with which a party may constantly trust in (so-called "cardinal obligations"). In other cases of slight negligence, Microsoft will not be liable for slight negligence.
82
+
83
+ ## 13) DISCLAIMER OF WARRANTY.
84
+
85
+ THE MATERIALS ARE LICENSED "AS IS." YOU BEAR THE RISK OF USING THEM. MICROSOFT GIVES NO EXPRESS WARRANTIES, GUARANTEES, OR CONDITIONS. TO THE EXTENT PERMITTED UNDER APPLICABLE LAWS, MICROSOFT EXCLUDES ALL IMPLIED WARRANTIES, INCLUDING MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT.
86
+
87
+ ## 14) LIMITATION ON AND EXCLUSION OF DAMAGES.
88
+
89
+ IF YOU HAVE ANY BASIS FOR RECOVERING DAMAGES DESPITE THE PRECEDING DISCLAIMER OF WARRANTY, YOU CAN RECOVER FROM MICROSOFT AND ITS SUPPLIERS ONLY DIRECT DAMAGES UP TO U.S. $5.00. YOU CANNOT RECOVER ANY OTHER DAMAGES, INCLUDING CONSEQUENTIAL, LOST PROFITS, SPECIAL, INDIRECT OR INCIDENTAL DAMAGES.
90
+
91
+ This limitation applies to:
92
+ - (a) anything related to the Materials, services, content (including code) on third party Internet sites, or third party applications; and
93
+ - (b) claims for breach of contract, warranty, guarantee, or condition; strict liability, negligence, or other tort; or any other claim; in each case to the extent permitted by applicable law.
94
+
95
+ It also applies even if Microsoft knew or should have known about the possibility of the damages. The above limitation or exclusion may not apply to you because your state, province, or country may not allow the exclusion or limitation of incidental, consequential, or other damages.
96
+
README.md ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - microsoft/bleeding-edge-gameplay-sample
4
+ tags:
5
+ - wham
6
+ - microsoft
7
+ language:
8
+ - en
9
+ license_link: LICENSE.md
10
+ ---
11
+ # World and Human Action Model (WHAM)
12
+ 📄 [Paper](https://www.nature.com/articles/s41586-025-08600-3) • 🔗 [Sample Data](https://huggingface.co/datasets/microsoft/bleeding-edge-gameplay-sample)
13
+ <div align="center">
14
+ Anssi Kanervisto, Dave Bignell, Linda Yilin Wen, Martin Grayson, Raluca Georgescu, Sergio Valcarcel Macua, Shan Zheng Tan, Tabish Rashid, Tim Pearce, Yuhan Cao,
15
+ Abdelhak Lemkhenter, Chentian Jiang, Gavin Costello, Gunshi Gupta, Marko Tot, Shu Ishida, Tarun Gupta, Udit Arora,
16
+ Ryen W. White, Sam Devlin, Cecily Morrison, Katja Hofmann
17
+ </div><br>
18
+ <div align='center'>
19
+ Dynamic Generated Gameplay Sequence using WHAM. Showcasing diverse characters and actions across intricate maps.
20
+ <div style="display: flex; flex-wrap: wrap;">
21
+ <img style="width: calc(33.33%); margin-bottom: -35px;" src="assets/Readme/wham_gen_1.gif">
22
+ <img style="width: calc(33.33%); margin-bottom: -35px;" src="assets/Readme/wham_gen_2.gif">
23
+ <img style="width: calc(33.33%); margin-bottom: -35px;" src="assets/Readme/wham_gen_3.gif">
24
+ <img style="width: calc(33.33%); margin-bottom: -35px;" src="assets/Readme/wham_gen_4.gif">
25
+ <img style="width: calc(33.33%); margin-bottom: -35px;" src="assets/Readme/wham_gen_5.gif">
26
+ <img style="width: calc(33.33%); margin-bottom: -35px;" src="assets/Readme/wham_gen_6.gif">
27
+ <img style="width: calc(33.33%);" src="assets/Readme/wham_gen_7.gif">
28
+ <img style="width: calc(33.33%);" src="assets/Readme/wham_gen_8.gif">
29
+ <img style="width: calc(33.33%);" src="assets/Readme/wham_gen_9.gif">
30
+ </div>
31
+ </div><br>
32
+ <div align='center'>
33
+ WHAM is capable of generating consistent, diverse, and persistent outputs, enabling various use cases for creative iteration.
34
+ <img style="width: 100%;" src="assets/Readme/model_capabilities.gif">
35
+ </div>
36
+
37
+ Muse is powered by a World and Human Action Model (WHAM), which is a generative model of gameplay (visuals and/or controller actions) trained on gameplay data of Ninja Theory’s Xbox game Bleeding Edge. Model development was informed by requirements of game creatives that we identified through a user study. Our goal is to explore the capabilities that generative AI models need to support human creative exploration. WHAM is developed by the [Game Intelligence group](https://www.microsoft.com/en-us/research/group/game-intelligence/) at [Microsoft Research](https://www.microsoft.com/en-us/research/), in collaboration with [TaiX](https://www.microsoft.com/en-us/research/project/taix/) and [Ninja Theory](https://ninjatheory.com/).
38
+
39
+ # Model Card
40
+
41
+ WHAM is an autoregressive model that has been trained to predict (tokenized) game visuals and controller actions given a prompt. Prompts here can be either visual (one or more initial game visuals) and / or controller actions. This allows the user to run the model in (a) world modelling mode (generate visuals given controller actions), (b) behavior policy (generate controller actions given past visuals), or (c) generate both visuals and behavior.
42
+
43
+ WHAM consists of two components, an encoder-decoder [VQ-GAN](https://compvis.github.io/taming-transformers/) trained to encode game visuals to a discrete representation, and a transformer backbone trained to perform next-token prediction. We train both components from scratch. The resulting model can generate consistent game sequences, and shows evidence of capturing the 3D structure of the game environment, the effects of controller actions, and the temporal structure of the game (up to the model’s context length).
44
+
45
+ WHAM was trained on human gameplay data to predict game visuals and players’ controller actions. We worked with the game studio Ninja Theory and their game [Bleeding Edge](https://www.bleedingedge.com/) – a 3D, 4v4 multiplayer video game. From the resulting data we extracted one year’s worth of anonymized gameplay from 27,990 players, capturing a wide range of behaviors and interactions. A sample of this data is provided [here](https://huggingface.co/datasets/microsoft/bleeding-edge-gameplay-sample)
46
+
47
+ ## Model Details
48
+
49
+ ### Trained Models
50
+
51
+ In this release we provide the weights of two WHAM instances: 200M WHAM and 1.6B WHAM. Both have been trained from scratch on the same data set. 1.6B WHAM is evaluated in [our paper](https://www.nature.com/articles/s41586-025-08600-3). We additionally provide 200M WHAM as a more lightweight option for faster explorations.
52
+ - [WHAM with 200M parameters](models/WHAM_200M.ckpt), model size: 3.7GB
53
+ - [WHAM with 1.6B parameters](models/WHAM_1.6B_v1.ckpt), model size: 18.9GB
54
+
55
+ ## Usage
56
+
57
+ ### System Requirements
58
+
59
+ The steps below have been tested on the following setup:
60
+ - Linux workstation with Ubuntu 20.04.4 LTS
61
+ - Windows 11 workstation running WSL2 with Ubuntu 20.04.6 LTS
62
+
63
+ The current setup assumes that a CUDA-supported GPU is available for model inference. This has been tested on systems with `NVIDIA RTX A6000` and `NVIDIA A100` respectively. In addition, approximately `15GB` of free hard disk space is required for dowmloading the models.
64
+
65
+ The steps under Installation assume a python 3.9 installation that can be
66
+ called using the command `python3.9` and the venv package for creating virtual environments. If either of these is not present, you can install this version of python under Ubuntu using:
67
+
68
+ ```bash
69
+ sudo apt install python3.9
70
+ sudo apt install python3.9-venv
71
+ ```
72
+
73
+ If you are using the WHAM Demonstrator, please ensure that you have the required [.NET Core Runtime](https://dotnet.microsoft.com/en-us/download/dotnet/7.0). If this is not yet installed, an error message will pop up from which you can follow a link to download and install this package.
74
+
75
+ ### Installation
76
+
77
+ 1. Clone this repository. We recommend starting without the large model files, using `GIT_LFS_SKIP_SMUDGE=1 git clone git@hf.co:microsoft/WHAM`
78
+ 2. `cd WHAM`
79
+ 3. `./setup_local.sh`
80
+
81
+ This will set up a `python3.9` virtual environment and install the required packages (this includes packages required for the model server). The typical install time should be approximately 5 minutes.
82
+
83
+ 4. Run `source venv/bin/activate` whenever you want to run model inference or the model server
84
+
85
+ 5. Download model from this HuggingFace repository (See note below):
86
+ 1. Go to Files and versions and navigate to the `models` folder.
87
+ 2. Download the model checkpoint. The instructions below assume that the model checkpoints have been downloaded to your local `models` folder.
88
+
89
+ **Note:** On Linux systems, you can use `git clone` to clone the enire repository, including large files. Due to a limitation of `git lfs` on Windows, only files up to `4GB` are supported and we recommend downloading the model files manually from the `models` folder.
90
+
91
+
92
+ ### Local Model Inference
93
+
94
+ This section assumes that you have followed the installation steps above.
95
+
96
+ (Optional) Download [sample data](https://huggingface.co/datasets/microsoft/bleeding-edge-gameplay-sample). For the local inference examples below, we recommend that you start with the `tiny-sample` set of only 4 trajectories for your initial exploration.
97
+
98
+ You can now run model inference to generate gameplay sequences as follows:
99
+
100
+ ```python
101
+ python run_dreaming.py --model_path <path_to_checkpoint.ckpt> --data_path <path_to_sample_data_folder>
102
+ ```
103
+
104
+ To run the 200M parameter (small) model (if you copied the tiny-sample folder to the root directory):
105
+
106
+ ```bash
107
+ python run_dreaming.py --model_path models/WHAM_200M.ckpt --data_path tiny-sample
108
+ ```
109
+
110
+ This uses the data in `data_path` as initial prompt sequences. The script will create a `dreaming_output` directory which will create two files per ground truth data file:
111
+ - An `.npz` file that contains a number of entries, most important of which are:
112
+ - `encoded_decoded_ground_truth_images`: the original context images, encoded and decoded with the VQGAN.
113
+ - `dreamt_images`: the sequence of all dreamt images.
114
+ - An `.mp4` file of the context data + dreamt images for easier viewing.
115
+
116
+ This requires approximately 4.5GB of VRAM on a single A6000, but only uses batch size of one. To speed up the process, increase batch size with `--batch_size` argument. With a single A6000 and `--batch_size 12` this uses approximately 30GB of VRAM. Generating gameplay sequences from the full 512 video dataset takes around 24 hours.
117
+
118
+ Please note that the first output from the script is generated when the first gameplay sequence has been generated. This may take several minutes when using an `A6000` GPU, or longer for older generation GPUs.
119
+
120
+ See python `run_dreaming.py --help` for different settings.
121
+
122
+ ### WHAM Demonstrator
123
+
124
+ #### Setting up the Model Server
125
+
126
+ We have tested the server code as provided on a single Linux machine with four `A6000 GPUs` (large model) as well as on a Windows machine running Ubuntu under `WSL2`, equipped with a single `GeForce GTX 1080` (small model). Model inferences can be run on lower spec NVIDIA GPUs by reducing the batchsize.
127
+
128
+ The steps below assume that the installation steps above have been followed and that the model files have been downloaded to your local machine.
129
+
130
+ In your terminal, activate the newly installed virtual environment (if it isn't already):
131
+
132
+ ```bash
133
+ source venv/bin/activate
134
+ ```
135
+
136
+ Start the server, pointing it to the model:
137
+
138
+ ```bash
139
+ python run_server.py --model <path_to_model_file>
140
+ ```
141
+
142
+ To run the 200M parameter (small) model:
143
+
144
+ ```bash
145
+ python run_server.py --model models/WHAM_200M.ckpt
146
+ ```
147
+
148
+ To run the 1.6B parameter (large) model:
149
+
150
+ ```bash
151
+ python run_server.py --model models/WHAM_1.6B_v1.ckpt
152
+ ```
153
+
154
+
155
+ The server will start and by default listen on localhost port 5000 (this can be configured with `--port <port>`).
156
+
157
+ **Note:** If you run out of VRAM when running the server, you can reduce the `MAX_BATCH_SIZE` variable in `run_server.py`.
158
+
159
+
160
+ #### Install the WHAM Demonstrator App (Windows only)
161
+
162
+ After cloning or downloading this repository, navigate to the folder `wham/wham_demonstrator`, and start the Windows application `WHAMDemonstrator.exe` within that folder.
163
+
164
+ Follow the instructions in the provided README.md within WHAM Demonstrator to connect to your model server and get an overview of supported functionality.
165
+
166
+
167
+ ## Intended Uses
168
+
169
+ This model and accompanying code are intended for academic research purposes only. WHAM has been trained on gameplay data from a single game, Bleeding Edge, and is intended to be used to generate plausible gameplay sequences resembling this game.
170
+
171
+ The model is not intended to be used to generate imagery outside of the game Bleeding Edge. Generated images include watermark and provenance metadata. Do not remove the watermark or provenance metadata..
172
+
173
+ WHAM can be used in multiple scenarios. The following list illustrates the types of tasks that WHAM can be used for:
174
+ - World Model: Visuals are predicted, given a real starting state and action sequence.
175
+ - Behaviour Policy: Given visuals, the model predicts the next controller action.
176
+ - Full Generation: The model generates both the visuals and the controller actions a human player might take in the game.
177
+
178
+ ## Training
179
+
180
+ ### Model
181
+
182
+ - Architecture: A decoder-only transformer that predicts the next token corresponding to an interleaved sequence of observations and actions. The image tokenizer is a VQ-GAN.
183
+ - Context length: 10 (observation, action) pairs / 5560 tokens
184
+ - Dataset size: The model was trained on data from approximately `500,000` Bleeding Edge games from all seven game maps (over 1 billion observation, action pairs 10Hz, equivalent to over 7 years of continuous human gameplay). A data sample is provided in [bleeding-edge-gameplay-sample](https://huggingface.co/datasets/microsoft/bleeding-edge-gameplay-sample). This is the test data used for our evaluation results, and has the same format as the training data.
185
+ - GPUs: 98xH100 GPUs
186
+ - Training time: 5 days
187
+
188
+ ### Software
189
+
190
+ - [PyTorch Lightning](https://github.com/pytorch/pytorch)
191
+ - [Flash-Attention](https://github.com/HazyResearch/flash-attention)
192
+ - [ffmpeg](https://github.com/FFmpeg/FFmpeg)
193
+ - [exiftool](https://github.com/exiftool/exiftool)
194
+
195
+ ## Bias, Risks and Limitations
196
+
197
+ - The training data represents gameplay recordings from a variety of skilled and unskilled gameplayers, representing diverse demographic characteristics. Not all possible player characteristics are represented and model performance may therefore vary.
198
+ - The model, as it is, can only be used to generate visuals and controller inputs. Users should not manipulate images and attempt to generate offensive scenes.
199
+
200
+ ### Technical limitations, operational factors, and ranges
201
+
202
+ Model:
203
+ - Trained on a single game, very specialized, not intended for image prompts that are out of context or from other domains
204
+ - Limited context length (10s)
205
+ - Limited image resolution (300px x 180px), the model can only generate images at this fixed resolution.
206
+ - Generated images and controls can incorrect or unrecognizable.
207
+ - Inference time is currently too slow for real-time use.
208
+
209
+ WHAM Demonstrator:
210
+ - Developed as a way to explore potential interactions. This is not intended as a fully-fledged user experience or demo.
211
+
212
+ Models trained using game data may potentially behave in ways that are unfair, unreliable, or offensive, in turn causing harms. We emphasize that these types of harms are not mutually exclusive. A single model can exhibit more than one type of harm, potentially relating to multiple different groups of people. For example, the output of the model can be nonsensical or might look reasonable but is inaccurate with respect to external validation sources.
213
+ Although users can input any image as a starting point, the model is only trained to generate images and controller actions based on the structure of the Bleeding Edge game environment that it has learned from the training data. Out of domain inputs lead to unpredictable results. For example, this could include a sequence of images that dissolve into unrecognizable blobs .
214
+ Model generations when “out of scope” image elements are introduced will either:
215
+ - Dissolve into unrecognizable blobs of color.
216
+ - Morphed into game-relevant items such as game characters.
217
+
218
+ ## Evaluating WHAM
219
+ WHAM is evaluated based on its consistency, diversity, and persistency. Consistency is measured using Fréchet Video Distance (FVD), while diversity is assessed by comparing the marginal distribution of real human actions to those generated by the model using the Wasserstein distance. Persistency is tested using two scenarios: by adding a static power-up object to a game visual and by adding another player character to a game visual used for prompting the model. For detailed evaluation results, see the paper that [introduces the model](https://www.nature.com/articles/s41586-025-08600-3).
220
+
221
+ ### Responsible AI testing
222
+ WHAM has been tested with out of context prompt images to evaluate the risk of outputting harmful or nonsensical images. The generated image sequences did not retain the initial image, but rather dissolved into either unrecognizable blobs or to scenes resembling the training environment.
223
+
224
+
225
+ ## License
226
+
227
+ The model is licensed under the [Microsoft Research License](LICENSE.md)
228
+
229
+ this work has been funded by Microsoft Research
230
+
231
+ ## Privacy & Ethics Statement
232
+
233
+ [Microsoft Privacy Statement](https://go.microsoft.com/fwlink/?LinkId=521839)
234
+
235
+ ## Trademark Notice
236
+
237
+ **Trademarks** This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft’s Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies.
238
+
239
+ ## Contact Information
240
+ For questions please email to muse@microsoft.com
241
+
242
+ ## Data Summary
243
+ https://huggingface.co/microsoft/wham/blob/main/data_summary_card.md
SECURITY.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Security
2
+
3
+ Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
4
+
5
+ If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
6
+
7
+ ## Reporting Security Issues
8
+
9
+ **Please do not report security vulnerabilities through public GitHub issues.**
10
+
11
+ Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
12
+
13
+ If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
14
+
15
+ You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
16
+
17
+ Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
18
+
19
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
20
+ * Full paths of source file(s) related to the manifestation of the issue
21
+ * The location of the affected source code (tag/branch/commit or direct URL)
22
+ * Any special configuration required to reproduce the issue
23
+ * Step-by-step instructions to reproduce the issue
24
+ * Proof-of-concept or exploit code (if possible)
25
+ * Impact of the issue, including how an attacker might exploit the issue
26
+
27
+ This information will help us triage your report more quickly.
28
+
29
+ If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
30
+
31
+ ## Preferred Languages
32
+
33
+ We prefer all communications to be in English.
34
+
35
+ ## Policy
36
+
37
+ Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
WHAM_Demonstrator.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d19ef23d22081044202e464e48dd4f5f6f232215a4cc797ab1a75dd3eb0e0d9
3
+ size 3565673
assets/Demonstrator/Fig_01.png ADDED

Git LFS Details

  • SHA256: fa511ba7a6259216efce2673a9155cfca299341f30b7a5cfd7f697fd814612d0
  • Pointer size: 129 Bytes
  • Size of remote file: 5.82 kB
assets/Demonstrator/Fig_02.png ADDED

Git LFS Details

  • SHA256: 1befa4db6c4b917a2fce7078145183e0572cdb2f3d0a2a1e3582dc90122ef5ef
  • Pointer size: 129 Bytes
  • Size of remote file: 8.62 kB
assets/Demonstrator/Fig_03.png ADDED

Git LFS Details

  • SHA256: 7daa2e65c3e3711dbbfc2699d1499dd0ae953b4010a08713ccabe39c7b5a2ed2
  • Pointer size: 129 Bytes
  • Size of remote file: 6.14 kB
assets/Demonstrator/Fig_04.png ADDED

Git LFS Details

  • SHA256: c2fdab8b17d4586fd745320af16bd586ba33f9847cb1154b3be9c8fbf10d5972
  • Pointer size: 131 Bytes
  • Size of remote file: 352 kB
assets/Demonstrator/Fig_05.png ADDED

Git LFS Details

  • SHA256: 6a2fde380629bffbc7120b82dae0d9da5fcbd278280b46b908169cf86a9a9fd1
  • Pointer size: 129 Bytes
  • Size of remote file: 3.53 kB
assets/Demonstrator/Fig_06.png ADDED

Git LFS Details

  • SHA256: 214d61764e933f154e6b31c644c88801e24a515ece32e12217b30660b34d256f
  • Pointer size: 131 Bytes
  • Size of remote file: 756 kB
assets/Demonstrator/Fig_07.png ADDED

Git LFS Details

  • SHA256: 5b563b1d5a5699a3c9985f80120fc242bfbaed30f028627b46daa50b1fad8779
  • Pointer size: 129 Bytes
  • Size of remote file: 2.43 kB
assets/Demonstrator/Fig_08.png ADDED

Git LFS Details

  • SHA256: 3c6533f8918807a3b3bbec20422921ecc70206531700903ab7f96ac1bce87198
  • Pointer size: 131 Bytes
  • Size of remote file: 603 kB
assets/Demonstrator/Fig_09.png ADDED

Git LFS Details

  • SHA256: 33ddf89573d4022e7ec6e4cb26be61b7b0f22d1a05a293dfed610e3f19940362
  • Pointer size: 131 Bytes
  • Size of remote file: 314 kB
assets/Demonstrator/Fig_10.png ADDED

Git LFS Details

  • SHA256: 27f25f50ea8b2b2902a67af7f96154e814db4771f3df49dc6dbbe651b0fc0c03
  • Pointer size: 131 Bytes
  • Size of remote file: 218 kB
assets/Demonstrator/Fig_11.png ADDED

Git LFS Details

  • SHA256: 579c6c17b655ff64e8aa8b8586484e082c082c4e7b7ee78cc1c02436483ee396
  • Pointer size: 131 Bytes
  • Size of remote file: 897 kB
assets/Demonstrator/Fig_12.png ADDED

Git LFS Details

  • SHA256: 0724bac7b32b7fb1f14cc830a53c0b74866fca1785174a75ffec893bdfa1714e
  • Pointer size: 130 Bytes
  • Size of remote file: 29.6 kB
assets/Demonstrator/Fig_13.png ADDED

Git LFS Details

  • SHA256: e4ab33f7f26c40e6670f6698c65bc1a8e3c0d839504e96cb05eaade12228b454
  • Pointer size: 131 Bytes
  • Size of remote file: 550 kB
assets/Demonstrator/Fig_14.png ADDED

Git LFS Details

  • SHA256: c9f3f79b94a1f0304cdfd2a22557a49b53720a7f754805e546cc88408dc2b8d2
  • Pointer size: 131 Bytes
  • Size of remote file: 202 kB
assets/Demonstrator/Fig_15.png ADDED

Git LFS Details

  • SHA256: 64b8ca32591c2e61320f0b9c278fd2ff7c103249e41a039a7bcad756aa0dbb75
  • Pointer size: 131 Bytes
  • Size of remote file: 483 kB
assets/Demonstrator/Fig_16.png ADDED

Git LFS Details

  • SHA256: 9ea8452b22bcbe5beeaf85d57bda2506126cfa99f386b60bdd111549de92020b
  • Pointer size: 130 Bytes
  • Size of remote file: 39.5 kB
assets/Demonstrator/Fig_17.png ADDED

Git LFS Details

  • SHA256: b38646958cb1f0f449a7e817817e9586aaeb09c026aa6e5f377b0e364c5b8050
  • Pointer size: 130 Bytes
  • Size of remote file: 12.5 kB
assets/Readme/model_capabilities.gif ADDED

Git LFS Details

  • SHA256: 87cf1460b2779a1c85b70e2229a7e1e256c501a5e3db26ea74e445b9dc75e965
  • Pointer size: 132 Bytes
  • Size of remote file: 8.63 MB
assets/Readme/wham_gen_1.gif ADDED

Git LFS Details

  • SHA256: 96558d0ad8084eafaf60ee360f13fe8decfbc5ac737b0c2788c01310e81750d1
  • Pointer size: 132 Bytes
  • Size of remote file: 4.42 MB
assets/Readme/wham_gen_2.gif ADDED

Git LFS Details

  • SHA256: 1296bb4ccdac5c7d3a1e7e9adfc48a6ec255933ff252a31d4e45cd117a28aee7
  • Pointer size: 132 Bytes
  • Size of remote file: 4.15 MB
assets/Readme/wham_gen_3.gif ADDED

Git LFS Details

  • SHA256: cb8ea8b3d6c8ec737a9b03f4cd93aeb36ddddc33695849b9b83543a8c2242b6f
  • Pointer size: 132 Bytes
  • Size of remote file: 4.27 MB
assets/Readme/wham_gen_4.gif ADDED

Git LFS Details

  • SHA256: 45e895599dddae5e6d2eb31f66957726fb82662f41b149f4de206466083f5a42
  • Pointer size: 132 Bytes
  • Size of remote file: 4.3 MB
assets/Readme/wham_gen_5.gif ADDED

Git LFS Details

  • SHA256: e7e7675c737bf5cbdfb54dfcc568eeda4c4212dbe5726741205610ab29cfcabb
  • Pointer size: 132 Bytes
  • Size of remote file: 4.24 MB
assets/Readme/wham_gen_6.gif ADDED

Git LFS Details

  • SHA256: e536b1f88a92de4e116a6acd022987778f63ed5a841517758c14a0d7f2a3c2bd
  • Pointer size: 132 Bytes
  • Size of remote file: 4.09 MB
assets/Readme/wham_gen_7.gif ADDED

Git LFS Details

  • SHA256: eb7e6c63eb8c46fc8c824d93406550082b6532ea9473cd021bae72a7d6cbe7db
  • Pointer size: 132 Bytes
  • Size of remote file: 4.13 MB
assets/Readme/wham_gen_8.gif ADDED

Git LFS Details

  • SHA256: 366f3f92310f3cfa55c9f4da719b01c8399c42f7d7bb860c5f7153568e4991d5
  • Pointer size: 132 Bytes
  • Size of remote file: 3.98 MB
assets/Readme/wham_gen_9.gif ADDED

Git LFS Details

  • SHA256: 931713a1d9a9dbdef7b4a1821ef78d490282bf8475e65b39948f8b5f42dc9982
  • Pointer size: 132 Bytes
  • Size of remote file: 4.53 MB
configs/metadata_custom_tag.config ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ %Image::ExifTool::UserDefined = (
2
+ 'Image::ExifTool::XMP::xmp' => {
3
+ 'ProgramName' => { Name => 'ProgramName', Writable => 'string' }
4
+ }
5
+ );
data_summary_card.md ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # Data Summary for microsoft_wham
4
+
5
+
6
+
7
+
8
+
9
+ ## 1. General information
10
+
11
+ **1.0.1 Version of the Summary:** 1.0
12
+
13
+
14
+
15
+ **1.0.2 Last update:** 16-Dec-2025
16
+
17
+
18
+
19
+ ## 1.1 Model Developer Identification
20
+
21
+ **1.1.1 Model Developer name and contact details:** Microsoft Corporation at One Microsoft Way, Redmond, WA 98052. Tel: 425-882-8080
22
+
23
+
24
+
25
+ ## 1.2 Model Identification
26
+
27
+ **1.2.1 Versioned model name(s):** WHAM 1.6B, WHAM 200M
28
+
29
+
30
+
31
+ **1.2.2 Model release date:** 19-Feb-2025
32
+
33
+
34
+
35
+ ## 1.3 Overall training data size and characteristics
36
+
37
+ ### 1.3.1 Size of dataset and characteristics
38
+
39
+ **1.3.1.A Text training data size:** Not applicable. Text data is not part of the training data.
40
+
41
+
42
+
43
+ **1.3.1.B Text training data content:** Not applicable.
44
+
45
+
46
+
47
+ **1.3.1.C Image training data size:** More than 1 billion images
48
+
49
+
50
+
51
+ **1.3.1.D Image training data content:** Sequences of rendered gameplay frames, constructed by replaying telemetry data to render visuals. Visuals were captured at 60 fps and downsampled to 10 Hz. They depict the 3D game environment, characters within the environment, UI elements, and dynamic in-game interactions across seven maps.
52
+
53
+
54
+
55
+ **1.3.1.E Audio training data size:** Not applicable. Audio data is not part of the training data.
56
+
57
+
58
+
59
+ **1.3.1.F Audio training data content:** Not applicable.
60
+
61
+
62
+
63
+ **1.3.1.G Video training data size:** Not applicable. Videos are not part of the training data.
64
+
65
+
66
+
67
+ **1.3.1.H Video training data content:** Not applicable
68
+
69
+ **1.3.1.I Other training data size:** Controller action sequences paired with each frame, totaling approximately 1.4 billion action entries after downsampling to 10 Hz.
70
+
71
+
72
+
73
+ **1.3.1.J Other training data content:** Controller input logs (telemetry) for gameplay including button states and discretized joystick positions aligned to each video frame to represent player behavior.
74
+
75
+
76
+
77
+ **1.3.2 Latest date of data acquisition/collection for model training:** 31-Oct-2022
78
+
79
+
80
+
81
+ **1.3.3 Is data collection ongoing to update the model with new data collection after deployment?** No
82
+
83
+
84
+
85
+ **1.3.4 Date the training dataset was first used to train the model:** 12-Jan-2022
86
+
87
+
88
+
89
+ **1.3.5 Rationale or purpose of data selection:** The dataset was constructed using gameplay telemetry. This telemetry data was used to render and export visuals (frames) and meta-data, including events corresponding to controller actions. The resulting data was used to learn consistent 3D world dynamics, effects of controller actions, and temporal structure directly from the telemetry. The dataset provides rich, temporally correlated multimodal signals (frames and controller actions) aligning with the model’s goal to support creative ideation of dynamic 3D environments via consistency, diversity, and persistency.
90
+
91
+
92
+
93
+ ## 2. List of data sources
94
+
95
+ ### 2.1 Publicly available datasets
96
+
97
+ **2.1.1 Have you used publicly available datasets to train the model?** No
98
+
99
+
100
+
101
+ ## 2.2 Private non-publicly available datasets obtained from third parties
102
+
103
+ ### 2.2.1 Datasets commercially licensed by rights holders or their representatives
104
+
105
+ **2.2.1.A Have you concluded transactional commercial licensing agreement(s) with rights holder(s) or with their representatives?** Not applicable
106
+
107
+
108
+
109
+ ### 2.2.2 Private datasets obtained from other third-parties
110
+
111
+ **2.2.2.A Have you obtained private datasets from third parties that are not licensed as described in Section 2.2.1, such as data obtained from providers of private databases, or data intermediaries?** No
112
+
113
+
114
+
115
+ ## 2.3 Personal Information
116
+
117
+ **2.3.1 Was personal data used to train the model?** Microsoft follows all relevant laws and regulations pertaining to personal information
118
+
119
+
120
+
121
+ ## 2.4 Synthetic data
122
+
123
+ **2.4.1 Was any synthetic AI-generated data used to train the model?** No
124
+
125
+
126
+
127
+ ## 3. Data processing aspects
128
+
129
+ ### 3.1 Respect of reservation of rights from text and data mining exception or limitation
130
+
131
+ **3.1.1 Does this dataset include any data protected by copyright, trademark, or patent?** Microsoft follows all required regulations and laws for processing data protected by copyright, trademark, or patent
132
+
133
+
134
+
135
+ ## 3.2 Other information
136
+
137
+ **3.2.1 Does the dataset include information about consumer groups without revealing individual consumer identities?** Microsoft follows all required regulations and laws for protecting consumer identities
138
+
139
+
140
+
141
+ **3.2.2 Was the dataset cleaned or modified before model training?** Yes
142
+
143
+
144
+
145
+
models/WHAM_1.6B_v1.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c4997074883aa1a39a5994a7dea91fb62b2382fc039523458827adb777af8e9
3
+ size 20339650059
models/WHAM_200M.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ddb8e03a33f0849a63da030fea3de4994d95e16888993b8ab92faa904f3b31f
3
+ size 3980245067
models/config.json ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ aiohttp==3.9.3
3
+ aiosignal==1.3.1
4
+ async-timeout==4.0.3
5
+ attrs==23.2.0
6
+ blinker==1.7.0
7
+ certifi==2024.2.2
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ cloudpickle==3.0.0
11
+ cmake==3.28.3
12
+ einops==0.6.0
13
+ ffmpegcv==0.3.10
14
+ filelock==3.13.1
15
+ Flask==3.0.2
16
+ frozenlist==1.4.1
17
+ fsspec==2024.2.0
18
+ idna==3.6
19
+ importlib_metadata==7.0.2
20
+ itsdangerous==2.1.2
21
+ Jinja2==3.1.3
22
+ lightning-utilities==0.10.1
23
+ lit==17.0.6
24
+ MarkupSafe==2.1.5
25
+ mpmath==1.3.0
26
+ multidict==6.0.5
27
+ networkx==3.2.1
28
+ numpy==1.25.2
29
+ opencv-python==4.6.0.66
30
+ opencv-python-headless==4.9.0.80
31
+ packaging==23.2
32
+ pillow==10.2.0
33
+ pytorch-lightning==1.9.4
34
+ PyYAML==6.0.1
35
+ requests==2.31.0
36
+ sympy==1.12
37
+ tensordict==0.1.2
38
+ torch==2.0.1+cu118
39
+ torchinfo==1.7.1
40
+ torchmetrics==0.11.4
41
+ torchvision==0.15.2+cu118
42
+ tqdm==4.66.2
43
+ triton==2.0.0
44
+ typing_extensions==4.10.0
45
+ urllib3==2.2.1
46
+ Werkzeug==3.0.1
47
+ yarl==1.9.4
48
+ zipp==3.17.0
run_dreaming.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example script for running dreaming on a dataset.
3
+ The idea is that there are ground_truth ("reference") video clips, and we dream the same clips given some initial context.
4
+
5
+ After dreaming, we have two sets of videos which, barring the intrinsic noise of the game environment (e.g., randomness of other players),
6
+ should be identical if model was ideal.
7
+ """
8
+
9
+ import argparse
10
+ from pathlib import Path
11
+ import os
12
+ import subprocess
13
+
14
+ import cv2
15
+ from tensordict import TensorDict
16
+ import torch as th
17
+ from tqdm import tqdm
18
+ import numpy as np
19
+ import ffmpegcv
20
+ from PIL import Image
21
+
22
+ import wham.utils as utils
23
+
24
+
25
+ parser = argparse.ArgumentParser(description="Run dreaming.")
26
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the model checkpoint.")
27
+ parser.add_argument("--data_path", type=str, required=True, help="Path to the directory that contains the ground truth data to dream for.")
28
+ parser.add_argument("--output", type=str, default="dreaming_output", help="Path to the directory where output should be put.")
29
+ parser.add_argument("--max_files", type=int, default=None, help="Maximum number of files to process.")
30
+ parser.add_argument("--metadata_config", type=str, default="configs/metadata_custom_tag.config", help="Path to metadata tag config for origin field.")
31
+
32
+
33
+ parser.add_argument(
34
+ "--protocol",
35
+ type=str,
36
+ default="base",
37
+ choices=["base", "comprehensive"],
38
+ help="What protocol to use for the dreaming. base = action conditioned, comprehensive = dream actions as well.",
39
+ )
40
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size for dreaming. Higher batch_size uses more VRAM but overall is faster.")
41
+ parser.add_argument("--context_length", type=int, default=10, help="Number of frames to use an initial context.")
42
+ parser.add_argument("--steps_to_dream", type=int, default=10, help="Batch size for dreaming.")
43
+
44
+ parser.add_argument("--sampling_temperature", type=float, default=0.9, help="Temperature for sampling from the model.")
45
+ parser.add_argument("--sampling_top_k", type=int, default=None, help="Top-k for sampling from the model.")
46
+ parser.add_argument("--sampling_top_p", type=float, default=None, help="Top-p for sampling from the model.")
47
+
48
+
49
+ def get_context_data(image_context, action_context, action_sequences):
50
+ # Make sure we have CHW images:
51
+ assert image_context.shape[-3] == 3, "Image context should be CHW"
52
+
53
+ image_context = th.from_numpy(image_context).cuda()
54
+ action_data = th.from_numpy(action_context).float().cuda()
55
+ action_sequences = th.from_numpy(action_sequences).float().cuda() if action_sequences is not None else None
56
+
57
+ return TensorDict({"images": image_context, "actions_output": action_data}, batch_size=image_context.shape[:2])
58
+
59
+
60
+ def add_video_metadata(file_path, metadata_config):
61
+ # Construct the exiftool command
62
+ cmd = [
63
+ 'exiftool',
64
+ '-config', metadata_config,
65
+ f'-ProgramName=\"{utils.PROGRAM_NAME}\"',
66
+ '-overwrite_original',
67
+ file_path
68
+ ]
69
+
70
+ try:
71
+ # Execute the exiftool command
72
+ subprocess.run(cmd, check=True)
73
+ print(f"Metadata modified successfully.")
74
+ # Print the new file metadata
75
+ cmd_output = [
76
+ 'exiftool',
77
+ file_path
78
+ ]
79
+ subprocess.run(cmd_output, check=True)
80
+ except subprocess.CalledProcessError as e:
81
+ print(f"Error modifying metadata: {e}")
82
+
83
+
84
+ @th.no_grad()
85
+ def do_dreaming(model, image_context, action_context, args, action_sequences=None):
86
+ """
87
+ image_contect and action_context provide the initial context for the model to dream from.
88
+
89
+ If action_sequences (batch_size, args.steps_to_dream, action_dim) is provided, then model will be prompted with these actions.
90
+ """
91
+ context_data = get_context_data(image_context, action_context, action_sequences)
92
+ encoded_context_data = model.encode_context(context_data)
93
+
94
+ encoded_action_sequences = None
95
+ if action_sequences is not None:
96
+ assert action_sequences.shape[1] == args.steps_to_dream, "action_sequences should have shape (batch_size, args.steps_to_dream, action_dim)"
97
+ action_sequences = TensorDict({"actions_output": action_sequences}, batch_size=action_sequences.shape[:2]).cuda()
98
+ encoded_action_sequences = model.encode_context(action_sequences)
99
+
100
+ encoded_dreamt_steps = []
101
+
102
+ for dream_step in range(args.steps_to_dream):
103
+ encoded_predicted_step, _ = model.predictor.predict_next_step(
104
+ encoded_context_data, temperature=args.sampling_temperature, top_k=args.sampling_top_k, top_p=args.sampling_top_p, min_tokens_to_keep=1
105
+ )
106
+
107
+ # Remove first step from context if we are at the max context length:
108
+ if encoded_context_data.shape[1] == args.context_length:
109
+ encoded_context_data = encoded_context_data[:, 1:]
110
+
111
+ # Add predicted image + action to the context
112
+ append_step = encoded_predicted_step
113
+ if encoded_action_sequences is not None:
114
+ # Replace predicted action with real action
115
+ append_step["actions_output"] = encoded_action_sequences["actions_output"][:, [dream_step], :]
116
+ encoded_context_data = th.cat((encoded_context_data, append_step), dim=1)
117
+
118
+ encoded_dreamt_steps.append(encoded_predicted_step)
119
+
120
+ # Decode everything
121
+ dreamed_images = []
122
+ actions_during_dream = []
123
+ for seq_i in range(args.steps_to_dream):
124
+ decoded_step = model.decode_context(encoded_dreamt_steps[seq_i])
125
+ dreamed_images.append(decoded_step["images"][:, [0]].cpu().numpy())
126
+ actions_during_dream.append(decoded_step["actions_output"][:, [0]].cpu().numpy())
127
+
128
+ dreamed_images = np.concatenate(dreamed_images, axis=1)
129
+ actions_during_dream = np.concatenate(actions_during_dream, axis=1)
130
+
131
+ return dreamed_images, actions_during_dream
132
+
133
+
134
+ @th.no_grad()
135
+ def encode_decode_images(model, images):
136
+ """
137
+ Pass ground_truth images through the encoding/decoding process of the model.
138
+ """
139
+ context = TensorDict({"images": th.from_numpy(images).cuda()}, batch_size=images.shape[:2])
140
+ output_images = []
141
+ for seq_i in range(images.shape[1]):
142
+ encoded_images = model.encode_context(context[:, [seq_i]])
143
+ decoded_images = model.decode_context(encoded_images)
144
+ output_images.append(decoded_images["images"].cpu().numpy())
145
+ return np.concatenate(output_images, axis=1)
146
+
147
+
148
+ def main(args):
149
+ total_video_length = args.context_length + args.steps_to_dream
150
+
151
+ # Now, load the model:
152
+ model_path = Path(args.model_path)
153
+ assert model_path.is_file(), "Could not find the model!"
154
+ model = utils.load_model_from_checkpoint(model_path).cuda()
155
+
156
+ # Glob the dataset to find all the ground truth segments we want to construct a dream for:
157
+ data_path = Path(args.data_path)
158
+ ground_truth_files = list(data_path.rglob("*.npz"))
159
+ num_dreams = len(ground_truth_files)
160
+
161
+ if args.max_files is not None:
162
+ # Sort to make sure we always get the same files
163
+ ground_truth_files = sorted(ground_truth_files)
164
+ ground_truth_files = ground_truth_files[: args.max_files]
165
+ num_dreams = len(ground_truth_files)
166
+
167
+ output_path = Path(args.output)
168
+ os.makedirs(output_path, exist_ok=True)
169
+
170
+ print("=" * 100)
171
+ print(f"GENERATING DREAMS OF {num_dreams} SEGMENTS")
172
+ print(f"WRITING TO {args.output}")
173
+ print("=" * 100)
174
+
175
+ dreams_created = 0
176
+ with tqdm(total=num_dreams, desc="Dreams") as pbar:
177
+ while ground_truth_files:
178
+ # Load batch_size headers:
179
+ batches = min(args.batch_size, len(ground_truth_files))
180
+ batched_image_context = []
181
+ batched_image_sequence = []
182
+ batched_action_context = []
183
+ batched_action_sequence = []
184
+ episode_names = []
185
+ for i in range(batches):
186
+ episode = ground_truth_files.pop()
187
+ episode_names.append(episode)
188
+ try:
189
+ data = np.load(episode)
190
+ images = data["images"]
191
+ actions = data["actions"]
192
+ except Exception:
193
+ print(f"Failed to load episode {episode} - skipping.")
194
+ continue
195
+
196
+ if actions.shape[0] < total_video_length:
197
+ # We want to make sure we have ground_truth comparisons for the entire dream, so we ensure the episode is long enough
198
+ raise ValueError(f"Episode {episode} is too short to dream from. It has {actions.shape[0]} steps, but we need at least {total_video_length}.")
199
+ batched_image_context.append(images[: args.context_length])
200
+ batched_image_sequence.append(images[args.context_length: total_video_length])
201
+ batched_action_context.append(actions[: args.context_length])
202
+ batched_action_sequence.append(actions[args.context_length: total_video_length])
203
+
204
+ image_context = np.array(batched_image_context)
205
+ image_sequences = np.array(batched_image_sequence)
206
+ action_context = np.array(batched_action_context)
207
+ action_sequences = np.array(batched_action_sequence)
208
+
209
+ if args.protocol == "comprehensive":
210
+ # We do not need to pass in the action sequences for comprehensive protocol
211
+ action_sequences = None
212
+
213
+ full_image_sequence = np.concatenate((image_context, image_sequences), axis=1)
214
+
215
+ dreamt_images, actions_during_dream = do_dreaming(model, image_context, action_context, args, action_sequences=action_sequences)
216
+ encoded_decoded_images_batch = encode_decode_images(model, full_image_sequence)
217
+
218
+ pbar.update(batches)
219
+ dreams_created += batches
220
+
221
+ # Save the dreams:
222
+ # We are aiming to mimic the folder structure of the ground truth dataset, so use the episode names
223
+ # but make them relative to our output folder:
224
+ for i, dream in enumerate(dreamt_images):
225
+ episode = episode_names[i]
226
+ output_file = output_path / episode.relative_to(data_path)
227
+ output_file.parent.mkdir(parents=True, exist_ok=True)
228
+ np.savez(
229
+ output_file,
230
+ context_length=args.context_length,
231
+ steps_to_dream=args.steps_to_dream,
232
+ raw_context=image_context[i],
233
+ dreamt_images=dream,
234
+ all_actions=np.concatenate((action_context[i], actions_during_dream[i])),
235
+ encoded_decoded_ground_truth_images=encoded_decoded_images_batch[i],
236
+ )
237
+
238
+ video_file = str(output_file.with_suffix(".mp4"))
239
+ writer = ffmpegcv.VideoWriter(video_file, None, utils.DREAMING_FPS)
240
+ full_sequence = np.concatenate((image_context[i], dream), axis=0)
241
+ for frame in full_sequence:
242
+ img = frame.transpose(1, 2, 0).astype(np.uint8).copy()
243
+ # Please DO NOT remove this watermark. This will infringe upon the repo's license agreement
244
+ (text_width, _), _ = cv2.getTextSize(utils.WATERMARK_TEXT, utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_THICKNESS)
245
+ x = img.shape[1] - text_width - 10 # 10 pixels from the right edge
246
+ y = img.shape[0] - 10 # 10 pixels from the bottom edge
247
+ cv2.putText(img, utils.WATERMARK_TEXT, (x, y), utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_COLOR, utils.WATERMARK_FONT_THICKNESS)
248
+
249
+ # Add image metadata
250
+ pil_image = Image.fromarray(img)
251
+ pil_image.info['Id'] = 0x0131
252
+ pil_image.info['Type'] = 2
253
+ pil_image.info['Value'] = utils.PROGRAM_NAME.encode("utf-8")
254
+ pil_image.info['Len'] = len(utils.PROGRAM_NAME) + 1
255
+
256
+ # Convert pil_image to a CV2 format for the video writer
257
+ cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
258
+ writer.write(cv_image)
259
+ writer.release()
260
+ add_video_metadata(video_file, args.metadata_config)
261
+
262
+ if __name__ == "__main__":
263
+ args = parser.parse_args()
264
+ main(args)
run_server.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import dataclass, field
3
+ import json
4
+ import copy
5
+ import multiprocessing as mp
6
+ import uuid
7
+ from datetime import datetime, timedelta
8
+ from collections import defaultdict, deque
9
+ import io
10
+ import zipfile
11
+ import queue
12
+ import time
13
+ import random
14
+ import logging
15
+
16
+ from tensordict import TensorDict
17
+ import cv2
18
+ from flask import Flask, request, make_response, send_file
19
+ from PIL import Image
20
+ import torchvision.transforms as T
21
+ import numpy as np
22
+ import torch as th
23
+
24
+ from wham.utils import load_model_from_checkpoint, POS_BINS_BOUNDARIES, POS_BINS_MIDDLE
25
+
26
+ logging.basicConfig(level=logging.INFO)
27
+
28
+ parser = argparse.ArgumentParser(description="Simple Dreamer")
29
+ parser.add_argument("--model", type=str, required=True, help="Path to the model file for the local runs")
30
+ parser.add_argument("--debug", action="store_true", help="Enable flask debug mode.")
31
+ parser.add_argument("--random_model", action="store_true", help="Use randomly initialized model instead of the provided one")
32
+ parser.add_argument("--port", type=int, default=5000)
33
+
34
+ parser.add_argument("--max_concurrent_jobs", type=int, default=30, help="Maximum number of jobs that can be run concurrently on this server.")
35
+ parser.add_argument("--max_dream_steps_per_job", type=int, default=10, help="Maximum number of dream steps each job can request.")
36
+ parser.add_argument("--max_job_lifespan", type=int, default=60 * 10, help="Maximum number of seconds we keep run around if not polled.")
37
+
38
+ parser.add_argument("--image_width", type=int, default=300, help="Width of the image")
39
+ parser.add_argument("--image_height", type=int, default=180, help="Height of the image")
40
+
41
+ parser.add_argument("--max_batch_size", type=int, default=3, help="Maximum batch size for the dreamer workers")
42
+
43
+ PREDICTION_JSON_FILENAME = "predictions.json"
44
+ # Minimum time between times we check when to delete jobs. We do this when adding new jobs.
45
+ JOB_CLEANUP_CHECK_RATE = timedelta(seconds=10)
46
+
47
+ MAX_CANCELLED_ID_QUEUE_SIZE = 100
48
+
49
+ DEFAULT_SAMPLING_SETTINGS = {
50
+ "temperature": 0.9,
51
+ "top_k": None,
52
+ "top_p": 1.0,
53
+ "max_context_length": 10,
54
+ }
55
+
56
+
57
+ def float_or_none(string):
58
+ if string.lower() == "none":
59
+ return None
60
+ return float(string)
61
+
62
+
63
+ def be_image_preprocess(image, target_width, target_height):
64
+ # If target_width and target_height are specified, resize the image.
65
+ if target_width is not None and target_height is not None:
66
+ # Make sure we do not try to resize if the image is already the correct size.
67
+ if image.shape[1] != target_width or image.shape[0] != target_height:
68
+ image = cv2.resize(image, (target_width, target_height))
69
+ return np.transpose(image, (2, 0, 1))
70
+
71
+
72
+ def action_vector_to_be_action_vector(action):
73
+ # Preprocess a BE action vector from 16 numbers with:
74
+ # 12 buttons [0, 1] and 4 stick directions [-1, 1]
75
+ # to discrete actions valid for the token model
76
+ # 12 buttons [0, 1] and 4 stick directions {discrete bin}
77
+ action[-4:] = np.digitize(action[-4:], bins=POS_BINS_BOUNDARIES) - 1
78
+ return action
79
+
80
+
81
+ def be_action_vector_to_action_vector(action):
82
+ # Preprocess a BE action vector into unified space
83
+ for stick_index in range(-4, 0):
84
+ action[stick_index] = POS_BINS_MIDDLE[int(action[stick_index])]
85
+ return action
86
+
87
+
88
+
89
+ @dataclass
90
+ class DreamJob:
91
+ job_id: str
92
+ sampling_settings: dict
93
+ num_predictions_remaining: int
94
+ num_predictions_done: int
95
+ # (B, T, C, H, W)
96
+ context_images: th.Tensor
97
+ context_actions: th.Tensor
98
+ # Tokens that will replace the context_images if they are provided
99
+ context_tokens: list
100
+ # This will replace the dreamed action if provided.
101
+ # For every step, we remove the first action until exhausted
102
+ actions_to_take: th.Tensor = None
103
+
104
+
105
+ @dataclass
106
+ class DreamJobResult:
107
+ job_id: str
108
+ dream_step_index: int
109
+ # (B, 1, C, H, W)
110
+ dreamt_image: th.Tensor
111
+ dreamt_action: th.Tensor
112
+ dreamt_tokens: th.Tensor
113
+ result_creation_time: datetime = field(default_factory=datetime.now)
114
+
115
+
116
+
117
+ def setup_and_load_model_be_model(args):
118
+ model = load_model_from_checkpoint(args.model)
119
+ th.set_float32_matmul_precision("high")
120
+ th.backends.cuda.matmul.allow_tf32 = True
121
+ return model
122
+
123
+
124
+ def get_job_batchable_information(job):
125
+ """Return comparable object of job information. Used for batching"""
126
+ context_length = job.context_images.shape[1]
127
+ return (context_length, job.sampling_settings)
128
+
129
+
130
+ def fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size, timeout=1):
131
+ """Return a list of jobs (or empty list) that can be batched together"""
132
+ batchable_jobs = []
133
+ required_job_info = None
134
+ while len(batchable_jobs) < max_batch_size:
135
+ try:
136
+ job = job_queue.get(timeout=timeout)
137
+ except queue.Empty:
138
+ break
139
+ # If pipe breaks, also gracefully return
140
+ except OSError:
141
+ break
142
+ if job.job_id in cancelled_ids_set:
143
+ # This job was cancelled, so discard it completely
144
+ continue
145
+ job_info = get_job_batchable_information(job)
146
+ if required_job_info is None:
147
+ required_job_info = job_info
148
+ elif required_job_info != job_info:
149
+ # This job is not batchable, put it back
150
+ job_queue.put(job)
151
+ # we assume here that, generally, the others jobs would also be
152
+ # invalid. So we just return the batchable jobs we have instead
153
+ # of going through more.
154
+ break
155
+ batchable_jobs.append(job)
156
+ return batchable_jobs
157
+
158
+
159
+ def update_cancelled_jobs(cancelled_ids_queue, cancelled_ids_deque, cancelled_ids_set):
160
+ """IN-PLACE Update cancelled_ids_set with new ids from the queue"""
161
+ has_changed = False
162
+ while not cancelled_ids_queue.empty():
163
+ try:
164
+ cancelled_id = cancelled_ids_queue.get_nowait()
165
+ except queue.Empty:
166
+ break
167
+ cancelled_ids_deque.append(cancelled_id)
168
+ has_changed = True
169
+
170
+ if has_changed:
171
+ cancelled_ids_set.clear()
172
+ cancelled_ids_set.update(cancelled_ids_deque)
173
+
174
+
175
+ def predict_step(context_data, sampling_settings, model, tokens=None):
176
+ with th.no_grad():
177
+ predicted_step = model.predict_next_step(context_data, min_tokens_to_keep=1, tokens=tokens, **sampling_settings)
178
+ return predicted_step
179
+
180
+
181
+ def dreamer_worker(job_queue, result_queue, cancelled_jobs_queue, quit_flag, device_to_use, args):
182
+ logger = logging.getLogger(f"dreamer_worker {device_to_use}")
183
+ logger.info("Loading up model...")
184
+ model = setup_and_load_model_be_model(args)
185
+ model = model.to(device_to_use)
186
+ logger.info("Model loaded. Fetching results")
187
+
188
+ cancelled_ids_deque = deque(maxlen=MAX_CANCELLED_ID_QUEUE_SIZE)
189
+ cancelled_ids_set = set()
190
+
191
+ while not quit_flag.is_set():
192
+ update_cancelled_jobs(cancelled_jobs_queue, cancelled_ids_deque, cancelled_ids_set)
193
+ batchable_jobs = fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size=args.max_batch_size)
194
+ if len(batchable_jobs) == 0:
195
+ continue
196
+ sampling_settings = batchable_jobs[0].sampling_settings
197
+ # make better way for passing these arguments around. sampling_settings
198
+ # is passed as kwargs to predicting step, but max_context_length is not part of valid
199
+ # keys there, so we need to pop it out.
200
+ max_context_length = sampling_settings.pop("max_context_length")
201
+
202
+ images = [job.context_images[:, :max_context_length] for job in batchable_jobs]
203
+ actions = [job.context_actions[:, :max_context_length] for job in batchable_jobs]
204
+ tokens = [job.context_tokens for job in batchable_jobs]
205
+
206
+ images = th.concat(images, dim=0).to(device_to_use)
207
+ actions = th.concat(actions, dim=0).to(device_to_use)
208
+
209
+ context_data = TensorDict({
210
+ "images": images,
211
+ "actions_output": actions
212
+ }, batch_size=images.shape[:2])
213
+
214
+ predicted_step, predicted_image_tokens = predict_step(context_data, sampling_settings, model, tokens)
215
+
216
+ predicted_step = predicted_step.cpu()
217
+ predicted_images = predicted_step["images"]
218
+ predicted_actions = predicted_step["actions_output"]
219
+ predicted_image_tokens = predicted_image_tokens.cpu()
220
+
221
+ for job_i, job in enumerate(batchable_jobs):
222
+ image_context = job.context_images
223
+ action_context = job.context_actions
224
+ token_context = job.context_tokens
225
+ # Keep batch dimension
226
+ dreamt_image = predicted_images[job_i].unsqueeze(0)
227
+ dreamt_action = predicted_actions[job_i].unsqueeze(0)
228
+ dreamt_tokens = predicted_image_tokens[job_i].unsqueeze(0)
229
+
230
+ # Replace the dreamed action if provided
231
+ actions_to_take = job.actions_to_take
232
+ if actions_to_take is not None and actions_to_take.shape[1] > 0:
233
+ dreamt_action = actions_to_take[:, 0:1]
234
+ # Remove the action we took
235
+ actions_to_take = actions_to_take[:, 1:]
236
+ if actions_to_take.shape[1] == 0:
237
+ actions_to_take = None
238
+
239
+ result_queue.put(DreamJobResult(
240
+ job_id=job.job_id,
241
+ dream_step_index=job.num_predictions_done,
242
+ dreamt_image=dreamt_image,
243
+ dreamt_action=dreamt_action,
244
+ dreamt_tokens=dreamt_tokens
245
+ ))
246
+
247
+ # Add job back in the queue if we have more steps to do
248
+ if job.num_predictions_remaining > 0:
249
+ # Stack the dreamt image and action to the context
250
+ if image_context.shape[1] >= max_context_length:
251
+ image_context = image_context[:, 1:]
252
+ action_context = action_context[:, 1:]
253
+ token_context = token_context[1:]
254
+ image_context = th.cat([image_context, dreamt_image], dim=1)
255
+ action_context = th.cat([action_context, dreamt_action], dim=1)
256
+ token_context.append(dreamt_tokens[0, 0].tolist())
257
+ # We need to add context length back to sampling settings...
258
+ # add some better way of passing these settings around
259
+ job.sampling_settings["max_context_length"] = max_context_length
260
+ job_queue.put(DreamJob(
261
+ job_id=job.job_id,
262
+ sampling_settings=job.sampling_settings,
263
+ num_predictions_remaining=job.num_predictions_remaining - 1,
264
+ num_predictions_done=job.num_predictions_done + 1,
265
+ context_images=image_context,
266
+ context_actions=action_context,
267
+ context_tokens=token_context,
268
+ actions_to_take=actions_to_take
269
+ ))
270
+
271
+
272
+ class DreamerServer:
273
+ def __init__(self, num_workers, args):
274
+ self.num_workers = num_workers
275
+ self.args = args
276
+ self.model = None
277
+ self.jobs = mp.Queue(maxsize=args.max_concurrent_jobs)
278
+ self.results_queue = mp.Queue()
279
+ self.cancelled_jobs = set()
280
+ self.cancelled_jobs_queues = [mp.Queue() for _ in range(num_workers)]
281
+ # job_id -> results
282
+ self._last_result_cleanup = datetime.now()
283
+ self._max_job_lifespan_datetime = timedelta(seconds=args.max_job_lifespan)
284
+ self.local_results = defaultdict(list)
285
+ self.logger = logging.getLogger("DreamerServer")
286
+
287
+ def get_details(self):
288
+ details = {
289
+ "model_file": self.args.model,
290
+ "max_concurrent_jobs": self.args.max_concurrent_jobs,
291
+ "max_dream_steps_per_job": self.args.max_dream_steps_per_job,
292
+ "max_job_lifespan": self.args.max_job_lifespan,
293
+ }
294
+ return json.dumps(details)
295
+
296
+ def _check_if_should_remove_old_jobs(self):
297
+ time_now = datetime.now()
298
+ # Only cleanup every JOB_CLEANUP_CHECK_RATE seconds at most
299
+ if time_now - self._last_result_cleanup < JOB_CLEANUP_CHECK_RATE:
300
+ return
301
+
302
+ self._last_result_cleanup = time_now
303
+ # First add existing results to the local results
304
+ self._gather_new_results()
305
+ # Check if we should remove old jobs
306
+ job_ids = list(self.local_results.keys())
307
+ for job_id in job_ids:
308
+ results = self.local_results[job_id]
309
+ # If newest result is older than max_job_lifespan, remove the job
310
+ if time_now - results[-1].result_creation_time > self._max_job_lifespan_datetime:
311
+ self.logger.info(f"Deleted job {job_id} because it was too old. Last result was {results[-1].result_creation_time}")
312
+ del self.local_results[job_id]
313
+
314
+ def add_new_job(self, request, request_json):
315
+ """
316
+ Add new dreaming job to the queues.
317
+ Request should have:
318
+
319
+
320
+ Returns: json object with new job id
321
+ """
322
+ self._check_if_should_remove_old_jobs()
323
+
324
+ sampling_settings = copy.deepcopy(DEFAULT_SAMPLING_SETTINGS)
325
+ if "num_steps_to_predict" not in request_json:
326
+ return make_response("num_steps_to_predict not in request", 400)
327
+ num_steps_to_predict = request_json['num_steps_to_predict']
328
+ if num_steps_to_predict > self.args.max_dream_steps_per_job:
329
+ return make_response(f"num_steps_to_predict too large. Max {self.args.max_dream_steps_per_job}", 400)
330
+
331
+ num_parallel_predictions = int(request_json['num_parallel_predictions']) if 'num_parallel_predictions' in request_json else 1
332
+
333
+ if (self.jobs.qsize() + num_parallel_predictions) >= self.args.max_concurrent_jobs:
334
+ return make_response(f"Too many jobs already running. Max {self.args.max_concurrent_jobs}", 400)
335
+
336
+ for key in sampling_settings:
337
+ sampling_settings[key] = float_or_none(request_json[key]) if key in request_json else sampling_settings[key]
338
+
339
+ context_images = []
340
+ context_actions = []
341
+ context_tokens = []
342
+ future_actions = []
343
+
344
+ for step in request_json["steps"]:
345
+ image_path = step["image_name"]
346
+ image = np.array(Image.open(request.files[image_path].stream))
347
+ image = be_image_preprocess(image, target_width=self.args.image_width, target_height=self.args.image_height)
348
+ context_images.append(th.from_numpy(image))
349
+
350
+ action = step["action"]
351
+ action = action_vector_to_be_action_vector(action)
352
+ context_actions.append(th.tensor(action))
353
+
354
+ tokens = step["tokens"]
355
+ context_tokens.append(tokens)
356
+
357
+ future_actions = None
358
+ if "future_actions" in request_json:
359
+ future_actions = []
360
+ for step in request_json["future_actions"]:
361
+ # The rest is the action vector
362
+ action = step["action"]
363
+ action = action_vector_to_be_action_vector(action)
364
+ # Add sequence and batch dimensions
365
+ future_actions.append(th.tensor(action))
366
+
367
+ # Add batch dimensions
368
+ context_images = th.stack(context_images).unsqueeze(0)
369
+ context_actions = th.stack(context_actions).unsqueeze(0)
370
+ future_actions = th.stack(future_actions).unsqueeze(0) if future_actions is not None else None
371
+
372
+ list_of_job_ids = []
373
+ for _ in range(num_parallel_predictions):
374
+ job_id = uuid.uuid4().hex
375
+ self.jobs.put(DreamJob(
376
+ job_id=job_id,
377
+ sampling_settings=sampling_settings,
378
+ num_predictions_remaining=num_steps_to_predict,
379
+ num_predictions_done=0,
380
+ context_images=context_images,
381
+ context_actions=context_actions,
382
+ context_tokens=context_tokens,
383
+ actions_to_take=future_actions
384
+ ))
385
+ list_of_job_ids.append(job_id)
386
+
387
+ job_queue_size = self.jobs.qsize()
388
+ return json.dumps({"job_ids": list_of_job_ids, "current_jobs_in_queue": job_queue_size})
389
+
390
+ def _gather_new_results(self):
391
+ if not self.results_queue.empty():
392
+ for _ in range(self.results_queue.qsize()):
393
+ result = self.results_queue.get()
394
+ if result.job_id in self.cancelled_jobs:
395
+ # Discard result if job was cancelled
396
+ continue
397
+ self.local_results[result.job_id].append(result)
398
+
399
+ def get_new_results(self, request, request_json):
400
+ if "job_ids" not in request_json:
401
+ return make_response("job_ids not in request", 400)
402
+ self._gather_new_results()
403
+ job_ids = request_json["job_ids"]
404
+ if not isinstance(job_ids, list):
405
+ job_ids = [job_ids]
406
+ return_results = []
407
+ for job_id in job_ids:
408
+ if job_id in self.local_results:
409
+ return_results.append(self.local_results[job_id])
410
+ del self.local_results[job_id]
411
+
412
+ if len(return_results) == 0:
413
+ return make_response("No new responses", 204)
414
+
415
+ output_json = []
416
+ output_image_bytes = {}
417
+ for job_results in return_results:
418
+ for result in job_results:
419
+ action = result.dreamt_action.numpy()
420
+ # Remember to remove batch and sequence dimensions
421
+ action = be_action_vector_to_action_vector(action[0, 0].tolist())
422
+ dreamt_tokens = result.dreamt_tokens[0, 0].tolist()
423
+ image_filename = f"{result.job_id}_{result.dream_step_index}.png"
424
+ output_json.append({
425
+ "job_id": result.job_id,
426
+ "dream_step_index": result.dream_step_index,
427
+ "action": action,
428
+ "tokens": dreamt_tokens,
429
+ "image_filename": image_filename
430
+ })
431
+
432
+ image_bytes = io.BytesIO()
433
+ # this probably is not as smooth as it could be
434
+ T.ToPILImage()(result.dreamt_image[0, 0]).save(image_bytes, format="PNG")
435
+ output_image_bytes[image_filename] = image_bytes.getvalue()
436
+
437
+ # Write a zip file with all the images
438
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
439
+ zip_bytes = io.BytesIO()
440
+ with zipfile.ZipFile(zip_bytes, "w") as z:
441
+ for filename, bytes in output_image_bytes.items():
442
+ z.writestr(filename, bytes)
443
+ # Write the json
444
+ z.writestr(PREDICTION_JSON_FILENAME, json.dumps(output_json))
445
+
446
+ zip_bytes.seek(0)
447
+
448
+ return send_file(
449
+ zip_bytes,
450
+ mimetype="zip",
451
+ as_attachment=True,
452
+ download_name=f"dreaming_results_{timestamp}.zip"
453
+ )
454
+
455
+ def cancel_job(self, request, request_json):
456
+ if "job_id" not in request_json:
457
+ return make_response("job_id not in request", 400)
458
+ job_id = request_json["job_id"]
459
+ self.cancelled_jobs.add(job_id)
460
+ # Cancel all jobs in the queue with this id
461
+ for job_queue in self.cancelled_jobs_queues:
462
+ job_queue.put(job_id)
463
+ return make_response("OK", 200)
464
+
465
+
466
+ def main_run(args):
467
+ app = Flask(__name__)
468
+
469
+ num_workers = th.cuda.device_count()
470
+ if num_workers == 0:
471
+ raise RuntimeError("No CUDA devices found. Cannot run Dreamer.")
472
+
473
+ server = DreamerServer(num_workers, args)
474
+ quit_flag = mp.Event()
475
+
476
+ # Start the dreamer worker(s)
477
+ dreamer_worker_processes = []
478
+ for device_i in range(num_workers):
479
+ device = f"cuda:{device_i}"
480
+ dreamer_worker_process = mp.Process(
481
+ target=dreamer_worker,
482
+ args=(server.jobs, server.results_queue, server.cancelled_jobs_queues[device_i], quit_flag, device, args)
483
+ )
484
+ dreamer_worker_process.daemon = True
485
+ dreamer_worker_process.start()
486
+ dreamer_worker_processes.append(dreamer_worker_process)
487
+
488
+ # Add the API endpoints
489
+ @app.route('/')
490
+ def details():
491
+ return server.get_details()
492
+
493
+ @app.route('/new_job', methods=['POST'])
494
+ def new_job():
495
+ request_json = json.loads(request.form["json"])
496
+ return server.add_new_job(request, request_json)
497
+
498
+ @app.route('/get_job_results', methods=['GET'])
499
+ def get_results():
500
+ # the "Json" is now in regular GET payload/parameters
501
+ request_json = {"job_ids": request.args.getlist("job_ids")}
502
+ return server.get_new_results(request, request_json)
503
+
504
+ @app.route('/cancel_job', methods=['GET'])
505
+ def cancel_job():
506
+ request_json = request.args.to_dict()
507
+ return server.cancel_job(request, request_json)
508
+
509
+ app.run(host="0.0.0.0", port=args.port, debug=args.debug)
510
+
511
+ # Cleanup
512
+ quit_flag.set()
513
+ for dreamer_worker_process in dreamer_worker_processes:
514
+ dreamer_worker_process.join()
515
+
516
+
517
+ if __name__ == '__main__':
518
+ args = parser.parse_args()
519
+ main_run(args)
setup_local.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tested using Python 3.9
2
+
3
+ echo "Making and activating a new virtual environment..."
4
+ python3.9 -m venv venv
5
+
6
+ echo "Activating the virtual environment..."
7
+ source venv/bin/activate
8
+
9
+ echo "Upgrading pip..."
10
+ pip install --upgrade pip
11
+
12
+ echo "Instaling the required packages..."
13
+ pip install -r requirements.txt
14
+
15
+ echo "Instaling the exiftool package for adding file metadata on Linux..."
16
+ sudo apt install -y exiftool
17
+
18
+ echo "Installing ffmpeg..."
19
+ sudo apt install ffmpeg
20
+
21
+ echo "All packages installed successfully!"
wham/models/nn/model_blocks.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ """
4
+ Some Utility blocks for ViT-VQGAN.
5
+
6
+ ConvNeXt blocks are based on:
7
+ Liu, Zhuang, et al. "A convnet for the 2020s."
8
+ Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.
9
+ """
10
+
11
+
12
+ class ConvNextDownsampleBig(nn.Module):
13
+ def __init__(self, c_in, c_out):
14
+ super().__init__()
15
+ self.group_norm = nn.GroupNorm(c_in, c_in)
16
+ self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=8, stride=4, padding=0)
17
+
18
+ def forward(self, x):
19
+ return self.conv1(self.group_norm(x))
20
+
21
+
22
+ class ConvNextBlock(nn.Module):
23
+ def __init__(self, channels):
24
+ super().__init__()
25
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=7, stride=1, padding=7 // 2, groups=channels) # 'Depthwise' conv
26
+ self.group_norm = nn.GroupNorm(channels, channels) # Should be equivalent to layernorm
27
+
28
+ # Transformer-style non-linearity
29
+ self.conv2 = nn.Conv2d(channels, channels * 4, kernel_size=1, stride=1, padding=0)
30
+ self.activation = nn.GELU()
31
+ self.conv3 = nn.Conv2d(channels * 4, channels, kernel_size=1, stride=1, padding=0)
32
+
33
+ def forward(self, x):
34
+ y = self.conv1(x)
35
+ y = self.group_norm(y)
36
+ y = self.conv2(y)
37
+ y = self.activation(y)
38
+ y = self.conv3(y)
39
+ return x + y
40
+
41
+
42
+ class ConvNextDownsample(nn.Module):
43
+ def __init__(self, c_in, c_out):
44
+ super().__init__()
45
+ self.group_norm = nn.GroupNorm(c_in, c_in)
46
+ self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)
47
+
48
+ def forward(self, x):
49
+ return self.conv1(self.group_norm(x))
wham/models/nn/nanoGPT.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/karpathy/nanoGPT/blob/master/model.py - Thanks Andrej Karpathy
2
+
3
+ # MIT License
4
+ # Copyright (c) 2022 Andrej Karpathy
5
+ # 2023 Microsoft Research
6
+
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
+ # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
+ # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
20
+ # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
21
+ # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
22
+ # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
23
+ # OR OTHER DEALINGS IN THE SOFTWARE.
24
+
25
+
26
+ """
27
+ Full definition of a GPT Language Model, all of it in this single file.
28
+ References:
29
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
30
+ https://github.com/openai/gpt-2/blob/master/src/model.py
31
+ 2) huggingface/transformers PyTorch implementation:
32
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
33
+ """
34
+
35
+ from dataclasses import dataclass
36
+ import inspect
37
+ import math
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+ from torch.nn import functional as F
42
+
43
+ NEGATIVE_INFINITE_FLOAT = -float("inf")
44
+ CROSS_ENTROPY_INVALID_CLASS_TARGET = -1
45
+
46
+ # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
47
+ def new_gelu(x):
48
+ """
49
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
50
+ Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
51
+ """
52
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
53
+
54
+
55
+ def limit_logits_to_valid_range(logits, valid_token_range):
56
+ """
57
+ MODIFIES logits INPLACE.
58
+ Mask out invalid positions in the logits tensor with -inf so they are not considered by the softmax.
59
+
60
+ Args:
61
+ logits: logits tensor of shape (batch_size, vocab_size)
62
+ valid_token_range: tuple of (start, end) indices of valid positions in the logits tensor (inclusive).
63
+ Everything outside is masked out with -inf.
64
+ """
65
+ logits[:, : valid_token_range[0]] = NEGATIVE_INFINITE_FLOAT
66
+ logits[:, valid_token_range[1] + 1 :] = NEGATIVE_INFINITE_FLOAT
67
+
68
+
69
+ def default_sample_token(logits, valid_token_range=None, temperature=1.0, deterministic=False, top_k=None, top_p=None, min_tokens_to_keep=1):
70
+ """
71
+ Given a vector of logits, sample and return an index according to settings.
72
+
73
+ logits: tensor of shape (batch_size, vocab_size)
74
+
75
+ valid_token_range should be a tuple, specifying start and end indices we'd like to sample from (inclusive).
76
+ If None, we'll sample from the full vocab.
77
+
78
+ If deterministic is True, we'll take the argmax of the logits which implies top-k sampling with top_k = 1, therefore user inputted values of top_p and top_k will be ignored.
79
+
80
+ Otherwise, either top-p (float) value can be specified or top-k (int) value can be specified.
81
+ Top-p (float top_p) : only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
82
+ Top-k (int top_k) : selects top_k tokens for generation.
83
+ min_tokens_to_keep: Used with both top_p and top_k sampling.
84
+ """
85
+ assert top_k is None or top_p is None, "Can only specify one of top-k or top-p sampling."
86
+ if temperature < 0.1:
87
+ # Avoid too low a temp, especially 0
88
+ temperature = 0.1
89
+ logits = logits / temperature
90
+ if valid_token_range is not None:
91
+ limit_logits_to_valid_range(logits, valid_token_range)
92
+ if deterministic:
93
+ selected_logits = select_logits(logits, top_k=1)
94
+ else:
95
+ selected_logits = select_logits(logits, top_p=top_p, top_k=top_k, min_tokens_to_keep=min_tokens_to_keep)
96
+ probs = F.softmax(selected_logits, dim=-1)
97
+ # More robustly handle errors in the sampling here
98
+ sampled_idx = torch.multinomial(probs, num_samples=1).squeeze(-1)
99
+ return sampled_idx
100
+
101
+
102
+ def select_logits(logits, top_k=None, top_p=None, min_tokens_to_keep=1):
103
+ """
104
+ Select from original logits using top-k or top-p sampling.
105
+
106
+ Args:
107
+ logits (torch.Tensor): Logits to sample from.
108
+ k (int, optional): Number of top elements to consider in top-k sampling.
109
+ p (float, optional): Threshold probability for top-p sampling.
110
+ min_tokens_to_keep (int, optional): Minimum number of tokens to keep in the output.
111
+
112
+ Returns:
113
+ logits: Selected logits after top-k or top-p sampling. Sets all logits outside the selected ones to NEGATIVE_INFINITE_FLOAT.
114
+ """
115
+ assert top_k is None or top_p is None, "Can only specify one of top-k or top-p sampling."
116
+ min_tokens_to_keep = min(min_tokens_to_keep, logits.size(-1))
117
+ if top_k is not None:
118
+ if not isinstance(top_k, int) or top_k <= 0:
119
+ raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
120
+
121
+ # Top-k sampling
122
+ top_k = max(top_k, min_tokens_to_keep)
123
+ top_k = min(top_k, logits.size(-1))
124
+ top_k_logits, _ = torch.topk(logits, top_k)
125
+ indices_to_remove = logits < top_k_logits[..., -1:]
126
+ logits = torch.where(indices_to_remove, NEGATIVE_INFINITE_FLOAT, logits)
127
+
128
+ elif top_p is not None:
129
+ top_p = float(top_p)
130
+ if top_p < 0 or top_p > 1.0:
131
+ raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
132
+
133
+ # Top-p sampling
134
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
135
+ sorted_probs = torch.softmax(sorted_logits, dim=-1)
136
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
137
+ sorted_indices_to_remove = cumulative_probs > top_p
138
+
139
+ # Remove tokens with cumulative probability above the threshold
140
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = False
141
+
142
+ # scatter sorted tensors to original indexing
143
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
144
+ logits = torch.where(indices_to_remove, NEGATIVE_INFINITE_FLOAT, logits)
145
+
146
+ else:
147
+ # Return logits as is
148
+ pass
149
+
150
+ return logits
151
+
152
+
153
+ class LayerNorm(nn.Module):
154
+ """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
155
+
156
+ def __init__(self, ndim, bias):
157
+ super().__init__()
158
+ self.weight = nn.Parameter(torch.ones(ndim))
159
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
160
+
161
+ def forward(self, input):
162
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
163
+
164
+ class LayerNormMinimal(nn.Module):
165
+ """LayerNorm like above, but without learnable parameters"""
166
+
167
+ def __init__(self, ndim, bias):
168
+ super().__init__()
169
+ self.ndim = (ndim,)
170
+
171
+ def forward(self, input):
172
+ return F.layer_norm(input, self.ndim, eps=1e-5)
173
+
174
+
175
+ class CausalSelfAttention(nn.Module):
176
+ def __init__(self, config):
177
+ super().__init__()
178
+ assert config.n_embd % config.n_head == 0
179
+ # key, query, value projections for all heads, but in a batch
180
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
181
+ # output projection
182
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
183
+ # regularization
184
+ self.attn_dropout = nn.Dropout(config.dropout)
185
+ self.resid_dropout = nn.Dropout(config.dropout)
186
+ self.n_head = config.n_head
187
+ self.n_embd = config.n_embd
188
+ self.dropout = config.dropout
189
+ # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
190
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
191
+ # causal mask to ensure that attention is only applied to the left in the input sequence
192
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size), persistent=False)
193
+
194
+ self.cached_k = None
195
+ self.cached_v = None
196
+ self.current_cache_size = 0
197
+
198
+ def _manual_causal_attention(self, q, k, v, mask):
199
+ # q, k and v should be of shape (B, nh, T, hs)
200
+ token_len = q.size(-2)
201
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
202
+ att = att.masked_fill(mask[:, :, :token_len, :token_len] == 0, float("-inf"))
203
+ att = F.softmax(att, dim=-1)
204
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
205
+ return y
206
+
207
+ def forward(self, x, cache=False):
208
+ batch_size, token_len, n_embd = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
209
+
210
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
211
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
212
+ k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
213
+ q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
214
+ v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
215
+
216
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
217
+ if self.flash and not cache:
218
+ # efficient attention using Flash Attention CUDA kernels
219
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
220
+ elif cache:
221
+ # manual implemention of attention (as below), but cache arrays we can reuse
222
+ assert token_len == 1, "Cache only works for single step"
223
+ assert self.cached_k is not None, "Must call reset_cache() before using cache"
224
+ assert self.current_cache_size < self.cached_k.size(2), "Trying to generate more steps than provided in reset_cache() `num_steps_to_come`"
225
+ assert self.dropout == 0.0, "Dropout not supported with caching"
226
+ this_step_q = q
227
+ self.cached_k[:, :, self.current_cache_size, :] = k[:, :, 0, :]
228
+ self.cached_v[:, :, self.current_cache_size, :] = v[:, :, 0, :]
229
+ # Remove the zero parts
230
+ k = self.cached_k[:, :, : self.current_cache_size + 1, :]
231
+ # compute last row of the attention mask
232
+ this_step_att_row = (this_step_q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
233
+ this_step_att_row = F.softmax(this_step_att_row, dim=-1)
234
+ # We only need output for the current step
235
+ y = this_step_att_row @ self.cached_v[:, :, : self.current_cache_size + 1, :]
236
+ # Update cache
237
+ self.current_cache_size += 1
238
+ else:
239
+ y = self._manual_causal_attention(q, k, v, self.bias)
240
+ y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd) # re-assemble all head outputs side by side
241
+
242
+ # output projection
243
+ y = self.resid_dropout(self.c_proj(y))
244
+ return y
245
+
246
+ def reset_cache(self, x, num_steps_to_come):
247
+ """
248
+ Reset caches by doing initial pass with x data (returning same output as forward).
249
+ Also set the number of steps to come, which is used to initialize the buffers
250
+ """
251
+ batch_size, token_len, n_embd = x.size()
252
+
253
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
254
+ k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
255
+ q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
256
+ v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
257
+
258
+ # Use SDPA instead of a manual implementation
259
+ # y = self._manual_causal_attention(q, k, v, self.bias)
260
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
261
+
262
+ y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd)
263
+ # output projection
264
+ y = self.resid_dropout(self.c_proj(y))
265
+
266
+ # Create full k,q,v for predicting all future steps.
267
+ # Just null-out the last num_steps_to_come-1 steps
268
+ pad_size = num_steps_to_come
269
+ self.current_cache_size = token_len
270
+ self.cached_k = torch.cat([k, torch.zeros(batch_size, self.n_head, pad_size, n_embd // self.n_head, device=k.device)], dim=2)
271
+ self.cached_v = torch.cat([v, torch.zeros(batch_size, self.n_head, pad_size, n_embd // self.n_head, device=v.device)], dim=2)
272
+
273
+ return y
274
+
275
+ class SelfAttention(nn.Module):
276
+ """
277
+ Non-causal self-attention layer, the same as CausalSelfAttention but without the causal mask.
278
+ Duplicating the code to keep this separate for clarity.
279
+ """
280
+
281
+ def __init__(self, config):
282
+ super().__init__()
283
+ assert config.n_embd % config.n_head == 0
284
+ # key, query, value projections for all heads, but in a batch
285
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
286
+ # output projection
287
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
288
+ # regularization
289
+ self.attn_dropout = nn.Dropout(config.dropout)
290
+ self.resid_dropout = nn.Dropout(config.dropout)
291
+ self.n_head = config.n_head
292
+ self.n_embd = config.n_embd
293
+ self.dropout = config.dropout
294
+ # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
295
+ self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
296
+ assert self.flash, "SelfAttention only supports flash attention for now."
297
+
298
+ self.register_buffer("attn_mask", torch.ones((config.block_size, config.block_size)).bool().unsqueeze(0).unsqueeze(0))
299
+
300
+ def forward(self, x):
301
+ batch_size, token_len, n_embd = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
302
+
303
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
304
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
305
+ k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
306
+ q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
307
+ v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
308
+
309
+ # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
310
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout, is_causal=False)
311
+ y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd) # re-assemble all head outputs side by side
312
+
313
+ # output projection
314
+ y = self.resid_dropout(self.c_proj(y))
315
+ return y
316
+
317
+ class MLP(nn.Module):
318
+ def __init__(self, config):
319
+ super().__init__()
320
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
321
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
322
+ self.dropout = nn.Dropout(config.dropout)
323
+
324
+ def forward(self, x):
325
+ x = self.c_fc(x)
326
+ x = new_gelu(x)
327
+ x = self.c_proj(x)
328
+ x = self.dropout(x)
329
+ return x
330
+
331
+ class GELU_MLP(nn.Module):
332
+ """MLP Block using PyTorch's native GELU activation function"""
333
+ def __init__(self, config):
334
+ super().__init__()
335
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
336
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
337
+ self.dropout = nn.Dropout(config.dropout)
338
+
339
+ def forward(self, x):
340
+ x = self.c_fc(x)
341
+ x = F.gelu(x, approximate="tanh")
342
+ x = self.c_proj(x)
343
+ x = self.dropout(x)
344
+ return x
345
+
346
+
347
+ class Block(nn.Module):
348
+ def __init__(self, config):
349
+ super().__init__()
350
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
351
+ self.attn = CausalSelfAttention(config)
352
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
353
+ self.mlp = MLP(config)
354
+
355
+ def forward(self, x, cache=False, reset_cache_with_num_steps_to_come=None):
356
+ """
357
+ Args:
358
+ cache: If True, use the cache to predict the next token (assumes model was initialized with `reset_cache`).
359
+ reset_cache_with_num_steps_to_come:
360
+ If not None, reset and prepare the cache for cached prediction of the next `reset_cache_with_num_steps_to_come` tokens.
361
+ This is same as calling `reset_cache` with the same argument, but we include option here in `forward` to support torch hook functions (used to get embeddings from this module output).
362
+
363
+ Caching example:
364
+ ```
365
+ # Initialize model with reset_cache_with_num_steps_to_come=10
366
+ outputs[0] = model(inputs, reset_cache_with_num_steps_to_come=10)
367
+ # Predict next 10 tokens using cache
368
+ for i in range(10):
369
+ outputs[i+1] = model(inputs, cache=True)
370
+ ```
371
+ """
372
+ if reset_cache_with_num_steps_to_come:
373
+ return self.reset_cache(x, num_steps_to_come=reset_cache_with_num_steps_to_come)
374
+ x = x + self.attn(self.ln_1(x), cache=cache)
375
+ x = x + self.mlp(self.ln_2(x))
376
+ return x
377
+
378
+ def reset_cache(self, x, num_steps_to_come):
379
+ x = x + self.attn.reset_cache(self.ln_1(x), num_steps_to_come=num_steps_to_come)
380
+ x = x + self.mlp(self.ln_2(x))
381
+ return x
382
+
383
+ class BlockV2(nn.Module):
384
+ """
385
+ Compared to the Block in the original implementation, this one uses non-parametric LayerNorm and Pytorch's GELU.
386
+ These two changes save significant vram but are incompatible with previously trained models.
387
+ Hence the separate class.
388
+ """
389
+
390
+ def __init__(self, config):
391
+ super().__init__()
392
+ self.ln_1 = LayerNormMinimal(config.n_embd, bias=config.bias)
393
+ self.attn = CausalSelfAttention(config)
394
+ self.ln_2 = LayerNormMinimal(config.n_embd, bias=config.bias)
395
+ self.mlp = GELU_MLP(config)
396
+
397
+ def forward(self, x, cache=False, reset_cache_with_num_steps_to_come=None):
398
+ if reset_cache_with_num_steps_to_come:
399
+ return self.reset_cache(x, num_steps_to_come=reset_cache_with_num_steps_to_come)
400
+ x = x + self.attn(self.ln_1(x), cache=cache)
401
+ x = x + self.mlp(self.ln_2(x))
402
+ return x
403
+
404
+ def reset_cache(self, x, num_steps_to_come):
405
+ x = x + self.attn.reset_cache(self.ln_1(x), num_steps_to_come=num_steps_to_come)
406
+ x = x + self.mlp(self.ln_2(x))
407
+ return x
408
+
409
+ class SelfAttentionBlock(nn.Module):
410
+ def __init__(self, config):
411
+ super().__init__()
412
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
413
+ self.attn = SelfAttention(config)
414
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
415
+ self.mlp = MLP(config)
416
+
417
+ def forward(self, x):
418
+ x = x + self.attn(self.ln_1(x))
419
+ x = x + self.mlp(self.ln_2(x))
420
+ return x
421
+
422
+ @dataclass
423
+ class GPTConfig:
424
+ block_size: int = 1024
425
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
426
+ n_layer: int = 12
427
+ n_head: int = 12
428
+ n_embd: int = 768
429
+ dropout: float = 0.0
430
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
431
+ version: int = 1 # Version 1 is the original GPT, Version 2 is the one with non-parametric LayerNorm and Pytorch's GELU
432
+
433
+
434
+ class GPT(nn.Module):
435
+ def __init__(self, config):
436
+ super().__init__()
437
+ assert config.vocab_size is not None
438
+ assert config.block_size is not None
439
+ self.config = config
440
+
441
+ self.version = config.version
442
+
443
+ print(f"[nanoGPT] creating model with version {self.version}")
444
+
445
+ if self.version == 1:
446
+ transformer_dict = dict(
447
+ wpe=nn.Embedding(config.block_size, config.n_embd),
448
+ drop=nn.Dropout(config.dropout),
449
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
450
+ ln_f=LayerNorm(config.n_embd, bias=config.bias),
451
+ )
452
+ elif self.version == 2:
453
+ transformer_dict = dict(
454
+ wpe=nn.Embedding(config.block_size, config.n_embd),
455
+ drop=nn.Dropout(config.dropout),
456
+ h=nn.ModuleList([BlockV2(config) for _ in range(config.n_layer)]),
457
+ ln_f=LayerNorm(config.n_embd, bias=config.bias), # This one is still parametric due to user error
458
+ )
459
+
460
+ transformer_dict["wte"] = nn.Embedding(config.vocab_size, config.n_embd)
461
+ self.transformer = nn.ModuleDict(transformer_dict)
462
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
463
+ # with weight tying when using torch.compile() some warnings get generated:
464
+ # "UserWarning: functional_call was passed multiple values for tied weights.
465
+ # This behavior is deprecated and will be an error in future versions"
466
+ # not 100% sure what this is, so far seems to be harmless.
467
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
468
+
469
+ # init all weights
470
+ self.apply(self._init_weights)
471
+ # apply special scaled init to the residual projections, per GPT-2 paper
472
+ for pn, p in self.named_parameters():
473
+ if pn.endswith("c_proj.weight"):
474
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
475
+
476
+ def get_num_params(self, non_embedding=True):
477
+ """
478
+ Return the number of parameters in the model.
479
+ For non-embedding count (default), the position embeddings get subtracted.
480
+ The token embeddings would too, except due to the parameter sharing these
481
+ params are actually used as weights in the final layer, so we include them.
482
+ """
483
+ n_params = sum(p.numel() for p in self.parameters())
484
+ if non_embedding:
485
+ n_params -= self.transformer.wpe.weight.numel()
486
+ return n_params
487
+
488
+ def _init_weights(self, module):
489
+ if isinstance(module, nn.Linear):
490
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
491
+ if module.bias is not None:
492
+ torch.nn.init.zeros_(module.bias)
493
+ elif isinstance(module, nn.Embedding):
494
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
495
+
496
+ def _apply_pos_encoding(self, x):
497
+ device = x.device
498
+ token_len = x.size(1)
499
+ pos = torch.arange(0, token_len, dtype=torch.long, device=device).unsqueeze(0)
500
+ pos_emb = self.transformer.wpe(pos)
501
+ x = x + pos_emb
502
+ return x
503
+
504
+ def original_forward(self, idx, targets=None, loss_mask=None, loss_reduction="mean"):
505
+ batch_size, seq_len = idx.shape[:2]
506
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
507
+ x = self.transformer.drop(self._apply_pos_encoding(tok_emb))
508
+ for block in self.transformer.h:
509
+ x = block(x)
510
+ x = self.transformer.ln_f(x)
511
+
512
+ if targets is not None:
513
+ # if we are given some desired targets also calculate the loss
514
+ logits = self.lm_head(x)
515
+ if loss_mask is not None:
516
+ # Feeding target = CROSS_ENTROPY_INVALID_CLASS_TARGET to cross_entropy will ignore the loss
517
+ # for that position. This is useful for padding tokens.
518
+ targets[loss_mask == 0] = CROSS_ENTROPY_INVALID_CLASS_TARGET
519
+ loss = F.cross_entropy(
520
+ logits.view(batch_size * seq_len, self.config.vocab_size), targets.view(-1), ignore_index=CROSS_ENTROPY_INVALID_CLASS_TARGET, reduction=loss_reduction
521
+ )
522
+ if loss_reduction == "none":
523
+ # Reshape back into batch_size and seq_len
524
+ loss = loss.view(batch_size, seq_len)
525
+ else:
526
+ # inference-time mini-optimization: only forward the lm_head on the very last position
527
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
528
+ loss = None
529
+
530
+ return logits, loss
531
+
532
+ def forward(self, x, targets=None, loss_mask=None, loss_reduction="mean"):
533
+ token_len = x.size(1)
534
+ assert token_len <= self.config.block_size, f"Cannot forward sequence of length {token_len}, block size is only {self.config.block_size}"
535
+ return self.original_forward(x, targets, loss_mask, loss_reduction)
536
+
537
+ @torch.no_grad()
538
+ def generate(self, idx, max_new_tokens, valid_token_range=None, temperature=1.0, top_k=None, raise_cropping=False, deterministic=False):
539
+ """
540
+ valid_token_range should be a tuple, specifying start and end indices we'd like to sample from (inclusive).
541
+ if None, we'll sample from the full vocab.
542
+
543
+ If raise_cropping is True, we'll raise an error if we need to crop the sequence context.
544
+ """
545
+ if valid_token_range is None:
546
+ valid_token_range = (0, self.config.vocab_size - 1)
547
+ assert len(valid_token_range) == 2
548
+ assert valid_token_range[0] < valid_token_range[1]
549
+ for _ in range(max_new_tokens):
550
+ # if the sequence context is growing too long we must crop it at block_size
551
+ idx_cond = idx
552
+ if idx.size(1) > self.config.block_size:
553
+ if raise_cropping:
554
+ raise ValueError("Tried to crop idxs but flag told to raise this")
555
+ else:
556
+ idx_cond = idx[:, -self.config.block_size :]
557
+ # forward the model to get the logits for the index in the sequence
558
+ logits, _ = self(idx_cond)
559
+ # pluck the logits at the final step and scale by desired temperature
560
+ logits = logits[:, -1, :] / temperature # logits is B T Vocabsize -> B Vocabsize
561
+ # optionally crop the logits to only the top k options
562
+ if top_k is not None:
563
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
564
+ logits[logits < v[:, [-1]]] = NEGATIVE_INFINITE_FLOAT
565
+
566
+ # Crop out the logits we don't want to sample from
567
+ if valid_token_range is not None:
568
+ limit_logits_to_valid_range(logits, valid_token_range)
569
+
570
+ # apply softmax to convert logits to (normalized) probabilities
571
+ probs = F.softmax(logits, dim=-1)
572
+
573
+ if deterministic:
574
+ # Take max of the results
575
+ idx_next = torch.argmax(probs, dim=-1, keepdim=True)
576
+ else:
577
+ # sample from the distribution
578
+ idx_next = torch.multinomial(probs, num_samples=1)
579
+ # append sampled index to the running sequence and continue
580
+ idx = torch.cat((idx, idx_next), dim=1)
581
+
582
+ return idx
583
+
584
+ @torch.no_grad()
585
+ def optimized_generate(
586
+ self,
587
+ idx,
588
+ num_new_tokens,
589
+ valid_token_ranges=None,
590
+ temperature=1.0,
591
+ deterministic=False,
592
+ raise_cropping=False,
593
+ top_k=None,
594
+ top_p=None,
595
+ min_tokens_to_keep=1,
596
+ ):
597
+ """
598
+ Generate function but optimized by caching the results in transformer blocks (think this is referred to as "attention caching").
599
+ The higher the num_new_tokens, the more the speedup compared to original generate.
600
+
601
+ Caveat: the context length + num_new_tokens must be less than the block size. This means that the first
602
+ generated tokens do not have full context length.
603
+
604
+ valid_token_ranges should be None or list of length num_new_tokens, specifying valid range for tokens for every step
605
+ """
606
+ # Properly compile the modules used and/or quantize for improved speed.
607
+ logit_layer = self.lm_head
608
+ embedder_fn = self.transformer.wte
609
+
610
+ if valid_token_ranges is None:
611
+ valid_token_ranges = [[0, self.config.vocab_size] for _ in range(num_new_tokens)]
612
+ assert len(valid_token_ranges) == num_new_tokens, "valid_token_ranges should be list of length num_new_tokens or None"
613
+
614
+ _, token_len = idx.size()
615
+ if token_len + num_new_tokens > self.config.block_size:
616
+ raise ValueError("Can't use optimized generation with num_new_tokens + context_length > block_size")
617
+ new_idxs = torch.zeros(idx.size(0), num_new_tokens, dtype=torch.long, device=idx.device)
618
+ # First, we need to cull the sequence to the block size
619
+ # and remove first max_new_tokens so we can reuse same position embeddings
620
+ # and not have to recompute them
621
+ num_original_tokens = idx.size(1)
622
+ original_idx = idx
623
+ if (num_original_tokens + num_new_tokens) > self.config.block_size:
624
+ if raise_cropping:
625
+ raise ValueError("Tried to crop idxs but flag told to raise this")
626
+ original_idx = idx[:, -self.config.block_size + num_new_tokens :]
627
+ original_pos = torch.arange(0, original_idx.size(1), dtype=torch.long, device=idx.device).unsqueeze(0)
628
+ # Now cache results with the original context
629
+ original_tok_emb = embedder_fn(original_idx)
630
+ original_pos_emb = self.transformer.wpe(original_pos)
631
+ original_x = original_tok_emb + original_pos_emb
632
+ for block in self.transformer.h:
633
+ # Reset the cache for each block, and cache new result
634
+ original_x = block(original_x, reset_cache_with_num_steps_to_come=num_new_tokens)
635
+
636
+ # Sample the first token
637
+ original_x = self.transformer.ln_f(original_x)
638
+ last_logit = logit_layer(original_x[:, [-1], :])
639
+ new_idxs[:, 0] = default_sample_token(
640
+ last_logit[:, -1, :], valid_token_ranges[0], temperature, deterministic, top_k=top_k, top_p=top_p, min_tokens_to_keep=min_tokens_to_keep
641
+ )
642
+
643
+ # Generate rest of the steps
644
+ for generation_idx in range(1, num_new_tokens):
645
+ # forward the model to get the logits for the index in the sequence
646
+ # This is the position of the latest generated token, not the currently going-to-be-generated token
647
+ latest_token_pos = num_original_tokens + generation_idx - 1
648
+ # We only need to pass in the latest token
649
+ newest_idx = new_idxs[:, generation_idx - 1].unsqueeze(-1)
650
+ newest_tok_emb = embedder_fn(newest_idx)
651
+ newest_pos_emb = self.transformer.wpe(torch.tensor(latest_token_pos, dtype=torch.long, device=idx.device).unsqueeze(0))
652
+ newest_x = newest_tok_emb + newest_pos_emb
653
+ for block in self.transformer.h:
654
+ newest_x = block(newest_x, cache=True)
655
+
656
+ newest_x = self.transformer.ln_f(newest_x)
657
+ newest_logit = logit_layer(newest_x)
658
+ # Check this function isn't slowing things down noticeably
659
+ new_idxs[:, generation_idx] = default_sample_token(
660
+ newest_logit[:, -1, :], valid_token_ranges[generation_idx], temperature, deterministic, top_k=top_k, top_p=top_p, min_tokens_to_keep=min_tokens_to_keep
661
+ )
662
+
663
+ # Combine indices
664
+ new_idxs = torch.cat((idx, new_idxs), dim=1)
665
+ return new_idxs
wham/models/pl/__init__.py ADDED
File without changes
wham/models/pl/pl_base_model.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+
3
+ class BaseTrainingModel(pl.LightningModule):
4
+ def __init__(self, **kwargs):
5
+ super().__init__(**kwargs)
wham/models/vqgan/taming/LICENSE ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ All files under this directory are originally from the taming-transformers repository:
2
+ https://github.com/CompVis/taming-transformers
3
+
4
+ Below is a copy of the original license
5
+ ------------------------------------------------------------------------------
6
+ Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
19
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
21
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
22
+ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
23
+ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
24
+ OR OTHER DEALINGS IN THE SOFTWARE./
wham/models/vqgan/taming/model.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All files under this directory are originally from the taming-transformers repository:
2
+ # https://github.com/CompVis/taming-transformers
3
+
4
+ # MIT License
5
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
6
+ # 2023 Microsoft Research
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
19
+ # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20
+ # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
21
+ # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
22
+ # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
23
+ # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
24
+ # OR OTHER DEALINGS IN THE SOFTWARE.
25
+
26
+ import math
27
+ import torch
28
+ import torch.nn as nn
29
+ import numpy as np
30
+
31
+
32
+ def get_timestep_embedding(timesteps, embedding_dim):
33
+ """
34
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
35
+ From Fairseq.
36
+ Build sinusoidal embeddings.
37
+ This matches the implementation in tensor2tensor, but differs slightly
38
+ from the description in Section 3.5 of "Attention Is All You Need".
39
+ """
40
+ assert len(timesteps.shape) == 1
41
+
42
+ half_dim = embedding_dim // 2
43
+ emb = math.log(10000) / (half_dim - 1)
44
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
45
+ emb = emb.to(device=timesteps.device)
46
+ emb = timesteps.float()[:, None] * emb[None, :]
47
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
48
+ if embedding_dim % 2 == 1: # zero pad
49
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
50
+ return emb
51
+
52
+
53
+ def nonlinearity(x):
54
+ # swish
55
+ return x * torch.sigmoid(x)
56
+
57
+
58
+ def Normalize(in_channels):
59
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
60
+
61
+
62
+ class Upsample(nn.Module):
63
+ def __init__(self, in_channels, with_conv):
64
+ super().__init__()
65
+ self.with_conv = with_conv
66
+ if self.with_conv:
67
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
68
+
69
+ def forward(self, x):
70
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
71
+ if self.with_conv:
72
+ x = self.conv(x)
73
+ return x
74
+
75
+
76
+ class Downsample(nn.Module):
77
+ def __init__(self, in_channels, with_conv):
78
+ super().__init__()
79
+ self.with_conv = with_conv
80
+ if self.with_conv:
81
+ # no asymmetric padding in torch conv, must do it ourselves
82
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
83
+
84
+ def forward(self, x):
85
+ if self.with_conv:
86
+ pad = (0, 1, 0, 1)
87
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
88
+ x = self.conv(x)
89
+ else:
90
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
91
+ return x
92
+
93
+
94
+ class ResnetBlock(nn.Module):
95
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
96
+ super().__init__()
97
+ self.in_channels = in_channels
98
+ out_channels = in_channels if out_channels is None else out_channels
99
+ self.out_channels = out_channels
100
+ self.use_conv_shortcut = conv_shortcut
101
+
102
+ self.norm1 = Normalize(in_channels)
103
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
104
+ if temb_channels > 0:
105
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
106
+ self.norm2 = Normalize(out_channels)
107
+ self.dropout = torch.nn.Dropout(dropout)
108
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
109
+ if self.in_channels != self.out_channels:
110
+ if self.use_conv_shortcut:
111
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
112
+ else:
113
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
114
+
115
+ def forward(self, x, temb):
116
+ h = x
117
+ h = self.norm1(h)
118
+ h = nonlinearity(h)
119
+ h = self.conv1(h)
120
+
121
+ if temb is not None:
122
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
123
+
124
+ h = self.norm2(h)
125
+ h = nonlinearity(h)
126
+ h = self.dropout(h)
127
+ h = self.conv2(h)
128
+
129
+ if self.in_channels != self.out_channels:
130
+ if self.use_conv_shortcut:
131
+ x = self.conv_shortcut(x)
132
+ else:
133
+ x = self.nin_shortcut(x)
134
+
135
+ return x + h
136
+
137
+
138
+ class AttnBlock(nn.Module):
139
+ def __init__(self, in_channels):
140
+ super().__init__()
141
+ self.in_channels = in_channels
142
+
143
+ self.norm = Normalize(in_channels)
144
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
145
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
146
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
147
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
148
+
149
+ def forward(self, x):
150
+ h_ = x
151
+ h_ = self.norm(h_)
152
+ q = self.q(h_)
153
+ k = self.k(h_)
154
+ v = self.v(h_)
155
+
156
+ # compute attention
157
+ b, c, h, w = q.shape
158
+ q = q.reshape(b, c, h * w)
159
+ q = q.permute(0, 2, 1) # b,hw,c
160
+ k = k.reshape(b, c, h * w) # b,c,hw
161
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
162
+ w_ = w_ * (int(c) ** (-0.5))
163
+ w_ = torch.nn.functional.softmax(w_, dim=2)
164
+
165
+ # attend to values
166
+ v = v.reshape(b, c, h * w)
167
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
168
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
169
+ h_ = h_.reshape(b, c, h, w)
170
+
171
+ h_ = self.proj_out(h_)
172
+
173
+ return x + h_
174
+
175
+
176
+ class Model(nn.Module):
177
+ def __init__(
178
+ self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True
179
+ ):
180
+ super().__init__()
181
+ self.ch = ch
182
+ self.temb_ch = self.ch * 4
183
+ self.num_resolutions = len(ch_mult)
184
+ self.num_res_blocks = num_res_blocks
185
+ self.resolution = resolution
186
+ self.in_channels = in_channels
187
+
188
+ self.use_timestep = use_timestep
189
+ if self.use_timestep:
190
+ # timestep embedding
191
+ self.temb = nn.Module()
192
+ self.temb.dense = nn.ModuleList(
193
+ [
194
+ torch.nn.Linear(self.ch, self.temb_ch),
195
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
196
+ ]
197
+ )
198
+
199
+ # downsampling
200
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
201
+
202
+ curr_res = resolution
203
+ in_ch_mult = (1,) + tuple(ch_mult)
204
+ self.down = nn.ModuleList()
205
+ for i_level in range(self.num_resolutions):
206
+ block = nn.ModuleList()
207
+ attn = nn.ModuleList()
208
+ block_in = ch * in_ch_mult[i_level]
209
+ block_out = ch * ch_mult[i_level]
210
+ for i_block in range(self.num_res_blocks):
211
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
212
+ block_in = block_out
213
+ if curr_res in attn_resolutions:
214
+ attn.append(AttnBlock(block_in))
215
+ down = nn.Module()
216
+ down.block = block
217
+ down.attn = attn
218
+ if i_level != self.num_resolutions - 1:
219
+ down.downsample = Downsample(block_in, resamp_with_conv)
220
+ curr_res = curr_res // 2
221
+ self.down.append(down)
222
+
223
+ # middle
224
+ self.mid = nn.Module()
225
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
226
+ self.mid.attn_1 = AttnBlock(block_in)
227
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
228
+
229
+ # upsampling
230
+ self.up = nn.ModuleList()
231
+ for i_level in reversed(range(self.num_resolutions)):
232
+ block = nn.ModuleList()
233
+ attn = nn.ModuleList()
234
+ block_out = ch * ch_mult[i_level]
235
+ skip_in = ch * ch_mult[i_level]
236
+ for i_block in range(self.num_res_blocks + 1):
237
+ if i_block == self.num_res_blocks:
238
+ skip_in = ch * in_ch_mult[i_level]
239
+ block.append(ResnetBlock(in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
240
+ block_in = block_out
241
+ if curr_res in attn_resolutions:
242
+ attn.append(AttnBlock(block_in))
243
+ up = nn.Module()
244
+ up.block = block
245
+ up.attn = attn
246
+ if i_level != 0:
247
+ up.upsample = Upsample(block_in, resamp_with_conv)
248
+ curr_res = curr_res * 2
249
+ self.up.insert(0, up) # prepend to get consistent order
250
+
251
+ # end
252
+ self.norm_out = Normalize(block_in)
253
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
254
+
255
+ def forward(self, x, t=None):
256
+ # assert x.shape[2] == x.shape[3] == self.resolution
257
+
258
+ if self.use_timestep:
259
+ # timestep embedding
260
+ assert t is not None
261
+ temb = get_timestep_embedding(t, self.ch)
262
+ temb = self.temb.dense[0](temb)
263
+ temb = nonlinearity(temb)
264
+ temb = self.temb.dense[1](temb)
265
+ else:
266
+ temb = None
267
+
268
+ # downsampling
269
+ hs = [self.conv_in(x)]
270
+ for i_level in range(self.num_resolutions):
271
+ for i_block in range(self.num_res_blocks):
272
+ h = self.down[i_level].block[i_block](hs[-1], temb)
273
+ if len(self.down[i_level].attn) > 0:
274
+ h = self.down[i_level].attn[i_block](h)
275
+ hs.append(h)
276
+ if i_level != self.num_resolutions - 1:
277
+ hs.append(self.down[i_level].downsample(hs[-1]))
278
+
279
+ # middle
280
+ h = hs[-1]
281
+ h = self.mid.block_1(h, temb)
282
+ h = self.mid.attn_1(h)
283
+ h = self.mid.block_2(h, temb)
284
+
285
+ # upsampling
286
+ for i_level in reversed(range(self.num_resolutions)):
287
+ for i_block in range(self.num_res_blocks + 1):
288
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
289
+ if len(self.up[i_level].attn) > 0:
290
+ h = self.up[i_level].attn[i_block](h)
291
+ if i_level != 0:
292
+ h = self.up[i_level].upsample(h)
293
+
294
+ # end
295
+ h = self.norm_out(h)
296
+ h = nonlinearity(h)
297
+ h = self.conv_out(h)
298
+ return h
299
+
300
+
301
+ class Encoder(nn.Module):
302
+ def __init__(
303
+ self,
304
+ *,
305
+ ch,
306
+ out_ch,
307
+ ch_mult=(1, 2, 4, 8),
308
+ num_res_blocks,
309
+ attn_resolutions,
310
+ dropout=0.0,
311
+ resamp_with_conv=True,
312
+ in_channels,
313
+ resolution,
314
+ z_channels,
315
+ double_z=True,
316
+ **ignore_kwargs
317
+ ):
318
+ super().__init__()
319
+ self.ch = ch
320
+ self.temb_ch = 0
321
+ self.num_resolutions = len(ch_mult)
322
+ self.num_res_blocks = num_res_blocks
323
+ self.resolution = resolution
324
+ self.in_channels = in_channels
325
+
326
+ # downsampling
327
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
328
+
329
+ curr_res = resolution
330
+ in_ch_mult = (1,) + tuple(ch_mult)
331
+ self.down = nn.ModuleList()
332
+ for i_level in range(self.num_resolutions):
333
+ block = nn.ModuleList()
334
+ attn = nn.ModuleList()
335
+ block_in = ch * in_ch_mult[i_level]
336
+ block_out = ch * ch_mult[i_level]
337
+ for i_block in range(self.num_res_blocks):
338
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
339
+ block_in = block_out
340
+ if curr_res in attn_resolutions:
341
+ attn.append(AttnBlock(block_in))
342
+ down = nn.Module()
343
+ down.block = block
344
+ down.attn = attn
345
+ if i_level != self.num_resolutions - 1:
346
+ down.downsample = Downsample(block_in, resamp_with_conv)
347
+ curr_res = curr_res // 2
348
+ self.down.append(down)
349
+
350
+ # middle
351
+ self.mid = nn.Module()
352
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
353
+ self.mid.attn_1 = AttnBlock(block_in)
354
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
355
+
356
+ # end
357
+ self.norm_out = Normalize(block_in)
358
+ self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1)
359
+
360
+ def forward(self, x):
361
+ # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
362
+
363
+ # timestep embedding
364
+ temb = None
365
+
366
+ # downsampling
367
+ hs = [self.conv_in(x)]
368
+ for i_level in range(self.num_resolutions):
369
+ for i_block in range(self.num_res_blocks):
370
+ h = self.down[i_level].block[i_block](hs[-1], temb)
371
+ if len(self.down[i_level].attn) > 0:
372
+ h = self.down[i_level].attn[i_block](h)
373
+ hs.append(h)
374
+ if i_level != self.num_resolutions - 1:
375
+ hs.append(self.down[i_level].downsample(hs[-1]))
376
+
377
+ # middle
378
+ h = hs[-1]
379
+ h = self.mid.block_1(h, temb)
380
+ h = self.mid.attn_1(h)
381
+ h = self.mid.block_2(h, temb)
382
+
383
+ # end
384
+ h = self.norm_out(h)
385
+ h = nonlinearity(h)
386
+ h = self.conv_out(h)
387
+ return h
388
+
389
+
390
+ class Decoder(nn.Module):
391
+ def __init__(
392
+ self,
393
+ *,
394
+ ch,
395
+ out_ch,
396
+ ch_mult=(1, 2, 4, 8),
397
+ num_res_blocks,
398
+ attn_resolutions,
399
+ dropout=0.0,
400
+ resamp_with_conv=True,
401
+ in_channels,
402
+ resolution,
403
+ z_channels,
404
+ give_pre_end=False,
405
+ **ignorekwargs
406
+ ):
407
+ super().__init__()
408
+ self.ch = ch
409
+ self.temb_ch = 0
410
+ self.num_resolutions = len(ch_mult)
411
+ self.num_res_blocks = num_res_blocks
412
+ self.resolution = resolution
413
+ self.in_channels = in_channels
414
+ self.give_pre_end = give_pre_end
415
+
416
+ # compute in_ch_mult, block_in and curr_res at lowest res
417
+ in_ch_mult = (1,) + tuple(ch_mult)
418
+ block_in = ch * ch_mult[self.num_resolutions - 1]
419
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
420
+ self.z_shape = (1, z_channels, curr_res, curr_res)
421
+
422
+ # z to block_in
423
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
424
+
425
+ # middle
426
+ self.mid = nn.Module()
427
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
428
+ self.mid.attn_1 = AttnBlock(block_in)
429
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
430
+
431
+ # upsampling
432
+ self.up = nn.ModuleList()
433
+ for i_level in reversed(range(self.num_resolutions)):
434
+ block = nn.ModuleList()
435
+ attn = nn.ModuleList()
436
+ block_out = ch * ch_mult[i_level]
437
+ for i_block in range(self.num_res_blocks + 1):
438
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
439
+ block_in = block_out
440
+ if curr_res in attn_resolutions:
441
+ attn.append(AttnBlock(block_in))
442
+ up = nn.Module()
443
+ up.block = block
444
+ up.attn = attn
445
+ if i_level != 0:
446
+ up.upsample = Upsample(block_in, resamp_with_conv)
447
+ curr_res = curr_res * 2
448
+ self.up.insert(0, up) # prepend to get consistent order
449
+
450
+ # end
451
+ self.norm_out = Normalize(block_in)
452
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
453
+
454
+ def forward(self, z):
455
+ # assert z.shape[1:] == self.z_shape[1:]
456
+ self.last_z_shape = z.shape
457
+
458
+ # timestep embedding
459
+ temb = None
460
+
461
+ # z to block_in
462
+ h = self.conv_in(z)
463
+
464
+ # middle
465
+ h = self.mid.block_1(h, temb)
466
+ h = self.mid.attn_1(h)
467
+ h = self.mid.block_2(h, temb)
468
+
469
+ # upsampling
470
+ for i_level in reversed(range(self.num_resolutions)):
471
+ for i_block in range(self.num_res_blocks + 1):
472
+ h = self.up[i_level].block[i_block](h, temb)
473
+ if len(self.up[i_level].attn) > 0:
474
+ h = self.up[i_level].attn[i_block](h)
475
+ if i_level != 0:
476
+ h = self.up[i_level].upsample(h)
477
+
478
+ # end
479
+ if self.give_pre_end:
480
+ return h
481
+
482
+ h = self.norm_out(h)
483
+ h = nonlinearity(h)
484
+ h = self.conv_out(h)
485
+ return h
486
+
487
+
488
+ class VUNet(nn.Module):
489
+ def __init__(
490
+ self,
491
+ *,
492
+ ch,
493
+ out_ch,
494
+ ch_mult=(1, 2, 4, 8),
495
+ num_res_blocks,
496
+ attn_resolutions,
497
+ dropout=0.0,
498
+ resamp_with_conv=True,
499
+ in_channels,
500
+ c_channels,
501
+ resolution,
502
+ z_channels,
503
+ use_timestep=False,
504
+ **ignore_kwargs
505
+ ):
506
+ super().__init__()
507
+ self.ch = ch
508
+ self.temb_ch = self.ch * 4
509
+ self.num_resolutions = len(ch_mult)
510
+ self.num_res_blocks = num_res_blocks
511
+ self.resolution = resolution
512
+
513
+ self.use_timestep = use_timestep
514
+ if self.use_timestep:
515
+ # timestep embedding
516
+ self.temb = nn.Module()
517
+ self.temb.dense = nn.ModuleList(
518
+ [
519
+ torch.nn.Linear(self.ch, self.temb_ch),
520
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
521
+ ]
522
+ )
523
+
524
+ # downsampling
525
+ self.conv_in = torch.nn.Conv2d(c_channels, self.ch, kernel_size=3, stride=1, padding=1)
526
+
527
+ curr_res = resolution
528
+ in_ch_mult = (1,) + tuple(ch_mult)
529
+ self.down = nn.ModuleList()
530
+ for i_level in range(self.num_resolutions):
531
+ block = nn.ModuleList()
532
+ attn = nn.ModuleList()
533
+ block_in = ch * in_ch_mult[i_level]
534
+ block_out = ch * ch_mult[i_level]
535
+ for i_block in range(self.num_res_blocks):
536
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
537
+ block_in = block_out
538
+ if curr_res in attn_resolutions:
539
+ attn.append(AttnBlock(block_in))
540
+ down = nn.Module()
541
+ down.block = block
542
+ down.attn = attn
543
+ if i_level != self.num_resolutions - 1:
544
+ down.downsample = Downsample(block_in, resamp_with_conv)
545
+ curr_res = curr_res // 2
546
+ self.down.append(down)
547
+
548
+ self.z_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=1, stride=1, padding=0)
549
+ # middle
550
+ self.mid = nn.Module()
551
+ self.mid.block_1 = ResnetBlock(in_channels=2 * block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
552
+ self.mid.attn_1 = AttnBlock(block_in)
553
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
554
+
555
+ # upsampling
556
+ self.up = nn.ModuleList()
557
+ for i_level in reversed(range(self.num_resolutions)):
558
+ block = nn.ModuleList()
559
+ attn = nn.ModuleList()
560
+ block_out = ch * ch_mult[i_level]
561
+ skip_in = ch * ch_mult[i_level]
562
+ for i_block in range(self.num_res_blocks + 1):
563
+ if i_block == self.num_res_blocks:
564
+ skip_in = ch * in_ch_mult[i_level]
565
+ block.append(ResnetBlock(in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
566
+ block_in = block_out
567
+ if curr_res in attn_resolutions:
568
+ attn.append(AttnBlock(block_in))
569
+ up = nn.Module()
570
+ up.block = block
571
+ up.attn = attn
572
+ if i_level != 0:
573
+ up.upsample = Upsample(block_in, resamp_with_conv)
574
+ curr_res = curr_res * 2
575
+ self.up.insert(0, up) # prepend to get consistent order
576
+
577
+ # end
578
+ self.norm_out = Normalize(block_in)
579
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
580
+
581
+ def forward(self, x, z):
582
+ # assert x.shape[2] == x.shape[3] == self.resolution
583
+
584
+ if self.use_timestep:
585
+ # timestep embedding
586
+ assert t is not None
587
+ temb = get_timestep_embedding(t, self.ch)
588
+ temb = self.temb.dense[0](temb)
589
+ temb = nonlinearity(temb)
590
+ temb = self.temb.dense[1](temb)
591
+ else:
592
+ temb = None
593
+
594
+ # downsampling
595
+ hs = [self.conv_in(x)]
596
+ for i_level in range(self.num_resolutions):
597
+ for i_block in range(self.num_res_blocks):
598
+ h = self.down[i_level].block[i_block](hs[-1], temb)
599
+ if len(self.down[i_level].attn) > 0:
600
+ h = self.down[i_level].attn[i_block](h)
601
+ hs.append(h)
602
+ if i_level != self.num_resolutions - 1:
603
+ hs.append(self.down[i_level].downsample(hs[-1]))
604
+
605
+ # middle
606
+ h = hs[-1]
607
+ z = self.z_in(z)
608
+ h = torch.cat((h, z), dim=1)
609
+ h = self.mid.block_1(h, temb)
610
+ h = self.mid.attn_1(h)
611
+ h = self.mid.block_2(h, temb)
612
+
613
+ # upsampling
614
+ for i_level in reversed(range(self.num_resolutions)):
615
+ for i_block in range(self.num_res_blocks + 1):
616
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
617
+ if len(self.up[i_level].attn) > 0:
618
+ h = self.up[i_level].attn[i_block](h)
619
+ if i_level != 0:
620
+ h = self.up[i_level].upsample(h)
621
+
622
+ # end
623
+ h = self.norm_out(h)
624
+ h = nonlinearity(h)
625
+ h = self.conv_out(h)
626
+ return h
627
+
628
+
629
+ class SimpleDecoder(nn.Module):
630
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
631
+ super().__init__()
632
+ self.model = nn.ModuleList(
633
+ [
634
+ nn.Conv2d(in_channels, in_channels, 1),
635
+ ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
636
+ ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0),
637
+ ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
638
+ nn.Conv2d(2 * in_channels, in_channels, 1),
639
+ Upsample(in_channels, with_conv=True),
640
+ ]
641
+ )
642
+ # end
643
+ self.norm_out = Normalize(in_channels)
644
+ self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
645
+
646
+ def forward(self, x):
647
+ for i, layer in enumerate(self.model):
648
+ if i in [1, 2, 3]:
649
+ x = layer(x, None)
650
+ else:
651
+ x = layer(x)
652
+
653
+ h = self.norm_out(x)
654
+ h = nonlinearity(h)
655
+ x = self.conv_out(h)
656
+ return x
657
+
658
+
659
+ class UpsampleDecoder(nn.Module):
660
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
661
+ super().__init__()
662
+ # upsampling
663
+ self.temb_ch = 0
664
+ self.num_resolutions = len(ch_mult)
665
+ self.num_res_blocks = num_res_blocks
666
+ block_in = in_channels
667
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
668
+ self.res_blocks = nn.ModuleList()
669
+ self.upsample_blocks = nn.ModuleList()
670
+ for i_level in range(self.num_resolutions):
671
+ res_block = []
672
+ block_out = ch * ch_mult[i_level]
673
+ for i_block in range(self.num_res_blocks + 1):
674
+ res_block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
675
+ block_in = block_out
676
+ self.res_blocks.append(nn.ModuleList(res_block))
677
+ if i_level != self.num_resolutions - 1:
678
+ self.upsample_blocks.append(Upsample(block_in, True))
679
+ curr_res = curr_res * 2
680
+
681
+ # end
682
+ self.norm_out = Normalize(block_in)
683
+ self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
684
+
685
+ def forward(self, x):
686
+ # upsampling
687
+ h = x
688
+ for k, i_level in enumerate(range(self.num_resolutions)):
689
+ for i_block in range(self.num_res_blocks + 1):
690
+ h = self.res_blocks[i_level][i_block](h, None)
691
+ if i_level != self.num_resolutions - 1:
692
+ h = self.upsample_blocks[k](h)
693
+ h = self.norm_out(h)
694
+ h = nonlinearity(h)
695
+ h = self.conv_out(h)
696
+ return h
wham/models/vqgan/taming/quantize.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All files under this directory are originally from the taming-transformers repository:
2
+ # https://github.com/CompVis/taming-transformers
3
+
4
+ # MIT License
5
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
6
+ # 2023 Microsoft Research
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
19
+ # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20
+ # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
21
+ # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
22
+ # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
23
+ # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
24
+ # OR OTHER DEALINGS IN THE SOFTWARE.
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import numpy as np
29
+ from einops import rearrange
30
+
31
+
32
+ class VectorQuantizer2(nn.Module):
33
+ """
34
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
35
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
36
+ """
37
+
38
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
39
+ # backwards compatibility we use the buggy version by default, but you can
40
+ # specify legacy=False to fix it.
41
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
42
+ super().__init__()
43
+ self.n_e = n_e
44
+ self.e_dim = e_dim
45
+ self.beta = beta
46
+ self.legacy = legacy
47
+
48
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
49
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
50
+
51
+ self.remap = remap
52
+ if self.remap is not None:
53
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
54
+ self.re_embed = self.used.shape[0]
55
+ self.unknown_index = unknown_index # "random" or "extra" or integer
56
+ if self.unknown_index == "extra":
57
+ self.unknown_index = self.re_embed
58
+ self.re_embed = self.re_embed + 1
59
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices.")
60
+ else:
61
+ self.re_embed = n_e
62
+
63
+ self.sane_index_shape = sane_index_shape
64
+
65
+ def remap_to_used(self, inds):
66
+ ishape = inds.shape
67
+ assert len(ishape) > 1
68
+ inds = inds.reshape(ishape[0], -1)
69
+ used = self.used.to(inds)
70
+ match = (inds[:, :, None] == used[None, None, ...]).long()
71
+ new = match.argmax(-1)
72
+ unknown = match.sum(2) < 1
73
+ if self.unknown_index == "random":
74
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
75
+ else:
76
+ new[unknown] = self.unknown_index
77
+ return new.reshape(ishape)
78
+
79
+ def unmap_to_all(self, inds):
80
+ ishape = inds.shape
81
+ assert len(ishape) > 1
82
+ inds = inds.reshape(ishape[0], -1)
83
+ used = self.used.to(inds)
84
+ if self.re_embed > self.used.shape[0]: # extra token
85
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
86
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
87
+ return back.reshape(ishape)
88
+
89
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
90
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
91
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
92
+ assert return_logits == False, "Only for interface compatible with Gumbel"
93
+ # reshape z -> (batch, height, width, channel) and flatten
94
+ z = rearrange(z, "b c h w -> b h w c").contiguous()
95
+ z_flattened = z.view(-1, self.e_dim)
96
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
97
+
98
+ d = (
99
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
100
+ + torch.sum(self.embedding.weight**2, dim=1)
101
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
102
+ )
103
+
104
+ min_encoding_indices = torch.argmin(d, dim=1)
105
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
106
+ perplexity = None
107
+ min_encodings = None
108
+
109
+ # compute loss for embedding
110
+ if not self.legacy:
111
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
112
+ else:
113
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
114
+
115
+ # preserve gradients
116
+ z_q = z + (z_q - z).detach()
117
+
118
+ # reshape back to match original input shape
119
+ z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
120
+
121
+ if self.remap is not None:
122
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
123
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
124
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
125
+
126
+ if self.sane_index_shape:
127
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
128
+
129
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
130
+
131
+ def get_codebook_entry(self, indices, shape):
132
+ # shape specifying (batch, height, width, channel)
133
+ if self.remap is not None:
134
+ indices = indices.reshape(shape[0], -1) # add batch axis
135
+ indices = self.unmap_to_all(indices)
136
+ indices = indices.reshape(-1) # flatten again
137
+
138
+ # get quantized latent vectors
139
+ z_q = self.embedding(indices)
140
+
141
+ if shape is not None:
142
+ z_q = z_q.view(shape)
143
+ # reshape back to match original input shape
144
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
145
+
146
+ return z_q