Upload 全流程.ipynb

#2
by marisming - opened
Files changed (1) hide show
  1. 全流程.ipynb +536 -0
全流程.ipynb ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "8b176d65-99f7-42a8-a6b6-4ec7ecceadf2",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "基础包下载:\n",
11
+ "!pip install transformers sentencepiece google protobuf deepspeed peft datasets "
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "id": "4702e6bb-8ade-4929-9981-f83b95d92606",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "设置huggingface镜像:\n",
22
+ "import os\n",
23
+ "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
24
+ "print(os.environ.get('HF_ENDPOINT'))"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "id": "6f64e588-a7d3-4009-bc98-f45703781ae8",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "autodl学术资源加速,在终端运行\n",
35
+ "source /etc/network_turbo"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "f4de3e11-6de0-4741-afa4-69dc73abd191",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "#lfs 支持,用于git clone一些需要lfs的包\n",
46
+ "!apt-get update\n",
47
+ "!apt-get install git-lfs\n",
48
+ "!git lfs install"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "id": "a774d0a1-0582-443f-89c6-cd7f4a84966a",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "#下载好数据后,读取dna数据,分为训练集train和测试集test,默认已经shuffle\n",
59
+ "from datasets import load_dataset\n",
60
+ "dna_dataset = load_dataset('text', data_files='data/dna_1g.txt')['train'].train_test_split(test_size=0.05)\n",
61
+ "dna_dataset"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "id": "5243ba2d-1e95-4161-98d9-a403f7270c74",
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "dna_dataset[\"train\"][0]"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "id": "a5b5e607-31b1-433f-a094-2fad9e4bc472",
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "前面这些数据集,就是常规的文本,一般就是当做预训练数据使用,而分类等下游微调任务,\n",
82
+ "一般都是包含标签的,多写成json或者csv的格式,这里也给出一个例子:\n",
83
+ "ft_dataset = load_dataset('json', data_files='data/dna_protein_my.json')\n",
84
+ "ft_dataset[\"train\"][0]"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "id": "d1ebc301-a222-4de9-be79-0150434f25f5",
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "当然,如果数据集过大,我们只需要其中一部分,这个也是一个常见的需求,一般可以使用 Dataset.select()函数\n",
95
+ "from datasets import load_dataset, DatasetDict\n",
96
+ "dna_dataset_sample = DatasetDict(\n",
97
+ " {\n",
98
+ " \"train\": dna_dataset[\"train\"].shuffle().select(range(50000)), \n",
99
+ " \"valid\": dna_dataset[\"test\"].shuffle().select(range(500)),\n",
100
+ " \"evla\": dna_dataset[\"test\"].shuffle().select(range(500))\n",
101
+ "\n",
102
+ " }\n",
103
+ ")\n",
104
+ "dna_dataset_sample\n",
105
+ "可以看到,我们使用DatasetDict来直接构造datasets,先使用shuffle()来随机,然后使用select来选择前n个数据\n",
106
+ "select的参数为indices (list 或 range): 索引列表或范围对象,指明要选择哪些样本,\n",
107
+ "如dataset.select([0, 2, 4])就是选择1,3,5条记录"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "id": "2ffbe618-7146-49c6-a8bd-f1d7e9f0ad4d",
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "分享数据集到huggingface上面\n",
118
+ "dna_data.push_to_hub(\"org_name/your_dataset_name\", token=\"hf_yourtoken\")"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "id": "f8512c9e-c673-49db-a60a-f818a546852f",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "从头训练一个基于BPE的DNA分词器\n",
129
+ "from tokenizers import (\n",
130
+ " decoders,\n",
131
+ " models,\n",
132
+ " normalizers,\n",
133
+ " pre_tokenizers,\n",
134
+ " processors,\n",
135
+ " trainers,\n",
136
+ " Tokenizer,\n",
137
+ ")\n",
138
+ "from transformers import AutoTokenizer"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "id": "5937e169-ee42-44b4-9939-3792cde80ac5",
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": [
148
+ "主分词器套分词算法\n",
149
+ "tokenizer = Tokenizer(models.BPE())\n",
150
+ "#预处理,ByteLevel就是按UTF-8分词,use_regex=False,空格当成一般字符串\n",
151
+ "tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False) \n",
152
+ "训练器,生成词表合并规则,词表大小3w\n",
153
+ "trainer1 = trainers.BpeTrainer(vocab_size=30000, special_tokens=[\"<|endoftext|>\"])"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "id": "a7249164-b5f7-4a6f-9b2d-a30ae5a3e8b4",
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "用DNA数据训练\n",
164
+ "tokenizer.train([\"../01-data_env/data/dna_1g.txt\"], trainer=trainer1)"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "id": "8d3ddbec-ef6f-4d46-a82c-3b6e1eeacb74",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "encode执行分词并转换为ID\n",
175
+ "encoding = tokenizer.encode(\"TGGCGTGAACCCGGGATCGGG\")\n",
176
+ "print(encoding.tokens)"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "id": "5539bc51-0d7d-4a31-a54a-597483b9861f",
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "#save简单保存\n",
187
+ "tokenizer.save(\"dna_bpe_dict.json\")"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "code",
192
+ "execution_count": null,
193
+ "id": "ecfed110-8840-4703-a659-b2d4a4d11f7d",
194
+ "metadata": {},
195
+ "outputs": [],
196
+ "source": [
197
+ "#然后我们可以使用from_file() 方法从该文件里重新加载 Tokenizer :\n",
198
+ "new_tokenizer = Tokenizer.from_file(\"dna_bpe_dict.json\")"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": null,
204
+ "id": "15611626-c431-4f32-8f75-b3c2a6a67138",
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "#要在 hf Transformers中使用这个标记器,我们必须将它包裹在一个 PreTrainedTokenizerFast中\n",
209
+ "from transformers import GPT2TokenizerFast\n",
210
+ "dna_tokenizer = GPT2TokenizerFast(tokenizer_object=new_tokenizer)\n",
211
+ "#save_pretrained完整、规范地保存到磁盘,包含以下几个关键文件:\n",
212
+ "#1.xx.json是一个字典,映射了Token字符串到唯一ID\n",
213
+ "#2.merges.txt记录了BPE训练过程中所有的合并操作\n",
214
+ "#3.special_tokens_map.json 和 tokenizer_config.json这些是配置文件,定义了分词器的各种设置和行为。\n",
215
+ "#例如:哪些token是特殊token(如填充符<pad>、未知符<unk>、句首<s>),模型名称、最大长度等。\n",
216
+ "#这保证了分词器在不同环境中使用时的行为一致性。\n",
217
+ "dna_tokenizer.save_pretrained(\"dna_bpe_dict\")\n",
218
+ "#dna_tokenizer.push_to_hub(\"dna_bpe_dict_1g\", organization=\"dnagpt\", use_auth_token=\"hf_*****\") "
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "id": "2624e2a7-7204-4c80-8806-74cb52ad11a1",
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "#save_pretrained 的标准逆操作,自动加载并实例化一个之前保存好的分词器。\n",
229
+ "tokenizer_new = AutoTokenizer.from_pretrained('dna_bpe_dict')\n",
230
+ "tokenizer.pad_token = tokenizer.eos_token"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "id": "c93cd0f4-4b4b-4a67-b35d-668f3b920806",
237
+ "metadata": {},
238
+ "outputs": [],
239
+ "source": [
240
+ "tokenizer_new.tokenize(\"TGGCGTGAACCCGGGATCGGG\")"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": null,
246
+ "id": "768a983c-fdc4-4f67-90e4-6ff164e6c029",
247
+ "metadata": {},
248
+ "outputs": [],
249
+ "source": [
250
+ "从头训练基于GPT2的DNA大模型\n",
251
+ "max_length = 256 #最大输入长度\n",
252
+ "#config加载并修改GPT2参数适配分词器\n",
253
+ "config = AutoConfig.from_pretrained(\n",
254
+ " \"gpt2\",\n",
255
+ " vocab_size=len(tokenizer),#标准的GPT-2是在英文文本上训练的,它的词汇表大小是固定的(比如50257),我们这个是3w\n",
256
+ " n_ctx=max_length, #上下文长度(Context length),即模型一次能处理的最大Token数量\n",
257
+ " bos_token_id=tokenizer.bos_token_id,#开始\n",
258
+ " eos_token_id=tokenizer.eos_token_id,#停止\n",
259
+ ")\n",
260
+ "model = GPT2LMHeadModel(config) #权重初始化,从头预训练"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": null,
266
+ "id": "c002b619-07ca-41fc-b6c0-abfd17676934",
267
+ "metadata": {},
268
+ "outputs": [],
269
+ "source": [
270
+ "# 1. 加载数据\n",
271
+ "raw_dataset = load_dataset('text', data_files=\"../01-data_env/data/dna_1g.txt\")\n",
272
+ "dna_dataset = load_dataset('text', data_files='data/dna_1g.txt')['train'].train_test_split(test_size=0.05)\n",
273
+ "\n",
274
+ "# 2. (encode详细定义版)truncation=True过长截断, padding='max_length'过短填充到256, 界限是max_length=max_length=256\n",
275
+ "def tokenize_function(examples):\n",
276
+ " return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=max_length)\n",
277
+ "#tokenizer的作用:\n",
278
+ "#输出input_ids一个二维列表(List[List[int]]),里面是填充和截断后的Token ID序列。\n",
279
+ "#例如:[[105, 206, 307, ..., 0, 0, 0], [408, 509, 0, 0, ...], ...]\n",
280
+ "#attention_mask: 同样重要的配套输出。一个与 input_ids 形状相同的二维列表,但里面全是0和1。\n",
281
+ "#1 表示这个位置是真实的Token。\n",
282
+ "#0 表示这个位置是填充的Token ([PAD])\n",
283
+ "\n",
284
+ "# 3. 对数据集应用分词函数,移除原始文本text\n",
285
+ "tokenized_datasets = dna_dataset.map(tokenize_function, batched=True, remove_columns=['text'], num_proc=15) # 设置为你�� CPU 核心数或根据需要调整\n",
286
+ "\n",
287
+ "# 4. 创建一个数据收集器,用于动态填充和遮蔽,tokenizer=tokenizer_new指定用于编码和解码的分词器对象\n",
288
+ "#上一步取到原始的 input_ids 后,会将这些数据交给 data_collator 函数,将它们处理成真正的 (inputs, labels) 对\n",
289
+ "data_collator = DataCollatorForLanguageModeling(\n",
290
+ " tokenizer=tokenizer_new, mlm=False\n",
291
+ ")"
292
+ ]
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "execution_count": null,
297
+ "id": "0ccd72d9-f4a2-4792-8176-96afee553d27",
298
+ "metadata": {},
299
+ "outputs": [],
300
+ "source": [
301
+ "开始训练\n",
302
+ "run_path = \"gpt2_run\"\n",
303
+ "train_epoches = 5\n",
304
+ "batch_size = 10\n",
305
+ "\n",
306
+ "#TrainingArguments定义训练的“规则”,如:练多久、怎么练、在哪保存\n",
307
+ "training_args = TrainingArguments(\n",
308
+ " output_dir=run_path,#指定所有输出结果的保存目录。这包括最终的模型、训练过程中的检查点(checkpoints)、日志和评估结果。\n",
309
+ " overwrite_output_dir=True,#已有就覆盖\n",
310
+ " num_train_epochs=train_epoches,#训练几次\n",
311
+ " per_device_train_batch_size=batch_size,#每个GPU批次大小,每个10,俩GPU一共20\n",
312
+ " save_steps=2000,\n",
313
+ " save_total_limit=2,#每训练2000步就自动保存一个检查点(checkpoint),但只保留最新的2个。\n",
314
+ " prediction_loss_only=True,#在评估(evaluation)时只计算损失(loss),而不计算其他指标(如准确率)\n",
315
+ " fp16=True, #v100没法用\n",
316
+ " )\n",
317
+ "\n",
318
+ "\n",
319
+ "trainer = Trainer(\n",
320
+ " model=model,#model = GPT2LMHeadModel(config)\n",
321
+ " args=training_args,#上面那个\n",
322
+ " train_dataset=tokenized_datasets[\"train\"],\n",
323
+ " eval_dataset=tokenized_datasets[\"test\"],\n",
324
+ " data_collator=data_collator,\n",
325
+ ")"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "id": "7ff9c5a7-ce64-4bd6-92dd-7eaaa8b714c3",
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "#训练完成后\n",
336
+ "import math\n",
337
+ "eval_results = trainer.evaluate()#使用上面创建trainer时给的测试集测试\n",
338
+ "print(f\"Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")\n",
339
+ "#困惑度=eval_results['eval_loss'计算损失--math.exp变成指数函数--.2f结果保留两位小数的浮点数\n",
340
+ "#困惑度是模型在预测下一个词时,平均面临的选择不确定性有多大"
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": null,
346
+ "id": "5a6bced4-2c0b-4ae1-911b-3a0eae83f202",
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": [
350
+ "#上传模型\n",
351
+ "model.push_to_hub(\"dna_gpt2_v0\", organization=\"dnagpt\", use_auth_token=\"hf_*******\")"
352
+ ]
353
+ },
354
+ {
355
+ "cell_type": "code",
356
+ "execution_count": null,
357
+ "id": "20ab5784-8488-4991-8f16-85343d766baa",
358
+ "metadata": {},
359
+ "outputs": [],
360
+ "source": [
361
+ "#训练完成后,我们就可以直接使用这个模型:\n",
362
+ "from transformers import AutoTokenizer, AutoModel\n",
363
+ "import torch\n",
364
+ "model = AutoModel.from_pretrained('dna_gpt2_v0')\n",
365
+ "model"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": null,
371
+ "id": "1ae12f89-7887-4a71-8edf-a74982f0c2c1",
372
+ "metadata": {},
373
+ "outputs": [],
374
+ "source": [
375
+ "#应用1:取特征比如是不是启动子pt\n",
376
+ "from transformers import AutoTokenizer, AutoModel\n",
377
+ "tokenizer = AutoTokenizer.from_pretrained('dna_bpe_dict')\n",
378
+ "tokenizer.tokenize(\"GAGCACATTCGCCTGCGTGCGCACTCACACACACGTTCAAAAAGAGTCCATTCGATTCTGGCAGTAG\")\n",
379
+ "#result: [G','AGCAC','ATTCGCC',....]\n",
380
+ "\n",
381
+ "#我认为tokenizer.encode是输出token ID,\n",
382
+ "#tokenizer.tokenize输出人类可以阅读的分词结果,\n",
383
+ "#tokenizer(dna)输出ID和mask用于后续步骤\n",
384
+ "model = AutoModel.from_pretrained('dna_gpt2_v0')\n",
385
+ "import torch\n",
386
+ "dna = \"ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC\"\n",
387
+ "inputs = tokenizer(dna, return_tensors = 'pt')#指定返回 PyTorch张量(pytorch tensor)\n",
388
+ "print(inputs)#输入数据并分词\n",
389
+ "\n",
390
+ "#输出用模型分析的结果\n",
391
+ "outputs = model(inputs[\"input_ids\"])\n",
392
+ "\n",
393
+ "#提取特征:对一条DNA序列所有Token的隐藏状态向量求平均值,\n",
394
+ "#从而得到一个能够代表整条序列的、固定维度的嵌入式表示(Embedding)。\n",
395
+ "hidden_states = outputs.last_hidden_state # [使用最后一层,1批次大小,序列长度多少个token, 768隐藏层维度] \n",
396
+ "\n",
397
+ "# embedding with mean pooling\n",
398
+ "embedding_mean = torch.mean(hidden_states[0], dim=0)#通过索引 [0],我们取出了批次中第一条,dim=0 指定了沿着哪个维度进行求平均。\n",
399
+ "#这里 dim=0 指的是沿着第0个维度(即序列长度维度,19个Token)进行���缩。\n",
400
+ "print(embedding_mean.shape) # expect to be 768\n",
401
+ "\n",
402
+ "# embedding with max pooling\n",
403
+ "embedding_max = torch.max(hidden_states[0], dim=0)[0]\n",
404
+ "print(embedding_max.shape) # expect to be 768\n",
405
+ "\n",
406
+ "# embedding with first token\n",
407
+ "embedding_first_token = hidden_states[0][0]\n",
408
+ "print(embedding_first_token.shape) # expect to be 768"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": null,
414
+ "id": "4ffacd74-15cb-4789-9c9b-d970be3915cf",
415
+ "metadata": {},
416
+ "outputs": [],
417
+ "source": [
418
+ "#获得embedding后开始分类,这个例子是线性全连接\n",
419
+ "import numpy as np\n",
420
+ "from sklearn.model_selection import train_test_split\n",
421
+ "from sklearn.linear_model import LogisticRegression\n",
422
+ "from sklearn.metrics import accuracy_score\n",
423
+ "from transformers import GPT2Tokenizer, GPT2Model\n",
424
+ "import torch\n",
425
+ "\n",
426
+ "X = np.array(embedding_mean) # 上面得到的embedding_mean/max/first_token,将列表转为NumPy数组,形状 (n_samples样本数, 768)\n",
427
+ "y = np.array(labels)\n",
428
+ "\n",
429
+ "# 划分训练集和测试集\n",
430
+ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
431
+ "\n",
432
+ "# 训练一个逻辑回归分类器\n",
433
+ "clf = LogisticRegression(random_state=42, max_iter=1000)\n",
434
+ "clf.fit(X_train, y_train)\n",
435
+ "\n",
436
+ "# 评估\n",
437
+ "accuracy = clf.score(X_test, y_test)\n",
438
+ "print(f\"Logistic Regression Accuracy: {accuracy:.4f}\")\n",
439
+ "\n",
440
+ "# 或者使用SVM\n",
441
+ "svm_clf = SVC(kernel='linear', random_state=42) # 线性核通常效果就很好\n",
442
+ "svm_clf.fit(X_train, y_train)\n",
443
+ "svm_accuracy = svm_clf.score(X_test, y_test)\n",
444
+ "print(f\"SVM Accuracy: {svm_accuracy:.4f}\")"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": null,
450
+ "id": "b731af8d-0f9d-496a-9008-3e96bc98d671",
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "#获得embedding后开始分类,这个例子是神经网络\n",
455
+ "import torch.nn as nn\n",
456
+ "import torch.optim as optim\n",
457
+ "\n",
458
+ "# 定义一个简单的神经网络分类器\n",
459
+ "class MLPClassifier(nn.Module):\n",
460
+ " def __init__(self, input_dim=768, hidden_dim=256, num_classes=2):\n",
461
+ " super().__init__()\n",
462
+ " self.layers = nn.Sequential(\n",
463
+ " nn.Linear(input_dim, hidden_dim),\n",
464
+ " nn.ReLU(),\n",
465
+ " nn.Dropout(0.2), # 防止过拟合\n",
466
+ " nn.Linear(hidden_dim, num_classes)\n",
467
+ " )\n",
468
+ " \n",
469
+ " def forward(self, x):\n",
470
+ " # x 是输入的embedding,形状 [batch_size, 768]\n",
471
+ " return self.layers(x)\n",
472
+ "\n",
473
+ "# 使用流程\n",
474
+ "model = MLPClassifier(num_classes=3) # 假设是3分类任务\n",
475
+ "criterion = nn.CrossEntropyLoss()\n",
476
+ "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
477
+ "\n",
478
+ "# 假设 train_loader 是已经准备好的PyTorch DataLoader\n",
479
+ "for epoch in range(10):\n",
480
+ " for batch_embeddings, batch_labels in train_loader:\n",
481
+ " optimizer.zero_grad()\n",
482
+ " outputs = model(batch_embeddings)\n",
483
+ " loss = criterion(outputs, batch_labels)\n",
484
+ " loss.backward()\n",
485
+ " optimizer.step()"
486
+ ]
487
+ },
488
+ {
489
+ "cell_type": "code",
490
+ "execution_count": 1,
491
+ "id": "e0e65a6f-faeb-4333-80f5-219ac2e0211e",
492
+ "metadata": {},
493
+ "outputs": [
494
+ {
495
+ "name": "stdout",
496
+ "output_type": "stream",
497
+ "text": [
498
+ "/root/autodl-tmp/dnagpt2/01-data_env/data\n"
499
+ ]
500
+ }
501
+ ],
502
+ "source": [
503
+ "!pwd"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "code",
508
+ "execution_count": null,
509
+ "id": "bfc1b2f8-62a8-422e-9a83-14a56c19272e",
510
+ "metadata": {},
511
+ "outputs": [],
512
+ "source": []
513
+ }
514
+ ],
515
+ "metadata": {
516
+ "kernelspec": {
517
+ "display_name": "Python 3 (ipykernel)",
518
+ "language": "python",
519
+ "name": "python3"
520
+ },
521
+ "language_info": {
522
+ "codemirror_mode": {
523
+ "name": "ipython",
524
+ "version": 3
525
+ },
526
+ "file_extension": ".py",
527
+ "mimetype": "text/x-python",
528
+ "name": "python",
529
+ "nbconvert_exporter": "python",
530
+ "pygments_lexer": "ipython3",
531
+ "version": "3.12.3"
532
+ }
533
+ },
534
+ "nbformat": 4,
535
+ "nbformat_minor": 5
536
+ }