| | from PIL import Image |
| | from io import BytesIO |
| | from matplotlib.figure import Figure |
| | from torchvision import transforms |
| | from tqdm import tqdm |
| | from typing import Literal, Any |
| | from urllib.request import urlopen |
| | import gradio as gr |
| | import matplotlib.pyplot as plt |
| | import os |
| | import spaces |
| | import sys |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| |
|
| | LABELS = [ |
| | "Panoramic", |
| | "Feature", |
| | "Detail", |
| | "Enclosed", |
| | "Focal", |
| | "Ephemeral", |
| | "Canopied", |
| | ] |
| | MODELFILE = "Litton-7type-visual-landscape-model.pth" |
| |
|
| |
|
| | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| |
|
| | if not os.path.exists(MODELFILE): |
| | model_url = f"https://lclab.thu.edu.tw/modelzoo/{MODELFILE}" |
| |
|
| | print(f"fetch model from {model_url}...", file=sys.stderr) |
| |
|
| | with urlopen(model_url) as resp: |
| | progress = tqdm(total=int(resp["Content-Length"]), desc="Downloading") |
| | with open(MODELFILE, "wb") as modelfile: |
| | while True: |
| | chunk = resp.read(1024) |
| | if len(chunk) == 0: |
| | break |
| | modelfile.write(chunk) |
| | progress.update(len(chunk)) |
| |
|
| | model = torch.load( |
| | MODELFILE, map_location=device, weights_only=False |
| | ).module |
| | model.eval() |
| | preprocess = transforms.Compose( |
| | [ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize( |
| | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
| | ), |
| | ] |
| | ) |
| |
|
| | @spaces.GPU |
| | def predict(image: Image.Image) -> Figure: |
| | image = image.convert("RGB") |
| | input_tensor = preprocess(image).unsqueeze(0).to(device) |
| |
|
| | with torch.no_grad(): |
| | logits = model(input_tensor) |
| | probs = F.softmax(logits[:, :7], dim=1).cpu() |
| |
|
| | return draw_bar_chart( |
| | { |
| | "class": LABELS, |
| | "probs": probs[0] * 100, |
| | } |
| | ) |
| |
|
| |
|
| | def draw_bar_chart(data: dict[str, list[str | float]]): |
| | classes = data["class"] |
| | probabilities = data["probs"] |
| |
|
| | fig, ax = plt.subplots(figsize=(8, 6)) |
| | ax.bar(classes, probabilities, color="skyblue") |
| |
|
| | ax.set_xlabel("Class") |
| | ax.set_ylabel("Probability (%)") |
| | ax.set_title("Class Probability") |
| |
|
| | for i, prob in enumerate(probabilities): |
| | ax.text(i, prob + 0.01, f"{prob:.2f}%", ha="center", va="bottom") |
| |
|
| | fig.tight_layout() |
| |
|
| | return fig |
| |
|
| |
|
| | def choose_example(imgpath: str) -> gr.Image: |
| | img = Image.open(imgpath) |
| | width, height = img.size |
| | ratio = 512 / max(width, height) |
| | img = img.resize((int(width * ratio), int(height * ratio))) |
| | return gr.Image(value=img, label="輸入影像(不支援 SVG 格式)", type="pil") |
| |
|
| |
|
| | def get_layout(): |
| | css = """ |
| | .main-title { |
| | font-size: 24px; |
| | font-weight: bold; |
| | text-align: center; |
| | margin-bottom: 20px; |
| | } |
| | .reference { |
| | text-align: center; |
| | font-size: 1.2em; |
| | color: #d1d5db; |
| | margin-bottom: 20px; |
| | } |
| | .reference a { |
| | color: #FB923C; |
| | text-decoration: none; |
| | } |
| | .reference a:hover { |
| | text-decoration: underline; |
| | color: #FB923C; |
| | } |
| | .title { |
| | border-bottom: 1px solid; |
| | } |
| | .footer { |
| | text-align: center; |
| | margin-top: 30px; |
| | padding-top: 20px; |
| | border-top: 1px solid #ddd; |
| | color: #d1d5db; |
| | font-size: 14px; |
| | } |
| | .example-image { |
| | height: 220px; |
| | padding: 25px; |
| | } |
| | """ |
| | theme = gr.themes.Base( |
| | primary_hue="orange", |
| | secondary_hue="cyan", |
| | neutral_hue="gray", |
| | ).set( |
| | body_text_color='*neutral_100', |
| | body_text_color_subdued='*neutral_600', |
| | background_fill_primary='*neutral_950', |
| | background_fill_secondary='*neutral_600', |
| | border_color_accent='*secondary_800', |
| | color_accent='*primary_50', |
| | color_accent_soft='*secondary_800', |
| | code_background_fill='*neutral_700', |
| | block_background_fill_dark='*body_background_fill', |
| | block_info_text_color='#6b7280', |
| | block_label_text_color='*neutral_300', |
| | block_label_text_weight='700', |
| | block_title_text_color='*block_label_text_color', |
| | block_title_text_weight='300', |
| | panel_background_fill='*neutral_800', |
| | table_text_color_dark='*secondary_800', |
| | checkbox_background_color_selected='*primary_500', |
| | checkbox_label_background_fill='*neutral_500', |
| | checkbox_label_background_fill_hover='*neutral_700', |
| | checkbox_label_text_color='*neutral_200', |
| | input_background_fill='*neutral_700', |
| | input_background_fill_focus='*neutral_600', |
| | slider_color='*primary_500', |
| | table_even_background_fill='*neutral_700', |
| | table_odd_background_fill='*neutral_600', |
| | table_row_focus='*neutral_800' |
| | ) |
| | with gr.Blocks(css=css, theme=theme) as demo: |
| | with gr.Column(): |
| | gr.HTML( |
| | value=( |
| | '<div class="main-title">Litton7景觀分類模型</div>' |
| | '<div class="reference">引用資料:' |
| | '<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">' |
| | "何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)" |
| | "</a>" |
| | "</div>" |
| | ), |
| | ) |
| |
|
| | with gr.Row(equal_height=True): |
| | with gr.Group(): |
| | img = gr.Image(label="上傳影像", type="pil", height="256px") |
| | gr.Label("範例影像", show_label=False) |
| | with gr.Row(): |
| | ex1 = gr.Image( |
| | value="examples/beach.jpg", |
| | show_label=False, |
| | type="filepath", |
| | elem_classes="example-image", |
| | interactive=False, |
| | show_download_button=False, |
| | show_fullscreen_button=False, |
| | show_share_button=False, |
| | ) |
| | ex2 = gr.Image( |
| | value="examples/field.jpg", |
| | show_label=False, |
| | type="filepath", |
| | elem_classes="example-image", |
| | interactive=False, |
| | show_download_button=False, |
| | show_fullscreen_button=False, |
| | show_share_button=False, |
| | ) |
| | ex3 = gr.Image( |
| | value="examples/sky.jpg", |
| | show_label=False, |
| | type="filepath", |
| | elem_classes="example-image", |
| | interactive=False, |
| | show_download_button=False, |
| | show_fullscreen_button=False, |
| | show_share_button=False, |
| | ) |
| | chart = gr.Plot(label="分類結果") |
| |
|
| | start_button = gr.Button("開始", variant="primary") |
| | gr.HTML( |
| | '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>', |
| | ) |
| |
|
| | start_button.click( |
| | fn=predict, |
| | inputs=img, |
| | outputs=chart, |
| | ) |
| |
|
| | ex1.select(fn=choose_example, inputs=ex1, outputs=img) |
| | ex2.select(fn=choose_example, inputs=ex2, outputs=img) |
| | ex3.select(fn=choose_example, inputs=ex3, outputs=img) |
| |
|
| | return demo |
| |
|
| |
|
| | if __name__ == "__main__": |
| | get_layout().launch() |
| |
|