| | |
| |
|
| | import torch |
| | from transformers import AutoTokenizer |
| | from mCLM.model.qwen_based.model import Qwen2ForCausalLM |
| | from mCLM.tokenizer.molecule_tokenizer import MoleculeTokenizer |
| |
|
| | |
| | model = Qwen2ForCausalLM.from_pretrained( |
| | "YOUR_REPO_ID", |
| | torch_dtype=torch.bfloat16, |
| | device_map="auto" |
| | ) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("YOUR_REPO_ID") |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | |
| | torch.serialization.add_safe_globals([MoleculeTokenizer]) |
| | molecule_tokenizer = torch.load("molecule_tokenizer.pth", weights_only=False) |
| |
|
| | |
| | user_input = "What is aspirin used for?" |
| | messages = [{"role": "user", "content": user_input}] |
| | inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt") |
| |
|
| | outputs = model.generate(input_ids=inputs, max_new_tokens=256) |
| | response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | print(response) |
| |
|