| |
|
|
| from wtpsplit import SaT |
| from typing import List |
| import torch |
|
|
|
|
| |
| _sat_model = None |
|
|
|
|
| def get_sat_model(model_name: str = "sat-3l", device: str = "cuda") -> SaT: |
| """ |
| Get or create global SaT model instance |
| |
| Args: |
| model_name: Model name from segment-any-text |
| device: Device to run model on |
| |
| Returns: |
| SaT model instance |
| """ |
| global _sat_model |
| |
| if _sat_model is None: |
| print(f"Loading SaT 3l full fine-tuned model: {model_name}") |
| |
| |
| _sat_model = SaT("sat-3l") |
| |
| import torch |
| model_path = "models/SaT_cunit_with_maze/model_finetuned/sat-3l_full_ENNI/pytorch_model.bin" |
| state_dict = torch.load(model_path, map_location="cpu") |
| |
| |
| new_state_dict = {} |
| for key, value in state_dict.items(): |
| if key.startswith("backbone."): |
| new_key = key[9:] |
| new_state_dict[new_key] = value |
| else: |
| new_state_dict[key] = value |
| |
| |
| |
| if "roberta.embeddings.word_embeddings.weight" in new_state_dict: |
| fine_tuned_vocab_size = new_state_dict["roberta.embeddings.word_embeddings.weight"].shape[0] |
| current_vocab_size = _sat_model.model.roberta.embeddings.word_embeddings.weight.shape[0] |
| if fine_tuned_vocab_size != current_vocab_size: |
| print(f"Resizing word embeddings from {current_vocab_size} to {fine_tuned_vocab_size}") |
| _sat_model.model.resize_token_embeddings(fine_tuned_vocab_size) |
| |
| |
| if "classifier.weight" in new_state_dict: |
| fine_tuned_num_labels = new_state_dict["classifier.weight"].shape[0] |
| current_num_labels = _sat_model.model.classifier.weight.shape[0] |
| if fine_tuned_num_labels != current_num_labels: |
| print(f"Resizing classifier from {current_num_labels} to {fine_tuned_num_labels}") |
| |
| import torch.nn as nn |
| _sat_model.model.classifier = nn.Linear( |
| _sat_model.model.classifier.in_features, |
| fine_tuned_num_labels |
| ) |
| _sat_model.model.num_labels = fine_tuned_num_labels |
| |
| _sat_model.model.load_state_dict(new_state_dict, strict=False) |
| |
| |
| if device == "cuda" and torch.cuda.is_available(): |
| _sat_model.half().to("cuda") |
| print(f"SaT 3l full model loaded on GPU") |
| else: |
| print(f"SaT 3l full model loaded on CPU") |
| |
| return _sat_model |
|
|
|
|
| |
| |
| |
| def segment_SaT(text: str) -> List[int]: |
| """ |
| Segment text using wtpsplit SaT 3l full fine-tuned model |
| |
| Args: |
| text: Input text to segment |
| |
| Returns: |
| List of labels: 0 = word is not the last word of c-unit, |
| 1 = word is the last word of c-unit |
| """ |
| if not text.strip(): |
| return [] |
| |
| |
| cleaned_text = text.lower().replace(".", "").replace(",", "") |
| words = cleaned_text.strip().split() |
| if not words: |
| return [] |
| |
| |
| sat_model = get_sat_model() |
| |
| |
| try: |
| sentences = sat_model.split(cleaned_text) |
| |
| |
| word_labels = [0] * len(words) |
| |
| |
| word_idx = 0 |
| |
| for sentence in sentences: |
| sentence_words = sentence.strip().split() |
| |
| |
| if sentence_words: |
| |
| sentence_end_idx = word_idx + len(sentence_words) - 1 |
| |
| |
| if sentence_end_idx < len(words): |
| word_labels[sentence_end_idx] = 1 |
| |
| word_idx += len(sentence_words) |
| |
| return word_labels |
| |
| except Exception as e: |
| print(f"Error in SaT 3l full segmentation: {e}") |
| return [0] * len(words) |
|
|
|
|
|
|
| |
| def reorganize_transcription_c_unit(session_id, base_dir="session_data"): |
| return |
|
|
|
|
|
|
| if __name__ == "__main__": |
| |
| test_text = "once a horse met elephant and then they saw a ball in a pool and then the horse tried to swim and get the ball they might be the same but they are doing something what do you think they are doing" |
| |
| print(f"Input text: {test_text}") |
| print(f"Words: {test_text.split()}") |
| |
| labels = segment_SaT(test_text) |
| print(f"Segment labels: {labels}") |
| |
| |
| words = test_text.split() |
| segments = [] |
| current_segment = [] |
| |
| for word, label in zip(words, labels): |
| current_segment.append(word) |
| if label == 1: |
| segments.append(" ".join(current_segment)) |
| current_segment = [] |
| |
| |
| if current_segment: |
| segments.append(" ".join(current_segment)) |
| |
| print("\nSegmented text:") |
| for i, segment in enumerate(segments, 1): |
| print(f"Segment {i}: {segment}") |