Commit ·
aa16c0a
0
Parent(s):
Duplicate from microsoft/wham
Browse filesCo-authored-by: Katja Hofmann <katja-hofmann@users.noreply.huggingface.co>
This view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +38 -0
- CODE_OF_CONDUCT.md +10 -0
- CONTRIBUTING.md +14 -0
- LICENSE.md +96 -0
- README.md +243 -0
- SECURITY.md +37 -0
- WHAM_Demonstrator.zip +3 -0
- assets/Demonstrator/Fig_01.png +3 -0
- assets/Demonstrator/Fig_02.png +3 -0
- assets/Demonstrator/Fig_03.png +3 -0
- assets/Demonstrator/Fig_04.png +3 -0
- assets/Demonstrator/Fig_05.png +3 -0
- assets/Demonstrator/Fig_06.png +3 -0
- assets/Demonstrator/Fig_07.png +3 -0
- assets/Demonstrator/Fig_08.png +3 -0
- assets/Demonstrator/Fig_09.png +3 -0
- assets/Demonstrator/Fig_10.png +3 -0
- assets/Demonstrator/Fig_11.png +3 -0
- assets/Demonstrator/Fig_12.png +3 -0
- assets/Demonstrator/Fig_13.png +3 -0
- assets/Demonstrator/Fig_14.png +3 -0
- assets/Demonstrator/Fig_15.png +3 -0
- assets/Demonstrator/Fig_16.png +3 -0
- assets/Demonstrator/Fig_17.png +3 -0
- assets/Readme/model_capabilities.gif +3 -0
- assets/Readme/wham_gen_1.gif +3 -0
- assets/Readme/wham_gen_2.gif +3 -0
- assets/Readme/wham_gen_3.gif +3 -0
- assets/Readme/wham_gen_4.gif +3 -0
- assets/Readme/wham_gen_5.gif +3 -0
- assets/Readme/wham_gen_6.gif +3 -0
- assets/Readme/wham_gen_7.gif +3 -0
- assets/Readme/wham_gen_8.gif +3 -0
- assets/Readme/wham_gen_9.gif +3 -0
- configs/metadata_custom_tag.config +5 -0
- data_summary_card.md +145 -0
- models/WHAM_1.6B_v1.ckpt +3 -0
- models/WHAM_200M.ckpt +3 -0
- models/config.json +0 -0
- requirements.txt +48 -0
- run_dreaming.py +264 -0
- run_server.py +519 -0
- setup_local.sh +21 -0
- wham/models/nn/model_blocks.py +49 -0
- wham/models/nn/nanoGPT.py +665 -0
- wham/models/pl/__init__.py +0 -0
- wham/models/pl/pl_base_model.py +5 -0
- wham/models/vqgan/taming/LICENSE +24 -0
- wham/models/vqgan/taming/model.py +696 -0
- wham/models/vqgan/taming/quantize.py +146 -0
.gitattributes
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
fonts/arial.ttf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Microsoft Open Source Code of Conduct
|
| 2 |
+
|
| 3 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
| 4 |
+
|
| 5 |
+
Resources:
|
| 6 |
+
|
| 7 |
+
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
| 8 |
+
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
| 9 |
+
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
| 10 |
+
- Employees can reach out at [aka.ms/opensource/moderation-support](https://aka.ms/opensource/moderation-support)
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing
|
| 2 |
+
|
| 3 |
+
This project welcomes contributions and suggestions. Most contributions require you to
|
| 4 |
+
agree to a Contributor License Agreement (CLA) declaring that you have the right to,
|
| 5 |
+
and actually do, grant us the rights to use your contribution. For details, visit
|
| 6 |
+
https://cla.microsoft.com.
|
| 7 |
+
|
| 8 |
+
When you submit a pull request, a CLA-bot will automatically determine whether you need
|
| 9 |
+
to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
|
| 10 |
+
instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
|
| 11 |
+
|
| 12 |
+
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
| 13 |
+
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
| 14 |
+
or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
LICENSE.md
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MICROSOFT RESEARCH LICENSE TERMS
|
| 2 |
+
|
| 3 |
+
**IF YOU LIVE IN THE UNITED STATES, PLEASE READ THE “BINDING ARBITRATION AND CLASS ACTION WAIVER” SECTION BELOW. IT AFFECTS HOW DISPUTES ARE RESOLVED.**
|
| 4 |
+
|
| 5 |
+
These license terms are an agreement between you and Microsoft Corporation (or one of its affiliates). They apply to the source code, object code, machine learning models, or data (collectively “Materials”) that accompany this license. IF YOU COMPLY WITH THESE LICENSE TERMS, YOU HAVE THE RIGHTS BELOW. BY USING THE MATERIALS, YOU ACCEPT THESE TERMS.
|
| 6 |
+
|
| 7 |
+
## 1) INSTALLATION AND USE RIGHTS TO THE MATERIALS.
|
| 8 |
+
|
| 9 |
+
Subject to the terms of this agreement, you have the below rights, if applicable, to use the Materials solely for non-commercial, non-revenue generating, research purposes:
|
| 10 |
+
|
| 11 |
+
a) **Source Code.** If source code is included, you may use and modify the source code, but you may not distribute the source code.
|
| 12 |
+
|
| 13 |
+
b) **Object Code.** If object code is included, you may use the object code, but you may not distribute the object code.
|
| 14 |
+
|
| 15 |
+
c) **Models.** If machine learning model(s) are included, you may use the model(s), but you may not distribute the models.
|
| 16 |
+
|
| 17 |
+
d) **Data.** If data is included, you may use the data, but your use must be consistent with the consent under which the data was provided and/or gathered and you may not modify or distribute the data.
|
| 18 |
+
|
| 19 |
+
## 2) SCOPE OF LICENSE.
|
| 20 |
+
|
| 21 |
+
The Materials are licensed, not sold. Microsoft reserves all other rights. Unless applicable law gives you more rights despite this limitation, you will not (and have no right to):
|
| 22 |
+
|
| 23 |
+
a) Work around any technical limitations in the Materials that only allow you to use it in certain ways;
|
| 24 |
+
|
| 25 |
+
b) Reverse engineer, decompile or disassemble the Materials;
|
| 26 |
+
|
| 27 |
+
c) Remove, minimize, block, or modify any notices of Microsoft or its suppliers in the Materials;
|
| 28 |
+
|
| 29 |
+
d) Use the Materials in any way that is against the law or to create or propagate malware; or
|
| 30 |
+
|
| 31 |
+
e) Share, publish, distribute or lend the Materials, provide the Materials as a stand-alone hosted solution for others to use, or transfer the Materials or this agreement to any third party.
|
| 32 |
+
|
| 33 |
+
## 3) PERSONAL DATA.
|
| 34 |
+
|
| 35 |
+
If the data (set forth in Section 1(d) above) includes or is found to include any data that enables any ability to identify an individual ("Personal Data"), you will not use such Personal Data for any purpose other than was authorized and consented to by the data subject/research participant. You will not use Personal Data to contact any person. You will keep Personal Data in strict confidence. You will not share any Personal Data that is collected or in your possession with any third party for any reason and as required under the original consent agreement. Further, you will destroy the Personal Data and any backup or copies, **immediately upon the completion of your research.**
|
| 36 |
+
|
| 37 |
+
## 4) LICENSE TO MICROSOFT.
|
| 38 |
+
|
| 39 |
+
Notwithstanding the limitations in Section 1, you may distribute your modifications back to Microsoft, and if you do provide Microsoft with modifications of the Materials, you hereby grant Microsoft, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, post, distribute, make and have made, sell and transfer such modifications and derivatives for any purpose.
|
| 40 |
+
|
| 41 |
+
## 5) PUBLICATION.
|
| 42 |
+
|
| 43 |
+
You may publish (or present papers or articles) on your results from using the Materials provided that no material or substantial portion of the Materials is included in any such publication or presentation.
|
| 44 |
+
|
| 45 |
+
## 6) FEEDBACK.
|
| 46 |
+
|
| 47 |
+
Any feedback about the Materials provided by you to us is voluntarily given, and Microsoft shall be free to use the feedback as it sees fit without obligation or restriction of any kind, even if the feedback is designated by you as confidential. **Additional** Such feedback shall be considered a contribution and licensed to Microsoft under the terms of Section 4 above.
|
| 48 |
+
|
| 49 |
+
## 7) COMPLIANCE WITH TRADE LAWS.
|
| 50 |
+
|
| 51 |
+
You acknowledge that the Materials may be subject to applicable trade laws in one or more countries. You will comply with all relevant laws and regulations applicable to the import or export of the Materials, including but not limited to, trade laws such as the U.S. Export Administration Regulations or other end-user, end use, and destination restrictions by the U.S. and other governments, as well as sanctions regulations administered by the U.S. Office of Foreign Assets Control. Microsoft may suspend or terminate the agreement immediately to the extent that Microsoft reasonably concludes that continued performance would violate trade laws or put it at risk of becoming subject to sanctions or penalties under trade laws. For additional information, see www.microsoft.com/exporting.
|
| 52 |
+
|
| 53 |
+
## 8) SUPPORT SERVICES.
|
| 54 |
+
|
| 55 |
+
Microsoft is not obligated under this agreement to provide any support services for the Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
| 56 |
+
|
| 57 |
+
## 9) BINDING ARBITRATION AND CLASS ACTION WAIVER.
|
| 58 |
+
|
| 59 |
+
**This Section applies if you live in (or, if a business, your principal place of business is in) the United States.** If you and Microsoft have a dispute, you and Microsoft agree to try for 60 days to resolve it informally. If you and Microsoft can’t, you and Microsoft agree to **binding individual arbitration before the American Arbitration Association** under the Federal Arbitration Act ("FAA"), and not to **sue in court in front of a judge or jury.** Instead, a neutral arbitrator will decide. **Class action lawsuits, class-wide arbitrations, private attorney-general actions,** and any other proceeding where someone acts in a representative capacity **are not allowed;** nor is combining individual proceedings without the consent of all parties. The complete Arbitration Agreement contains more terms and is at aka.ms/arb-agreement-1. You and Microsoft agree to these terms.
|
| 60 |
+
|
| 61 |
+
## 10) ENTIRE AGREEMENT.
|
| 62 |
+
|
| 63 |
+
This agreement, and any other terms Microsoft may provide for supplements, updates, or third-party applications, is the entire agreement for the Materials.
|
| 64 |
+
|
| 65 |
+
## 11) APPLICABLE LAW AND PLACE TO RESOLVE DISPUTES.
|
| 66 |
+
|
| 67 |
+
If you acquired the Materials in the United States or Canada, the laws of the state or province where you live (or, if a business, where your principal place of business is located) govern the interpretation of this agreement, claims for its breach, and all other claims (including consumer protection, unfair competition, and tort claims), regardless of conflict of laws principles, except that the FAA governs everything related to arbitration. If you acquired the Materials in any other country, its laws apply, except that the FAA governs everything related to arbitration. If U.S. federal jurisdiction exists, you and Microsoft consent to exclusive jurisdiction and venue in the federal court in King County, Washington for all disputes heard in court (excluding arbitration). If not, you and Microsoft consent to exclusive jurisdiction and venue in the Superior Court of King County, Washington for all disputes heard in court (excluding arbitration).
|
| 68 |
+
|
| 69 |
+
## 12) CONSUMER RIGHTS; REGIONAL VARIATIONS.
|
| 70 |
+
|
| 71 |
+
This agreement describes certain legal rights. You may have other rights, including consumer rights, under the laws of your state, province, or country. Separate and apart from your relationship with Microsoft, you may also have rights with respect to the party from which you acquired the Materials. This agreement does not change those other rights if the laws of your state, province, or country do not permit it to do so. For example, if you acquired the Materials in one of the below regions, or mandatory country law applies, then the following provisions apply to you:
|
| 72 |
+
|
| 73 |
+
a) **Australia.** You have statutory guarantees under the Australian Consumer Law and nothing in this agreement is intended to affect those rights.
|
| 74 |
+
|
| 75 |
+
b) **Canada.** If you acquired this software in Canada, you may stop receiving updates by turning off the automatic update feature, disconnecting your device from the Internet (if and when you re-connect to the Internet, however, the Materials will resume checking for and installing updates), or uninstalling the Materials. The product documentation, if any, may also specify how to turn off updates for your specific device or software.
|
| 76 |
+
|
| 77 |
+
c) **Germany and Austria.**
|
| 78 |
+
i. **Warranty.** The properly licensed software will perform substantially as described in any Microsoft materials that accompany the Materials. However, Microsoft gives no contractual guarantee in relation to the licensed software.
|
| 79 |
+
ii. **Limitation of Liability.** In case of intentional conduct, gross negligence, claims based on the Product Liability Act, as well as, in case of death or personal or physical injury, Microsoft is liable according to the statutory law.
|
| 80 |
+
|
| 81 |
+
Subject to the foregoing clause (ii), Microsoft will only be liable for slight negligence if Microsoft is in breach of such material contractual obligations, the fulfillment of which facilitate the due performance of this agreement, the breach of which would endanger the purpose of this agreement and the compliance with which a party may constantly trust in (so-called "cardinal obligations"). In other cases of slight negligence, Microsoft will not be liable for slight negligence.
|
| 82 |
+
|
| 83 |
+
## 13) DISCLAIMER OF WARRANTY.
|
| 84 |
+
|
| 85 |
+
THE MATERIALS ARE LICENSED "AS IS." YOU BEAR THE RISK OF USING THEM. MICROSOFT GIVES NO EXPRESS WARRANTIES, GUARANTEES, OR CONDITIONS. TO THE EXTENT PERMITTED UNDER APPLICABLE LAWS, MICROSOFT EXCLUDES ALL IMPLIED WARRANTIES, INCLUDING MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT.
|
| 86 |
+
|
| 87 |
+
## 14) LIMITATION ON AND EXCLUSION OF DAMAGES.
|
| 88 |
+
|
| 89 |
+
IF YOU HAVE ANY BASIS FOR RECOVERING DAMAGES DESPITE THE PRECEDING DISCLAIMER OF WARRANTY, YOU CAN RECOVER FROM MICROSOFT AND ITS SUPPLIERS ONLY DIRECT DAMAGES UP TO U.S. $5.00. YOU CANNOT RECOVER ANY OTHER DAMAGES, INCLUDING CONSEQUENTIAL, LOST PROFITS, SPECIAL, INDIRECT OR INCIDENTAL DAMAGES.
|
| 90 |
+
|
| 91 |
+
This limitation applies to:
|
| 92 |
+
- (a) anything related to the Materials, services, content (including code) on third party Internet sites, or third party applications; and
|
| 93 |
+
- (b) claims for breach of contract, warranty, guarantee, or condition; strict liability, negligence, or other tort; or any other claim; in each case to the extent permitted by applicable law.
|
| 94 |
+
|
| 95 |
+
It also applies even if Microsoft knew or should have known about the possibility of the damages. The above limitation or exclusion may not apply to you because your state, province, or country may not allow the exclusion or limitation of incidental, consequential, or other damages.
|
| 96 |
+
|
README.md
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
datasets:
|
| 3 |
+
- microsoft/bleeding-edge-gameplay-sample
|
| 4 |
+
tags:
|
| 5 |
+
- wham
|
| 6 |
+
- microsoft
|
| 7 |
+
language:
|
| 8 |
+
- en
|
| 9 |
+
license_link: LICENSE.md
|
| 10 |
+
---
|
| 11 |
+
# World and Human Action Model (WHAM)
|
| 12 |
+
📄 [Paper](https://www.nature.com/articles/s41586-025-08600-3) • 🔗 [Sample Data](https://huggingface.co/datasets/microsoft/bleeding-edge-gameplay-sample)
|
| 13 |
+
<div align="center">
|
| 14 |
+
Anssi Kanervisto, Dave Bignell, Linda Yilin Wen, Martin Grayson, Raluca Georgescu, Sergio Valcarcel Macua, Shan Zheng Tan, Tabish Rashid, Tim Pearce, Yuhan Cao,
|
| 15 |
+
Abdelhak Lemkhenter, Chentian Jiang, Gavin Costello, Gunshi Gupta, Marko Tot, Shu Ishida, Tarun Gupta, Udit Arora,
|
| 16 |
+
Ryen W. White, Sam Devlin, Cecily Morrison, Katja Hofmann
|
| 17 |
+
</div><br>
|
| 18 |
+
<div align='center'>
|
| 19 |
+
Dynamic Generated Gameplay Sequence using WHAM. Showcasing diverse characters and actions across intricate maps.
|
| 20 |
+
<div style="display: flex; flex-wrap: wrap;">
|
| 21 |
+
<img style="width: calc(33.33%); margin-bottom: -35px;" src="assets/Readme/wham_gen_1.gif">
|
| 22 |
+
<img style="width: calc(33.33%); margin-bottom: -35px;" src="assets/Readme/wham_gen_2.gif">
|
| 23 |
+
<img style="width: calc(33.33%); margin-bottom: -35px;" src="assets/Readme/wham_gen_3.gif">
|
| 24 |
+
<img style="width: calc(33.33%); margin-bottom: -35px;" src="assets/Readme/wham_gen_4.gif">
|
| 25 |
+
<img style="width: calc(33.33%); margin-bottom: -35px;" src="assets/Readme/wham_gen_5.gif">
|
| 26 |
+
<img style="width: calc(33.33%); margin-bottom: -35px;" src="assets/Readme/wham_gen_6.gif">
|
| 27 |
+
<img style="width: calc(33.33%);" src="assets/Readme/wham_gen_7.gif">
|
| 28 |
+
<img style="width: calc(33.33%);" src="assets/Readme/wham_gen_8.gif">
|
| 29 |
+
<img style="width: calc(33.33%);" src="assets/Readme/wham_gen_9.gif">
|
| 30 |
+
</div>
|
| 31 |
+
</div><br>
|
| 32 |
+
<div align='center'>
|
| 33 |
+
WHAM is capable of generating consistent, diverse, and persistent outputs, enabling various use cases for creative iteration.
|
| 34 |
+
<img style="width: 100%;" src="assets/Readme/model_capabilities.gif">
|
| 35 |
+
</div>
|
| 36 |
+
|
| 37 |
+
Muse is powered by a World and Human Action Model (WHAM), which is a generative model of gameplay (visuals and/or controller actions) trained on gameplay data of Ninja Theory’s Xbox game Bleeding Edge. Model development was informed by requirements of game creatives that we identified through a user study. Our goal is to explore the capabilities that generative AI models need to support human creative exploration. WHAM is developed by the [Game Intelligence group](https://www.microsoft.com/en-us/research/group/game-intelligence/) at [Microsoft Research](https://www.microsoft.com/en-us/research/), in collaboration with [TaiX](https://www.microsoft.com/en-us/research/project/taix/) and [Ninja Theory](https://ninjatheory.com/).
|
| 38 |
+
|
| 39 |
+
# Model Card
|
| 40 |
+
|
| 41 |
+
WHAM is an autoregressive model that has been trained to predict (tokenized) game visuals and controller actions given a prompt. Prompts here can be either visual (one or more initial game visuals) and / or controller actions. This allows the user to run the model in (a) world modelling mode (generate visuals given controller actions), (b) behavior policy (generate controller actions given past visuals), or (c) generate both visuals and behavior.
|
| 42 |
+
|
| 43 |
+
WHAM consists of two components, an encoder-decoder [VQ-GAN](https://compvis.github.io/taming-transformers/) trained to encode game visuals to a discrete representation, and a transformer backbone trained to perform next-token prediction. We train both components from scratch. The resulting model can generate consistent game sequences, and shows evidence of capturing the 3D structure of the game environment, the effects of controller actions, and the temporal structure of the game (up to the model’s context length).
|
| 44 |
+
|
| 45 |
+
WHAM was trained on human gameplay data to predict game visuals and players’ controller actions. We worked with the game studio Ninja Theory and their game [Bleeding Edge](https://www.bleedingedge.com/) – a 3D, 4v4 multiplayer video game. From the resulting data we extracted one year’s worth of anonymized gameplay from 27,990 players, capturing a wide range of behaviors and interactions. A sample of this data is provided [here](https://huggingface.co/datasets/microsoft/bleeding-edge-gameplay-sample)
|
| 46 |
+
|
| 47 |
+
## Model Details
|
| 48 |
+
|
| 49 |
+
### Trained Models
|
| 50 |
+
|
| 51 |
+
In this release we provide the weights of two WHAM instances: 200M WHAM and 1.6B WHAM. Both have been trained from scratch on the same data set. 1.6B WHAM is evaluated in [our paper](https://www.nature.com/articles/s41586-025-08600-3). We additionally provide 200M WHAM as a more lightweight option for faster explorations.
|
| 52 |
+
- [WHAM with 200M parameters](models/WHAM_200M.ckpt), model size: 3.7GB
|
| 53 |
+
- [WHAM with 1.6B parameters](models/WHAM_1.6B_v1.ckpt), model size: 18.9GB
|
| 54 |
+
|
| 55 |
+
## Usage
|
| 56 |
+
|
| 57 |
+
### System Requirements
|
| 58 |
+
|
| 59 |
+
The steps below have been tested on the following setup:
|
| 60 |
+
- Linux workstation with Ubuntu 20.04.4 LTS
|
| 61 |
+
- Windows 11 workstation running WSL2 with Ubuntu 20.04.6 LTS
|
| 62 |
+
|
| 63 |
+
The current setup assumes that a CUDA-supported GPU is available for model inference. This has been tested on systems with `NVIDIA RTX A6000` and `NVIDIA A100` respectively. In addition, approximately `15GB` of free hard disk space is required for dowmloading the models.
|
| 64 |
+
|
| 65 |
+
The steps under Installation assume a python 3.9 installation that can be
|
| 66 |
+
called using the command `python3.9` and the venv package for creating virtual environments. If either of these is not present, you can install this version of python under Ubuntu using:
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
sudo apt install python3.9
|
| 70 |
+
sudo apt install python3.9-venv
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
If you are using the WHAM Demonstrator, please ensure that you have the required [.NET Core Runtime](https://dotnet.microsoft.com/en-us/download/dotnet/7.0). If this is not yet installed, an error message will pop up from which you can follow a link to download and install this package.
|
| 74 |
+
|
| 75 |
+
### Installation
|
| 76 |
+
|
| 77 |
+
1. Clone this repository. We recommend starting without the large model files, using `GIT_LFS_SKIP_SMUDGE=1 git clone git@hf.co:microsoft/WHAM`
|
| 78 |
+
2. `cd WHAM`
|
| 79 |
+
3. `./setup_local.sh`
|
| 80 |
+
|
| 81 |
+
This will set up a `python3.9` virtual environment and install the required packages (this includes packages required for the model server). The typical install time should be approximately 5 minutes.
|
| 82 |
+
|
| 83 |
+
4. Run `source venv/bin/activate` whenever you want to run model inference or the model server
|
| 84 |
+
|
| 85 |
+
5. Download model from this HuggingFace repository (See note below):
|
| 86 |
+
1. Go to Files and versions and navigate to the `models` folder.
|
| 87 |
+
2. Download the model checkpoint. The instructions below assume that the model checkpoints have been downloaded to your local `models` folder.
|
| 88 |
+
|
| 89 |
+
**Note:** On Linux systems, you can use `git clone` to clone the enire repository, including large files. Due to a limitation of `git lfs` on Windows, only files up to `4GB` are supported and we recommend downloading the model files manually from the `models` folder.
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
### Local Model Inference
|
| 93 |
+
|
| 94 |
+
This section assumes that you have followed the installation steps above.
|
| 95 |
+
|
| 96 |
+
(Optional) Download [sample data](https://huggingface.co/datasets/microsoft/bleeding-edge-gameplay-sample). For the local inference examples below, we recommend that you start with the `tiny-sample` set of only 4 trajectories for your initial exploration.
|
| 97 |
+
|
| 98 |
+
You can now run model inference to generate gameplay sequences as follows:
|
| 99 |
+
|
| 100 |
+
```python
|
| 101 |
+
python run_dreaming.py --model_path <path_to_checkpoint.ckpt> --data_path <path_to_sample_data_folder>
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
To run the 200M parameter (small) model (if you copied the tiny-sample folder to the root directory):
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
python run_dreaming.py --model_path models/WHAM_200M.ckpt --data_path tiny-sample
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
This uses the data in `data_path` as initial prompt sequences. The script will create a `dreaming_output` directory which will create two files per ground truth data file:
|
| 111 |
+
- An `.npz` file that contains a number of entries, most important of which are:
|
| 112 |
+
- `encoded_decoded_ground_truth_images`: the original context images, encoded and decoded with the VQGAN.
|
| 113 |
+
- `dreamt_images`: the sequence of all dreamt images.
|
| 114 |
+
- An `.mp4` file of the context data + dreamt images for easier viewing.
|
| 115 |
+
|
| 116 |
+
This requires approximately 4.5GB of VRAM on a single A6000, but only uses batch size of one. To speed up the process, increase batch size with `--batch_size` argument. With a single A6000 and `--batch_size 12` this uses approximately 30GB of VRAM. Generating gameplay sequences from the full 512 video dataset takes around 24 hours.
|
| 117 |
+
|
| 118 |
+
Please note that the first output from the script is generated when the first gameplay sequence has been generated. This may take several minutes when using an `A6000` GPU, or longer for older generation GPUs.
|
| 119 |
+
|
| 120 |
+
See python `run_dreaming.py --help` for different settings.
|
| 121 |
+
|
| 122 |
+
### WHAM Demonstrator
|
| 123 |
+
|
| 124 |
+
#### Setting up the Model Server
|
| 125 |
+
|
| 126 |
+
We have tested the server code as provided on a single Linux machine with four `A6000 GPUs` (large model) as well as on a Windows machine running Ubuntu under `WSL2`, equipped with a single `GeForce GTX 1080` (small model). Model inferences can be run on lower spec NVIDIA GPUs by reducing the batchsize.
|
| 127 |
+
|
| 128 |
+
The steps below assume that the installation steps above have been followed and that the model files have been downloaded to your local machine.
|
| 129 |
+
|
| 130 |
+
In your terminal, activate the newly installed virtual environment (if it isn't already):
|
| 131 |
+
|
| 132 |
+
```bash
|
| 133 |
+
source venv/bin/activate
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
Start the server, pointing it to the model:
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
python run_server.py --model <path_to_model_file>
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
To run the 200M parameter (small) model:
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
python run_server.py --model models/WHAM_200M.ckpt
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
To run the 1.6B parameter (large) model:
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
python run_server.py --model models/WHAM_1.6B_v1.ckpt
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
The server will start and by default listen on localhost port 5000 (this can be configured with `--port <port>`).
|
| 156 |
+
|
| 157 |
+
**Note:** If you run out of VRAM when running the server, you can reduce the `MAX_BATCH_SIZE` variable in `run_server.py`.
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
#### Install the WHAM Demonstrator App (Windows only)
|
| 161 |
+
|
| 162 |
+
After cloning or downloading this repository, navigate to the folder `wham/wham_demonstrator`, and start the Windows application `WHAMDemonstrator.exe` within that folder.
|
| 163 |
+
|
| 164 |
+
Follow the instructions in the provided README.md within WHAM Demonstrator to connect to your model server and get an overview of supported functionality.
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
## Intended Uses
|
| 168 |
+
|
| 169 |
+
This model and accompanying code are intended for academic research purposes only. WHAM has been trained on gameplay data from a single game, Bleeding Edge, and is intended to be used to generate plausible gameplay sequences resembling this game.
|
| 170 |
+
|
| 171 |
+
The model is not intended to be used to generate imagery outside of the game Bleeding Edge. Generated images include watermark and provenance metadata. Do not remove the watermark or provenance metadata..
|
| 172 |
+
|
| 173 |
+
WHAM can be used in multiple scenarios. The following list illustrates the types of tasks that WHAM can be used for:
|
| 174 |
+
- World Model: Visuals are predicted, given a real starting state and action sequence.
|
| 175 |
+
- Behaviour Policy: Given visuals, the model predicts the next controller action.
|
| 176 |
+
- Full Generation: The model generates both the visuals and the controller actions a human player might take in the game.
|
| 177 |
+
|
| 178 |
+
## Training
|
| 179 |
+
|
| 180 |
+
### Model
|
| 181 |
+
|
| 182 |
+
- Architecture: A decoder-only transformer that predicts the next token corresponding to an interleaved sequence of observations and actions. The image tokenizer is a VQ-GAN.
|
| 183 |
+
- Context length: 10 (observation, action) pairs / 5560 tokens
|
| 184 |
+
- Dataset size: The model was trained on data from approximately `500,000` Bleeding Edge games from all seven game maps (over 1 billion observation, action pairs 10Hz, equivalent to over 7 years of continuous human gameplay). A data sample is provided in [bleeding-edge-gameplay-sample](https://huggingface.co/datasets/microsoft/bleeding-edge-gameplay-sample). This is the test data used for our evaluation results, and has the same format as the training data.
|
| 185 |
+
- GPUs: 98xH100 GPUs
|
| 186 |
+
- Training time: 5 days
|
| 187 |
+
|
| 188 |
+
### Software
|
| 189 |
+
|
| 190 |
+
- [PyTorch Lightning](https://github.com/pytorch/pytorch)
|
| 191 |
+
- [Flash-Attention](https://github.com/HazyResearch/flash-attention)
|
| 192 |
+
- [ffmpeg](https://github.com/FFmpeg/FFmpeg)
|
| 193 |
+
- [exiftool](https://github.com/exiftool/exiftool)
|
| 194 |
+
|
| 195 |
+
## Bias, Risks and Limitations
|
| 196 |
+
|
| 197 |
+
- The training data represents gameplay recordings from a variety of skilled and unskilled gameplayers, representing diverse demographic characteristics. Not all possible player characteristics are represented and model performance may therefore vary.
|
| 198 |
+
- The model, as it is, can only be used to generate visuals and controller inputs. Users should not manipulate images and attempt to generate offensive scenes.
|
| 199 |
+
|
| 200 |
+
### Technical limitations, operational factors, and ranges
|
| 201 |
+
|
| 202 |
+
Model:
|
| 203 |
+
- Trained on a single game, very specialized, not intended for image prompts that are out of context or from other domains
|
| 204 |
+
- Limited context length (10s)
|
| 205 |
+
- Limited image resolution (300px x 180px), the model can only generate images at this fixed resolution.
|
| 206 |
+
- Generated images and controls can incorrect or unrecognizable.
|
| 207 |
+
- Inference time is currently too slow for real-time use.
|
| 208 |
+
|
| 209 |
+
WHAM Demonstrator:
|
| 210 |
+
- Developed as a way to explore potential interactions. This is not intended as a fully-fledged user experience or demo.
|
| 211 |
+
|
| 212 |
+
Models trained using game data may potentially behave in ways that are unfair, unreliable, or offensive, in turn causing harms. We emphasize that these types of harms are not mutually exclusive. A single model can exhibit more than one type of harm, potentially relating to multiple different groups of people. For example, the output of the model can be nonsensical or might look reasonable but is inaccurate with respect to external validation sources.
|
| 213 |
+
Although users can input any image as a starting point, the model is only trained to generate images and controller actions based on the structure of the Bleeding Edge game environment that it has learned from the training data. Out of domain inputs lead to unpredictable results. For example, this could include a sequence of images that dissolve into unrecognizable blobs .
|
| 214 |
+
Model generations when “out of scope” image elements are introduced will either:
|
| 215 |
+
- Dissolve into unrecognizable blobs of color.
|
| 216 |
+
- Morphed into game-relevant items such as game characters.
|
| 217 |
+
|
| 218 |
+
## Evaluating WHAM
|
| 219 |
+
WHAM is evaluated based on its consistency, diversity, and persistency. Consistency is measured using Fréchet Video Distance (FVD), while diversity is assessed by comparing the marginal distribution of real human actions to those generated by the model using the Wasserstein distance. Persistency is tested using two scenarios: by adding a static power-up object to a game visual and by adding another player character to a game visual used for prompting the model. For detailed evaluation results, see the paper that [introduces the model](https://www.nature.com/articles/s41586-025-08600-3).
|
| 220 |
+
|
| 221 |
+
### Responsible AI testing
|
| 222 |
+
WHAM has been tested with out of context prompt images to evaluate the risk of outputting harmful or nonsensical images. The generated image sequences did not retain the initial image, but rather dissolved into either unrecognizable blobs or to scenes resembling the training environment.
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
## License
|
| 226 |
+
|
| 227 |
+
The model is licensed under the [Microsoft Research License](LICENSE.md)
|
| 228 |
+
|
| 229 |
+
this work has been funded by Microsoft Research
|
| 230 |
+
|
| 231 |
+
## Privacy & Ethics Statement
|
| 232 |
+
|
| 233 |
+
[Microsoft Privacy Statement](https://go.microsoft.com/fwlink/?LinkId=521839)
|
| 234 |
+
|
| 235 |
+
## Trademark Notice
|
| 236 |
+
|
| 237 |
+
**Trademarks** This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow [Microsoft’s Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party’s policies.
|
| 238 |
+
|
| 239 |
+
## Contact Information
|
| 240 |
+
For questions please email to muse@microsoft.com
|
| 241 |
+
|
| 242 |
+
## Data Summary
|
| 243 |
+
https://huggingface.co/microsoft/wham/blob/main/data_summary_card.md
|
SECURITY.md
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Security
|
| 2 |
+
|
| 3 |
+
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
|
| 4 |
+
|
| 5 |
+
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
|
| 6 |
+
|
| 7 |
+
## Reporting Security Issues
|
| 8 |
+
|
| 9 |
+
**Please do not report security vulnerabilities through public GitHub issues.**
|
| 10 |
+
|
| 11 |
+
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
|
| 12 |
+
|
| 13 |
+
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
|
| 14 |
+
|
| 15 |
+
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
| 16 |
+
|
| 17 |
+
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
| 18 |
+
|
| 19 |
+
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
| 20 |
+
* Full paths of source file(s) related to the manifestation of the issue
|
| 21 |
+
* The location of the affected source code (tag/branch/commit or direct URL)
|
| 22 |
+
* Any special configuration required to reproduce the issue
|
| 23 |
+
* Step-by-step instructions to reproduce the issue
|
| 24 |
+
* Proof-of-concept or exploit code (if possible)
|
| 25 |
+
* Impact of the issue, including how an attacker might exploit the issue
|
| 26 |
+
|
| 27 |
+
This information will help us triage your report more quickly.
|
| 28 |
+
|
| 29 |
+
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
|
| 30 |
+
|
| 31 |
+
## Preferred Languages
|
| 32 |
+
|
| 33 |
+
We prefer all communications to be in English.
|
| 34 |
+
|
| 35 |
+
## Policy
|
| 36 |
+
|
| 37 |
+
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
|
WHAM_Demonstrator.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8d19ef23d22081044202e464e48dd4f5f6f232215a4cc797ab1a75dd3eb0e0d9
|
| 3 |
+
size 3565673
|
assets/Demonstrator/Fig_01.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_02.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_03.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_04.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_05.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_06.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_07.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_08.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_09.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_10.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_11.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_12.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_13.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_14.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_15.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_16.png
ADDED
|
Git LFS Details
|
assets/Demonstrator/Fig_17.png
ADDED
|
Git LFS Details
|
assets/Readme/model_capabilities.gif
ADDED
|
Git LFS Details
|
assets/Readme/wham_gen_1.gif
ADDED
|
Git LFS Details
|
assets/Readme/wham_gen_2.gif
ADDED
|
Git LFS Details
|
assets/Readme/wham_gen_3.gif
ADDED
|
Git LFS Details
|
assets/Readme/wham_gen_4.gif
ADDED
|
Git LFS Details
|
assets/Readme/wham_gen_5.gif
ADDED
|
Git LFS Details
|
assets/Readme/wham_gen_6.gif
ADDED
|
Git LFS Details
|
assets/Readme/wham_gen_7.gif
ADDED
|
Git LFS Details
|
assets/Readme/wham_gen_8.gif
ADDED
|
Git LFS Details
|
assets/Readme/wham_gen_9.gif
ADDED
|
Git LFS Details
|
configs/metadata_custom_tag.config
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
%Image::ExifTool::UserDefined = (
|
| 2 |
+
'Image::ExifTool::XMP::xmp' => {
|
| 3 |
+
'ProgramName' => { Name => 'ProgramName', Writable => 'string' }
|
| 4 |
+
}
|
| 5 |
+
);
|
data_summary_card.md
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
# Data Summary for microsoft_wham
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
## 1. General information
|
| 10 |
+
|
| 11 |
+
**1.0.1 Version of the Summary:** 1.0
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
**1.0.2 Last update:** 16-Dec-2025
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
## 1.1 Model Developer Identification
|
| 20 |
+
|
| 21 |
+
**1.1.1 Model Developer name and contact details:** Microsoft Corporation at One Microsoft Way, Redmond, WA 98052. Tel: 425-882-8080
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## 1.2 Model Identification
|
| 26 |
+
|
| 27 |
+
**1.2.1 Versioned model name(s):** WHAM 1.6B, WHAM 200M
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
**1.2.2 Model release date:** 19-Feb-2025
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
## 1.3 Overall training data size and characteristics
|
| 36 |
+
|
| 37 |
+
### 1.3.1 Size of dataset and characteristics
|
| 38 |
+
|
| 39 |
+
**1.3.1.A Text training data size:** Not applicable. Text data is not part of the training data.
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
**1.3.1.B Text training data content:** Not applicable.
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
**1.3.1.C Image training data size:** More than 1 billion images
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
**1.3.1.D Image training data content:** Sequences of rendered gameplay frames, constructed by replaying telemetry data to render visuals. Visuals were captured at 60 fps and downsampled to 10 Hz. They depict the 3D game environment, characters within the environment, UI elements, and dynamic in-game interactions across seven maps.
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
**1.3.1.E Audio training data size:** Not applicable. Audio data is not part of the training data.
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
**1.3.1.F Audio training data content:** Not applicable.
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
**1.3.1.G Video training data size:** Not applicable. Videos are not part of the training data.
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
**1.3.1.H Video training data content:** Not applicable
|
| 68 |
+
|
| 69 |
+
**1.3.1.I Other training data size:** Controller action sequences paired with each frame, totaling approximately 1.4 billion action entries after downsampling to 10 Hz.
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
**1.3.1.J Other training data content:** Controller input logs (telemetry) for gameplay including button states and discretized joystick positions aligned to each video frame to represent player behavior.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
**1.3.2 Latest date of data acquisition/collection for model training:** 31-Oct-2022
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
**1.3.3 Is data collection ongoing to update the model with new data collection after deployment?** No
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
**1.3.4 Date the training dataset was first used to train the model:** 12-Jan-2022
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
**1.3.5 Rationale or purpose of data selection:** The dataset was constructed using gameplay telemetry. This telemetry data was used to render and export visuals (frames) and meta-data, including events corresponding to controller actions. The resulting data was used to learn consistent 3D world dynamics, effects of controller actions, and temporal structure directly from the telemetry. The dataset provides rich, temporally correlated multimodal signals (frames and controller actions) aligning with the model’s goal to support creative ideation of dynamic 3D environments via consistency, diversity, and persistency.
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
## 2. List of data sources
|
| 94 |
+
|
| 95 |
+
### 2.1 Publicly available datasets
|
| 96 |
+
|
| 97 |
+
**2.1.1 Have you used publicly available datasets to train the model?** No
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
## 2.2 Private non-publicly available datasets obtained from third parties
|
| 102 |
+
|
| 103 |
+
### 2.2.1 Datasets commercially licensed by rights holders or their representatives
|
| 104 |
+
|
| 105 |
+
**2.2.1.A Have you concluded transactional commercial licensing agreement(s) with rights holder(s) or with their representatives?** Not applicable
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
### 2.2.2 Private datasets obtained from other third-parties
|
| 110 |
+
|
| 111 |
+
**2.2.2.A Have you obtained private datasets from third parties that are not licensed as described in Section 2.2.1, such as data obtained from providers of private databases, or data intermediaries?** No
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
## 2.3 Personal Information
|
| 116 |
+
|
| 117 |
+
**2.3.1 Was personal data used to train the model?** Microsoft follows all relevant laws and regulations pertaining to personal information
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
## 2.4 Synthetic data
|
| 122 |
+
|
| 123 |
+
**2.4.1 Was any synthetic AI-generated data used to train the model?** No
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
## 3. Data processing aspects
|
| 128 |
+
|
| 129 |
+
### 3.1 Respect of reservation of rights from text and data mining exception or limitation
|
| 130 |
+
|
| 131 |
+
**3.1.1 Does this dataset include any data protected by copyright, trademark, or patent?** Microsoft follows all required regulations and laws for processing data protected by copyright, trademark, or patent
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## 3.2 Other information
|
| 136 |
+
|
| 137 |
+
**3.2.1 Does the dataset include information about consumer groups without revealing individual consumer identities?** Microsoft follows all required regulations and laws for protecting consumer identities
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
**3.2.2 Was the dataset cleaned or modified before model training?** Yes
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
models/WHAM_1.6B_v1.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c4997074883aa1a39a5994a7dea91fb62b2382fc039523458827adb777af8e9
|
| 3 |
+
size 20339650059
|
models/WHAM_200M.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5ddb8e03a33f0849a63da030fea3de4994d95e16888993b8ab92faa904f3b31f
|
| 3 |
+
size 3980245067
|
models/config.json
ADDED
|
File without changes
|
requirements.txt
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--find-links https://download.pytorch.org/whl/torch_stable.html
|
| 2 |
+
aiohttp==3.9.3
|
| 3 |
+
aiosignal==1.3.1
|
| 4 |
+
async-timeout==4.0.3
|
| 5 |
+
attrs==23.2.0
|
| 6 |
+
blinker==1.7.0
|
| 7 |
+
certifi==2024.2.2
|
| 8 |
+
charset-normalizer==3.3.2
|
| 9 |
+
click==8.1.7
|
| 10 |
+
cloudpickle==3.0.0
|
| 11 |
+
cmake==3.28.3
|
| 12 |
+
einops==0.6.0
|
| 13 |
+
ffmpegcv==0.3.10
|
| 14 |
+
filelock==3.13.1
|
| 15 |
+
Flask==3.0.2
|
| 16 |
+
frozenlist==1.4.1
|
| 17 |
+
fsspec==2024.2.0
|
| 18 |
+
idna==3.6
|
| 19 |
+
importlib_metadata==7.0.2
|
| 20 |
+
itsdangerous==2.1.2
|
| 21 |
+
Jinja2==3.1.3
|
| 22 |
+
lightning-utilities==0.10.1
|
| 23 |
+
lit==17.0.6
|
| 24 |
+
MarkupSafe==2.1.5
|
| 25 |
+
mpmath==1.3.0
|
| 26 |
+
multidict==6.0.5
|
| 27 |
+
networkx==3.2.1
|
| 28 |
+
numpy==1.25.2
|
| 29 |
+
opencv-python==4.6.0.66
|
| 30 |
+
opencv-python-headless==4.9.0.80
|
| 31 |
+
packaging==23.2
|
| 32 |
+
pillow==10.2.0
|
| 33 |
+
pytorch-lightning==1.9.4
|
| 34 |
+
PyYAML==6.0.1
|
| 35 |
+
requests==2.31.0
|
| 36 |
+
sympy==1.12
|
| 37 |
+
tensordict==0.1.2
|
| 38 |
+
torch==2.0.1+cu118
|
| 39 |
+
torchinfo==1.7.1
|
| 40 |
+
torchmetrics==0.11.4
|
| 41 |
+
torchvision==0.15.2+cu118
|
| 42 |
+
tqdm==4.66.2
|
| 43 |
+
triton==2.0.0
|
| 44 |
+
typing_extensions==4.10.0
|
| 45 |
+
urllib3==2.2.1
|
| 46 |
+
Werkzeug==3.0.1
|
| 47 |
+
yarl==1.9.4
|
| 48 |
+
zipp==3.17.0
|
run_dreaming.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example script for running dreaming on a dataset.
|
| 3 |
+
The idea is that there are ground_truth ("reference") video clips, and we dream the same clips given some initial context.
|
| 4 |
+
|
| 5 |
+
After dreaming, we have two sets of videos which, barring the intrinsic noise of the game environment (e.g., randomness of other players),
|
| 6 |
+
should be identical if model was ideal.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import os
|
| 12 |
+
import subprocess
|
| 13 |
+
|
| 14 |
+
import cv2
|
| 15 |
+
from tensordict import TensorDict
|
| 16 |
+
import torch as th
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
import numpy as np
|
| 19 |
+
import ffmpegcv
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
import wham.utils as utils
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
parser = argparse.ArgumentParser(description="Run dreaming.")
|
| 26 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to the model checkpoint.")
|
| 27 |
+
parser.add_argument("--data_path", type=str, required=True, help="Path to the directory that contains the ground truth data to dream for.")
|
| 28 |
+
parser.add_argument("--output", type=str, default="dreaming_output", help="Path to the directory where output should be put.")
|
| 29 |
+
parser.add_argument("--max_files", type=int, default=None, help="Maximum number of files to process.")
|
| 30 |
+
parser.add_argument("--metadata_config", type=str, default="configs/metadata_custom_tag.config", help="Path to metadata tag config for origin field.")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--protocol",
|
| 35 |
+
type=str,
|
| 36 |
+
default="base",
|
| 37 |
+
choices=["base", "comprehensive"],
|
| 38 |
+
help="What protocol to use for the dreaming. base = action conditioned, comprehensive = dream actions as well.",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for dreaming. Higher batch_size uses more VRAM but overall is faster.")
|
| 41 |
+
parser.add_argument("--context_length", type=int, default=10, help="Number of frames to use an initial context.")
|
| 42 |
+
parser.add_argument("--steps_to_dream", type=int, default=10, help="Batch size for dreaming.")
|
| 43 |
+
|
| 44 |
+
parser.add_argument("--sampling_temperature", type=float, default=0.9, help="Temperature for sampling from the model.")
|
| 45 |
+
parser.add_argument("--sampling_top_k", type=int, default=None, help="Top-k for sampling from the model.")
|
| 46 |
+
parser.add_argument("--sampling_top_p", type=float, default=None, help="Top-p for sampling from the model.")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_context_data(image_context, action_context, action_sequences):
|
| 50 |
+
# Make sure we have CHW images:
|
| 51 |
+
assert image_context.shape[-3] == 3, "Image context should be CHW"
|
| 52 |
+
|
| 53 |
+
image_context = th.from_numpy(image_context).cuda()
|
| 54 |
+
action_data = th.from_numpy(action_context).float().cuda()
|
| 55 |
+
action_sequences = th.from_numpy(action_sequences).float().cuda() if action_sequences is not None else None
|
| 56 |
+
|
| 57 |
+
return TensorDict({"images": image_context, "actions_output": action_data}, batch_size=image_context.shape[:2])
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def add_video_metadata(file_path, metadata_config):
|
| 61 |
+
# Construct the exiftool command
|
| 62 |
+
cmd = [
|
| 63 |
+
'exiftool',
|
| 64 |
+
'-config', metadata_config,
|
| 65 |
+
f'-ProgramName=\"{utils.PROGRAM_NAME}\"',
|
| 66 |
+
'-overwrite_original',
|
| 67 |
+
file_path
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
# Execute the exiftool command
|
| 72 |
+
subprocess.run(cmd, check=True)
|
| 73 |
+
print(f"Metadata modified successfully.")
|
| 74 |
+
# Print the new file metadata
|
| 75 |
+
cmd_output = [
|
| 76 |
+
'exiftool',
|
| 77 |
+
file_path
|
| 78 |
+
]
|
| 79 |
+
subprocess.run(cmd_output, check=True)
|
| 80 |
+
except subprocess.CalledProcessError as e:
|
| 81 |
+
print(f"Error modifying metadata: {e}")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@th.no_grad()
|
| 85 |
+
def do_dreaming(model, image_context, action_context, args, action_sequences=None):
|
| 86 |
+
"""
|
| 87 |
+
image_contect and action_context provide the initial context for the model to dream from.
|
| 88 |
+
|
| 89 |
+
If action_sequences (batch_size, args.steps_to_dream, action_dim) is provided, then model will be prompted with these actions.
|
| 90 |
+
"""
|
| 91 |
+
context_data = get_context_data(image_context, action_context, action_sequences)
|
| 92 |
+
encoded_context_data = model.encode_context(context_data)
|
| 93 |
+
|
| 94 |
+
encoded_action_sequences = None
|
| 95 |
+
if action_sequences is not None:
|
| 96 |
+
assert action_sequences.shape[1] == args.steps_to_dream, "action_sequences should have shape (batch_size, args.steps_to_dream, action_dim)"
|
| 97 |
+
action_sequences = TensorDict({"actions_output": action_sequences}, batch_size=action_sequences.shape[:2]).cuda()
|
| 98 |
+
encoded_action_sequences = model.encode_context(action_sequences)
|
| 99 |
+
|
| 100 |
+
encoded_dreamt_steps = []
|
| 101 |
+
|
| 102 |
+
for dream_step in range(args.steps_to_dream):
|
| 103 |
+
encoded_predicted_step, _ = model.predictor.predict_next_step(
|
| 104 |
+
encoded_context_data, temperature=args.sampling_temperature, top_k=args.sampling_top_k, top_p=args.sampling_top_p, min_tokens_to_keep=1
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Remove first step from context if we are at the max context length:
|
| 108 |
+
if encoded_context_data.shape[1] == args.context_length:
|
| 109 |
+
encoded_context_data = encoded_context_data[:, 1:]
|
| 110 |
+
|
| 111 |
+
# Add predicted image + action to the context
|
| 112 |
+
append_step = encoded_predicted_step
|
| 113 |
+
if encoded_action_sequences is not None:
|
| 114 |
+
# Replace predicted action with real action
|
| 115 |
+
append_step["actions_output"] = encoded_action_sequences["actions_output"][:, [dream_step], :]
|
| 116 |
+
encoded_context_data = th.cat((encoded_context_data, append_step), dim=1)
|
| 117 |
+
|
| 118 |
+
encoded_dreamt_steps.append(encoded_predicted_step)
|
| 119 |
+
|
| 120 |
+
# Decode everything
|
| 121 |
+
dreamed_images = []
|
| 122 |
+
actions_during_dream = []
|
| 123 |
+
for seq_i in range(args.steps_to_dream):
|
| 124 |
+
decoded_step = model.decode_context(encoded_dreamt_steps[seq_i])
|
| 125 |
+
dreamed_images.append(decoded_step["images"][:, [0]].cpu().numpy())
|
| 126 |
+
actions_during_dream.append(decoded_step["actions_output"][:, [0]].cpu().numpy())
|
| 127 |
+
|
| 128 |
+
dreamed_images = np.concatenate(dreamed_images, axis=1)
|
| 129 |
+
actions_during_dream = np.concatenate(actions_during_dream, axis=1)
|
| 130 |
+
|
| 131 |
+
return dreamed_images, actions_during_dream
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@th.no_grad()
|
| 135 |
+
def encode_decode_images(model, images):
|
| 136 |
+
"""
|
| 137 |
+
Pass ground_truth images through the encoding/decoding process of the model.
|
| 138 |
+
"""
|
| 139 |
+
context = TensorDict({"images": th.from_numpy(images).cuda()}, batch_size=images.shape[:2])
|
| 140 |
+
output_images = []
|
| 141 |
+
for seq_i in range(images.shape[1]):
|
| 142 |
+
encoded_images = model.encode_context(context[:, [seq_i]])
|
| 143 |
+
decoded_images = model.decode_context(encoded_images)
|
| 144 |
+
output_images.append(decoded_images["images"].cpu().numpy())
|
| 145 |
+
return np.concatenate(output_images, axis=1)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def main(args):
|
| 149 |
+
total_video_length = args.context_length + args.steps_to_dream
|
| 150 |
+
|
| 151 |
+
# Now, load the model:
|
| 152 |
+
model_path = Path(args.model_path)
|
| 153 |
+
assert model_path.is_file(), "Could not find the model!"
|
| 154 |
+
model = utils.load_model_from_checkpoint(model_path).cuda()
|
| 155 |
+
|
| 156 |
+
# Glob the dataset to find all the ground truth segments we want to construct a dream for:
|
| 157 |
+
data_path = Path(args.data_path)
|
| 158 |
+
ground_truth_files = list(data_path.rglob("*.npz"))
|
| 159 |
+
num_dreams = len(ground_truth_files)
|
| 160 |
+
|
| 161 |
+
if args.max_files is not None:
|
| 162 |
+
# Sort to make sure we always get the same files
|
| 163 |
+
ground_truth_files = sorted(ground_truth_files)
|
| 164 |
+
ground_truth_files = ground_truth_files[: args.max_files]
|
| 165 |
+
num_dreams = len(ground_truth_files)
|
| 166 |
+
|
| 167 |
+
output_path = Path(args.output)
|
| 168 |
+
os.makedirs(output_path, exist_ok=True)
|
| 169 |
+
|
| 170 |
+
print("=" * 100)
|
| 171 |
+
print(f"GENERATING DREAMS OF {num_dreams} SEGMENTS")
|
| 172 |
+
print(f"WRITING TO {args.output}")
|
| 173 |
+
print("=" * 100)
|
| 174 |
+
|
| 175 |
+
dreams_created = 0
|
| 176 |
+
with tqdm(total=num_dreams, desc="Dreams") as pbar:
|
| 177 |
+
while ground_truth_files:
|
| 178 |
+
# Load batch_size headers:
|
| 179 |
+
batches = min(args.batch_size, len(ground_truth_files))
|
| 180 |
+
batched_image_context = []
|
| 181 |
+
batched_image_sequence = []
|
| 182 |
+
batched_action_context = []
|
| 183 |
+
batched_action_sequence = []
|
| 184 |
+
episode_names = []
|
| 185 |
+
for i in range(batches):
|
| 186 |
+
episode = ground_truth_files.pop()
|
| 187 |
+
episode_names.append(episode)
|
| 188 |
+
try:
|
| 189 |
+
data = np.load(episode)
|
| 190 |
+
images = data["images"]
|
| 191 |
+
actions = data["actions"]
|
| 192 |
+
except Exception:
|
| 193 |
+
print(f"Failed to load episode {episode} - skipping.")
|
| 194 |
+
continue
|
| 195 |
+
|
| 196 |
+
if actions.shape[0] < total_video_length:
|
| 197 |
+
# We want to make sure we have ground_truth comparisons for the entire dream, so we ensure the episode is long enough
|
| 198 |
+
raise ValueError(f"Episode {episode} is too short to dream from. It has {actions.shape[0]} steps, but we need at least {total_video_length}.")
|
| 199 |
+
batched_image_context.append(images[: args.context_length])
|
| 200 |
+
batched_image_sequence.append(images[args.context_length: total_video_length])
|
| 201 |
+
batched_action_context.append(actions[: args.context_length])
|
| 202 |
+
batched_action_sequence.append(actions[args.context_length: total_video_length])
|
| 203 |
+
|
| 204 |
+
image_context = np.array(batched_image_context)
|
| 205 |
+
image_sequences = np.array(batched_image_sequence)
|
| 206 |
+
action_context = np.array(batched_action_context)
|
| 207 |
+
action_sequences = np.array(batched_action_sequence)
|
| 208 |
+
|
| 209 |
+
if args.protocol == "comprehensive":
|
| 210 |
+
# We do not need to pass in the action sequences for comprehensive protocol
|
| 211 |
+
action_sequences = None
|
| 212 |
+
|
| 213 |
+
full_image_sequence = np.concatenate((image_context, image_sequences), axis=1)
|
| 214 |
+
|
| 215 |
+
dreamt_images, actions_during_dream = do_dreaming(model, image_context, action_context, args, action_sequences=action_sequences)
|
| 216 |
+
encoded_decoded_images_batch = encode_decode_images(model, full_image_sequence)
|
| 217 |
+
|
| 218 |
+
pbar.update(batches)
|
| 219 |
+
dreams_created += batches
|
| 220 |
+
|
| 221 |
+
# Save the dreams:
|
| 222 |
+
# We are aiming to mimic the folder structure of the ground truth dataset, so use the episode names
|
| 223 |
+
# but make them relative to our output folder:
|
| 224 |
+
for i, dream in enumerate(dreamt_images):
|
| 225 |
+
episode = episode_names[i]
|
| 226 |
+
output_file = output_path / episode.relative_to(data_path)
|
| 227 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 228 |
+
np.savez(
|
| 229 |
+
output_file,
|
| 230 |
+
context_length=args.context_length,
|
| 231 |
+
steps_to_dream=args.steps_to_dream,
|
| 232 |
+
raw_context=image_context[i],
|
| 233 |
+
dreamt_images=dream,
|
| 234 |
+
all_actions=np.concatenate((action_context[i], actions_during_dream[i])),
|
| 235 |
+
encoded_decoded_ground_truth_images=encoded_decoded_images_batch[i],
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
video_file = str(output_file.with_suffix(".mp4"))
|
| 239 |
+
writer = ffmpegcv.VideoWriter(video_file, None, utils.DREAMING_FPS)
|
| 240 |
+
full_sequence = np.concatenate((image_context[i], dream), axis=0)
|
| 241 |
+
for frame in full_sequence:
|
| 242 |
+
img = frame.transpose(1, 2, 0).astype(np.uint8).copy()
|
| 243 |
+
# Please DO NOT remove this watermark. This will infringe upon the repo's license agreement
|
| 244 |
+
(text_width, _), _ = cv2.getTextSize(utils.WATERMARK_TEXT, utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_THICKNESS)
|
| 245 |
+
x = img.shape[1] - text_width - 10 # 10 pixels from the right edge
|
| 246 |
+
y = img.shape[0] - 10 # 10 pixels from the bottom edge
|
| 247 |
+
cv2.putText(img, utils.WATERMARK_TEXT, (x, y), utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_COLOR, utils.WATERMARK_FONT_THICKNESS)
|
| 248 |
+
|
| 249 |
+
# Add image metadata
|
| 250 |
+
pil_image = Image.fromarray(img)
|
| 251 |
+
pil_image.info['Id'] = 0x0131
|
| 252 |
+
pil_image.info['Type'] = 2
|
| 253 |
+
pil_image.info['Value'] = utils.PROGRAM_NAME.encode("utf-8")
|
| 254 |
+
pil_image.info['Len'] = len(utils.PROGRAM_NAME) + 1
|
| 255 |
+
|
| 256 |
+
# Convert pil_image to a CV2 format for the video writer
|
| 257 |
+
cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
| 258 |
+
writer.write(cv_image)
|
| 259 |
+
writer.release()
|
| 260 |
+
add_video_metadata(video_file, args.metadata_config)
|
| 261 |
+
|
| 262 |
+
if __name__ == "__main__":
|
| 263 |
+
args = parser.parse_args()
|
| 264 |
+
main(args)
|
run_server.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
import json
|
| 4 |
+
import copy
|
| 5 |
+
import multiprocessing as mp
|
| 6 |
+
import uuid
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
from collections import defaultdict, deque
|
| 9 |
+
import io
|
| 10 |
+
import zipfile
|
| 11 |
+
import queue
|
| 12 |
+
import time
|
| 13 |
+
import random
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
from tensordict import TensorDict
|
| 17 |
+
import cv2
|
| 18 |
+
from flask import Flask, request, make_response, send_file
|
| 19 |
+
from PIL import Image
|
| 20 |
+
import torchvision.transforms as T
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch as th
|
| 23 |
+
|
| 24 |
+
from wham.utils import load_model_from_checkpoint, POS_BINS_BOUNDARIES, POS_BINS_MIDDLE
|
| 25 |
+
|
| 26 |
+
logging.basicConfig(level=logging.INFO)
|
| 27 |
+
|
| 28 |
+
parser = argparse.ArgumentParser(description="Simple Dreamer")
|
| 29 |
+
parser.add_argument("--model", type=str, required=True, help="Path to the model file for the local runs")
|
| 30 |
+
parser.add_argument("--debug", action="store_true", help="Enable flask debug mode.")
|
| 31 |
+
parser.add_argument("--random_model", action="store_true", help="Use randomly initialized model instead of the provided one")
|
| 32 |
+
parser.add_argument("--port", type=int, default=5000)
|
| 33 |
+
|
| 34 |
+
parser.add_argument("--max_concurrent_jobs", type=int, default=30, help="Maximum number of jobs that can be run concurrently on this server.")
|
| 35 |
+
parser.add_argument("--max_dream_steps_per_job", type=int, default=10, help="Maximum number of dream steps each job can request.")
|
| 36 |
+
parser.add_argument("--max_job_lifespan", type=int, default=60 * 10, help="Maximum number of seconds we keep run around if not polled.")
|
| 37 |
+
|
| 38 |
+
parser.add_argument("--image_width", type=int, default=300, help="Width of the image")
|
| 39 |
+
parser.add_argument("--image_height", type=int, default=180, help="Height of the image")
|
| 40 |
+
|
| 41 |
+
parser.add_argument("--max_batch_size", type=int, default=3, help="Maximum batch size for the dreamer workers")
|
| 42 |
+
|
| 43 |
+
PREDICTION_JSON_FILENAME = "predictions.json"
|
| 44 |
+
# Minimum time between times we check when to delete jobs. We do this when adding new jobs.
|
| 45 |
+
JOB_CLEANUP_CHECK_RATE = timedelta(seconds=10)
|
| 46 |
+
|
| 47 |
+
MAX_CANCELLED_ID_QUEUE_SIZE = 100
|
| 48 |
+
|
| 49 |
+
DEFAULT_SAMPLING_SETTINGS = {
|
| 50 |
+
"temperature": 0.9,
|
| 51 |
+
"top_k": None,
|
| 52 |
+
"top_p": 1.0,
|
| 53 |
+
"max_context_length": 10,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def float_or_none(string):
|
| 58 |
+
if string.lower() == "none":
|
| 59 |
+
return None
|
| 60 |
+
return float(string)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def be_image_preprocess(image, target_width, target_height):
|
| 64 |
+
# If target_width and target_height are specified, resize the image.
|
| 65 |
+
if target_width is not None and target_height is not None:
|
| 66 |
+
# Make sure we do not try to resize if the image is already the correct size.
|
| 67 |
+
if image.shape[1] != target_width or image.shape[0] != target_height:
|
| 68 |
+
image = cv2.resize(image, (target_width, target_height))
|
| 69 |
+
return np.transpose(image, (2, 0, 1))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def action_vector_to_be_action_vector(action):
|
| 73 |
+
# Preprocess a BE action vector from 16 numbers with:
|
| 74 |
+
# 12 buttons [0, 1] and 4 stick directions [-1, 1]
|
| 75 |
+
# to discrete actions valid for the token model
|
| 76 |
+
# 12 buttons [0, 1] and 4 stick directions {discrete bin}
|
| 77 |
+
action[-4:] = np.digitize(action[-4:], bins=POS_BINS_BOUNDARIES) - 1
|
| 78 |
+
return action
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def be_action_vector_to_action_vector(action):
|
| 82 |
+
# Preprocess a BE action vector into unified space
|
| 83 |
+
for stick_index in range(-4, 0):
|
| 84 |
+
action[stick_index] = POS_BINS_MIDDLE[int(action[stick_index])]
|
| 85 |
+
return action
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@dataclass
|
| 90 |
+
class DreamJob:
|
| 91 |
+
job_id: str
|
| 92 |
+
sampling_settings: dict
|
| 93 |
+
num_predictions_remaining: int
|
| 94 |
+
num_predictions_done: int
|
| 95 |
+
# (B, T, C, H, W)
|
| 96 |
+
context_images: th.Tensor
|
| 97 |
+
context_actions: th.Tensor
|
| 98 |
+
# Tokens that will replace the context_images if they are provided
|
| 99 |
+
context_tokens: list
|
| 100 |
+
# This will replace the dreamed action if provided.
|
| 101 |
+
# For every step, we remove the first action until exhausted
|
| 102 |
+
actions_to_take: th.Tensor = None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclass
|
| 106 |
+
class DreamJobResult:
|
| 107 |
+
job_id: str
|
| 108 |
+
dream_step_index: int
|
| 109 |
+
# (B, 1, C, H, W)
|
| 110 |
+
dreamt_image: th.Tensor
|
| 111 |
+
dreamt_action: th.Tensor
|
| 112 |
+
dreamt_tokens: th.Tensor
|
| 113 |
+
result_creation_time: datetime = field(default_factory=datetime.now)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def setup_and_load_model_be_model(args):
|
| 118 |
+
model = load_model_from_checkpoint(args.model)
|
| 119 |
+
th.set_float32_matmul_precision("high")
|
| 120 |
+
th.backends.cuda.matmul.allow_tf32 = True
|
| 121 |
+
return model
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_job_batchable_information(job):
|
| 125 |
+
"""Return comparable object of job information. Used for batching"""
|
| 126 |
+
context_length = job.context_images.shape[1]
|
| 127 |
+
return (context_length, job.sampling_settings)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size, timeout=1):
|
| 131 |
+
"""Return a list of jobs (or empty list) that can be batched together"""
|
| 132 |
+
batchable_jobs = []
|
| 133 |
+
required_job_info = None
|
| 134 |
+
while len(batchable_jobs) < max_batch_size:
|
| 135 |
+
try:
|
| 136 |
+
job = job_queue.get(timeout=timeout)
|
| 137 |
+
except queue.Empty:
|
| 138 |
+
break
|
| 139 |
+
# If pipe breaks, also gracefully return
|
| 140 |
+
except OSError:
|
| 141 |
+
break
|
| 142 |
+
if job.job_id in cancelled_ids_set:
|
| 143 |
+
# This job was cancelled, so discard it completely
|
| 144 |
+
continue
|
| 145 |
+
job_info = get_job_batchable_information(job)
|
| 146 |
+
if required_job_info is None:
|
| 147 |
+
required_job_info = job_info
|
| 148 |
+
elif required_job_info != job_info:
|
| 149 |
+
# This job is not batchable, put it back
|
| 150 |
+
job_queue.put(job)
|
| 151 |
+
# we assume here that, generally, the others jobs would also be
|
| 152 |
+
# invalid. So we just return the batchable jobs we have instead
|
| 153 |
+
# of going through more.
|
| 154 |
+
break
|
| 155 |
+
batchable_jobs.append(job)
|
| 156 |
+
return batchable_jobs
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def update_cancelled_jobs(cancelled_ids_queue, cancelled_ids_deque, cancelled_ids_set):
|
| 160 |
+
"""IN-PLACE Update cancelled_ids_set with new ids from the queue"""
|
| 161 |
+
has_changed = False
|
| 162 |
+
while not cancelled_ids_queue.empty():
|
| 163 |
+
try:
|
| 164 |
+
cancelled_id = cancelled_ids_queue.get_nowait()
|
| 165 |
+
except queue.Empty:
|
| 166 |
+
break
|
| 167 |
+
cancelled_ids_deque.append(cancelled_id)
|
| 168 |
+
has_changed = True
|
| 169 |
+
|
| 170 |
+
if has_changed:
|
| 171 |
+
cancelled_ids_set.clear()
|
| 172 |
+
cancelled_ids_set.update(cancelled_ids_deque)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def predict_step(context_data, sampling_settings, model, tokens=None):
|
| 176 |
+
with th.no_grad():
|
| 177 |
+
predicted_step = model.predict_next_step(context_data, min_tokens_to_keep=1, tokens=tokens, **sampling_settings)
|
| 178 |
+
return predicted_step
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def dreamer_worker(job_queue, result_queue, cancelled_jobs_queue, quit_flag, device_to_use, args):
|
| 182 |
+
logger = logging.getLogger(f"dreamer_worker {device_to_use}")
|
| 183 |
+
logger.info("Loading up model...")
|
| 184 |
+
model = setup_and_load_model_be_model(args)
|
| 185 |
+
model = model.to(device_to_use)
|
| 186 |
+
logger.info("Model loaded. Fetching results")
|
| 187 |
+
|
| 188 |
+
cancelled_ids_deque = deque(maxlen=MAX_CANCELLED_ID_QUEUE_SIZE)
|
| 189 |
+
cancelled_ids_set = set()
|
| 190 |
+
|
| 191 |
+
while not quit_flag.is_set():
|
| 192 |
+
update_cancelled_jobs(cancelled_jobs_queue, cancelled_ids_deque, cancelled_ids_set)
|
| 193 |
+
batchable_jobs = fetch_list_of_batchable_jobs(job_queue, cancelled_ids_set, max_batch_size=args.max_batch_size)
|
| 194 |
+
if len(batchable_jobs) == 0:
|
| 195 |
+
continue
|
| 196 |
+
sampling_settings = batchable_jobs[0].sampling_settings
|
| 197 |
+
# make better way for passing these arguments around. sampling_settings
|
| 198 |
+
# is passed as kwargs to predicting step, but max_context_length is not part of valid
|
| 199 |
+
# keys there, so we need to pop it out.
|
| 200 |
+
max_context_length = sampling_settings.pop("max_context_length")
|
| 201 |
+
|
| 202 |
+
images = [job.context_images[:, :max_context_length] for job in batchable_jobs]
|
| 203 |
+
actions = [job.context_actions[:, :max_context_length] for job in batchable_jobs]
|
| 204 |
+
tokens = [job.context_tokens for job in batchable_jobs]
|
| 205 |
+
|
| 206 |
+
images = th.concat(images, dim=0).to(device_to_use)
|
| 207 |
+
actions = th.concat(actions, dim=0).to(device_to_use)
|
| 208 |
+
|
| 209 |
+
context_data = TensorDict({
|
| 210 |
+
"images": images,
|
| 211 |
+
"actions_output": actions
|
| 212 |
+
}, batch_size=images.shape[:2])
|
| 213 |
+
|
| 214 |
+
predicted_step, predicted_image_tokens = predict_step(context_data, sampling_settings, model, tokens)
|
| 215 |
+
|
| 216 |
+
predicted_step = predicted_step.cpu()
|
| 217 |
+
predicted_images = predicted_step["images"]
|
| 218 |
+
predicted_actions = predicted_step["actions_output"]
|
| 219 |
+
predicted_image_tokens = predicted_image_tokens.cpu()
|
| 220 |
+
|
| 221 |
+
for job_i, job in enumerate(batchable_jobs):
|
| 222 |
+
image_context = job.context_images
|
| 223 |
+
action_context = job.context_actions
|
| 224 |
+
token_context = job.context_tokens
|
| 225 |
+
# Keep batch dimension
|
| 226 |
+
dreamt_image = predicted_images[job_i].unsqueeze(0)
|
| 227 |
+
dreamt_action = predicted_actions[job_i].unsqueeze(0)
|
| 228 |
+
dreamt_tokens = predicted_image_tokens[job_i].unsqueeze(0)
|
| 229 |
+
|
| 230 |
+
# Replace the dreamed action if provided
|
| 231 |
+
actions_to_take = job.actions_to_take
|
| 232 |
+
if actions_to_take is not None and actions_to_take.shape[1] > 0:
|
| 233 |
+
dreamt_action = actions_to_take[:, 0:1]
|
| 234 |
+
# Remove the action we took
|
| 235 |
+
actions_to_take = actions_to_take[:, 1:]
|
| 236 |
+
if actions_to_take.shape[1] == 0:
|
| 237 |
+
actions_to_take = None
|
| 238 |
+
|
| 239 |
+
result_queue.put(DreamJobResult(
|
| 240 |
+
job_id=job.job_id,
|
| 241 |
+
dream_step_index=job.num_predictions_done,
|
| 242 |
+
dreamt_image=dreamt_image,
|
| 243 |
+
dreamt_action=dreamt_action,
|
| 244 |
+
dreamt_tokens=dreamt_tokens
|
| 245 |
+
))
|
| 246 |
+
|
| 247 |
+
# Add job back in the queue if we have more steps to do
|
| 248 |
+
if job.num_predictions_remaining > 0:
|
| 249 |
+
# Stack the dreamt image and action to the context
|
| 250 |
+
if image_context.shape[1] >= max_context_length:
|
| 251 |
+
image_context = image_context[:, 1:]
|
| 252 |
+
action_context = action_context[:, 1:]
|
| 253 |
+
token_context = token_context[1:]
|
| 254 |
+
image_context = th.cat([image_context, dreamt_image], dim=1)
|
| 255 |
+
action_context = th.cat([action_context, dreamt_action], dim=1)
|
| 256 |
+
token_context.append(dreamt_tokens[0, 0].tolist())
|
| 257 |
+
# We need to add context length back to sampling settings...
|
| 258 |
+
# add some better way of passing these settings around
|
| 259 |
+
job.sampling_settings["max_context_length"] = max_context_length
|
| 260 |
+
job_queue.put(DreamJob(
|
| 261 |
+
job_id=job.job_id,
|
| 262 |
+
sampling_settings=job.sampling_settings,
|
| 263 |
+
num_predictions_remaining=job.num_predictions_remaining - 1,
|
| 264 |
+
num_predictions_done=job.num_predictions_done + 1,
|
| 265 |
+
context_images=image_context,
|
| 266 |
+
context_actions=action_context,
|
| 267 |
+
context_tokens=token_context,
|
| 268 |
+
actions_to_take=actions_to_take
|
| 269 |
+
))
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class DreamerServer:
|
| 273 |
+
def __init__(self, num_workers, args):
|
| 274 |
+
self.num_workers = num_workers
|
| 275 |
+
self.args = args
|
| 276 |
+
self.model = None
|
| 277 |
+
self.jobs = mp.Queue(maxsize=args.max_concurrent_jobs)
|
| 278 |
+
self.results_queue = mp.Queue()
|
| 279 |
+
self.cancelled_jobs = set()
|
| 280 |
+
self.cancelled_jobs_queues = [mp.Queue() for _ in range(num_workers)]
|
| 281 |
+
# job_id -> results
|
| 282 |
+
self._last_result_cleanup = datetime.now()
|
| 283 |
+
self._max_job_lifespan_datetime = timedelta(seconds=args.max_job_lifespan)
|
| 284 |
+
self.local_results = defaultdict(list)
|
| 285 |
+
self.logger = logging.getLogger("DreamerServer")
|
| 286 |
+
|
| 287 |
+
def get_details(self):
|
| 288 |
+
details = {
|
| 289 |
+
"model_file": self.args.model,
|
| 290 |
+
"max_concurrent_jobs": self.args.max_concurrent_jobs,
|
| 291 |
+
"max_dream_steps_per_job": self.args.max_dream_steps_per_job,
|
| 292 |
+
"max_job_lifespan": self.args.max_job_lifespan,
|
| 293 |
+
}
|
| 294 |
+
return json.dumps(details)
|
| 295 |
+
|
| 296 |
+
def _check_if_should_remove_old_jobs(self):
|
| 297 |
+
time_now = datetime.now()
|
| 298 |
+
# Only cleanup every JOB_CLEANUP_CHECK_RATE seconds at most
|
| 299 |
+
if time_now - self._last_result_cleanup < JOB_CLEANUP_CHECK_RATE:
|
| 300 |
+
return
|
| 301 |
+
|
| 302 |
+
self._last_result_cleanup = time_now
|
| 303 |
+
# First add existing results to the local results
|
| 304 |
+
self._gather_new_results()
|
| 305 |
+
# Check if we should remove old jobs
|
| 306 |
+
job_ids = list(self.local_results.keys())
|
| 307 |
+
for job_id in job_ids:
|
| 308 |
+
results = self.local_results[job_id]
|
| 309 |
+
# If newest result is older than max_job_lifespan, remove the job
|
| 310 |
+
if time_now - results[-1].result_creation_time > self._max_job_lifespan_datetime:
|
| 311 |
+
self.logger.info(f"Deleted job {job_id} because it was too old. Last result was {results[-1].result_creation_time}")
|
| 312 |
+
del self.local_results[job_id]
|
| 313 |
+
|
| 314 |
+
def add_new_job(self, request, request_json):
|
| 315 |
+
"""
|
| 316 |
+
Add new dreaming job to the queues.
|
| 317 |
+
Request should have:
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
Returns: json object with new job id
|
| 321 |
+
"""
|
| 322 |
+
self._check_if_should_remove_old_jobs()
|
| 323 |
+
|
| 324 |
+
sampling_settings = copy.deepcopy(DEFAULT_SAMPLING_SETTINGS)
|
| 325 |
+
if "num_steps_to_predict" not in request_json:
|
| 326 |
+
return make_response("num_steps_to_predict not in request", 400)
|
| 327 |
+
num_steps_to_predict = request_json['num_steps_to_predict']
|
| 328 |
+
if num_steps_to_predict > self.args.max_dream_steps_per_job:
|
| 329 |
+
return make_response(f"num_steps_to_predict too large. Max {self.args.max_dream_steps_per_job}", 400)
|
| 330 |
+
|
| 331 |
+
num_parallel_predictions = int(request_json['num_parallel_predictions']) if 'num_parallel_predictions' in request_json else 1
|
| 332 |
+
|
| 333 |
+
if (self.jobs.qsize() + num_parallel_predictions) >= self.args.max_concurrent_jobs:
|
| 334 |
+
return make_response(f"Too many jobs already running. Max {self.args.max_concurrent_jobs}", 400)
|
| 335 |
+
|
| 336 |
+
for key in sampling_settings:
|
| 337 |
+
sampling_settings[key] = float_or_none(request_json[key]) if key in request_json else sampling_settings[key]
|
| 338 |
+
|
| 339 |
+
context_images = []
|
| 340 |
+
context_actions = []
|
| 341 |
+
context_tokens = []
|
| 342 |
+
future_actions = []
|
| 343 |
+
|
| 344 |
+
for step in request_json["steps"]:
|
| 345 |
+
image_path = step["image_name"]
|
| 346 |
+
image = np.array(Image.open(request.files[image_path].stream))
|
| 347 |
+
image = be_image_preprocess(image, target_width=self.args.image_width, target_height=self.args.image_height)
|
| 348 |
+
context_images.append(th.from_numpy(image))
|
| 349 |
+
|
| 350 |
+
action = step["action"]
|
| 351 |
+
action = action_vector_to_be_action_vector(action)
|
| 352 |
+
context_actions.append(th.tensor(action))
|
| 353 |
+
|
| 354 |
+
tokens = step["tokens"]
|
| 355 |
+
context_tokens.append(tokens)
|
| 356 |
+
|
| 357 |
+
future_actions = None
|
| 358 |
+
if "future_actions" in request_json:
|
| 359 |
+
future_actions = []
|
| 360 |
+
for step in request_json["future_actions"]:
|
| 361 |
+
# The rest is the action vector
|
| 362 |
+
action = step["action"]
|
| 363 |
+
action = action_vector_to_be_action_vector(action)
|
| 364 |
+
# Add sequence and batch dimensions
|
| 365 |
+
future_actions.append(th.tensor(action))
|
| 366 |
+
|
| 367 |
+
# Add batch dimensions
|
| 368 |
+
context_images = th.stack(context_images).unsqueeze(0)
|
| 369 |
+
context_actions = th.stack(context_actions).unsqueeze(0)
|
| 370 |
+
future_actions = th.stack(future_actions).unsqueeze(0) if future_actions is not None else None
|
| 371 |
+
|
| 372 |
+
list_of_job_ids = []
|
| 373 |
+
for _ in range(num_parallel_predictions):
|
| 374 |
+
job_id = uuid.uuid4().hex
|
| 375 |
+
self.jobs.put(DreamJob(
|
| 376 |
+
job_id=job_id,
|
| 377 |
+
sampling_settings=sampling_settings,
|
| 378 |
+
num_predictions_remaining=num_steps_to_predict,
|
| 379 |
+
num_predictions_done=0,
|
| 380 |
+
context_images=context_images,
|
| 381 |
+
context_actions=context_actions,
|
| 382 |
+
context_tokens=context_tokens,
|
| 383 |
+
actions_to_take=future_actions
|
| 384 |
+
))
|
| 385 |
+
list_of_job_ids.append(job_id)
|
| 386 |
+
|
| 387 |
+
job_queue_size = self.jobs.qsize()
|
| 388 |
+
return json.dumps({"job_ids": list_of_job_ids, "current_jobs_in_queue": job_queue_size})
|
| 389 |
+
|
| 390 |
+
def _gather_new_results(self):
|
| 391 |
+
if not self.results_queue.empty():
|
| 392 |
+
for _ in range(self.results_queue.qsize()):
|
| 393 |
+
result = self.results_queue.get()
|
| 394 |
+
if result.job_id in self.cancelled_jobs:
|
| 395 |
+
# Discard result if job was cancelled
|
| 396 |
+
continue
|
| 397 |
+
self.local_results[result.job_id].append(result)
|
| 398 |
+
|
| 399 |
+
def get_new_results(self, request, request_json):
|
| 400 |
+
if "job_ids" not in request_json:
|
| 401 |
+
return make_response("job_ids not in request", 400)
|
| 402 |
+
self._gather_new_results()
|
| 403 |
+
job_ids = request_json["job_ids"]
|
| 404 |
+
if not isinstance(job_ids, list):
|
| 405 |
+
job_ids = [job_ids]
|
| 406 |
+
return_results = []
|
| 407 |
+
for job_id in job_ids:
|
| 408 |
+
if job_id in self.local_results:
|
| 409 |
+
return_results.append(self.local_results[job_id])
|
| 410 |
+
del self.local_results[job_id]
|
| 411 |
+
|
| 412 |
+
if len(return_results) == 0:
|
| 413 |
+
return make_response("No new responses", 204)
|
| 414 |
+
|
| 415 |
+
output_json = []
|
| 416 |
+
output_image_bytes = {}
|
| 417 |
+
for job_results in return_results:
|
| 418 |
+
for result in job_results:
|
| 419 |
+
action = result.dreamt_action.numpy()
|
| 420 |
+
# Remember to remove batch and sequence dimensions
|
| 421 |
+
action = be_action_vector_to_action_vector(action[0, 0].tolist())
|
| 422 |
+
dreamt_tokens = result.dreamt_tokens[0, 0].tolist()
|
| 423 |
+
image_filename = f"{result.job_id}_{result.dream_step_index}.png"
|
| 424 |
+
output_json.append({
|
| 425 |
+
"job_id": result.job_id,
|
| 426 |
+
"dream_step_index": result.dream_step_index,
|
| 427 |
+
"action": action,
|
| 428 |
+
"tokens": dreamt_tokens,
|
| 429 |
+
"image_filename": image_filename
|
| 430 |
+
})
|
| 431 |
+
|
| 432 |
+
image_bytes = io.BytesIO()
|
| 433 |
+
# this probably is not as smooth as it could be
|
| 434 |
+
T.ToPILImage()(result.dreamt_image[0, 0]).save(image_bytes, format="PNG")
|
| 435 |
+
output_image_bytes[image_filename] = image_bytes.getvalue()
|
| 436 |
+
|
| 437 |
+
# Write a zip file with all the images
|
| 438 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
|
| 439 |
+
zip_bytes = io.BytesIO()
|
| 440 |
+
with zipfile.ZipFile(zip_bytes, "w") as z:
|
| 441 |
+
for filename, bytes in output_image_bytes.items():
|
| 442 |
+
z.writestr(filename, bytes)
|
| 443 |
+
# Write the json
|
| 444 |
+
z.writestr(PREDICTION_JSON_FILENAME, json.dumps(output_json))
|
| 445 |
+
|
| 446 |
+
zip_bytes.seek(0)
|
| 447 |
+
|
| 448 |
+
return send_file(
|
| 449 |
+
zip_bytes,
|
| 450 |
+
mimetype="zip",
|
| 451 |
+
as_attachment=True,
|
| 452 |
+
download_name=f"dreaming_results_{timestamp}.zip"
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
def cancel_job(self, request, request_json):
|
| 456 |
+
if "job_id" not in request_json:
|
| 457 |
+
return make_response("job_id not in request", 400)
|
| 458 |
+
job_id = request_json["job_id"]
|
| 459 |
+
self.cancelled_jobs.add(job_id)
|
| 460 |
+
# Cancel all jobs in the queue with this id
|
| 461 |
+
for job_queue in self.cancelled_jobs_queues:
|
| 462 |
+
job_queue.put(job_id)
|
| 463 |
+
return make_response("OK", 200)
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def main_run(args):
|
| 467 |
+
app = Flask(__name__)
|
| 468 |
+
|
| 469 |
+
num_workers = th.cuda.device_count()
|
| 470 |
+
if num_workers == 0:
|
| 471 |
+
raise RuntimeError("No CUDA devices found. Cannot run Dreamer.")
|
| 472 |
+
|
| 473 |
+
server = DreamerServer(num_workers, args)
|
| 474 |
+
quit_flag = mp.Event()
|
| 475 |
+
|
| 476 |
+
# Start the dreamer worker(s)
|
| 477 |
+
dreamer_worker_processes = []
|
| 478 |
+
for device_i in range(num_workers):
|
| 479 |
+
device = f"cuda:{device_i}"
|
| 480 |
+
dreamer_worker_process = mp.Process(
|
| 481 |
+
target=dreamer_worker,
|
| 482 |
+
args=(server.jobs, server.results_queue, server.cancelled_jobs_queues[device_i], quit_flag, device, args)
|
| 483 |
+
)
|
| 484 |
+
dreamer_worker_process.daemon = True
|
| 485 |
+
dreamer_worker_process.start()
|
| 486 |
+
dreamer_worker_processes.append(dreamer_worker_process)
|
| 487 |
+
|
| 488 |
+
# Add the API endpoints
|
| 489 |
+
@app.route('/')
|
| 490 |
+
def details():
|
| 491 |
+
return server.get_details()
|
| 492 |
+
|
| 493 |
+
@app.route('/new_job', methods=['POST'])
|
| 494 |
+
def new_job():
|
| 495 |
+
request_json = json.loads(request.form["json"])
|
| 496 |
+
return server.add_new_job(request, request_json)
|
| 497 |
+
|
| 498 |
+
@app.route('/get_job_results', methods=['GET'])
|
| 499 |
+
def get_results():
|
| 500 |
+
# the "Json" is now in regular GET payload/parameters
|
| 501 |
+
request_json = {"job_ids": request.args.getlist("job_ids")}
|
| 502 |
+
return server.get_new_results(request, request_json)
|
| 503 |
+
|
| 504 |
+
@app.route('/cancel_job', methods=['GET'])
|
| 505 |
+
def cancel_job():
|
| 506 |
+
request_json = request.args.to_dict()
|
| 507 |
+
return server.cancel_job(request, request_json)
|
| 508 |
+
|
| 509 |
+
app.run(host="0.0.0.0", port=args.port, debug=args.debug)
|
| 510 |
+
|
| 511 |
+
# Cleanup
|
| 512 |
+
quit_flag.set()
|
| 513 |
+
for dreamer_worker_process in dreamer_worker_processes:
|
| 514 |
+
dreamer_worker_process.join()
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
if __name__ == '__main__':
|
| 518 |
+
args = parser.parse_args()
|
| 519 |
+
main_run(args)
|
setup_local.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Tested using Python 3.9
|
| 2 |
+
|
| 3 |
+
echo "Making and activating a new virtual environment..."
|
| 4 |
+
python3.9 -m venv venv
|
| 5 |
+
|
| 6 |
+
echo "Activating the virtual environment..."
|
| 7 |
+
source venv/bin/activate
|
| 8 |
+
|
| 9 |
+
echo "Upgrading pip..."
|
| 10 |
+
pip install --upgrade pip
|
| 11 |
+
|
| 12 |
+
echo "Instaling the required packages..."
|
| 13 |
+
pip install -r requirements.txt
|
| 14 |
+
|
| 15 |
+
echo "Instaling the exiftool package for adding file metadata on Linux..."
|
| 16 |
+
sudo apt install -y exiftool
|
| 17 |
+
|
| 18 |
+
echo "Installing ffmpeg..."
|
| 19 |
+
sudo apt install ffmpeg
|
| 20 |
+
|
| 21 |
+
echo "All packages installed successfully!"
|
wham/models/nn/model_blocks.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Some Utility blocks for ViT-VQGAN.
|
| 5 |
+
|
| 6 |
+
ConvNeXt blocks are based on:
|
| 7 |
+
Liu, Zhuang, et al. "A convnet for the 2020s."
|
| 8 |
+
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ConvNextDownsampleBig(nn.Module):
|
| 13 |
+
def __init__(self, c_in, c_out):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.group_norm = nn.GroupNorm(c_in, c_in)
|
| 16 |
+
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=8, stride=4, padding=0)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
return self.conv1(self.group_norm(x))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ConvNextBlock(nn.Module):
|
| 23 |
+
def __init__(self, channels):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=7, stride=1, padding=7 // 2, groups=channels) # 'Depthwise' conv
|
| 26 |
+
self.group_norm = nn.GroupNorm(channels, channels) # Should be equivalent to layernorm
|
| 27 |
+
|
| 28 |
+
# Transformer-style non-linearity
|
| 29 |
+
self.conv2 = nn.Conv2d(channels, channels * 4, kernel_size=1, stride=1, padding=0)
|
| 30 |
+
self.activation = nn.GELU()
|
| 31 |
+
self.conv3 = nn.Conv2d(channels * 4, channels, kernel_size=1, stride=1, padding=0)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
y = self.conv1(x)
|
| 35 |
+
y = self.group_norm(y)
|
| 36 |
+
y = self.conv2(y)
|
| 37 |
+
y = self.activation(y)
|
| 38 |
+
y = self.conv3(y)
|
| 39 |
+
return x + y
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ConvNextDownsample(nn.Module):
|
| 43 |
+
def __init__(self, c_in, c_out):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.group_norm = nn.GroupNorm(c_in, c_in)
|
| 46 |
+
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return self.conv1(self.group_norm(x))
|
wham/models/nn/nanoGPT.py
ADDED
|
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# From https://github.com/karpathy/nanoGPT/blob/master/model.py - Thanks Andrej Karpathy
|
| 2 |
+
|
| 3 |
+
# MIT License
|
| 4 |
+
# Copyright (c) 2022 Andrej Karpathy
|
| 5 |
+
# 2023 Microsoft Research
|
| 6 |
+
|
| 7 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 8 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 9 |
+
# in the Software without restriction, including without limitation the rights
|
| 10 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 11 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 12 |
+
# furnished to do so, subject to the following conditions:
|
| 13 |
+
|
| 14 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 15 |
+
# copies or substantial portions of the Software.
|
| 16 |
+
|
| 17 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
| 18 |
+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 19 |
+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
| 20 |
+
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
| 21 |
+
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
| 22 |
+
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
| 23 |
+
# OR OTHER DEALINGS IN THE SOFTWARE.
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
Full definition of a GPT Language Model, all of it in this single file.
|
| 28 |
+
References:
|
| 29 |
+
1) the official GPT-2 TensorFlow implementation released by OpenAI:
|
| 30 |
+
https://github.com/openai/gpt-2/blob/master/src/model.py
|
| 31 |
+
2) huggingface/transformers PyTorch implementation:
|
| 32 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
from dataclasses import dataclass
|
| 36 |
+
import inspect
|
| 37 |
+
import math
|
| 38 |
+
|
| 39 |
+
import torch
|
| 40 |
+
import torch.nn as nn
|
| 41 |
+
from torch.nn import functional as F
|
| 42 |
+
|
| 43 |
+
NEGATIVE_INFINITE_FLOAT = -float("inf")
|
| 44 |
+
CROSS_ENTROPY_INVALID_CLASS_TARGET = -1
|
| 45 |
+
|
| 46 |
+
# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
|
| 47 |
+
def new_gelu(x):
|
| 48 |
+
"""
|
| 49 |
+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
|
| 50 |
+
Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
|
| 51 |
+
"""
|
| 52 |
+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def limit_logits_to_valid_range(logits, valid_token_range):
|
| 56 |
+
"""
|
| 57 |
+
MODIFIES logits INPLACE.
|
| 58 |
+
Mask out invalid positions in the logits tensor with -inf so they are not considered by the softmax.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
logits: logits tensor of shape (batch_size, vocab_size)
|
| 62 |
+
valid_token_range: tuple of (start, end) indices of valid positions in the logits tensor (inclusive).
|
| 63 |
+
Everything outside is masked out with -inf.
|
| 64 |
+
"""
|
| 65 |
+
logits[:, : valid_token_range[0]] = NEGATIVE_INFINITE_FLOAT
|
| 66 |
+
logits[:, valid_token_range[1] + 1 :] = NEGATIVE_INFINITE_FLOAT
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def default_sample_token(logits, valid_token_range=None, temperature=1.0, deterministic=False, top_k=None, top_p=None, min_tokens_to_keep=1):
|
| 70 |
+
"""
|
| 71 |
+
Given a vector of logits, sample and return an index according to settings.
|
| 72 |
+
|
| 73 |
+
logits: tensor of shape (batch_size, vocab_size)
|
| 74 |
+
|
| 75 |
+
valid_token_range should be a tuple, specifying start and end indices we'd like to sample from (inclusive).
|
| 76 |
+
If None, we'll sample from the full vocab.
|
| 77 |
+
|
| 78 |
+
If deterministic is True, we'll take the argmax of the logits which implies top-k sampling with top_k = 1, therefore user inputted values of top_p and top_k will be ignored.
|
| 79 |
+
|
| 80 |
+
Otherwise, either top-p (float) value can be specified or top-k (int) value can be specified.
|
| 81 |
+
Top-p (float top_p) : only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
|
| 82 |
+
Top-k (int top_k) : selects top_k tokens for generation.
|
| 83 |
+
min_tokens_to_keep: Used with both top_p and top_k sampling.
|
| 84 |
+
"""
|
| 85 |
+
assert top_k is None or top_p is None, "Can only specify one of top-k or top-p sampling."
|
| 86 |
+
if temperature < 0.1:
|
| 87 |
+
# Avoid too low a temp, especially 0
|
| 88 |
+
temperature = 0.1
|
| 89 |
+
logits = logits / temperature
|
| 90 |
+
if valid_token_range is not None:
|
| 91 |
+
limit_logits_to_valid_range(logits, valid_token_range)
|
| 92 |
+
if deterministic:
|
| 93 |
+
selected_logits = select_logits(logits, top_k=1)
|
| 94 |
+
else:
|
| 95 |
+
selected_logits = select_logits(logits, top_p=top_p, top_k=top_k, min_tokens_to_keep=min_tokens_to_keep)
|
| 96 |
+
probs = F.softmax(selected_logits, dim=-1)
|
| 97 |
+
# More robustly handle errors in the sampling here
|
| 98 |
+
sampled_idx = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
| 99 |
+
return sampled_idx
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def select_logits(logits, top_k=None, top_p=None, min_tokens_to_keep=1):
|
| 103 |
+
"""
|
| 104 |
+
Select from original logits using top-k or top-p sampling.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
logits (torch.Tensor): Logits to sample from.
|
| 108 |
+
k (int, optional): Number of top elements to consider in top-k sampling.
|
| 109 |
+
p (float, optional): Threshold probability for top-p sampling.
|
| 110 |
+
min_tokens_to_keep (int, optional): Minimum number of tokens to keep in the output.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
logits: Selected logits after top-k or top-p sampling. Sets all logits outside the selected ones to NEGATIVE_INFINITE_FLOAT.
|
| 114 |
+
"""
|
| 115 |
+
assert top_k is None or top_p is None, "Can only specify one of top-k or top-p sampling."
|
| 116 |
+
min_tokens_to_keep = min(min_tokens_to_keep, logits.size(-1))
|
| 117 |
+
if top_k is not None:
|
| 118 |
+
if not isinstance(top_k, int) or top_k <= 0:
|
| 119 |
+
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
|
| 120 |
+
|
| 121 |
+
# Top-k sampling
|
| 122 |
+
top_k = max(top_k, min_tokens_to_keep)
|
| 123 |
+
top_k = min(top_k, logits.size(-1))
|
| 124 |
+
top_k_logits, _ = torch.topk(logits, top_k)
|
| 125 |
+
indices_to_remove = logits < top_k_logits[..., -1:]
|
| 126 |
+
logits = torch.where(indices_to_remove, NEGATIVE_INFINITE_FLOAT, logits)
|
| 127 |
+
|
| 128 |
+
elif top_p is not None:
|
| 129 |
+
top_p = float(top_p)
|
| 130 |
+
if top_p < 0 or top_p > 1.0:
|
| 131 |
+
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
| 132 |
+
|
| 133 |
+
# Top-p sampling
|
| 134 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 135 |
+
sorted_probs = torch.softmax(sorted_logits, dim=-1)
|
| 136 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 137 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 138 |
+
|
| 139 |
+
# Remove tokens with cumulative probability above the threshold
|
| 140 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = False
|
| 141 |
+
|
| 142 |
+
# scatter sorted tensors to original indexing
|
| 143 |
+
indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
|
| 144 |
+
logits = torch.where(indices_to_remove, NEGATIVE_INFINITE_FLOAT, logits)
|
| 145 |
+
|
| 146 |
+
else:
|
| 147 |
+
# Return logits as is
|
| 148 |
+
pass
|
| 149 |
+
|
| 150 |
+
return logits
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class LayerNorm(nn.Module):
|
| 154 |
+
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
|
| 155 |
+
|
| 156 |
+
def __init__(self, ndim, bias):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
| 159 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
| 160 |
+
|
| 161 |
+
def forward(self, input):
|
| 162 |
+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
| 163 |
+
|
| 164 |
+
class LayerNormMinimal(nn.Module):
|
| 165 |
+
"""LayerNorm like above, but without learnable parameters"""
|
| 166 |
+
|
| 167 |
+
def __init__(self, ndim, bias):
|
| 168 |
+
super().__init__()
|
| 169 |
+
self.ndim = (ndim,)
|
| 170 |
+
|
| 171 |
+
def forward(self, input):
|
| 172 |
+
return F.layer_norm(input, self.ndim, eps=1e-5)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class CausalSelfAttention(nn.Module):
|
| 176 |
+
def __init__(self, config):
|
| 177 |
+
super().__init__()
|
| 178 |
+
assert config.n_embd % config.n_head == 0
|
| 179 |
+
# key, query, value projections for all heads, but in a batch
|
| 180 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
| 181 |
+
# output projection
|
| 182 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
| 183 |
+
# regularization
|
| 184 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 185 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 186 |
+
self.n_head = config.n_head
|
| 187 |
+
self.n_embd = config.n_embd
|
| 188 |
+
self.dropout = config.dropout
|
| 189 |
+
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
| 190 |
+
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
|
| 191 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
| 192 |
+
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size), persistent=False)
|
| 193 |
+
|
| 194 |
+
self.cached_k = None
|
| 195 |
+
self.cached_v = None
|
| 196 |
+
self.current_cache_size = 0
|
| 197 |
+
|
| 198 |
+
def _manual_causal_attention(self, q, k, v, mask):
|
| 199 |
+
# q, k and v should be of shape (B, nh, T, hs)
|
| 200 |
+
token_len = q.size(-2)
|
| 201 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 202 |
+
att = att.masked_fill(mask[:, :, :token_len, :token_len] == 0, float("-inf"))
|
| 203 |
+
att = F.softmax(att, dim=-1)
|
| 204 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
| 205 |
+
return y
|
| 206 |
+
|
| 207 |
+
def forward(self, x, cache=False):
|
| 208 |
+
batch_size, token_len, n_embd = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
| 209 |
+
|
| 210 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
| 211 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 212 |
+
k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 213 |
+
q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 214 |
+
v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 215 |
+
|
| 216 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
| 217 |
+
if self.flash and not cache:
|
| 218 |
+
# efficient attention using Flash Attention CUDA kernels
|
| 219 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
|
| 220 |
+
elif cache:
|
| 221 |
+
# manual implemention of attention (as below), but cache arrays we can reuse
|
| 222 |
+
assert token_len == 1, "Cache only works for single step"
|
| 223 |
+
assert self.cached_k is not None, "Must call reset_cache() before using cache"
|
| 224 |
+
assert self.current_cache_size < self.cached_k.size(2), "Trying to generate more steps than provided in reset_cache() `num_steps_to_come`"
|
| 225 |
+
assert self.dropout == 0.0, "Dropout not supported with caching"
|
| 226 |
+
this_step_q = q
|
| 227 |
+
self.cached_k[:, :, self.current_cache_size, :] = k[:, :, 0, :]
|
| 228 |
+
self.cached_v[:, :, self.current_cache_size, :] = v[:, :, 0, :]
|
| 229 |
+
# Remove the zero parts
|
| 230 |
+
k = self.cached_k[:, :, : self.current_cache_size + 1, :]
|
| 231 |
+
# compute last row of the attention mask
|
| 232 |
+
this_step_att_row = (this_step_q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 233 |
+
this_step_att_row = F.softmax(this_step_att_row, dim=-1)
|
| 234 |
+
# We only need output for the current step
|
| 235 |
+
y = this_step_att_row @ self.cached_v[:, :, : self.current_cache_size + 1, :]
|
| 236 |
+
# Update cache
|
| 237 |
+
self.current_cache_size += 1
|
| 238 |
+
else:
|
| 239 |
+
y = self._manual_causal_attention(q, k, v, self.bias)
|
| 240 |
+
y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd) # re-assemble all head outputs side by side
|
| 241 |
+
|
| 242 |
+
# output projection
|
| 243 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 244 |
+
return y
|
| 245 |
+
|
| 246 |
+
def reset_cache(self, x, num_steps_to_come):
|
| 247 |
+
"""
|
| 248 |
+
Reset caches by doing initial pass with x data (returning same output as forward).
|
| 249 |
+
Also set the number of steps to come, which is used to initialize the buffers
|
| 250 |
+
"""
|
| 251 |
+
batch_size, token_len, n_embd = x.size()
|
| 252 |
+
|
| 253 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 254 |
+
k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 255 |
+
q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 256 |
+
v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 257 |
+
|
| 258 |
+
# Use SDPA instead of a manual implementation
|
| 259 |
+
# y = self._manual_causal_attention(q, k, v, self.bias)
|
| 260 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
|
| 261 |
+
|
| 262 |
+
y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd)
|
| 263 |
+
# output projection
|
| 264 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 265 |
+
|
| 266 |
+
# Create full k,q,v for predicting all future steps.
|
| 267 |
+
# Just null-out the last num_steps_to_come-1 steps
|
| 268 |
+
pad_size = num_steps_to_come
|
| 269 |
+
self.current_cache_size = token_len
|
| 270 |
+
self.cached_k = torch.cat([k, torch.zeros(batch_size, self.n_head, pad_size, n_embd // self.n_head, device=k.device)], dim=2)
|
| 271 |
+
self.cached_v = torch.cat([v, torch.zeros(batch_size, self.n_head, pad_size, n_embd // self.n_head, device=v.device)], dim=2)
|
| 272 |
+
|
| 273 |
+
return y
|
| 274 |
+
|
| 275 |
+
class SelfAttention(nn.Module):
|
| 276 |
+
"""
|
| 277 |
+
Non-causal self-attention layer, the same as CausalSelfAttention but without the causal mask.
|
| 278 |
+
Duplicating the code to keep this separate for clarity.
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def __init__(self, config):
|
| 282 |
+
super().__init__()
|
| 283 |
+
assert config.n_embd % config.n_head == 0
|
| 284 |
+
# key, query, value projections for all heads, but in a batch
|
| 285 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
| 286 |
+
# output projection
|
| 287 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
| 288 |
+
# regularization
|
| 289 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 290 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 291 |
+
self.n_head = config.n_head
|
| 292 |
+
self.n_embd = config.n_embd
|
| 293 |
+
self.dropout = config.dropout
|
| 294 |
+
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
| 295 |
+
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
|
| 296 |
+
assert self.flash, "SelfAttention only supports flash attention for now."
|
| 297 |
+
|
| 298 |
+
self.register_buffer("attn_mask", torch.ones((config.block_size, config.block_size)).bool().unsqueeze(0).unsqueeze(0))
|
| 299 |
+
|
| 300 |
+
def forward(self, x):
|
| 301 |
+
batch_size, token_len, n_embd = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
| 302 |
+
|
| 303 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
| 304 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 305 |
+
k = k.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 306 |
+
q = q.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 307 |
+
v = v.view(batch_size, token_len, self.n_head, n_embd // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 308 |
+
|
| 309 |
+
# self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
| 310 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask, dropout_p=self.dropout, is_causal=False)
|
| 311 |
+
y = y.transpose(1, 2).contiguous().view(batch_size, token_len, n_embd) # re-assemble all head outputs side by side
|
| 312 |
+
|
| 313 |
+
# output projection
|
| 314 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 315 |
+
return y
|
| 316 |
+
|
| 317 |
+
class MLP(nn.Module):
|
| 318 |
+
def __init__(self, config):
|
| 319 |
+
super().__init__()
|
| 320 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
| 321 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
| 322 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 323 |
+
|
| 324 |
+
def forward(self, x):
|
| 325 |
+
x = self.c_fc(x)
|
| 326 |
+
x = new_gelu(x)
|
| 327 |
+
x = self.c_proj(x)
|
| 328 |
+
x = self.dropout(x)
|
| 329 |
+
return x
|
| 330 |
+
|
| 331 |
+
class GELU_MLP(nn.Module):
|
| 332 |
+
"""MLP Block using PyTorch's native GELU activation function"""
|
| 333 |
+
def __init__(self, config):
|
| 334 |
+
super().__init__()
|
| 335 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
| 336 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
| 337 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 338 |
+
|
| 339 |
+
def forward(self, x):
|
| 340 |
+
x = self.c_fc(x)
|
| 341 |
+
x = F.gelu(x, approximate="tanh")
|
| 342 |
+
x = self.c_proj(x)
|
| 343 |
+
x = self.dropout(x)
|
| 344 |
+
return x
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class Block(nn.Module):
|
| 348 |
+
def __init__(self, config):
|
| 349 |
+
super().__init__()
|
| 350 |
+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
| 351 |
+
self.attn = CausalSelfAttention(config)
|
| 352 |
+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
| 353 |
+
self.mlp = MLP(config)
|
| 354 |
+
|
| 355 |
+
def forward(self, x, cache=False, reset_cache_with_num_steps_to_come=None):
|
| 356 |
+
"""
|
| 357 |
+
Args:
|
| 358 |
+
cache: If True, use the cache to predict the next token (assumes model was initialized with `reset_cache`).
|
| 359 |
+
reset_cache_with_num_steps_to_come:
|
| 360 |
+
If not None, reset and prepare the cache for cached prediction of the next `reset_cache_with_num_steps_to_come` tokens.
|
| 361 |
+
This is same as calling `reset_cache` with the same argument, but we include option here in `forward` to support torch hook functions (used to get embeddings from this module output).
|
| 362 |
+
|
| 363 |
+
Caching example:
|
| 364 |
+
```
|
| 365 |
+
# Initialize model with reset_cache_with_num_steps_to_come=10
|
| 366 |
+
outputs[0] = model(inputs, reset_cache_with_num_steps_to_come=10)
|
| 367 |
+
# Predict next 10 tokens using cache
|
| 368 |
+
for i in range(10):
|
| 369 |
+
outputs[i+1] = model(inputs, cache=True)
|
| 370 |
+
```
|
| 371 |
+
"""
|
| 372 |
+
if reset_cache_with_num_steps_to_come:
|
| 373 |
+
return self.reset_cache(x, num_steps_to_come=reset_cache_with_num_steps_to_come)
|
| 374 |
+
x = x + self.attn(self.ln_1(x), cache=cache)
|
| 375 |
+
x = x + self.mlp(self.ln_2(x))
|
| 376 |
+
return x
|
| 377 |
+
|
| 378 |
+
def reset_cache(self, x, num_steps_to_come):
|
| 379 |
+
x = x + self.attn.reset_cache(self.ln_1(x), num_steps_to_come=num_steps_to_come)
|
| 380 |
+
x = x + self.mlp(self.ln_2(x))
|
| 381 |
+
return x
|
| 382 |
+
|
| 383 |
+
class BlockV2(nn.Module):
|
| 384 |
+
"""
|
| 385 |
+
Compared to the Block in the original implementation, this one uses non-parametric LayerNorm and Pytorch's GELU.
|
| 386 |
+
These two changes save significant vram but are incompatible with previously trained models.
|
| 387 |
+
Hence the separate class.
|
| 388 |
+
"""
|
| 389 |
+
|
| 390 |
+
def __init__(self, config):
|
| 391 |
+
super().__init__()
|
| 392 |
+
self.ln_1 = LayerNormMinimal(config.n_embd, bias=config.bias)
|
| 393 |
+
self.attn = CausalSelfAttention(config)
|
| 394 |
+
self.ln_2 = LayerNormMinimal(config.n_embd, bias=config.bias)
|
| 395 |
+
self.mlp = GELU_MLP(config)
|
| 396 |
+
|
| 397 |
+
def forward(self, x, cache=False, reset_cache_with_num_steps_to_come=None):
|
| 398 |
+
if reset_cache_with_num_steps_to_come:
|
| 399 |
+
return self.reset_cache(x, num_steps_to_come=reset_cache_with_num_steps_to_come)
|
| 400 |
+
x = x + self.attn(self.ln_1(x), cache=cache)
|
| 401 |
+
x = x + self.mlp(self.ln_2(x))
|
| 402 |
+
return x
|
| 403 |
+
|
| 404 |
+
def reset_cache(self, x, num_steps_to_come):
|
| 405 |
+
x = x + self.attn.reset_cache(self.ln_1(x), num_steps_to_come=num_steps_to_come)
|
| 406 |
+
x = x + self.mlp(self.ln_2(x))
|
| 407 |
+
return x
|
| 408 |
+
|
| 409 |
+
class SelfAttentionBlock(nn.Module):
|
| 410 |
+
def __init__(self, config):
|
| 411 |
+
super().__init__()
|
| 412 |
+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
| 413 |
+
self.attn = SelfAttention(config)
|
| 414 |
+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
| 415 |
+
self.mlp = MLP(config)
|
| 416 |
+
|
| 417 |
+
def forward(self, x):
|
| 418 |
+
x = x + self.attn(self.ln_1(x))
|
| 419 |
+
x = x + self.mlp(self.ln_2(x))
|
| 420 |
+
return x
|
| 421 |
+
|
| 422 |
+
@dataclass
|
| 423 |
+
class GPTConfig:
|
| 424 |
+
block_size: int = 1024
|
| 425 |
+
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
|
| 426 |
+
n_layer: int = 12
|
| 427 |
+
n_head: int = 12
|
| 428 |
+
n_embd: int = 768
|
| 429 |
+
dropout: float = 0.0
|
| 430 |
+
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
| 431 |
+
version: int = 1 # Version 1 is the original GPT, Version 2 is the one with non-parametric LayerNorm and Pytorch's GELU
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class GPT(nn.Module):
|
| 435 |
+
def __init__(self, config):
|
| 436 |
+
super().__init__()
|
| 437 |
+
assert config.vocab_size is not None
|
| 438 |
+
assert config.block_size is not None
|
| 439 |
+
self.config = config
|
| 440 |
+
|
| 441 |
+
self.version = config.version
|
| 442 |
+
|
| 443 |
+
print(f"[nanoGPT] creating model with version {self.version}")
|
| 444 |
+
|
| 445 |
+
if self.version == 1:
|
| 446 |
+
transformer_dict = dict(
|
| 447 |
+
wpe=nn.Embedding(config.block_size, config.n_embd),
|
| 448 |
+
drop=nn.Dropout(config.dropout),
|
| 449 |
+
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
| 450 |
+
ln_f=LayerNorm(config.n_embd, bias=config.bias),
|
| 451 |
+
)
|
| 452 |
+
elif self.version == 2:
|
| 453 |
+
transformer_dict = dict(
|
| 454 |
+
wpe=nn.Embedding(config.block_size, config.n_embd),
|
| 455 |
+
drop=nn.Dropout(config.dropout),
|
| 456 |
+
h=nn.ModuleList([BlockV2(config) for _ in range(config.n_layer)]),
|
| 457 |
+
ln_f=LayerNorm(config.n_embd, bias=config.bias), # This one is still parametric due to user error
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
transformer_dict["wte"] = nn.Embedding(config.vocab_size, config.n_embd)
|
| 461 |
+
self.transformer = nn.ModuleDict(transformer_dict)
|
| 462 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 463 |
+
# with weight tying when using torch.compile() some warnings get generated:
|
| 464 |
+
# "UserWarning: functional_call was passed multiple values for tied weights.
|
| 465 |
+
# This behavior is deprecated and will be an error in future versions"
|
| 466 |
+
# not 100% sure what this is, so far seems to be harmless.
|
| 467 |
+
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
| 468 |
+
|
| 469 |
+
# init all weights
|
| 470 |
+
self.apply(self._init_weights)
|
| 471 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
| 472 |
+
for pn, p in self.named_parameters():
|
| 473 |
+
if pn.endswith("c_proj.weight"):
|
| 474 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
|
| 475 |
+
|
| 476 |
+
def get_num_params(self, non_embedding=True):
|
| 477 |
+
"""
|
| 478 |
+
Return the number of parameters in the model.
|
| 479 |
+
For non-embedding count (default), the position embeddings get subtracted.
|
| 480 |
+
The token embeddings would too, except due to the parameter sharing these
|
| 481 |
+
params are actually used as weights in the final layer, so we include them.
|
| 482 |
+
"""
|
| 483 |
+
n_params = sum(p.numel() for p in self.parameters())
|
| 484 |
+
if non_embedding:
|
| 485 |
+
n_params -= self.transformer.wpe.weight.numel()
|
| 486 |
+
return n_params
|
| 487 |
+
|
| 488 |
+
def _init_weights(self, module):
|
| 489 |
+
if isinstance(module, nn.Linear):
|
| 490 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 491 |
+
if module.bias is not None:
|
| 492 |
+
torch.nn.init.zeros_(module.bias)
|
| 493 |
+
elif isinstance(module, nn.Embedding):
|
| 494 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 495 |
+
|
| 496 |
+
def _apply_pos_encoding(self, x):
|
| 497 |
+
device = x.device
|
| 498 |
+
token_len = x.size(1)
|
| 499 |
+
pos = torch.arange(0, token_len, dtype=torch.long, device=device).unsqueeze(0)
|
| 500 |
+
pos_emb = self.transformer.wpe(pos)
|
| 501 |
+
x = x + pos_emb
|
| 502 |
+
return x
|
| 503 |
+
|
| 504 |
+
def original_forward(self, idx, targets=None, loss_mask=None, loss_reduction="mean"):
|
| 505 |
+
batch_size, seq_len = idx.shape[:2]
|
| 506 |
+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
| 507 |
+
x = self.transformer.drop(self._apply_pos_encoding(tok_emb))
|
| 508 |
+
for block in self.transformer.h:
|
| 509 |
+
x = block(x)
|
| 510 |
+
x = self.transformer.ln_f(x)
|
| 511 |
+
|
| 512 |
+
if targets is not None:
|
| 513 |
+
# if we are given some desired targets also calculate the loss
|
| 514 |
+
logits = self.lm_head(x)
|
| 515 |
+
if loss_mask is not None:
|
| 516 |
+
# Feeding target = CROSS_ENTROPY_INVALID_CLASS_TARGET to cross_entropy will ignore the loss
|
| 517 |
+
# for that position. This is useful for padding tokens.
|
| 518 |
+
targets[loss_mask == 0] = CROSS_ENTROPY_INVALID_CLASS_TARGET
|
| 519 |
+
loss = F.cross_entropy(
|
| 520 |
+
logits.view(batch_size * seq_len, self.config.vocab_size), targets.view(-1), ignore_index=CROSS_ENTROPY_INVALID_CLASS_TARGET, reduction=loss_reduction
|
| 521 |
+
)
|
| 522 |
+
if loss_reduction == "none":
|
| 523 |
+
# Reshape back into batch_size and seq_len
|
| 524 |
+
loss = loss.view(batch_size, seq_len)
|
| 525 |
+
else:
|
| 526 |
+
# inference-time mini-optimization: only forward the lm_head on the very last position
|
| 527 |
+
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
| 528 |
+
loss = None
|
| 529 |
+
|
| 530 |
+
return logits, loss
|
| 531 |
+
|
| 532 |
+
def forward(self, x, targets=None, loss_mask=None, loss_reduction="mean"):
|
| 533 |
+
token_len = x.size(1)
|
| 534 |
+
assert token_len <= self.config.block_size, f"Cannot forward sequence of length {token_len}, block size is only {self.config.block_size}"
|
| 535 |
+
return self.original_forward(x, targets, loss_mask, loss_reduction)
|
| 536 |
+
|
| 537 |
+
@torch.no_grad()
|
| 538 |
+
def generate(self, idx, max_new_tokens, valid_token_range=None, temperature=1.0, top_k=None, raise_cropping=False, deterministic=False):
|
| 539 |
+
"""
|
| 540 |
+
valid_token_range should be a tuple, specifying start and end indices we'd like to sample from (inclusive).
|
| 541 |
+
if None, we'll sample from the full vocab.
|
| 542 |
+
|
| 543 |
+
If raise_cropping is True, we'll raise an error if we need to crop the sequence context.
|
| 544 |
+
"""
|
| 545 |
+
if valid_token_range is None:
|
| 546 |
+
valid_token_range = (0, self.config.vocab_size - 1)
|
| 547 |
+
assert len(valid_token_range) == 2
|
| 548 |
+
assert valid_token_range[0] < valid_token_range[1]
|
| 549 |
+
for _ in range(max_new_tokens):
|
| 550 |
+
# if the sequence context is growing too long we must crop it at block_size
|
| 551 |
+
idx_cond = idx
|
| 552 |
+
if idx.size(1) > self.config.block_size:
|
| 553 |
+
if raise_cropping:
|
| 554 |
+
raise ValueError("Tried to crop idxs but flag told to raise this")
|
| 555 |
+
else:
|
| 556 |
+
idx_cond = idx[:, -self.config.block_size :]
|
| 557 |
+
# forward the model to get the logits for the index in the sequence
|
| 558 |
+
logits, _ = self(idx_cond)
|
| 559 |
+
# pluck the logits at the final step and scale by desired temperature
|
| 560 |
+
logits = logits[:, -1, :] / temperature # logits is B T Vocabsize -> B Vocabsize
|
| 561 |
+
# optionally crop the logits to only the top k options
|
| 562 |
+
if top_k is not None:
|
| 563 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 564 |
+
logits[logits < v[:, [-1]]] = NEGATIVE_INFINITE_FLOAT
|
| 565 |
+
|
| 566 |
+
# Crop out the logits we don't want to sample from
|
| 567 |
+
if valid_token_range is not None:
|
| 568 |
+
limit_logits_to_valid_range(logits, valid_token_range)
|
| 569 |
+
|
| 570 |
+
# apply softmax to convert logits to (normalized) probabilities
|
| 571 |
+
probs = F.softmax(logits, dim=-1)
|
| 572 |
+
|
| 573 |
+
if deterministic:
|
| 574 |
+
# Take max of the results
|
| 575 |
+
idx_next = torch.argmax(probs, dim=-1, keepdim=True)
|
| 576 |
+
else:
|
| 577 |
+
# sample from the distribution
|
| 578 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 579 |
+
# append sampled index to the running sequence and continue
|
| 580 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
| 581 |
+
|
| 582 |
+
return idx
|
| 583 |
+
|
| 584 |
+
@torch.no_grad()
|
| 585 |
+
def optimized_generate(
|
| 586 |
+
self,
|
| 587 |
+
idx,
|
| 588 |
+
num_new_tokens,
|
| 589 |
+
valid_token_ranges=None,
|
| 590 |
+
temperature=1.0,
|
| 591 |
+
deterministic=False,
|
| 592 |
+
raise_cropping=False,
|
| 593 |
+
top_k=None,
|
| 594 |
+
top_p=None,
|
| 595 |
+
min_tokens_to_keep=1,
|
| 596 |
+
):
|
| 597 |
+
"""
|
| 598 |
+
Generate function but optimized by caching the results in transformer blocks (think this is referred to as "attention caching").
|
| 599 |
+
The higher the num_new_tokens, the more the speedup compared to original generate.
|
| 600 |
+
|
| 601 |
+
Caveat: the context length + num_new_tokens must be less than the block size. This means that the first
|
| 602 |
+
generated tokens do not have full context length.
|
| 603 |
+
|
| 604 |
+
valid_token_ranges should be None or list of length num_new_tokens, specifying valid range for tokens for every step
|
| 605 |
+
"""
|
| 606 |
+
# Properly compile the modules used and/or quantize for improved speed.
|
| 607 |
+
logit_layer = self.lm_head
|
| 608 |
+
embedder_fn = self.transformer.wte
|
| 609 |
+
|
| 610 |
+
if valid_token_ranges is None:
|
| 611 |
+
valid_token_ranges = [[0, self.config.vocab_size] for _ in range(num_new_tokens)]
|
| 612 |
+
assert len(valid_token_ranges) == num_new_tokens, "valid_token_ranges should be list of length num_new_tokens or None"
|
| 613 |
+
|
| 614 |
+
_, token_len = idx.size()
|
| 615 |
+
if token_len + num_new_tokens > self.config.block_size:
|
| 616 |
+
raise ValueError("Can't use optimized generation with num_new_tokens + context_length > block_size")
|
| 617 |
+
new_idxs = torch.zeros(idx.size(0), num_new_tokens, dtype=torch.long, device=idx.device)
|
| 618 |
+
# First, we need to cull the sequence to the block size
|
| 619 |
+
# and remove first max_new_tokens so we can reuse same position embeddings
|
| 620 |
+
# and not have to recompute them
|
| 621 |
+
num_original_tokens = idx.size(1)
|
| 622 |
+
original_idx = idx
|
| 623 |
+
if (num_original_tokens + num_new_tokens) > self.config.block_size:
|
| 624 |
+
if raise_cropping:
|
| 625 |
+
raise ValueError("Tried to crop idxs but flag told to raise this")
|
| 626 |
+
original_idx = idx[:, -self.config.block_size + num_new_tokens :]
|
| 627 |
+
original_pos = torch.arange(0, original_idx.size(1), dtype=torch.long, device=idx.device).unsqueeze(0)
|
| 628 |
+
# Now cache results with the original context
|
| 629 |
+
original_tok_emb = embedder_fn(original_idx)
|
| 630 |
+
original_pos_emb = self.transformer.wpe(original_pos)
|
| 631 |
+
original_x = original_tok_emb + original_pos_emb
|
| 632 |
+
for block in self.transformer.h:
|
| 633 |
+
# Reset the cache for each block, and cache new result
|
| 634 |
+
original_x = block(original_x, reset_cache_with_num_steps_to_come=num_new_tokens)
|
| 635 |
+
|
| 636 |
+
# Sample the first token
|
| 637 |
+
original_x = self.transformer.ln_f(original_x)
|
| 638 |
+
last_logit = logit_layer(original_x[:, [-1], :])
|
| 639 |
+
new_idxs[:, 0] = default_sample_token(
|
| 640 |
+
last_logit[:, -1, :], valid_token_ranges[0], temperature, deterministic, top_k=top_k, top_p=top_p, min_tokens_to_keep=min_tokens_to_keep
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
# Generate rest of the steps
|
| 644 |
+
for generation_idx in range(1, num_new_tokens):
|
| 645 |
+
# forward the model to get the logits for the index in the sequence
|
| 646 |
+
# This is the position of the latest generated token, not the currently going-to-be-generated token
|
| 647 |
+
latest_token_pos = num_original_tokens + generation_idx - 1
|
| 648 |
+
# We only need to pass in the latest token
|
| 649 |
+
newest_idx = new_idxs[:, generation_idx - 1].unsqueeze(-1)
|
| 650 |
+
newest_tok_emb = embedder_fn(newest_idx)
|
| 651 |
+
newest_pos_emb = self.transformer.wpe(torch.tensor(latest_token_pos, dtype=torch.long, device=idx.device).unsqueeze(0))
|
| 652 |
+
newest_x = newest_tok_emb + newest_pos_emb
|
| 653 |
+
for block in self.transformer.h:
|
| 654 |
+
newest_x = block(newest_x, cache=True)
|
| 655 |
+
|
| 656 |
+
newest_x = self.transformer.ln_f(newest_x)
|
| 657 |
+
newest_logit = logit_layer(newest_x)
|
| 658 |
+
# Check this function isn't slowing things down noticeably
|
| 659 |
+
new_idxs[:, generation_idx] = default_sample_token(
|
| 660 |
+
newest_logit[:, -1, :], valid_token_ranges[generation_idx], temperature, deterministic, top_k=top_k, top_p=top_p, min_tokens_to_keep=min_tokens_to_keep
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
# Combine indices
|
| 664 |
+
new_idxs = torch.cat((idx, new_idxs), dim=1)
|
| 665 |
+
return new_idxs
|
wham/models/pl/__init__.py
ADDED
|
File without changes
|
wham/models/pl/pl_base_model.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
|
| 3 |
+
class BaseTrainingModel(pl.LightningModule):
|
| 4 |
+
def __init__(self, **kwargs):
|
| 5 |
+
super().__init__(**kwargs)
|
wham/models/vqgan/taming/LICENSE
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
All files under this directory are originally from the taming-transformers repository:
|
| 2 |
+
https://github.com/CompVis/taming-transformers
|
| 3 |
+
|
| 4 |
+
Below is a copy of the original license
|
| 5 |
+
------------------------------------------------------------------------------
|
| 6 |
+
Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
| 7 |
+
|
| 8 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
in the Software without restriction, including without limitation the rights
|
| 11 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
The above copyright notice and this permission notice shall be included in all
|
| 16 |
+
copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
| 19 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 20 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
| 21 |
+
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
| 22 |
+
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
| 23 |
+
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
| 24 |
+
OR OTHER DEALINGS IN THE SOFTWARE./
|
wham/models/vqgan/taming/model.py
ADDED
|
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# All files under this directory are originally from the taming-transformers repository:
|
| 2 |
+
# https://github.com/CompVis/taming-transformers
|
| 3 |
+
|
| 4 |
+
# MIT License
|
| 5 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
| 6 |
+
# 2023 Microsoft Research
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 16 |
+
# copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
| 19 |
+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 20 |
+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
| 21 |
+
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
| 22 |
+
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
| 23 |
+
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
| 24 |
+
# OR OTHER DEALINGS IN THE SOFTWARE.
|
| 25 |
+
|
| 26 |
+
import math
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
import numpy as np
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
| 33 |
+
"""
|
| 34 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
| 35 |
+
From Fairseq.
|
| 36 |
+
Build sinusoidal embeddings.
|
| 37 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
| 38 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
| 39 |
+
"""
|
| 40 |
+
assert len(timesteps.shape) == 1
|
| 41 |
+
|
| 42 |
+
half_dim = embedding_dim // 2
|
| 43 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 44 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
| 45 |
+
emb = emb.to(device=timesteps.device)
|
| 46 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
| 47 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 48 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 49 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 50 |
+
return emb
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def nonlinearity(x):
|
| 54 |
+
# swish
|
| 55 |
+
return x * torch.sigmoid(x)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def Normalize(in_channels):
|
| 59 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Upsample(nn.Module):
|
| 63 |
+
def __init__(self, in_channels, with_conv):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.with_conv = with_conv
|
| 66 |
+
if self.with_conv:
|
| 67 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 71 |
+
if self.with_conv:
|
| 72 |
+
x = self.conv(x)
|
| 73 |
+
return x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Downsample(nn.Module):
|
| 77 |
+
def __init__(self, in_channels, with_conv):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.with_conv = with_conv
|
| 80 |
+
if self.with_conv:
|
| 81 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 82 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
if self.with_conv:
|
| 86 |
+
pad = (0, 1, 0, 1)
|
| 87 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 88 |
+
x = self.conv(x)
|
| 89 |
+
else:
|
| 90 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ResnetBlock(nn.Module):
|
| 95 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.in_channels = in_channels
|
| 98 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 99 |
+
self.out_channels = out_channels
|
| 100 |
+
self.use_conv_shortcut = conv_shortcut
|
| 101 |
+
|
| 102 |
+
self.norm1 = Normalize(in_channels)
|
| 103 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 104 |
+
if temb_channels > 0:
|
| 105 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
| 106 |
+
self.norm2 = Normalize(out_channels)
|
| 107 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 108 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 109 |
+
if self.in_channels != self.out_channels:
|
| 110 |
+
if self.use_conv_shortcut:
|
| 111 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 112 |
+
else:
|
| 113 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 114 |
+
|
| 115 |
+
def forward(self, x, temb):
|
| 116 |
+
h = x
|
| 117 |
+
h = self.norm1(h)
|
| 118 |
+
h = nonlinearity(h)
|
| 119 |
+
h = self.conv1(h)
|
| 120 |
+
|
| 121 |
+
if temb is not None:
|
| 122 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
| 123 |
+
|
| 124 |
+
h = self.norm2(h)
|
| 125 |
+
h = nonlinearity(h)
|
| 126 |
+
h = self.dropout(h)
|
| 127 |
+
h = self.conv2(h)
|
| 128 |
+
|
| 129 |
+
if self.in_channels != self.out_channels:
|
| 130 |
+
if self.use_conv_shortcut:
|
| 131 |
+
x = self.conv_shortcut(x)
|
| 132 |
+
else:
|
| 133 |
+
x = self.nin_shortcut(x)
|
| 134 |
+
|
| 135 |
+
return x + h
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class AttnBlock(nn.Module):
|
| 139 |
+
def __init__(self, in_channels):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.in_channels = in_channels
|
| 142 |
+
|
| 143 |
+
self.norm = Normalize(in_channels)
|
| 144 |
+
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 145 |
+
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 146 |
+
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 147 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
| 148 |
+
|
| 149 |
+
def forward(self, x):
|
| 150 |
+
h_ = x
|
| 151 |
+
h_ = self.norm(h_)
|
| 152 |
+
q = self.q(h_)
|
| 153 |
+
k = self.k(h_)
|
| 154 |
+
v = self.v(h_)
|
| 155 |
+
|
| 156 |
+
# compute attention
|
| 157 |
+
b, c, h, w = q.shape
|
| 158 |
+
q = q.reshape(b, c, h * w)
|
| 159 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
| 160 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
| 161 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 162 |
+
w_ = w_ * (int(c) ** (-0.5))
|
| 163 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 164 |
+
|
| 165 |
+
# attend to values
|
| 166 |
+
v = v.reshape(b, c, h * w)
|
| 167 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
| 168 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 169 |
+
h_ = h_.reshape(b, c, h, w)
|
| 170 |
+
|
| 171 |
+
h_ = self.proj_out(h_)
|
| 172 |
+
|
| 173 |
+
return x + h_
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class Model(nn.Module):
|
| 177 |
+
def __init__(
|
| 178 |
+
self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True
|
| 179 |
+
):
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.ch = ch
|
| 182 |
+
self.temb_ch = self.ch * 4
|
| 183 |
+
self.num_resolutions = len(ch_mult)
|
| 184 |
+
self.num_res_blocks = num_res_blocks
|
| 185 |
+
self.resolution = resolution
|
| 186 |
+
self.in_channels = in_channels
|
| 187 |
+
|
| 188 |
+
self.use_timestep = use_timestep
|
| 189 |
+
if self.use_timestep:
|
| 190 |
+
# timestep embedding
|
| 191 |
+
self.temb = nn.Module()
|
| 192 |
+
self.temb.dense = nn.ModuleList(
|
| 193 |
+
[
|
| 194 |
+
torch.nn.Linear(self.ch, self.temb_ch),
|
| 195 |
+
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
| 196 |
+
]
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# downsampling
|
| 200 |
+
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 201 |
+
|
| 202 |
+
curr_res = resolution
|
| 203 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 204 |
+
self.down = nn.ModuleList()
|
| 205 |
+
for i_level in range(self.num_resolutions):
|
| 206 |
+
block = nn.ModuleList()
|
| 207 |
+
attn = nn.ModuleList()
|
| 208 |
+
block_in = ch * in_ch_mult[i_level]
|
| 209 |
+
block_out = ch * ch_mult[i_level]
|
| 210 |
+
for i_block in range(self.num_res_blocks):
|
| 211 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
| 212 |
+
block_in = block_out
|
| 213 |
+
if curr_res in attn_resolutions:
|
| 214 |
+
attn.append(AttnBlock(block_in))
|
| 215 |
+
down = nn.Module()
|
| 216 |
+
down.block = block
|
| 217 |
+
down.attn = attn
|
| 218 |
+
if i_level != self.num_resolutions - 1:
|
| 219 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 220 |
+
curr_res = curr_res // 2
|
| 221 |
+
self.down.append(down)
|
| 222 |
+
|
| 223 |
+
# middle
|
| 224 |
+
self.mid = nn.Module()
|
| 225 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 226 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 227 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 228 |
+
|
| 229 |
+
# upsampling
|
| 230 |
+
self.up = nn.ModuleList()
|
| 231 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 232 |
+
block = nn.ModuleList()
|
| 233 |
+
attn = nn.ModuleList()
|
| 234 |
+
block_out = ch * ch_mult[i_level]
|
| 235 |
+
skip_in = ch * ch_mult[i_level]
|
| 236 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 237 |
+
if i_block == self.num_res_blocks:
|
| 238 |
+
skip_in = ch * in_ch_mult[i_level]
|
| 239 |
+
block.append(ResnetBlock(in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
| 240 |
+
block_in = block_out
|
| 241 |
+
if curr_res in attn_resolutions:
|
| 242 |
+
attn.append(AttnBlock(block_in))
|
| 243 |
+
up = nn.Module()
|
| 244 |
+
up.block = block
|
| 245 |
+
up.attn = attn
|
| 246 |
+
if i_level != 0:
|
| 247 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 248 |
+
curr_res = curr_res * 2
|
| 249 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 250 |
+
|
| 251 |
+
# end
|
| 252 |
+
self.norm_out = Normalize(block_in)
|
| 253 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 254 |
+
|
| 255 |
+
def forward(self, x, t=None):
|
| 256 |
+
# assert x.shape[2] == x.shape[3] == self.resolution
|
| 257 |
+
|
| 258 |
+
if self.use_timestep:
|
| 259 |
+
# timestep embedding
|
| 260 |
+
assert t is not None
|
| 261 |
+
temb = get_timestep_embedding(t, self.ch)
|
| 262 |
+
temb = self.temb.dense[0](temb)
|
| 263 |
+
temb = nonlinearity(temb)
|
| 264 |
+
temb = self.temb.dense[1](temb)
|
| 265 |
+
else:
|
| 266 |
+
temb = None
|
| 267 |
+
|
| 268 |
+
# downsampling
|
| 269 |
+
hs = [self.conv_in(x)]
|
| 270 |
+
for i_level in range(self.num_resolutions):
|
| 271 |
+
for i_block in range(self.num_res_blocks):
|
| 272 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 273 |
+
if len(self.down[i_level].attn) > 0:
|
| 274 |
+
h = self.down[i_level].attn[i_block](h)
|
| 275 |
+
hs.append(h)
|
| 276 |
+
if i_level != self.num_resolutions - 1:
|
| 277 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 278 |
+
|
| 279 |
+
# middle
|
| 280 |
+
h = hs[-1]
|
| 281 |
+
h = self.mid.block_1(h, temb)
|
| 282 |
+
h = self.mid.attn_1(h)
|
| 283 |
+
h = self.mid.block_2(h, temb)
|
| 284 |
+
|
| 285 |
+
# upsampling
|
| 286 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 287 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 288 |
+
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
|
| 289 |
+
if len(self.up[i_level].attn) > 0:
|
| 290 |
+
h = self.up[i_level].attn[i_block](h)
|
| 291 |
+
if i_level != 0:
|
| 292 |
+
h = self.up[i_level].upsample(h)
|
| 293 |
+
|
| 294 |
+
# end
|
| 295 |
+
h = self.norm_out(h)
|
| 296 |
+
h = nonlinearity(h)
|
| 297 |
+
h = self.conv_out(h)
|
| 298 |
+
return h
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class Encoder(nn.Module):
|
| 302 |
+
def __init__(
|
| 303 |
+
self,
|
| 304 |
+
*,
|
| 305 |
+
ch,
|
| 306 |
+
out_ch,
|
| 307 |
+
ch_mult=(1, 2, 4, 8),
|
| 308 |
+
num_res_blocks,
|
| 309 |
+
attn_resolutions,
|
| 310 |
+
dropout=0.0,
|
| 311 |
+
resamp_with_conv=True,
|
| 312 |
+
in_channels,
|
| 313 |
+
resolution,
|
| 314 |
+
z_channels,
|
| 315 |
+
double_z=True,
|
| 316 |
+
**ignore_kwargs
|
| 317 |
+
):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.ch = ch
|
| 320 |
+
self.temb_ch = 0
|
| 321 |
+
self.num_resolutions = len(ch_mult)
|
| 322 |
+
self.num_res_blocks = num_res_blocks
|
| 323 |
+
self.resolution = resolution
|
| 324 |
+
self.in_channels = in_channels
|
| 325 |
+
|
| 326 |
+
# downsampling
|
| 327 |
+
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 328 |
+
|
| 329 |
+
curr_res = resolution
|
| 330 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 331 |
+
self.down = nn.ModuleList()
|
| 332 |
+
for i_level in range(self.num_resolutions):
|
| 333 |
+
block = nn.ModuleList()
|
| 334 |
+
attn = nn.ModuleList()
|
| 335 |
+
block_in = ch * in_ch_mult[i_level]
|
| 336 |
+
block_out = ch * ch_mult[i_level]
|
| 337 |
+
for i_block in range(self.num_res_blocks):
|
| 338 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
| 339 |
+
block_in = block_out
|
| 340 |
+
if curr_res in attn_resolutions:
|
| 341 |
+
attn.append(AttnBlock(block_in))
|
| 342 |
+
down = nn.Module()
|
| 343 |
+
down.block = block
|
| 344 |
+
down.attn = attn
|
| 345 |
+
if i_level != self.num_resolutions - 1:
|
| 346 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 347 |
+
curr_res = curr_res // 2
|
| 348 |
+
self.down.append(down)
|
| 349 |
+
|
| 350 |
+
# middle
|
| 351 |
+
self.mid = nn.Module()
|
| 352 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 353 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 354 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 355 |
+
|
| 356 |
+
# end
|
| 357 |
+
self.norm_out = Normalize(block_in)
|
| 358 |
+
self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1)
|
| 359 |
+
|
| 360 |
+
def forward(self, x):
|
| 361 |
+
# assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
| 362 |
+
|
| 363 |
+
# timestep embedding
|
| 364 |
+
temb = None
|
| 365 |
+
|
| 366 |
+
# downsampling
|
| 367 |
+
hs = [self.conv_in(x)]
|
| 368 |
+
for i_level in range(self.num_resolutions):
|
| 369 |
+
for i_block in range(self.num_res_blocks):
|
| 370 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 371 |
+
if len(self.down[i_level].attn) > 0:
|
| 372 |
+
h = self.down[i_level].attn[i_block](h)
|
| 373 |
+
hs.append(h)
|
| 374 |
+
if i_level != self.num_resolutions - 1:
|
| 375 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 376 |
+
|
| 377 |
+
# middle
|
| 378 |
+
h = hs[-1]
|
| 379 |
+
h = self.mid.block_1(h, temb)
|
| 380 |
+
h = self.mid.attn_1(h)
|
| 381 |
+
h = self.mid.block_2(h, temb)
|
| 382 |
+
|
| 383 |
+
# end
|
| 384 |
+
h = self.norm_out(h)
|
| 385 |
+
h = nonlinearity(h)
|
| 386 |
+
h = self.conv_out(h)
|
| 387 |
+
return h
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class Decoder(nn.Module):
|
| 391 |
+
def __init__(
|
| 392 |
+
self,
|
| 393 |
+
*,
|
| 394 |
+
ch,
|
| 395 |
+
out_ch,
|
| 396 |
+
ch_mult=(1, 2, 4, 8),
|
| 397 |
+
num_res_blocks,
|
| 398 |
+
attn_resolutions,
|
| 399 |
+
dropout=0.0,
|
| 400 |
+
resamp_with_conv=True,
|
| 401 |
+
in_channels,
|
| 402 |
+
resolution,
|
| 403 |
+
z_channels,
|
| 404 |
+
give_pre_end=False,
|
| 405 |
+
**ignorekwargs
|
| 406 |
+
):
|
| 407 |
+
super().__init__()
|
| 408 |
+
self.ch = ch
|
| 409 |
+
self.temb_ch = 0
|
| 410 |
+
self.num_resolutions = len(ch_mult)
|
| 411 |
+
self.num_res_blocks = num_res_blocks
|
| 412 |
+
self.resolution = resolution
|
| 413 |
+
self.in_channels = in_channels
|
| 414 |
+
self.give_pre_end = give_pre_end
|
| 415 |
+
|
| 416 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 417 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 418 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 419 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 420 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 421 |
+
|
| 422 |
+
# z to block_in
|
| 423 |
+
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 424 |
+
|
| 425 |
+
# middle
|
| 426 |
+
self.mid = nn.Module()
|
| 427 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 428 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 429 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 430 |
+
|
| 431 |
+
# upsampling
|
| 432 |
+
self.up = nn.ModuleList()
|
| 433 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 434 |
+
block = nn.ModuleList()
|
| 435 |
+
attn = nn.ModuleList()
|
| 436 |
+
block_out = ch * ch_mult[i_level]
|
| 437 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 438 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
| 439 |
+
block_in = block_out
|
| 440 |
+
if curr_res in attn_resolutions:
|
| 441 |
+
attn.append(AttnBlock(block_in))
|
| 442 |
+
up = nn.Module()
|
| 443 |
+
up.block = block
|
| 444 |
+
up.attn = attn
|
| 445 |
+
if i_level != 0:
|
| 446 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 447 |
+
curr_res = curr_res * 2
|
| 448 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 449 |
+
|
| 450 |
+
# end
|
| 451 |
+
self.norm_out = Normalize(block_in)
|
| 452 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 453 |
+
|
| 454 |
+
def forward(self, z):
|
| 455 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
| 456 |
+
self.last_z_shape = z.shape
|
| 457 |
+
|
| 458 |
+
# timestep embedding
|
| 459 |
+
temb = None
|
| 460 |
+
|
| 461 |
+
# z to block_in
|
| 462 |
+
h = self.conv_in(z)
|
| 463 |
+
|
| 464 |
+
# middle
|
| 465 |
+
h = self.mid.block_1(h, temb)
|
| 466 |
+
h = self.mid.attn_1(h)
|
| 467 |
+
h = self.mid.block_2(h, temb)
|
| 468 |
+
|
| 469 |
+
# upsampling
|
| 470 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 471 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 472 |
+
h = self.up[i_level].block[i_block](h, temb)
|
| 473 |
+
if len(self.up[i_level].attn) > 0:
|
| 474 |
+
h = self.up[i_level].attn[i_block](h)
|
| 475 |
+
if i_level != 0:
|
| 476 |
+
h = self.up[i_level].upsample(h)
|
| 477 |
+
|
| 478 |
+
# end
|
| 479 |
+
if self.give_pre_end:
|
| 480 |
+
return h
|
| 481 |
+
|
| 482 |
+
h = self.norm_out(h)
|
| 483 |
+
h = nonlinearity(h)
|
| 484 |
+
h = self.conv_out(h)
|
| 485 |
+
return h
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class VUNet(nn.Module):
|
| 489 |
+
def __init__(
|
| 490 |
+
self,
|
| 491 |
+
*,
|
| 492 |
+
ch,
|
| 493 |
+
out_ch,
|
| 494 |
+
ch_mult=(1, 2, 4, 8),
|
| 495 |
+
num_res_blocks,
|
| 496 |
+
attn_resolutions,
|
| 497 |
+
dropout=0.0,
|
| 498 |
+
resamp_with_conv=True,
|
| 499 |
+
in_channels,
|
| 500 |
+
c_channels,
|
| 501 |
+
resolution,
|
| 502 |
+
z_channels,
|
| 503 |
+
use_timestep=False,
|
| 504 |
+
**ignore_kwargs
|
| 505 |
+
):
|
| 506 |
+
super().__init__()
|
| 507 |
+
self.ch = ch
|
| 508 |
+
self.temb_ch = self.ch * 4
|
| 509 |
+
self.num_resolutions = len(ch_mult)
|
| 510 |
+
self.num_res_blocks = num_res_blocks
|
| 511 |
+
self.resolution = resolution
|
| 512 |
+
|
| 513 |
+
self.use_timestep = use_timestep
|
| 514 |
+
if self.use_timestep:
|
| 515 |
+
# timestep embedding
|
| 516 |
+
self.temb = nn.Module()
|
| 517 |
+
self.temb.dense = nn.ModuleList(
|
| 518 |
+
[
|
| 519 |
+
torch.nn.Linear(self.ch, self.temb_ch),
|
| 520 |
+
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
| 521 |
+
]
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
# downsampling
|
| 525 |
+
self.conv_in = torch.nn.Conv2d(c_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 526 |
+
|
| 527 |
+
curr_res = resolution
|
| 528 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 529 |
+
self.down = nn.ModuleList()
|
| 530 |
+
for i_level in range(self.num_resolutions):
|
| 531 |
+
block = nn.ModuleList()
|
| 532 |
+
attn = nn.ModuleList()
|
| 533 |
+
block_in = ch * in_ch_mult[i_level]
|
| 534 |
+
block_out = ch * ch_mult[i_level]
|
| 535 |
+
for i_block in range(self.num_res_blocks):
|
| 536 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
| 537 |
+
block_in = block_out
|
| 538 |
+
if curr_res in attn_resolutions:
|
| 539 |
+
attn.append(AttnBlock(block_in))
|
| 540 |
+
down = nn.Module()
|
| 541 |
+
down.block = block
|
| 542 |
+
down.attn = attn
|
| 543 |
+
if i_level != self.num_resolutions - 1:
|
| 544 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 545 |
+
curr_res = curr_res // 2
|
| 546 |
+
self.down.append(down)
|
| 547 |
+
|
| 548 |
+
self.z_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=1, stride=1, padding=0)
|
| 549 |
+
# middle
|
| 550 |
+
self.mid = nn.Module()
|
| 551 |
+
self.mid.block_1 = ResnetBlock(in_channels=2 * block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 552 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 553 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
|
| 554 |
+
|
| 555 |
+
# upsampling
|
| 556 |
+
self.up = nn.ModuleList()
|
| 557 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 558 |
+
block = nn.ModuleList()
|
| 559 |
+
attn = nn.ModuleList()
|
| 560 |
+
block_out = ch * ch_mult[i_level]
|
| 561 |
+
skip_in = ch * ch_mult[i_level]
|
| 562 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 563 |
+
if i_block == self.num_res_blocks:
|
| 564 |
+
skip_in = ch * in_ch_mult[i_level]
|
| 565 |
+
block.append(ResnetBlock(in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
| 566 |
+
block_in = block_out
|
| 567 |
+
if curr_res in attn_resolutions:
|
| 568 |
+
attn.append(AttnBlock(block_in))
|
| 569 |
+
up = nn.Module()
|
| 570 |
+
up.block = block
|
| 571 |
+
up.attn = attn
|
| 572 |
+
if i_level != 0:
|
| 573 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 574 |
+
curr_res = curr_res * 2
|
| 575 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 576 |
+
|
| 577 |
+
# end
|
| 578 |
+
self.norm_out = Normalize(block_in)
|
| 579 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 580 |
+
|
| 581 |
+
def forward(self, x, z):
|
| 582 |
+
# assert x.shape[2] == x.shape[3] == self.resolution
|
| 583 |
+
|
| 584 |
+
if self.use_timestep:
|
| 585 |
+
# timestep embedding
|
| 586 |
+
assert t is not None
|
| 587 |
+
temb = get_timestep_embedding(t, self.ch)
|
| 588 |
+
temb = self.temb.dense[0](temb)
|
| 589 |
+
temb = nonlinearity(temb)
|
| 590 |
+
temb = self.temb.dense[1](temb)
|
| 591 |
+
else:
|
| 592 |
+
temb = None
|
| 593 |
+
|
| 594 |
+
# downsampling
|
| 595 |
+
hs = [self.conv_in(x)]
|
| 596 |
+
for i_level in range(self.num_resolutions):
|
| 597 |
+
for i_block in range(self.num_res_blocks):
|
| 598 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 599 |
+
if len(self.down[i_level].attn) > 0:
|
| 600 |
+
h = self.down[i_level].attn[i_block](h)
|
| 601 |
+
hs.append(h)
|
| 602 |
+
if i_level != self.num_resolutions - 1:
|
| 603 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 604 |
+
|
| 605 |
+
# middle
|
| 606 |
+
h = hs[-1]
|
| 607 |
+
z = self.z_in(z)
|
| 608 |
+
h = torch.cat((h, z), dim=1)
|
| 609 |
+
h = self.mid.block_1(h, temb)
|
| 610 |
+
h = self.mid.attn_1(h)
|
| 611 |
+
h = self.mid.block_2(h, temb)
|
| 612 |
+
|
| 613 |
+
# upsampling
|
| 614 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 615 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 616 |
+
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
|
| 617 |
+
if len(self.up[i_level].attn) > 0:
|
| 618 |
+
h = self.up[i_level].attn[i_block](h)
|
| 619 |
+
if i_level != 0:
|
| 620 |
+
h = self.up[i_level].upsample(h)
|
| 621 |
+
|
| 622 |
+
# end
|
| 623 |
+
h = self.norm_out(h)
|
| 624 |
+
h = nonlinearity(h)
|
| 625 |
+
h = self.conv_out(h)
|
| 626 |
+
return h
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
class SimpleDecoder(nn.Module):
|
| 630 |
+
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
| 631 |
+
super().__init__()
|
| 632 |
+
self.model = nn.ModuleList(
|
| 633 |
+
[
|
| 634 |
+
nn.Conv2d(in_channels, in_channels, 1),
|
| 635 |
+
ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
|
| 636 |
+
ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0),
|
| 637 |
+
ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
|
| 638 |
+
nn.Conv2d(2 * in_channels, in_channels, 1),
|
| 639 |
+
Upsample(in_channels, with_conv=True),
|
| 640 |
+
]
|
| 641 |
+
)
|
| 642 |
+
# end
|
| 643 |
+
self.norm_out = Normalize(in_channels)
|
| 644 |
+
self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 645 |
+
|
| 646 |
+
def forward(self, x):
|
| 647 |
+
for i, layer in enumerate(self.model):
|
| 648 |
+
if i in [1, 2, 3]:
|
| 649 |
+
x = layer(x, None)
|
| 650 |
+
else:
|
| 651 |
+
x = layer(x)
|
| 652 |
+
|
| 653 |
+
h = self.norm_out(x)
|
| 654 |
+
h = nonlinearity(h)
|
| 655 |
+
x = self.conv_out(h)
|
| 656 |
+
return x
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
class UpsampleDecoder(nn.Module):
|
| 660 |
+
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
|
| 661 |
+
super().__init__()
|
| 662 |
+
# upsampling
|
| 663 |
+
self.temb_ch = 0
|
| 664 |
+
self.num_resolutions = len(ch_mult)
|
| 665 |
+
self.num_res_blocks = num_res_blocks
|
| 666 |
+
block_in = in_channels
|
| 667 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 668 |
+
self.res_blocks = nn.ModuleList()
|
| 669 |
+
self.upsample_blocks = nn.ModuleList()
|
| 670 |
+
for i_level in range(self.num_resolutions):
|
| 671 |
+
res_block = []
|
| 672 |
+
block_out = ch * ch_mult[i_level]
|
| 673 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 674 |
+
res_block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout))
|
| 675 |
+
block_in = block_out
|
| 676 |
+
self.res_blocks.append(nn.ModuleList(res_block))
|
| 677 |
+
if i_level != self.num_resolutions - 1:
|
| 678 |
+
self.upsample_blocks.append(Upsample(block_in, True))
|
| 679 |
+
curr_res = curr_res * 2
|
| 680 |
+
|
| 681 |
+
# end
|
| 682 |
+
self.norm_out = Normalize(block_in)
|
| 683 |
+
self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
| 684 |
+
|
| 685 |
+
def forward(self, x):
|
| 686 |
+
# upsampling
|
| 687 |
+
h = x
|
| 688 |
+
for k, i_level in enumerate(range(self.num_resolutions)):
|
| 689 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 690 |
+
h = self.res_blocks[i_level][i_block](h, None)
|
| 691 |
+
if i_level != self.num_resolutions - 1:
|
| 692 |
+
h = self.upsample_blocks[k](h)
|
| 693 |
+
h = self.norm_out(h)
|
| 694 |
+
h = nonlinearity(h)
|
| 695 |
+
h = self.conv_out(h)
|
| 696 |
+
return h
|
wham/models/vqgan/taming/quantize.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# All files under this directory are originally from the taming-transformers repository:
|
| 2 |
+
# https://github.com/CompVis/taming-transformers
|
| 3 |
+
|
| 4 |
+
# MIT License
|
| 5 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
|
| 6 |
+
# 2023 Microsoft Research
|
| 7 |
+
|
| 8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 10 |
+
# in the Software without restriction, including without limitation the rights
|
| 11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 13 |
+
# furnished to do so, subject to the following conditions:
|
| 14 |
+
|
| 15 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 16 |
+
# copies or substantial portions of the Software.
|
| 17 |
+
|
| 18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
| 19 |
+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 20 |
+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
| 21 |
+
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
| 22 |
+
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
| 23 |
+
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
| 24 |
+
# OR OTHER DEALINGS IN THE SOFTWARE.
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import numpy as np
|
| 29 |
+
from einops import rearrange
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class VectorQuantizer2(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
| 35 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
| 39 |
+
# backwards compatibility we use the buggy version by default, but you can
|
| 40 |
+
# specify legacy=False to fix it.
|
| 41 |
+
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.n_e = n_e
|
| 44 |
+
self.e_dim = e_dim
|
| 45 |
+
self.beta = beta
|
| 46 |
+
self.legacy = legacy
|
| 47 |
+
|
| 48 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
| 49 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
| 50 |
+
|
| 51 |
+
self.remap = remap
|
| 52 |
+
if self.remap is not None:
|
| 53 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
| 54 |
+
self.re_embed = self.used.shape[0]
|
| 55 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
| 56 |
+
if self.unknown_index == "extra":
|
| 57 |
+
self.unknown_index = self.re_embed
|
| 58 |
+
self.re_embed = self.re_embed + 1
|
| 59 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices.")
|
| 60 |
+
else:
|
| 61 |
+
self.re_embed = n_e
|
| 62 |
+
|
| 63 |
+
self.sane_index_shape = sane_index_shape
|
| 64 |
+
|
| 65 |
+
def remap_to_used(self, inds):
|
| 66 |
+
ishape = inds.shape
|
| 67 |
+
assert len(ishape) > 1
|
| 68 |
+
inds = inds.reshape(ishape[0], -1)
|
| 69 |
+
used = self.used.to(inds)
|
| 70 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
| 71 |
+
new = match.argmax(-1)
|
| 72 |
+
unknown = match.sum(2) < 1
|
| 73 |
+
if self.unknown_index == "random":
|
| 74 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
|
| 75 |
+
else:
|
| 76 |
+
new[unknown] = self.unknown_index
|
| 77 |
+
return new.reshape(ishape)
|
| 78 |
+
|
| 79 |
+
def unmap_to_all(self, inds):
|
| 80 |
+
ishape = inds.shape
|
| 81 |
+
assert len(ishape) > 1
|
| 82 |
+
inds = inds.reshape(ishape[0], -1)
|
| 83 |
+
used = self.used.to(inds)
|
| 84 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
| 85 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
| 86 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
| 87 |
+
return back.reshape(ishape)
|
| 88 |
+
|
| 89 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
| 90 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
| 91 |
+
assert rescale_logits == False, "Only for interface compatible with Gumbel"
|
| 92 |
+
assert return_logits == False, "Only for interface compatible with Gumbel"
|
| 93 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 94 |
+
z = rearrange(z, "b c h w -> b h w c").contiguous()
|
| 95 |
+
z_flattened = z.view(-1, self.e_dim)
|
| 96 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 97 |
+
|
| 98 |
+
d = (
|
| 99 |
+
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
| 100 |
+
+ torch.sum(self.embedding.weight**2, dim=1)
|
| 101 |
+
- 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
| 105 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
| 106 |
+
perplexity = None
|
| 107 |
+
min_encodings = None
|
| 108 |
+
|
| 109 |
+
# compute loss for embedding
|
| 110 |
+
if not self.legacy:
|
| 111 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
|
| 112 |
+
else:
|
| 113 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
| 114 |
+
|
| 115 |
+
# preserve gradients
|
| 116 |
+
z_q = z + (z_q - z).detach()
|
| 117 |
+
|
| 118 |
+
# reshape back to match original input shape
|
| 119 |
+
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
|
| 120 |
+
|
| 121 |
+
if self.remap is not None:
|
| 122 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
|
| 123 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
| 124 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
| 125 |
+
|
| 126 |
+
if self.sane_index_shape:
|
| 127 |
+
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
| 128 |
+
|
| 129 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
| 130 |
+
|
| 131 |
+
def get_codebook_entry(self, indices, shape):
|
| 132 |
+
# shape specifying (batch, height, width, channel)
|
| 133 |
+
if self.remap is not None:
|
| 134 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
| 135 |
+
indices = self.unmap_to_all(indices)
|
| 136 |
+
indices = indices.reshape(-1) # flatten again
|
| 137 |
+
|
| 138 |
+
# get quantized latent vectors
|
| 139 |
+
z_q = self.embedding(indices)
|
| 140 |
+
|
| 141 |
+
if shape is not None:
|
| 142 |
+
z_q = z_q.view(shape)
|
| 143 |
+
# reshape back to match original input shape
|
| 144 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 145 |
+
|
| 146 |
+
return z_q
|