#!/usr/bin/env python3 """ Tunnel Crack Detection Streamlit App A Streamlit-based web interface for tunnel crack detection using YOLOv12-DINO. Upload images and get real-time crack detection results with interactive controls. """ import streamlit as st import cv2 import numpy as np import pandas as pd from pathlib import Path import tempfile import os import sys from typing import Dict, List, Tuple import time from PIL import Image import plotly.express as px import plotly.graph_objects as go # Add the current directory to the Python path current_dir = Path(__file__).parent sys.path.insert(0, str(current_dir)) try: from inference import YOLOInference except ImportError as e: st.error(f"Error importing inference module: {e}") st.stop() # Page configuration st.set_page_config( page_title="Tunnel Crack Detection", page_icon="🔍", layout="wide", initial_sidebar_state="expanded" ) # Configure large file uploads import streamlit.web.bootstrap try: # Set large upload size (5GB in MB) if hasattr(st, '_config'): st._config.set_option('server.maxUploadSize', 5000) else: # For newer Streamlit versions from streamlit import config config.set_option('server.maxUploadSize', 5000) config.set_option('server.maxMessageSize', 5000) except Exception: pass # Set environment variable for large files (5GB in MB) os.environ['STREAMLIT_SERVER_MAX_UPLOAD_SIZE'] = '5000' # Custom CSS for better styling st.markdown(""" """, unsafe_allow_html=True) # Initialize session state if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False if 'model_instance' not in st.session_state: st.session_state.model_instance = None if 'detection_history' not in st.session_state: st.session_state.detection_history = [] # Default model weight path DEFAULT_WEIGHTS_PATH = "/Users/sompoteyouwai/env/model_weight/segment_defect.pt" def load_model(weights_path: str, device: str = "cpu") -> Tuple[bool, str]: """Load the YOLO model with specified weights.""" try: if not Path(weights_path).exists(): return False, f"❌ Model file not found: {weights_path}" with st.spinner("Loading tunnel crack detection model..."): st.session_state.model_instance = YOLOInference( weights=weights_path, conf=0.25, iou=0.7, imgsz=640, device=device, verbose=True ) st.session_state.model_loaded = True # Get model information model_info = f"✅ Model loaded successfully\n" model_info += f"📋 Task: {st.session_state.model_instance.model.task}\n" if hasattr(st.session_state.model_instance.model.model, 'names'): class_names = list(st.session_state.model_instance.model.model.names.values()) model_info += f"🏷️ Classes ({len(class_names)}): {', '.join(class_names)}" return True, model_info except Exception as e: return False, f"❌ Error loading model: {str(e)}" def perform_detection( image: np.ndarray, conf_threshold: float, iou_threshold: float, image_size: int ) -> Tuple[np.ndarray, Dict, str]: """Perform crack detection using the loaded model.""" if st.session_state.model_instance is None: return None, {}, "❌ No model loaded" try: # Update model parameters st.session_state.model_instance.conf = conf_threshold st.session_state.model_instance.iou = iou_threshold st.session_state.model_instance.imgsz = image_size # Save image to temporary file with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp_file: # Convert RGB to BGR for OpenCV image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) cv2.imwrite(tmp_file.name, image_bgr) tmp_path = tmp_file.name start_time = time.time() try: # Use the exact same method as inference.py results = st.session_state.model_instance.predict_single( source=tmp_path, save=False, show=False, save_txt=False, save_conf=False, save_crop=False, output_dir=None ) inference_time = time.time() - start_time if not results: return image, {}, "❌ No results returned from model" result = results[0] # Get annotated image annotated_img = result.plot() annotated_img = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB) # Process detection results detection_data = process_detection_results(result, inference_time) # Generate summary text summary_text = generate_detection_summary(result, detection_data, inference_time) return annotated_img, detection_data, summary_text finally: # Clean up temporary file if os.path.exists(tmp_path): os.unlink(tmp_path) except Exception as e: return None, {}, f"❌ Error during detection: {str(e)}" def process_detection_results(result, inference_time: float) -> Dict: """Process detection results into structured data.""" if result.boxes is None or len(result.boxes) == 0: return { 'total_detections': 0, 'class_counts': {}, 'detections': [], 'inference_time': inference_time } detections = result.boxes # Get class names if hasattr(st.session_state.model_instance.model.model, 'names'): class_names = st.session_state.model_instance.model.model.names else: class_names = getattr(result, 'names', {i: f"Class_{i}" for i in range(100)}) # Process each detection detection_list = [] class_counts = {} for i, (box, conf, cls) in enumerate(zip(detections.xyxy, detections.conf, detections.cls)): cls_id = int(cls) cls_name = class_names.get(cls_id, f"Class_{cls_id}") confidence = float(conf) x1, y1, x2, y2 = box.tolist() detection_list.append({ 'id': i + 1, 'class': cls_name, 'confidence': confidence, 'x1': int(x1), 'y1': int(y1), 'x2': int(x2), 'y2': int(y2), 'width': int(x2 - x1), 'height': int(y2 - y1), 'area': int((x2 - x1) * (y2 - y1)) }) class_counts[cls_name] = class_counts.get(cls_name, 0) + 1 return { 'total_detections': len(detection_list), 'class_counts': class_counts, 'detections': detection_list, 'inference_time': inference_time } def generate_detection_summary(result, detection_data: Dict, inference_time: float) -> str: """Generate detection summary text.""" total_detections = detection_data['total_detections'] if total_detections == 0: return "🔍 No cracks or defects detected in the image." summary = f"✅ **Detection Results:**\n\n" summary += f"📊 **Images processed:** 1\n" summary += f"📊 **Total detections:** {total_detections}\n" summary += f"⏱️ **Inference time:** {inference_time:.3f}s\n\n" summary += "📋 **Detections by class:**\n" for cls_name, count in sorted(detection_data['class_counts'].items()): summary += f" • {cls_name}: {count}\n" return summary def create_detection_chart(detection_data: Dict): """Create interactive charts for detection results.""" if detection_data['total_detections'] == 0: st.info("No detections to visualize") return # Class distribution pie chart class_counts = detection_data['class_counts'] fig_pie = px.pie( values=list(class_counts.values()), names=list(class_counts.keys()), title="Detection Distribution by Class", color_discrete_sequence=px.colors.qualitative.Set3 ) fig_pie.update_layout(height=400) st.plotly_chart(fig_pie, use_container_width=True) # Confidence distribution confidences = [det['confidence'] for det in detection_data['detections']] classes = [det['class'] for det in detection_data['detections']] fig_conf = px.box( x=classes, y=confidences, title="Confidence Distribution by Class", color=classes ) fig_conf.update_layout(height=400) fig_conf.update_xaxes(title="Class") fig_conf.update_yaxes(title="Confidence Score") st.plotly_chart(fig_conf, use_container_width=True) def create_detection_table(detection_data: Dict): """Create detailed detection table.""" if detection_data['total_detections'] == 0: st.info("No detections to display") return # Convert to DataFrame df = pd.DataFrame(detection_data['detections']) # Format confidence as percentage df['confidence_pct'] = df['confidence'].apply(lambda x: f"{x:.1%}") # Reorder columns for better display display_columns = ['id', 'class', 'confidence_pct', 'x1', 'y1', 'x2', 'y2', 'width', 'height', 'area'] df_display = df[display_columns].copy() # Rename columns for better readability df_display.columns = ['ID', 'Class', 'Confidence', 'X1', 'Y1', 'X2', 'Y2', 'Width', 'Height', 'Area'] st.dataframe(df_display, use_container_width=True) # Download button for results csv = df_display.to_csv(index=False) st.download_button( label="📥 Download Detection Results (CSV)", data=csv, file_name=f"crack_detection_results_{int(time.time())}.csv", mime="text/csv" ) def main(): """Main Streamlit application.""" # Header st.markdown('
{DEFAULT_WEIGHTS_PATH}