"""
Server GUI for EEG Streaming Server
Allows configuration of dataset and montage before starting the server.
Using PySide6 (Qt for Python) with Qt Designer .ui file
"""

import os
import subprocess
import sys
import json
import time
import socket
import mne
import requests
from pathlib import Path

# Directory containing this script (portable when shared as a standalone folder).
SCRIPT_DIR = Path(__file__).resolve().parent


def _resource_path(name: str) -> str:
    return str(SCRIPT_DIR / name)


# Try to import psutil for better process management
try:
    import psutil
    PSUTIL_AVAILABLE = True
except ImportError:
    PSUTIL_AVAILABLE = False
    print("Warning: psutil not available. Server stopping may be less robust.")
from PySide6.QtWidgets import (QApplication, QMainWindow, QFileDialog, QMessageBox, 
                                QLabel, QSpinBox, QProgressBar, QVBoxLayout, QHBoxLayout, QGroupBox, QCheckBox)
from PySide6.QtUiTools import QUiLoader
from PySide6.QtCore import QFile, QIODevice, QTimer


def get_local_ip():
    """Get the local IP address of this machine"""
    try:
        # Create a socket connection to determine local IP
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        s.connect(("8.8.8.8", 80))  # Connect to Google DNS (doesn't actually send data)
        local_ip = s.getsockname()[0]
        s.close()
        return local_ip
    except Exception:
        return "Unable to detect"


