#!/usr/bin/env python3
"""
UART Image Sender for ONEAI CNN Interface
==========================================

This script reads an image file or webcam strem, resizes it to 128x128, converts to gray scale format,
and sends it via UART to the VHDL UART_Interface module.

Protocol:
- Header: 0xFF 0xAA 0x55 (3 bytes)
- Data: 128*128*1 = 16,384 RGB bytes
- Total: 16387 bytes per frame
- Baud rate: 115200
- Response: 1 Byte per Class

Usage:
    python uart_image_sender.py
    python uart_image_sender.py --port /dev/ttyUSB0 --baud 115200
"""

import sys
import argparse
import time
import serial
import numpy as np
from PIL import Image
import cv2
from pathlib import Path
from threading import Thread


class WebcamCapture:
    def __init__(self, uart, camera_index=0, fps=60, num_classes = 10):
        """
        Initialize webcam capture
        
        Args:
            camera_index: Camera device index (0 for default camera)
            fps: Target frames per second
            num_classes: length of response to expect (1 byte per class)
        """
        self.camera_index = camera_index
        self.fps = fps
        self.cap = None
        self.FRAME_WIDTH = 128
        self.FRAME_HEIGHT = 128
        
        self.UART = uart
        self.run = False
        
        self.num_classes = num_classes
        
    def preprocess_frame(self, frame):
        """
        Process frame to right format:
        1. Crop to quatratic
        2. Resize 128x128
        3. Convert to gray scale
        4. Convert to uint8 range 0...127
        
        Args:
            frame: Input-Frame from Webcam
            
        Returns:
            processed: preprocessed frame (128x128, gray scale uint8, range 0...127)
        """
        height, width = frame.shape[:2]
        
        # Calculate crop
        if width > height:
            # cut sides
            crop_size = height
            x_offset = (width - crop_size) // 2
            y_offset = 0
        else:
            # cut top/bottom
            crop_size = width
            x_offset = 0
            y_offset = (height - crop_size) // 2
        
        # Crop to quatratic
        cropped = frame[y_offset:y_offset+crop_size, x_offset:x_offset+crop_size]
        
        # Resize to 128x128
        resized = cv2.resize(cropped, (self.FRAME_WIDTH, self.FRAME_HEIGHT))
        
        # Convert to gray scale
        frame_gray = cv2.cvtColor(resized, cv2.COLOR_BGR2GRAY)
    
        # Convert to numpy array and scale to 0...127
        img_array = (np.array(frame_gray) / 2).astype(np.uint8)
        
        return img_array
    
    def draw_prediction(self, frame, data):
        """
        Draw prediction on image.
        
        Args:
            frame: Frame as background
            data: raw bytes returned from FPGA. 10 Bytes each giving a 0...127 scaled propability for the 10 possible digits (0...9)
            
        Returns:
            frame: Frame with prediction
        """
        height, width = frame.shape[:2]
        
        # prediction
        # find maximum
        confidence = max(data)
        # look up index/class
        predicted_class = data.index(confidence)
        
        # Draw half transparent background
        overlay = frame.copy()
        header_height = int(height * 0.15)  # 15% der Höhe für Header
        cv2.rectangle(overlay, (0, 0), (width, header_height), (0, 0, 0), -1)
        cv2.addWeighted(overlay, 0.6, frame, 0.4, 0, frame)
        
        # Predicted digit
        text = str(predicted_class)
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = min(width, height) / 250  # Viel kleiner für dünneren Text
        thickness = max(2, int(font_scale * 4))  # Dünnere Linien
        
        # calculate text height
        (text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, thickness)
        
        # Position for text, top left
        margin = 20
        x = width//2 - text_width//2
        y = margin + text_height
        
        # make text just outlines
        cv2.putText(frame, text, (x, y), font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA)
        cv2.putText(frame, text, (x, y), font, font_scale, (0, 255, 0), thickness, cv2.LINE_AA)
        
        # Confidence, top right
        conf_text = f"{confidence/1.27:.1f}%" #condifence is [0...127], dividing by 1.27 maps this to percentages
        conf_font_scale = font_scale * 0.6
        conf_thickness = max(1, int(conf_font_scale * 4))
        
        (conf_width, conf_height), _ = cv2.getTextSize(conf_text, font, conf_font_scale, conf_thickness)
        conf_x = width - conf_width - margin
        conf_y = margin + conf_height
        
        cv2.putText(frame, conf_text, (conf_x, conf_y), font, conf_font_scale, (255, 255, 255), conf_thickness, cv2.LINE_AA)

        return frame
        
    def stream(self):
        """
        Continuously captures images from WebCam, sends it via UART and shows prediction ontop of captured image.
        """
        
        # start capture
        self.cap = cv2.VideoCapture(self.camera_index, cv2.CAP_DSHOW)
        
        #helper variables
        frame_count = 0
        frame_divider = int(self.fps/self.UART.fps)
        data = None
                
        while self.run:
            # read frame
            ret, frame = self.cap.read()
            
            # only process as many frames as the UART can handle. @ 1MBaud it sends with ~4 fps.
            if frame_count % frame_divider == 0:
                # preprocess frame and send via UART
                send_frame = self.preprocess_frame(frame)
                data = self.UART.send_frame(self.UART.create_uart_data(send_frame),raw=True, verbose=False)
            if len(data) == self.num_classes:
                # add prediction to frame before depiction
                frame = self.draw_prediction(frame, data)
            
            # show image with prediction
            cv2.imshow("Digit Classifier", frame)
            cv2.waitKey(1)
            frame_count += 1
            
        self.cap.release()
        cv2.destroyAllWindows()
            
            
