| import torch
|
| import re
|
| import numpy as np
|
| import torch
|
| import cv2
|
| import os
|
| import math
|
| from typing import Tuple
|
| import pandas as pd
|
| import io
|
| from pydub import AudioSegment
|
| from PIL import Image
|
|
|
|
|
| def preprocess_image_tensor(image_path, device, target_dtype, h_w_multiple_of=32, resize_total_area=720*720):
|
| """Preprocess video data into standardized tensor format and (optionally) resize area."""
|
| def _parse_area(val):
|
| if val is None:
|
| return None
|
| if isinstance(val, (int, float)):
|
| return int(val)
|
| if isinstance(val, (tuple, list)) and len(val) == 2:
|
| return int(val[0]) * int(val[1])
|
| if isinstance(val, str):
|
| m = re.match(r"\s*(\d+)\s*[x\*\s]\s*(\d+)\s*$", val, flags=re.IGNORECASE)
|
| if m:
|
| return int(m.group(1)) * int(m.group(2))
|
| if val.strip().isdigit():
|
| return int(val.strip())
|
| raise ValueError(f"resize_total_area={val!r} is not a valid area or WxH.")
|
|
|
| def _best_hw_for_area(h, w, area_target, multiple):
|
| if area_target <= 0:
|
| return h, w
|
| ratio_wh = w / float(h)
|
| area_unit = multiple * multiple
|
| tgt_units = max(1, area_target // area_unit)
|
| p0 = max(1, int(round(np.sqrt(tgt_units / max(ratio_wh, 1e-8)))))
|
| candidates = []
|
| for dp in range(-3, 4):
|
| p = max(1, p0 + dp)
|
| q = max(1, int(round(p * ratio_wh)))
|
| H = p * multiple
|
| W = q * multiple
|
| candidates.append((H, W))
|
| scale = np.sqrt(area_target / (h * float(w)))
|
| H_sc = max(multiple, int(round(h * scale / multiple)) * multiple)
|
| W_sc = max(multiple, int(round(w * scale / multiple)) * multiple)
|
| candidates.append((H_sc, W_sc))
|
| def score(HW):
|
| H, W = HW
|
| area = H * W
|
| return (abs(area - area_target), abs((W / max(H, 1e-8)) - ratio_wh))
|
| H_best, W_best = min(candidates, key=score)
|
| return H_best, W_best
|
|
|
| if isinstance(image_path, str):
|
| image = cv2.imread(image_path)
|
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| else:
|
| assert isinstance(image_path, Image.Image)
|
| if image_path.mode != "RGB":
|
| image_path = image_path.convert("RGB")
|
| image = np.array(image_path)
|
|
|
| image = image.transpose(2, 0, 1)
|
| image = image.astype(np.float32) / 255.0
|
|
|
| image_tensor = torch.from_numpy(image).float().to(device, dtype=target_dtype).unsqueeze(0)
|
| image_tensor = image_tensor * 2.0 - 1.0
|
|
|
| _, c, h, w = image_tensor.shape
|
| area_target = _parse_area(resize_total_area)
|
| if area_target is not None:
|
| target_h, target_w = _best_hw_for_area(h, w, area_target, h_w_multiple_of)
|
| else:
|
| target_h = (h // h_w_multiple_of) * h_w_multiple_of
|
| target_w = (w // h_w_multiple_of) * h_w_multiple_of
|
|
|
| target_h = max(h_w_multiple_of, int(target_h))
|
| target_w = max(h_w_multiple_of, int(target_w))
|
|
|
| if (h != target_h) or (w != target_w):
|
| image_tensor = torch.nn.functional.interpolate(
|
| image_tensor,
|
| size=(target_h, target_w),
|
| mode='bicubic',
|
| align_corners=False
|
| )
|
|
|
| return image_tensor
|
|
|
| def preprocess_audio_tensor(audio, device):
|
| """Preprocess audio data into standardized tensor format."""
|
| if isinstance(audio, np.ndarray):
|
| audio_tensor = torch.from_numpy(audio).float().squeeze().unsqueeze(0).to(device)
|
| else:
|
| audio_tensor = audio.squeeze().unsqueeze(0).to(device)
|
| return audio_tensor
|
|
|
|
|
| def calc_dims_from_area(
|
| aspect_ratio: str,
|
| total_area: int = 720*720,
|
| divisible_by: int = 32
|
| ) -> Tuple[int, int]:
|
| """
|
| Calculate width and height given an aspect ratio (h:w), total area,
|
| and divisibility constraint.
|
|
|
| Args:
|
| aspect_ratio (str): Aspect ratio string in format "h:w" (e.g., "9:16").
|
| total_area (int): Target maximum area (width * height ≤ total_area).
|
| divisible_by (int): Force width and height to be divisible by this value.
|
|
|
| Returns:
|
| (width, height): Tuple of integers that satisfy constraints.
|
| """
|
|
|
| h_ratio, w_ratio = map(int, aspect_ratio.split(":"))
|
|
|
|
|
| gcd = math.gcd(h_ratio, w_ratio)
|
| h_ratio //= gcd
|
| w_ratio //= gcd
|
|
|
|
|
| k = math.sqrt(total_area / (h_ratio * w_ratio))
|
|
|
|
|
| height = (int(k * h_ratio) // divisible_by) * divisible_by
|
| width = (int(k * w_ratio) // divisible_by) * divisible_by
|
|
|
|
|
| height = max(height, divisible_by)
|
| width = max(width, divisible_by)
|
|
|
| return height, width
|
|
|
|
|
| def snap_hw_to_multiple_of_32(h: int, w: int, area = 720 * 720) -> tuple[int, int]:
|
| """
|
| Scale (h, w) to match a target area if provided, then snap both
|
| dimensions to the nearest multiple of 32 (min 32).
|
|
|
| Args:
|
| h (int): original height
|
| w (int): original width
|
| area (int, optional): target area to scale to. If None, no scaling is applied.
|
|
|
| Returns:
|
| (new_h, new_w): dimensions adjusted
|
| """
|
| if h <= 0 or w <= 0:
|
| raise ValueError(f"h and w must be positive, got {(h, w)}")
|
|
|
|
|
| if area is not None and area > 0:
|
| current_area = h * w
|
| scale = math.sqrt(area / float(current_area))
|
| h = int(round(h * scale))
|
| w = int(round(w * scale))
|
|
|
|
|
| def _n32(x: int) -> int:
|
| return max(32, int(round(x / 32)) * 32)
|
|
|
| return _n32(h), _n32(w)
|
| def scale_hw_to_area_divisible(h, w, area=1024*1024, n=16):
|
| """
|
| Scale (h, w) so that area ≈ A, while keeping aspect ratio,
|
| and then round so both are divisible by n.
|
|
|
| Args:
|
| h (int): original height
|
| w (int): original width
|
| A (int or float): target area
|
| n (int): divisibility requirement
|
|
|
| Returns:
|
| (new_h, new_w): scaled and adjusted dimensions
|
| """
|
|
|
| current_area = h * w
|
|
|
| if current_area == 0:
|
| raise ValueError("Height and width must be positive")
|
|
|
|
|
| scale = math.sqrt(area / current_area)
|
|
|
|
|
| new_h = h * scale
|
| new_w = w * scale
|
|
|
|
|
| new_h = int(round(new_h / n) * n)
|
| new_w = int(round(new_w / n) * n)
|
|
|
|
|
| new_h = max(new_h, n)
|
| new_w = max(new_w, n)
|
|
|
| return new_h, new_w
|
|
|
| def validate_and_process_user_prompt(text_prompt: str, image_path: str = None, mode: str = "t2v") -> str:
|
| if not isinstance(text_prompt, str):
|
| raise ValueError("User input must be a string")
|
|
|
|
|
| text_prompt = text_prompt.strip()
|
|
|
|
|
| if os.path.isfile(text_prompt):
|
| _, ext = os.path.splitext(text_prompt.lower())
|
|
|
| if ext == ".csv":
|
| df = pd.read_csv(text_prompt)
|
| df = df.fillna("")
|
| elif ext == ".tsv":
|
| df = pd.read_csv(text_prompt, sep="\t")
|
| df = df.fillna("")
|
| else:
|
| raise ValueError(f"Unsupported file type: {ext}. Only .csv and .tsv are allowed.")
|
|
|
| assert "text_prompt" in df.keys(), f"Missing required columns in TSV file."
|
| text_prompts = list(df["text_prompt"])
|
| if mode == "i2v" and 'image_path' in df.keys():
|
| image_paths = list(df["image_path"])
|
| assert all(p is None or len(p) == 0 or os.path.isfile(p) for p in image_paths), "One or more image paths in the TSV file do not exist."
|
| else:
|
| print("Warning: image_path was not found, assuming t2v or t2i2v mode...")
|
| image_paths = [None] * len(text_prompts)
|
|
|
| else:
|
| assert image_path is None or os.path.isfile(image_path), f"Image path is not None but {image_path} does not exist."
|
| text_prompts = [text_prompt]
|
| image_paths = [image_path]
|
|
|
| return text_prompts, image_paths
|
|
|
|
|
| def format_prompt_for_filename(text: str) -> str:
|
|
|
| no_tags = re.sub(r"<.*?>", "", text)
|
|
|
| safe = no_tags.replace(" ", "_").replace("/", "_")
|
|
|
| return safe[:50]
|
|
|
|
|
|
|
| def audio_bytes_to_tensor(audio_bytes, target_sr=16000):
|
| """
|
| Convert audio bytes to a 16kHz mono torch tensor in [-1, 1].
|
|
|
| Args:
|
| audio_bytes (bytes): Raw audio bytes
|
| target_sr (int): Target sample rate
|
|
|
| Returns:
|
| torch.Tensor: shape (num_samples,)
|
| int: sample rate
|
| """
|
|
|
| audio = AudioSegment.from_file(io.BytesIO(audio_bytes), format="wav")
|
|
|
|
|
| if audio.channels != 1:
|
| audio = audio.set_channels(1)
|
|
|
|
|
| if audio.frame_rate != target_sr:
|
| audio = audio.set_frame_rate(target_sr)
|
|
|
|
|
| samples = np.array(audio.get_array_of_samples())
|
| samples = samples.astype(np.float32) / np.iinfo(samples.dtype).max
|
|
|
|
|
| tensor = torch.from_numpy(samples)
|
|
|
| return tensor, target_sr
|
|
|
| def audio_path_to_tensor(path, target_sr=16000):
|
| with open(path, "rb") as f:
|
| audio_bytes = f.read()
|
| return audio_bytes_to_tensor(audio_bytes, target_sr=target_sr)
|
|
|
| def clean_text(text: str) -> str:
|
| """
|
| Remove all text between <S>...</E> and <AUDCAP>...</ENDAUDCAP> tags,
|
| including the tags themselves.
|
| """
|
|
|
| text = re.sub(r"<S>.*?<E>", "", text, flags=re.DOTALL)
|
|
|
|
|
| text = re.sub(r"<AUDCAP>.*?<ENDAUDCAP>", "", text, flags=re.DOTALL)
|
|
|
|
|
| return text.strip() |