class ServerGUI(QMainWindow):
    def __init__(self):
        super().__init__()
        
        # Initialize variables
        self.data_file_path = ""
        self.n_channels = None
        self.port = 8000
        self.server_address = "localhost"  # Default to localhost
        self.server_process = None

        # Data streaming parameters
        self.sampling_frequency = None  # Will be detected from file
        self.packet_interval = None      # Default: 1 second per packet
        self.wrap_around = True          # Default: restart stream at EOF
        
        self.load_ui()
        self.setup_connections()
    
    def load_ui(self):
        """Load the UI from the .ui file"""
        loader = QUiLoader()
        ui_file = QFile(_resource_path("server_gui.ui"))
        
        if not ui_file.open(QIODevice.ReadOnly):
            print("Error: Cannot open server_gui.ui file")
            sys.exit(-1)
        
        # Load UI - important: pass None as parent to get a standalone widget
        self.ui = loader.load(ui_file, None)
        ui_file.close()
        
        if not self.ui:
            print("Error: Failed to load UI from server_gui.ui")
            sys.exit(-1)
        
        # Set the loaded widget as central widget
        self.setCentralWidget(self.ui)
        
        # Set window properties from the loaded UI
        self.setWindowTitle(self.ui.windowTitle())
        self.setMinimumSize(self.ui.minimumSize())
        
        # Set default values for packet interval
        if hasattr(self.ui, 'packetIntervalCombo'):
            self.ui.packetIntervalCombo.setCurrentText("1.0 sec")
            self.ui.packetIntervalCombo.setEnabled(False)
            self.ui.packetIntervalCombo.setVisible(False)
            self.packet_interval = 1.0
        if hasattr(self.ui, 'packetIntervalLabel'):
            self.ui.packetIntervalLabel.setVisible(False)
        if hasattr(self.ui, 'streamSettingsLabel'):
            self.ui.streamSettingsLabel.setVisible(False)

        # Stream mode option (default: wrap-around)
        self.wrap_around_checkbox = QCheckBox("Wrap around at end of file")
        self.wrap_around_checkbox.setChecked(True)
        self.wrap_around_checkbox.setToolTip(
            "If enabled, streaming restarts from the beginning at EOF.\n"
            "If disabled, streaming stops and closes the WebSocket at EOF."
        )
        self.wrap_around_checkbox.toggled.connect(self.on_wrap_around_toggled)
        inserted = False
        if hasattr(self.ui, 'packetIntervalCombo'):
            parent_widget = self.ui.packetIntervalCombo.parent()
            if parent_widget and parent_widget.layout():
                parent_widget.layout().addWidget(self.wrap_around_checkbox)
                inserted = True
        if not inserted and hasattr(self.ui, 'centralwidget') and self.ui.centralwidget.layout():
            self.ui.centralwidget.layout().addWidget(self.wrap_around_checkbox)
        
        # Set default server address
        if hasattr(self.ui, 'serverAddressEdit'):
            self.ui.serverAddressEdit.setText("localhost")
            self.server_address = "localhost"
    
    def _compute_data_dir(self):
        """Return directory containing the selected data file(s)."""
        if not self.data_file_path:
            return ""
        try:
            parsed = json.loads(self.data_file_path)
            if isinstance(parsed, list) and parsed:
                first = Path(parsed[0])
                return str(first.parent)
        except (json.JSONDecodeError, TypeError):
            pass
        try:
            p = Path(self.data_file_path)
            return str(p.parent if p.is_file() else p)
        except Exception:
            return ""
    
    def setup_connections(self):
        """Connect all widget signals to their slots"""
        # Browse button
        self.ui.browseDataBtn.clicked.connect(self.browse_data_file)
        
        # Server address helper
        if hasattr(self.ui, 'detectIPBtn'):
            self.ui.detectIPBtn.clicked.connect(self.detect_ip_address)
        
        # Server info display button
        if hasattr(self.ui, 'showServerInfoBtn'):
            self.ui.showServerInfoBtn.clicked.connect(self.show_server_info)
        
        # Data streaming parameter widgets
        if hasattr(self.ui, 'packetIntervalCombo'):
            self.ui.packetIntervalCombo.currentTextChanged.connect(self.on_packet_interval_changed)
        
        # Control buttons
        self.ui.startBtn.clicked.connect(self.start_server)
        self.ui.stopBtn.clicked.connect(self.stop_server)
    
    def on_packet_interval_changed(self, value_text):
        """Handle packet interval selection change"""
        try:
            self.packet_interval = float(value_text.split()[0])  # Extract number from "0.5 sec" etc.
            print(f"Packet interval set to: {self.packet_interval} sec")
            self.update_buffer_size_display()
        except Exception:
            self.packet_interval = 1.0

    def on_wrap_around_toggled(self, checked):
        """Handle stream mode selection at end-of-file."""
        self.wrap_around = bool(checked)
        mode = "wrap-around (restart)" if self.wrap_around else "stop at end-of-stream"
        print(f"Stream mode set to: {mode}")
    
    def update_buffer_size_display(self):
        """Calculate and display segment size (buffer = 1 segment)"""
        if self.sampling_frequency is None or not hasattr(self.ui, 'bufferSizeValueLabel'):
            return
        
        try:
            # Segment size equals sampling frequency multiplied by packet interval
            n_per_seg = int(round(self.sampling_frequency * self.packet_interval))
            segment_duration = self.packet_interval
            
            # Update display
            update_rate = 1.0 / self.packet_interval if self.packet_interval else 0
            buffer_text = f"{n_per_seg} samples ({segment_duration:.2f}s) | {update_rate:.1f} packets/sec"
            self.ui.bufferSizeValueLabel.setText(buffer_text)
            
        except Exception as e:
            print(f"Warning: Could not calculate segment size: {e}")
            self.ui.bufferSizeValueLabel.setText("Calculation error")
    
    def detect_sampling_frequency(self):
        """Detect sampling frequency from the selected EEG file(s)"""
        if not self.data_file_path:
            return
        
        try:
            # Check if data_file_path is a JSON list (multiple files)
            file_paths = []
            try:
                parsed = json.loads(self.data_file_path)
                if isinstance(parsed, list):
                    file_paths = [Path(p) for p in parsed]
                else:
                    file_paths = [Path(self.data_file_path)]
            except (json.JSONDecodeError, TypeError):
                # Single file path
                file_paths = [Path(self.data_file_path)]
            
            # Check all files exist
            for fp in file_paths:
                if not fp.exists():
                    print(f"Warning: File does not exist: {fp}")
                    return
            
            # Load first file to get metadata (will be used as header for concatenation)
            first_file = file_paths[0]
            ext = first_file.suffix.lower()
            
            if ext in [".edf", ".bdf"]:
                raw = mne.io.read_raw_edf(first_file, preload=False, verbose=False)
            elif ext in [".fif"]:
                raw = mne.io.read_raw_fif(first_file, preload=False, verbose=False)
            else:
                print(f"Unsupported file format: {ext}")
                return
            
            # Get sampling frequency and number of channels from first file
            self.sampling_frequency = raw.info["sfreq"]
            self.n_channels = len(raw.ch_names)
            
            # Update UI to show sampling frequency
            if hasattr(self.ui, 'sfreqValueLabel'):
                self.ui.sfreqValueLabel.setText(f"{self.sampling_frequency:.0f} Hz")
            
            # Update UI to show number of channels
            if hasattr(self.ui, 'channelsValueLabel'):
                self.ui.channelsValueLabel.setText(str(self.n_channels))
            
            if len(file_paths) > 1:
                print(f"Detected {len(file_paths)} files for concatenation")
                print(f"Using header from: {first_file.name}")
            print(f"Detected sampling frequency: {self.sampling_frequency} Hz")
            print(f"Detected number of channels: {self.n_channels}")
            
            # Calculate and display buffer size with current settings
            self.update_buffer_size_display()
            
        except Exception as e:
            print(f"Warning: Could not detect file information: {e}")
            import traceback
            traceback.print_exc()
            self.sampling_frequency = None
            self.n_channels = None
            if hasattr(self.ui, 'sfreqValueLabel'):
                self.ui.sfreqValueLabel.setText("Unknown")
            if hasattr(self.ui, 'channelsValueLabel'):
                self.ui.channelsValueLabel.setText("Unknown")
            if hasattr(self.ui, 'bufferSizeValueLabel'):
                self.ui.bufferSizeValueLabel.setText("Not calculated")
    
    def detect_ip_address(self):
        """Detect and fill in the local IP address for network access"""
        local_ip = get_local_ip()
        if local_ip != "Unable to detect":
            if hasattr(self.ui, 'serverAddressEdit'):
                self.ui.serverAddressEdit.setText(local_ip)
                self.server_address = local_ip
                print(f"Detected IP address: {local_ip}")
                QMessageBox.information(
                    self, "IP Address Detected",
                    f"Detected IP address: {local_ip}\n\n"
                    f"You can also use:\n"
                    f"- 'localhost' or '127.0.0.1' (local only)\n"
                    f"- '0.0.0.0' (all network interfaces)"
                )
        else:
            QMessageBox.warning(
                self, "IP Detection Failed",
                "Could not automatically detect IP address.\n\n"
                "Please enter manually:\n"
                "- 'localhost' or '127.0.0.1' (local only)\n"
                "- Your machine's IP address (e.g., 192.168.1.100)\n"
                "- '0.0.0.0' (all network interfaces)"
            )
    
    def show_server_info(self):
        """Display server connection information in a popup window"""
        if not self.server_process or self.server_process.poll() is not None:
            QMessageBox.warning(
                self, "Server Not Running",
                "Server is not currently running.\n\nPlease start the server first."
            )
            return
        
        def format_data_file_display():
            """Format selected data file(s) for human-friendly display."""
            if not self.data_file_path:
                return "--"
            try:
                parsed = json.loads(self.data_file_path)
                if isinstance(parsed, list):
                    if len(parsed) == 1:
                        return parsed[0]
                    names = [Path(p).name for p in parsed]
                    preview = ", ".join(names[:3])
                    if len(parsed) > 3:
                        preview += f" ... (+{len(parsed)-3} more)"
                    return f"{len(parsed)} files: {preview}"
            except (json.JSONDecodeError, TypeError):
                pass
            return self.data_file_path
        
        # Get current server configuration
        local_ip = get_local_ip()
        
        # Calculate segment settings
        if self.sampling_frequency:
            n_per_seg = int(round(self.sampling_frequency * self.packet_interval))
        else:
            n_per_seg = 0
        
        # Build info message
        info_text = "═" * 60 + "\n"
        info_text += "SERVER CONNECTION INFORMATION\n"
        info_text += "═" * 60 + "\n\n"
        
        info_text += "🖥️  Server Configuration:\n"
        info_text += f"   • Address: {self.server_address}\n"
        info_text += f"   • Port: {self.port}\n"
        info_text += f"   • Local IP: {local_ip}\n\n"
        
        info_text += "📊 Data Stream Information (sent to clients):\n"
        info_text += f"   • Data File: {format_data_file_display()}\n"
        info_text += f"   • Data Folder: {self._compute_data_dir() or '--'}\n"
        info_text += f"   • Number of Channels: {self.n_channels}\n"
        info_text += f"   • Sampling Frequency: {self.sampling_frequency:.0f} Hz\n"
        info_text += f"   • Packet Interval: {self.packet_interval:.1f} sec\n"
        info_text += f"   • Update Rate: {1.0/self.packet_interval:.1f} packets/sec\n\n"
        info_text += f"   • End-of-file Mode: {'Wrap around (restart)' if self.wrap_around else 'Stop stream'}\n\n"
        info_text += "📦 Segment Configuration:\n"
        info_text += f"   • Segment Size (n_per_seg): {n_per_seg} samples ({self.packet_interval:.1f} sec)\n"

        
        if self.server_address == "localhost" or self.server_address == "127.0.0.1":
            info_text += "📡 Access URLs:\n"
            info_text += f"   • Local only: http://localhost:{self.port}\n\n"
            
            info_text += "⚠️  Network Access:\n"
            info_text += "   Server is configured for LOCAL ONLY access.\n"
            info_text += "   To enable network access, stop the server and\n"
            info_text += "   change Server Address to:\n"
            info_text += f"     • '{local_ip}' (this machine's IP)\n"
            info_text += "     • '0.0.0.0' (all interfaces)\n\n"
            
            info_text += "💻 Client Configuration (Same Machine):\n"
            info_text += f"   Server URL: localhost\n"
            info_text += f"   Port: {self.port}\n"
        
        elif self.server_address == "0.0.0.0":
            info_text += "📡 Access URLs:\n"
            info_text += f"   • Local: http://localhost:{self.port}\n"
            info_text += f"   • Network: http://{local_ip}:{self.port}\n\n"
            
            info_text += "✅ Network Access: ENABLED\n"
            info_text += "   Server is listening on all network interfaces.\n\n"
            
            info_text += "💻 Client Configuration:\n"
            info_text += "   Same machine:\n"
            info_text += f"     Server URL: localhost\n"
            info_text += f"     Port: {self.port}\n\n"
            info_text += "   Different machine (same network):\n"
            info_text += f"     Server URL: {local_ip}\n"
            info_text += f"     Port: {self.port}\n"
        
        else:
            info_text += "📡 Access URL:\n"
            info_text += f"   • http://{self.server_address}:{self.port}\n\n"
            
            info_text += "✅ Network Access: ENABLED\n"
            info_text += f"   Server is listening on: {self.server_address}\n\n"
            
            info_text += "💻 Client Configuration:\n"
            info_text += f"   Server URL: {self.server_address}\n"
            info_text += f"   Port: {self.port}\n"
        
        info_text += "\n" + "─" * 60 + "\n"
        info_text += "📋 Quick Test:\n"
        test_url = f"http://localhost:{self.port}" if self.server_address == "0.0.0.0" else f"http://{self.server_address}:{self.port}"
        info_text += "   1. Test server is running:\n"
        info_text += f"      {test_url}\n"
        info_text += "      Should show: {\"message\":\"EEG Streaming Server\"}\n\n"
        info_text += "   2. Check data stream info:\n"
        info_text += f"      {test_url}/info\n"
        info_text += "      Should return JSON with channel count, sampling\n"
        info_text += "      frequency, buffer settings, etc.\n"
        
        # Create message box with monospace font for better formatting
        msg_box = QMessageBox(self)
        msg_box.setWindowTitle("Server Connection Information")
        msg_box.setText(info_text)
        msg_box.setIcon(QMessageBox.Icon.Information)
        msg_box.setStyleSheet("QLabel{font-family: 'Courier New', monospace; min-width: 600px;}")
        msg_box.exec()
    
    def browse_data_file(self):
        """Browse for single or multiple EEG data files"""
        # File selection (allow multiple files)
        filenames, _ = QFileDialog.getOpenFileNames(
            self,
            "Select EEG Data File(s) - Multiple files will be concatenated",
            "",
            "EEG Files (*.edf *.bdf *.fif);;All Files (*)"
        )
        if filenames:
            if len(filenames) == 1:
                self.data_file_path = filenames[0]
                self.ui.dataPathEdit.setText(filenames[0])
            else:
                # Multiple files selected - store as JSON list
                import json
                self.data_file_path = json.dumps(filenames)
                # Display summary in UI
                file_list = [Path(f).name for f in filenames]
                display_text = f"{len(filenames)} files: {', '.join(file_list[:3])}"
                if len(filenames) > 3:
                    display_text += f" ... (+{len(filenames)-3} more)"
                self.ui.dataPathEdit.setText(display_text)
                self.ui.dataPathEdit.setToolTip('\n'.join(filenames))
            
            # Detect sampling frequency after file(s) selected
            self.detect_sampling_frequency()
    
    def validate_inputs(self):
        """Validate user inputs before starting server"""
        if not self.data_file_path:
            QMessageBox.critical(self, "Error", "Please select an EEG data file.")
            return False
        
        # Check if it's a JSON list or single file
        try:
            parsed = json.loads(self.data_file_path)
            if isinstance(parsed, list):
                # Multiple files - check all exist
                for file_path in parsed:
                    if not os.path.exists(file_path):
                        QMessageBox.critical(self, "Error", f"File does not exist: {file_path}")
                        return False
            else:
                # Not a list, treat as single file
                if not os.path.exists(self.data_file_path):
                    QMessageBox.critical(self, "Error", "Selected data file does not exist.")
                    return False
        except (json.JSONDecodeError, TypeError):
            # Single file path
            if not os.path.exists(self.data_file_path):
                QMessageBox.critical(self, "Error", "Selected data file does not exist.")
                return False
        
        return True
    
    def start_server(self):
        """Start the EEG streaming server"""
        # Note: self.data_file_path is already set by browse_data_file()
        # Don't override it from the text field as it may contain a summary
        self.port = self.ui.portSpin.value()
        
        # Get server address from UI (default to localhost if empty)
        if hasattr(self.ui, 'serverAddressEdit'):
            self.server_address = self.ui.serverAddressEdit.text().strip()
            if not self.server_address:
                self.server_address = "localhost"
        else:
            self.server_address = "localhost"
        
        # Ensure n_channels was detected
        if self.n_channels is None:
            QMessageBox.critical(
                self, "Error", 
                "Number of channels not detected. Please select a valid EEG file."
            )
            return
        
        if not self.validate_inputs():
            return
        
        # Save configuration to JSON file
        data_dir = self._compute_data_dir()
        config = {
            "data_file": self.data_file_path,
            "data_dir": data_dir,
            "n_channels": self.n_channels,
            "host": self.server_address,
            "port": self.port,
            "packet_interval": self.packet_interval,
            "wrap_around": self.wrap_around,
        }
        
        config_path = SCRIPT_DIR / "server_config.json"
        with open(config_path, "w") as f:
            json.dump(config, f, indent=2)
        
        # Start server in a subprocess
        try:
            self.server_process = subprocess.Popen(
                [
                    sys.executable,
                    str(SCRIPT_DIR / "configurable_server.py"),
                    "--config",
                    str(config_path),
                ],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                cwd=str(SCRIPT_DIR),
            )
            
            # Give server a moment to start and check if it crashes immediately
            time.sleep(2)
            
            # Check if process is still running
            poll_result = self.server_process.poll()
            if poll_result is not None:
                # Process exited - get error messages
                stdout, stderr = self.server_process.communicate()
                error_msg = stderr.decode('utf-8') if stderr else ""
                output_msg = stdout.decode('utf-8') if stdout else ""
                
                full_msg = f"Server crashed during startup!\n\n"
                if error_msg:
                    full_msg += f"Error:\n{error_msg}\n\n"
                if output_msg:
                    full_msg += f"Output:\n{output_msg}"
                
                QMessageBox.critical(self, "Server Startup Failed", full_msg)
                self.server_process = None
                return
            
            self.ui.startBtn.setEnabled(False)
            self.ui.stopBtn.setEnabled(True)
            if hasattr(self.ui, 'showServerInfoBtn'):
                self.ui.showServerInfoBtn.setEnabled(False)  # Wait for verification
            
            # Get and display network information
            local_ip = get_local_ip()
            self.ui.statusLabel.setText(f"Status: Server starting on {self.server_address}:{self.port}... (check after 5s)")
            self.ui.statusLabel.setStyleSheet("color: orange; font-weight: bold;")
            
            print(f"\n{'='*60}")
            print(f"SERVER NETWORK INFORMATION:")
            print(f"{'='*60}")
            print(f"Server configured to listen on: {self.server_address}")
            print(f"Port: {self.port}")
            
            if self.server_address == "localhost" or self.server_address == "127.0.0.1":
                print(f"Access: http://localhost:{self.port} (local only)")
                print(f"\nFor network access, change Server Address to:")
                print(f"  - '0.0.0.0' (all interfaces)")
                print(f"  - '{local_ip}' (this machine's IP)")
            else:
                print(f"Access: http://{self.server_address}:{self.port}")
                if self.server_address == "0.0.0.0":
                    print(f"  (Listening on all network interfaces)")
                    print(f"  Local: http://localhost:{self.port}")
                    print(f"  Network: http://{local_ip}:{self.port}")
                else:
                    print(f"  (Listening on specific interface: {self.server_address})")
            
            print(f"\nFor clients on other machines:")
            print(f"  Server URL: {local_ip if self.server_address != 'localhost' else 'localhost'}")
            print(f"  Port: {self.port}")
            print(f"{'='*60}\n")
            
            # Schedule a check after 5 seconds to verify server is actually running
            QTimer.singleShot(5000, self.verify_server_running)
            
        except Exception as e:
            QMessageBox.critical(self, "Error", f"Failed to start server: {e}")
    
    def verify_server_running(self):
        """Verify the server is actually running after startup"""
        if self.server_process is None:
            return
        
        # Check if process is still alive
        poll_result = self.server_process.poll()
        if poll_result is not None:
            # Process died - get error messages
            stdout, stderr = self.server_process.communicate()
            error_msg = stderr.decode('utf-8') if stderr else ""
            output_msg = stdout.decode('utf-8') if stdout else ""
            
            # Show error dialog
            full_msg = f"Server crashed after startup!\n\n"
            if error_msg:
                full_msg += f"Error output:\n{error_msg[-2000:]}\n\n"  # Last 2000 chars
            if output_msg:
                full_msg += f"Console output:\n{output_msg[-2000:]}"  # Last 2000 chars
            
            QMessageBox.critical(self, "Server Crashed", full_msg)
            
            # Update UI
            self.ui.statusLabel.setText(f"Status: Server crashed")
            self.ui.statusLabel.setStyleSheet("color: red; font-weight: bold;")
            self.ui.startBtn.setEnabled(True)
            self.ui.stopBtn.setEnabled(False)
            if hasattr(self.ui, 'showServerInfoBtn'):
                self.ui.showServerInfoBtn.setEnabled(False)
            self.server_process = None
        else:
            # Process still running - likely OK
            local_ip = get_local_ip()
            address_display = self.server_address if self.server_address != "0.0.0.0" else f"0.0.0.0 (IP: {local_ip})"
            self.ui.statusLabel.setText(f"Status: Server running on {address_display}:{self.port}")
            self.ui.statusLabel.setStyleSheet("color: green; font-weight: bold;")
            
            # Enable server info button now that server is confirmed running
            if hasattr(self.ui, 'showServerInfoBtn'):
                self.ui.showServerInfoBtn.setEnabled(True)
            
            print(f"✅ Server confirmed running!")
            if self.server_address == "0.0.0.0":
                print(f"   Listening on: All interfaces")
                print(f"   Local: http://localhost:{self.port}")
                print(f"   Network: http://{local_ip}:{self.port}")
            else:
                print(f"   Listening on: {self.server_address}:{self.port}")
            print()
    
    def stop_server(self):
        """Stop the EEG streaming server and ensure port is freed"""
        # First, try to stop the tracked process
        if self.server_process:
            try:
                pid = self.server_process.pid
                
                # Try graceful termination first
                self.server_process.terminate()
                try:
                    self.server_process.wait(timeout=3)
                except subprocess.TimeoutExpired:
                    # Force kill if it doesn't terminate
                    self.server_process.kill()
                    try:
                        self.server_process.wait(timeout=2)
                    except subprocess.TimeoutExpired:
                        pass
                
                # Also kill any child processes if psutil is available
                if PSUTIL_AVAILABLE:
                    try:
                        parent = psutil.Process(pid)
                        children = parent.children(recursive=True)
                        for child in children:
                            try:
                                child.terminate()
                            except (psutil.NoSuchProcess, psutil.AccessDenied):
                                pass
                        # Wait a bit for children to terminate
                        gone, alive = psutil.wait_procs(children, timeout=2)
                        for p in alive:
                            try:
                                p.kill()
                            except (psutil.NoSuchProcess, psutil.AccessDenied):
                                pass
                    except (psutil.NoSuchProcess, psutil.AccessDenied):
                        pass
                
            except Exception as e:
                print(f"Warning: Error stopping server process: {e}")
            finally:
                self.server_process = None
        
        # Also check for any processes still using the port and kill them
        if PSUTIL_AVAILABLE:
            try:
                # Find processes using the port
                for proc in psutil.process_iter(['pid', 'name', 'connections']):
                    try:
                        connections = proc.info['connections']
                        if connections:
                            for conn in connections:
                                if conn.status == psutil.CONN_LISTEN:
                                    if conn.laddr.port == self.port:
                                        print(f"⚠️  Killing process {proc.info['pid']} ({proc.info['name']}) using port {self.port}")
                                        p = psutil.Process(proc.info['pid'])
                                        p.terminate()
                                        try:
                                            p.wait(timeout=2)
                                        except psutil.TimeoutExpired:
                                            p.kill()
                    except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess, AttributeError):
                        continue
            except Exception as e:
                print(f"Warning: Error checking for processes on port {self.port}: {e}")
        else:
            # Fallback: use lsof on Unix systems to find and kill processes on the port
            try:
                import platform
                if platform.system() != "Windows":
                    # Use lsof to find processes using the port
                    result = subprocess.run(
                        ['lsof', '-ti', f':{self.port}'],
                        capture_output=True,
                        text=True,
                        timeout=2
                    )
                    if result.returncode == 0 and result.stdout.strip():
                        pids = result.stdout.strip().split('\n')
                        for pid_str in pids:
                            try:
                                pid = int(pid_str)
                                print(f"⚠️  Killing process {pid} using port {self.port}")
                                os.kill(pid, 15)  # SIGTERM
                                time.sleep(0.5)
                                # Force kill if still running
                                try:
                                    os.kill(pid, 0)  # Check if still exists
                                    os.kill(pid, 9)  # SIGKILL
                                except ProcessLookupError:
                                    pass  # Already dead
                            except (ValueError, ProcessLookupError, PermissionError):
                                pass
            except Exception as e:
                print(f"Warning: Error checking for processes on port {self.port}: {e}")
        
        # Clean log files
        log_files = ["configurable_server.log", "server_gui.log"]
        for log_file in log_files:
            log_path = os.path.join(os.getcwd(), log_file)
            if os.path.exists(log_path):
                try:
                    os.remove(log_path)
                    print(f"🧹 Cleaned log file: {log_file}")
                except Exception as e:
                    print(f"⚠️  Could not remove log file {log_file}: {e}")
        
        # Update UI
        self.ui.startBtn.setEnabled(True)
        self.ui.stopBtn.setEnabled(False)
        if hasattr(self.ui, 'showServerInfoBtn'):
            self.ui.showServerInfoBtn.setEnabled(False)
        self.ui.statusLabel.setText("Status: Server stopped")
        self.ui.statusLabel.setStyleSheet("color: blue; font-weight: bold;")
        
        print(f"✅ Server stopped and port {self.port} should be free")


if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = ServerGUI()
    window.show()
    sys.exit(app.exec())