class UARTImageSender:
    def __init__(self, port='/dev/ttyUSB0', baud_rate=115200, timeout=10):
        """
        Initialize UART connection
        
        Args:
            port: Serial port name (e.g., '/dev/ttyUSB0' on Linux, 'COM3' on Windows)
            baud_rate: UART baud rate (must match VHDL interface)
            timeout: Communication timeout in seconds
        """
        self.port = port
        self.baud_rate = baud_rate
        self.timeout = timeout
        self.serial_conn = None
        
        # Protocol constants
        self.HEADER_BYTES = [0xFF, 0xAA, 0x55]
        self.FRAME_WIDTH = 128
        self.FRAME_HEIGHT = 128
        self.CHANNELS = 1  # gray
        self.TOTAL_PIXELS = self.FRAME_WIDTH * self.FRAME_HEIGHT
        self.TOTAL_BYTES = self.TOTAL_PIXELS * self.CHANNELS
        
        # images per second
        max_fps = self.baud_rate / (self.TOTAL_BYTES*10) # 10 Bits to the Byte, because of stop and start bit
        self.fps = int(max_fps/1.5)
        
    def connect(self):
        """Establish UART connection"""
        try:
            self.serial_conn = serial.Serial(
                port=self.port,
                baudrate=self.baud_rate,
                bytesize=serial.EIGHTBITS,
                parity=serial.PARITY_NONE,
                stopbits=serial.STOPBITS_ONE,
                timeout=self.timeout,
                xonxoff=False,
                rtscts=False,
                dsrdtr=False
            )
            
            # Clear any existing data
            self.serial_conn.flush()
            time.sleep(0.1)
            self.serial_conn.reset_input_buffer()
            self.serial_conn.reset_output_buffer()
            
            print(f"  Connected to {self.port} at {self.baud_rate} baud")
            return True
            
        except serial.SerialException as e:
            print(f"  Failed to connect to {self.port}: {e}")
            return False
    
    def disconnect(self):
        """Close UART connection"""
        if self.serial_conn and self.serial_conn.is_open:
            self.serial_conn.close()
            print("  UART connection closed")
    
    def load_and_process_image(self, image_path):
        """
        Load image from file and process to 128x128 grayscale format
        
        Args:
            image_path: Path to image file
            
        Returns:
            numpy.ndarray: Processed image as 128x128 uint8 array
        """
        try:
            # Load image using PIL (supports many formats)
            img = Image.open(image_path)
            print(f"  Loaded image: {img.size} pixels, mode: {img.mode}")
            
            # Convert to RGB if necessary
            if img.mode != 'L':
                img = img.convert('L')
                print("  Converted to grayscale")
            
            # Resize to 128x128
            img_resized = img.resize((self.FRAME_WIDTH, self.FRAME_HEIGHT), resample=Image.BILINEAR)
            #print(f"  Resized to {self.FRAME_WIDTH}x{self.FRAME_HEIGHT}")
            
            # Convert to numpy array and scale to 0...127
            img_array = (np.array(img_resized) / 2).astype(np.uint8)
            
            # Verify shape
            assert img_array.shape == (self.FRAME_HEIGHT, self.FRAME_WIDTH)#, self.CHANNELS)
            
            print(f"  Image processed: shape={img_array.shape}, dtype={img_array.dtype}")
            print(f"  Pixel value range: {img_array.min()} - {img_array.max()}")
            
            return img_array
            
        except Exception as e:
            print(f"  Error processing image {image_path}: {e}")
            return None
    
    def create_uart_data(self, image_array):
        """
        Convert image array to UART transmission format
        
        Args:
            image_array: 128x128 numpy array
            
        Returns:
            bytes: Complete UART frame (header + RGB data)
        """
        # Start with header
        uart_data = bytearray(self.HEADER_BYTES)
        
        # Add RGB pixel data in row-major order
        for row in range(self.FRAME_HEIGHT):
            for col in range(self.FRAME_WIDTH):
                
                # Add to UART data
                uart_data.extend([image_array[row, col]])
        
        expected_size = 3 + self.TOTAL_BYTES  # header + RGB data
        assert len(uart_data) == expected_size, f"Data size mismatch: {len(uart_data)} != {expected_size}"
        
        return bytes(uart_data)
    
    def process_result(self, byte_array):
        """
        Extracts prediction from data returned from FPGA.
            
        Returns:
            ret: String with predicted digit and confidences for all digits.
        """
        
        if len(byte_array) != 10:
            return byte_array
        max_val = max(byte_array)
        max_id = byte_array.index(max_val)
        ret = "Found number: " + str(max_id) + " (" + ','.join(map(str,res)) + ")"
        
        return ret
    
    def send_frame(self, uart_data, raw=False, verbose = True):
        """
        Send frame data via UART and wait for response
        
        Args:
            uart_data: Complete frame data including header
            
        Returns:
            str: Response from FPGA ("OK", "ER", or timeout message)
        """
        if not self.serial_conn or not self.serial_conn.is_open:
            return "ERROR: UART not connected"
        
        try:
            if verbose:
                print(f"  Sending {len(uart_data)} bytes...")
            
            # Send data
            start_time = time.time()
            bytes_sent = self.serial_conn.write(uart_data)
            self.serial_conn.flush()  # Ensure all data is sent
            
            send_time = time.time() - start_time
            if verbose:
                print(f"  Sent {bytes_sent} bytes in {send_time:.2f}s ({bytes_sent/send_time:.0f} bytes/s)")
            
            # Wait for response
            if verbose:
                print("  Waiting for FPGA response...")
            response_start = time.time()
            
            while (time.time() - response_start) < self.timeout:
                if self.serial_conn.in_waiting > 0:
                    # Read available bytes
                    data = self.serial_conn.read(self.serial_conn.in_waiting)
                    if raw:
                        return data
                    else:
                        return self.process_result(data)
                
                time.sleep(0.01)  # Small delay to prevent busy waiting
            
            # Timeout occurred
            elapsed = time.time() - response_start
            if verbose:
                print(f"  Timeout after {elapsed:.2f}s")
            return f"TIMEOUT "
            
        except serial.SerialException as e:
            return f"UART_ERROR: {e}"
        except Exception as e:
            return f"ERROR: {e}"
    
    def send_image(self, image_path):
        """
        Complete process: load image, process, and send via UART
        
        Args:
            image_path: Path to image file
            
        Returns:
            str: Final result status
        """
        print(f"\n  Processing image: {image_path}")
        
        # Load and process image
        image_array = self.load_and_process_image(image_path)
        if image_array is None:
            return "FAILED: Could not process image"
        
        # Create UART data
        uart_data = self.create_uart_data(image_array)
        print(f"  Created UART frame: {len(uart_data)} bytes")
        
        # Send via UART
        result = self.send_frame(uart_data)
        
        return result
        
    def process_webcam_frame(self, frame_rgb):
        """
        Process webcam frame to 128x128 RGB format
        
        Args:
            frame_rgb: Raw webcam frame in RGB format
            
        Returns:
            numpy.ndarray: Processed image as 128x128x3 uint8 array
        """
        try:
            # Convert numpy array to PIL Image
            img = Image.fromarray(frame_rgb)
            
            # Resize to 128x128
            img_resized = img.resize((self.FRAME_WIDTH, self.FRAME_HEIGHT))
            
            # Convert back to numpy array and scale to 0...127
            img_array = np.array(img_resized, dtype=np.uint8) / 2
            
            return img_array
            
        except Exception as e:
            print(f"  Error processing webcam frame: {e}")
            return None
    
    def send_webcam_frame(self, image_array):
        """
        Process and send webcam frame via UART
        
        Args:
            frame_rgb: Raw webcam frame in RGB format
            
        Returns:
            str: Result status
        """
        # Process webcam frame
        #image_array = self.process_webcam_frame(frame_rgb)
        #if image_array is None:
        #    return "FAILED: Could not process webcam frame"
        
        # convert to 0....127
        image_array = image_array / 2
        
        # Create UART data
        uart_data = self.create_uart_data(image_array)
        
        # Send via UART
        result = self.send_frame(uart_data)
        return result

def main():
    parser = argparse.ArgumentParser(
        description="Send image to ONEAI CNN via UART",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    python uart_image_sender.py photo.jpg
    python uart_image_sender.py --port COM3 --baud 115200 image.png
    python uart_image_sender.py --port /dev/ttyACM0 --timeout 60 large_image.bmp
    python uart_image_sender.py --camera 0
        """
    )
    
    parser.add_argument('--port', 
                       default='/dev/ttyUSB0',
                       help='Serial port (default: /dev/ttyUSB0)')
    parser.add_argument('--baud', 
                       type=int, 
                       default=115200,
                       help='Baud rate (default: 115200)')
    parser.add_argument('--timeout', 
                       type=int, 
                       default=3,
                       help='Response timeout in seconds (default: 30)')
    parser.add_argument('--list-ports', 
                       action='store_true',
                       help='List available serial ports')
    parser.add_argument('--camera', 
                       type=int,
                       default=0,
                       help='Camera index (default: 0)')
    
    args = parser.parse_args()
    
    # List available ports if requested
    if args.list_ports:
        try:
            import serial.tools.list_ports
            ports = serial.tools.list_ports.comports()
            print("Available serial ports:")
            for port in ports:
                print(f"  {port.device} - {port.description}")
        except ImportError:
            print("pyserial not installed. Install with: pip install pyserial")
        return
    
    # Create UART sender
    sender = UARTImageSender(
        port=args.port,
        baud_rate=args.baud,
        timeout=args.timeout
    )
    
    # Create Webcam Object
    webcam = WebcamCapture(sender)
    
    try:
        # Connect to UART
        if not sender.connect():
            print(f"  No Connection...")
            sys.exit(1)

    except Exception as e:
        print(f"  Unexpected error: {e}")
        sys.exit(1)
    
    print("Possible commands: 'e' or 'q' to exit, 'c' or 'cam' for camera mode. everything else is interpreted as a file-name.")


    while(True):

        # Taking image path from user input
        uinput = input("Enter image path or command: ")
        
        # exit on user request
        if uinput.lower() in ['exit', 'quit', 'q', 'e']:
            print("Exiting...")
            break
        
        # send test byte to FPGA. Should return "42".
        if uinput.lower() == '?':
            result = sender.send_frame(bytearray([0x24]))
            if isinstance(result, bytearray):
                result = result.hex()
            print(f"\n  Returned Response: {result}")
            continue
        
        # Stop webcam stream.
        if uinput.lower() == 's':
            webcam.run = False
            continue
            
        # Start webcam streaming
        if uinput.lower() in ['c', 'cam']:
            
            print("  Webcam streaming started. Enter 's' to stop stream.")
            
            webcam.run = True
            t = Thread(target=webcam.stream)
            t.start()
            
            continue

        # Validate image file exists
        image_path = Path("img/" + uinput.strip())
        if not image_path.exists():
            print(f"  Image file not found: {image_path}")
            continue

        try:
            # Send image
            result = sender.send_image(image_path)
            
            # Print final result
            print(f"\n  Returned Result: {result}")
                
        except KeyboardInterrupt:
            print("\n   Interrupted by user")
            sys.exit(1)
        except Exception as e:
            print(f"  Unexpected error: {e}")
            sys.exit(1)

    # Disconnect UART
    sender.disconnect()

if __name__ == "__main__":
    main()