#!/usr/bin/env python3
"""
Full Working Speech-to-Text Server
Real-time transcription with FastAPI, WebSockets, and Whisper
"""

import asyncio
import json
import logging
import base64
import io
import wave
import numpy as np
from typing import Dict, Set
import os
import tempfile
import threading
from queue import Queue

# Import Chinese optimization config
try:
    from chinese_config import get_chinese_config, CHINESE_TEST_PHRASES
    CHINESE_CONFIG_AVAILABLE = True
except ImportError:
    CHINESE_CONFIG_AVAILABLE = False

# Audio processing imports
try:
    import ffmpeg
    FFMPEG_AVAILABLE = True
except ImportError:
    FFMPEG_AVAILABLE = False

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
import uvicorn

# Import Whisper for real transcription
try:
    from faster_whisper import WhisperModel
    WHISPER_AVAILABLE = True
    print("✓ Faster-Whisper available - using real AI transcription")
except ImportError:
    WHISPER_AVAILABLE = False
    print("✗ Faster-Whisper not available - using mock transcription")

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class WhisperTranscriptionService:
    def __init__(self):
        self.model = None
        self.model_name = None
        self.chinese_config = get_chinese_config() if CHINESE_CONFIG_AVAILABLE else None
        self.initialize_model()

    def initialize_model(self):
        """Initialize Whisper model optimized for Chinese"""
        if WHISPER_AVAILABLE:
            try:
                # Ensure models directory exists
                os.makedirs("models", exist_ok=True)

                # Use Chinese config if available, otherwise fallback
                if self.chinese_config:
                    model_options = self.chinese_config["model_priority"]
                    logger.info("🇨🇳 Using Chinese optimized model configuration")
                else:
                    model_options = [
                        ("large-v3", "Best quality - Chinese optimized"),
                        ("large-v2", "High quality - Good Chinese support"),
                        ("medium", "Balanced - Good Chinese performance"),
                        ("small", "Fast - Decent Chinese support"),
                        ("base", "Fastest - Basic Chinese support")
                    ]
                    logger.info("Using fallback model configuration")

                for model_name, description in model_options:
                    try:
                        # Check if model exists in local models directory first
                        local_model_path = f"models/{model_name}"

                        # Check for existing model files (model.bin and config files)
                        actual_model_path = self._find_actual_model_path(local_model_path)

                        if actual_model_path:
                            logger.info(f"✓ Found existing model at: {actual_model_path}")
                            # Try GPU first
                            try:
                                self.model = WhisperModel(actual_model_path, device="cuda", compute_type="float16")
                                self.model_name = model_name
                                logger.info(f"✓ Whisper {model_name} loaded successfully on GPU - {description}")
                                break
                            except Exception:
                                # Fallback to CPU
                                try:
                                    self.model = WhisperModel(actual_model_path, device="cpu", compute_type="int8")
                                    self.model_name = model_name
                                    logger.info(f"✓ Whisper {model_name} loaded successfully on CPU - {description}")
                                    break
                                except Exception as e:
                                    logger.warning(f"Failed to load existing model {model_name}: {e}")
                                    continue
                        else:
                            logger.info(f"Model {model_name} not found locally, downloading to models/{model_name}/...")
                            # Create the model-specific directory
                            model_dir = f"models/{model_name}"
                            os.makedirs(model_dir, exist_ok=True)

                            try:
                                # Download model to the models directory
                                self.model = WhisperModel(model_name, device="auto", compute_type="auto", download_root="models")
                                self.model_name = model_name
                                logger.info(f"✓ Whisper {model_name} downloaded and loaded successfully - {description}")

                                # Check organization of downloaded files
                                self._organize_downloaded_model(model_name)
                                break
                            except Exception as e:
                                logger.warning(f"Failed to download/load {model_name}: {e}")
                                continue
                    except Exception as e:
                        logger.warning(f"Skipping {model_name}: {e}")
                        continue

                if self.model is None:
                    raise Exception("No Whisper model could be loaded")

            except Exception as e:
                logger.error(f"Failed to load any Whisper model: {e}")
                self.model = None
        else:
            logger.warning("Whisper not available, using mock transcription")

    def _find_actual_model_path(self, model_path: str) -> str:
        """Find the actual path where model files are located"""
        try:
            # Check for model.bin file in the direct path first
            model_bin = os.path.join(model_path, "model.bin")
            if os.path.exists(model_bin):
                return model_path

            # Check for model files in nested structure (existing downloads)
            for root, dirs, files in os.walk(model_path):
                if "model.bin" in files:
                    # Verify other required files are also present
                    config_json = os.path.join(root, "config.json")
                    if os.path.exists(config_json):
                        logger.info(f"Found model files in nested structure: {root}")
                        return root

            return None
        except Exception as e:
            logger.warning(f"Error finding model path: {e}")
            return None

    def _organize_downloaded_model(self, model_name: str):
        """Organize downloaded model files to the correct structure (optional optimization)"""
        try:
            model_dir = f"models/{model_name}"

            # Find the actual model files in nested structure
            actual_path = self._find_actual_model_path(model_dir)
            if not actual_path:
                return

            # If files are already in the root directory, no need to reorganize
            target_model_bin = os.path.join(model_dir, "model.bin")
            if os.path.exists(target_model_bin):
                logger.info(f"Model {model_name} is already organized")
                return

            # Optional: Create flat structure for faster access
            if actual_path != model_dir:
                logger.info(f"Model files found in nested structure at: {actual_path}")
                logger.info(f"Model {model_name} will be loaded from existing location")
                # Note: We're keeping the nested structure as it works fine
                # Users can manually flatten if they prefer

        except Exception as e:
            logger.warning(f"Model organization check failed for {model_name}: {e}")

    def transcribe_audio(self, audio_data: bytes, is_first_chunk: bool = True) -> str:
        """Transcribe audio data to text"""
        try:
            if not audio_data or len(audio_data) < 1000:  # Too short
                logger.info(f"Audio too short: {len(audio_data) if audio_data else 0} bytes")
                return ""

            # First, try to decode the compressed audio data
            processed_audio = self._decode_audio_data(audio_data, is_first_chunk)
            if processed_audio is None:
                logger.error("Failed to decode audio data")
                return ""

            # Analyze processed audio data for debugging
            try:
                audio_array = np.frombuffer(processed_audio, dtype=np.int16)
                if len(audio_array) > 0:
                    rms = np.sqrt(np.mean(audio_array.astype(float) ** 2))
                    max_amplitude = np.max(np.abs(audio_array))
                    min_amplitude = np.min(audio_array)
                    duration_samples = len(audio_array)
                    duration_seconds = duration_samples / 16000  # Assuming 16kHz

                    logger.info(f"Processed audio - RMS: {rms:.2f}, Max: {max_amplitude}, Min: {min_amplitude}, "
                               f"Samples: {duration_samples}, Duration: {duration_seconds:.3f}s, "
                               f"Data bytes: {len(processed_audio)}")
                else:
                    logger.warning("Empty audio array after decoding")
                    return ""
            except Exception as e:
                logger.warning(f"Audio analysis failed: {e}")

            if self.model is None:
                return self._mock_transcription(processed_audio)

            # Convert processed bytes to temporary WAV file
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
                # Write processed audio as WAV
                with wave.open(temp_file.name, 'wb') as wav_file:
                    wav_file.setnchannels(1)  # Mono
                    wav_file.setsampwidth(2)  # 16-bit
                    wav_file.setframerate(16000)  # 16kHz
                    wav_file.writeframes(processed_audio)

                # Check WAV file properties and save sample for debugging
                try:
                    with wave.open(temp_file.name, 'rb') as wav_check:
                        frames = wav_check.getnframes()
                        sample_rate = wav_check.getframerate()
                        channels = wav_check.getnchannels()
                        sample_width = wav_check.getsampwidth()
                        duration = frames / float(sample_rate)

                        logger.info(f"WAV file created - Frames: {frames}, Rate: {sample_rate}Hz, "
                                   f"Channels: {channels}, Width: {sample_width}, Duration: {duration:.3f}s")

                    # Save first few samples for debugging (only once)
                    debug_file = "debug_audio_sample.wav"
                    if not os.path.exists(debug_file):
                        import shutil
                        shutil.copy2(temp_file.name, debug_file)
                        logger.info(f"Saved audio sample to {debug_file} for debugging")

                        # Also log first few audio samples
                        first_samples = audio_array[:min(20, len(audio_array))]
                        logger.info(f"First 20 audio samples: {first_samples.tolist()}")

                except Exception as e:
                    logger.error(f"WAV file check failed: {e}")

                # Transcribe with Whisper - Use Chinese optimized parameters
                if self.chinese_config:
                    transcribe_params = self.chinese_config["transcription_params"]
                    logger.info(f"🇨🇳 Using Chinese optimized transcription with model: {self.model_name}")
                else:
                    # Fallback parameters
                    transcribe_params = {
                        "beam_size": 5,
                        "best_of": 5,
                        "language": "zh",
                        "task": "transcribe",
                        "condition_on_previous_text": False,
                        "compression_ratio_threshold": 2.4,
                        "log_prob_threshold": -1.0,
                        "no_speech_threshold": 0.6,
                        "initial_prompt": "以下是中文语音转录:",
                        "vad_filter": True,
                        "vad_parameters": dict(min_silence_duration_ms=500)
                    }

                segments, info = self.model.transcribe(temp_file.name, **transcribe_params)

                # Log transcription info
                logger.info(f"Whisper info: language={info.language}, probability={info.language_probability:.3f}")

                # Combine all segments and log each one
                segment_texts = []
                for i, segment in enumerate(segments):
                    logger.info(f"Segment {i}: [{segment.start:.2f}-{segment.end:.2f}s] '{segment.text}' (prob: {segment.avg_logprob:.3f})")
                    segment_texts.append(segment.text.strip())

                transcription = " ".join(segment_texts)
                logger.info(f"Final transcription: '{transcription}' (length: {len(transcription)})")

                # Clean up temp file
                os.unlink(temp_file.name)

                return transcription.strip()

        except Exception as e:
            logger.error(f"Transcription error: {e}")
            return self._mock_transcription(audio_data)

    def _mock_transcription(self, audio_data: bytes) -> str:
        """Mock transcription for testing"""
        try:
            audio_array = np.frombuffer(audio_data, dtype=np.int16)
            if len(audio_array) == 0:
                return ""

            # Simple voice activity detection
            rms = np.sqrt(np.mean(audio_array.astype(float) ** 2))
            if rms < 500:
                return ""

            # Use Chinese test phrases from config if available
            if CHINESE_CONFIG_AVAILABLE and CHINESE_TEST_PHRASES:
                mock_phrases = CHINESE_TEST_PHRASES
            else:
                mock_phrases = [
                    "这是一个语音识别系统的测试。",
                    "实时转录功能运行正常。",
                    "检测到语音活动并已处理。",
                    "音频成功转换为文字。",
                    "语音转文字系统运行正常。",
                    "中文语音识别正在工作。",
                    "智能语音处理系统已激活。"
                ]

            phrase_index = int(rms / 1000) % len(mock_phrases)
            return mock_phrases[phrase_index]

        except Exception:
            return "Audio processing completed"

    def _decode_audio_data(self, audio_data: bytes, is_first_chunk: bool = True) -> bytes:
        """Decode compressed audio data to raw PCM"""
        try:
            # Always try FFmpeg decoding first for browser audio
            if FFMPEG_AVAILABLE:
                decoded = self._decode_with_ffmpeg(audio_data)
                if decoded:
                    return decoded

            # If FFmpeg fails, check if it's raw PCM
            if self._looks_like_raw_pcm(audio_data):
                logger.info("Data appears to be raw PCM - using directly")
                return audio_data

            # Last resort - assume raw PCM
            logger.warning("Could not decode audio - assuming raw PCM data")
            return audio_data
        except Exception as e:
            logger.error(f"Audio decoding failed: {e}")
            return audio_data

    def _decode_with_ffmpeg_opus(self, audio_data: bytes) -> bytes:
        """Decode Opus audio frames to PCM"""
        try:
            # Create temp file as Opus
            with tempfile.NamedTemporaryFile(suffix=".opus", delete=False) as input_file:
                input_file.write(audio_data)
                input_file.flush()
                input_path = input_file.name

            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as output_file:
                output_path = output_file.name

            try:
                # Decode Opus to PCM
                result = (
                    ffmpeg
                    .input(input_path, f='opus')  # Treat as raw Opus
                    .output(output_path,
                           acodec='pcm_s16le',
                           ac=1,
                           ar=16000,
                           af='loudnorm=I=-16:LRA=11:TP=-1.5',
                           f='wav')
                    .overwrite_output()
                    .run(capture_stdout=True, capture_stderr=True, quiet=False)
                )

                with wave.open(output_path, 'rb') as wav_file:
                    pcm_data = wav_file.readframes(wav_file.getnframes())

                logger.info(f"FFmpeg Opus decoded: {len(audio_data)} -> {len(pcm_data)} bytes")
                return pcm_data

            except ffmpeg.Error:
                # If Opus decoding fails, try as WebM chunk
                return self._decode_with_ffmpeg(audio_data)
            finally:
                try:
                    os.unlink(input_path)
                    os.unlink(output_path)
                except:
                    pass

        except Exception as e:
            logger.error(f"Opus decoding failed: {e}")
            return None

    def _looks_like_webm(self, audio_data: bytes) -> bool:
        """Check if audio data starts with WebM signature"""
        if len(audio_data) < 4:
            return False
        # WebM/Matroska starts with 0x1A 0x45 0xDF 0xA3
        return audio_data.startswith(b'\x1a\x45\xdf\xa3')

    def _looks_like_raw_pcm(self, audio_data: bytes) -> bool:
        """Check if audio data looks like raw PCM rather than compressed"""
        if len(audio_data) < 16:
            return False

        # Check for known compressed format headers
        header = audio_data[:16]

        # WebM/Matroska starts with 0x1A 0x45 0xDF 0xA3
        if header.startswith(b'\x1a\x45\xdf\xa3'):
            return False

        # MediaRecorder chunks often start with 0x1F 0x43 0xB6 0x75 (custom container)
        if header.startswith(b'\x1f\x43\xb6\x75'):
            logger.info("Detected MediaRecorder container format - NOT raw PCM")
            return False

        # MP4 starts with various signatures
        if b'ftyp' in header[:8] or header.startswith(b'\x00\x00\x00'):
            return False

        # OGG starts with 'OggS'
        if header.startswith(b'OggS'):
            return False

        # WAV/RIFF starts with 'RIFF'
        if header.startswith(b'RIFF'):
            return False

        # If none of the above, likely raw PCM
        logger.info("No compressed format signatures detected - assuming raw PCM")
        return True

    def _decode_with_ffmpeg(self, audio_data: bytes) -> bytes:
        """Use FFmpeg to decode compressed audio to PCM"""
        try:
            # Log incoming data info
            logger.info(f"FFmpeg input: {len(audio_data)} bytes")

            # Try to detect format from first few bytes
            header = audio_data[:16] if len(audio_data) >= 16 else audio_data
            logger.info(f"Audio header bytes: {[hex(b) for b in header]}")

            # Determine file extension based on header
            file_ext = ""
            if header.startswith(b'\x1a\x45\xdf\xa3'):  # WebM
                file_ext = ".webm"
                logger.info("Detected WebM format")
            elif b'ftyp' in header[:8]:  # MP4
                file_ext = ".mp4"
            elif header.startswith(b'OggS'):  # OGG
                file_ext = ".ogg"

            # Create temporary input file with proper extension
            with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as input_file:
                input_file.write(audio_data)
                input_file.flush()
                input_path = input_file.name

            # Save first failing chunk for analysis
            debug_chunk_path = "debug_failing_chunk.bin"
            if not os.path.exists(debug_chunk_path) and header.startswith(b'\x1f\x43\xb6\x75'):
                with open(debug_chunk_path, 'wb') as debug_file:
                    debug_file.write(audio_data)
                logger.info(f"Saved failing chunk to {debug_chunk_path} for analysis")

            # Create temporary output file for PCM
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as output_file:
                output_path = output_file.name

            try:
                # First try auto-detection, then specific formats if it fails
                try:
                    # Try with auto format detection first
                    result = (
                        ffmpeg
                        .input(input_path)
                        .output(output_path,
                               acodec='pcm_s16le',  # 16-bit PCM
                               ac=1,                # Mono
                               ar=16000,            # 16kHz sample rate
                               f='wav')             # Force WAV output
                        .overwrite_output()
                        .run(capture_stdout=True, capture_stderr=True, quiet=True)
                    )
                except ffmpeg.Error:
                    # If auto-detection fails, try different approaches
                    try:
                        # Try as WebM specifically
                        result = (
                            ffmpeg
                            .input(input_path, f='webm')
                            .output(output_path,
                                   acodec='pcm_s16le',
                                   ac=1,
                                   ar=16000,
                                   f='wav')
                            .overwrite_output()
                            .run(capture_stdout=True, capture_stderr=True, quiet=True)
                        )
                    except ffmpeg.Error:
                        # Try as Matroska
                        result = (
                            ffmpeg
                            .input(input_path, f='matroska')
                            .output(output_path,
                                   acodec='pcm_s16le',
                                   ac=1,
                                   ar=16000,
                                   f='wav')
                            .overwrite_output()
                            .run(capture_stdout=True, capture_stderr=True, quiet=True)
                        )

                # Read the converted PCM data
                with wave.open(output_path, 'rb') as wav_file:
                    pcm_data = wav_file.readframes(wav_file.getnframes())

                logger.info(f"FFmpeg decoded: {len(audio_data)} -> {len(pcm_data)} bytes")
                return pcm_data

            except ffmpeg.Error as e:
                logger.error(f"FFmpeg error - stdout: {e.stdout}")
                logger.error(f"FFmpeg error - stderr: {e.stderr}")
                return None

            finally:
                # Clean up temporary files
                try:
                    os.unlink(input_path)
                    os.unlink(output_path)
                except:
                    pass

        except Exception as e:
            logger.error(f"FFmpeg decoding failed: {e}")
            import traceback
            logger.error(f"Full traceback: {traceback.format_exc()}")
            return None

class ConnectionManager:
    def __init__(self):
        self.active_connections: Set[WebSocket] = set()
        self.transcription_service = WhisperTranscriptionService()
        # Track if we're in the middle of a recording session per websocket
        self.recording_sessions: Dict[WebSocket, bool] = {}
        # Track chunk count per session
        self.chunk_counts: Dict[WebSocket, int] = {}

    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections.add(websocket)
        logger.info(f"Client connected. Total connections: {len(self.active_connections)}")

    def disconnect(self, websocket: WebSocket):
        self.active_connections.discard(websocket)
        self.recording_sessions.pop(websocket, None)
        self.chunk_counts.pop(websocket, None)
        logger.info(f"Client disconnected. Total connections: {len(self.active_connections)}")

    async def send_personal_message(self, message: dict, websocket: WebSocket):
        try:
            await websocket.send_text(json.dumps(message))
        except Exception as e:
            logger.error(f"Error sending message: {e}")
            self.disconnect(websocket)

    async def process_audio(self, audio_data: bytes, websocket: WebSocket):
        """Process audio data and send transcription"""
        try:
            # Track chunk count for this websocket
            chunk_count = self.chunk_counts.get(websocket, 0)
            is_first_chunk = (chunk_count == 0)
            self.chunk_counts[websocket] = chunk_count + 1

            # Run transcription in thread to avoid blocking
            loop = asyncio.get_event_loop()
            transcription = await loop.run_in_executor(
                None,
                self.transcription_service.transcribe_audio,
                audio_data,
                is_first_chunk
            )

            if transcription:
                response = {
                    'type': 'transcription',
                    'text': transcription,
                    'timestamp': asyncio.get_event_loop().time()
                }
                await self.send_personal_message(response, websocket)
                logger.info(f"Transcribed: {transcription}")

        except Exception as e:
            logger.error(f"Audio processing error: {e}")
            error_response = {
                'type': 'error',
                'message': f'Audio processing failed: {str(e)}'
            }
            await self.send_personal_message(error_response, websocket)

# Initialize FastAPI app
app = FastAPI(title="Speech-to-Text Server")

# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")

# Connection manager
manager = ConnectionManager()

@app.get("/")
async def get_index():
    """Serve main page"""
    return FileResponse("templates/index.html")

@app.get("/test")
async def get_test():
    """Serve test page"""
    return FileResponse("templates/test.html")

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await manager.connect(websocket)

    try:
        while True:
            # Receive data from client
            data = await websocket.receive()

            if 'text' in data:
                # Handle JSON messages
                try:
                    message = json.loads(data['text'])
                    message_type = message.get('type', '')

                    if message_type == 'audio':
                        # Handle base64 encoded audio data
                        audio_b64 = message.get('data', '')
                        if audio_b64:
                            audio_data = base64.b64decode(audio_b64)
                            await manager.process_audio(audio_data, websocket)

                    elif message_type == 'start_recording':
                        manager.recording_sessions[websocket] = True
                        manager.chunk_counts[websocket] = 0  # Reset chunk count
                        response = {'type': 'status', 'message': 'Recording started'}
                        await manager.send_personal_message(response, websocket)

                    elif message_type == 'stop_recording':
                        manager.recording_sessions[websocket] = False
                        manager.chunk_counts[websocket] = 0  # Reset chunk count
                        response = {'type': 'status', 'message': 'Recording stopped'}
                        await manager.send_personal_message(response, websocket)

                except json.JSONDecodeError as e:
                    error_response = {'type': 'error', 'message': f'Invalid JSON: {str(e)}'}
                    await manager.send_personal_message(error_response, websocket)

            elif 'bytes' in data:
                # Handle binary audio data directly
                audio_data = data['bytes']
                await manager.process_audio(audio_data, websocket)

    except WebSocketDisconnect:
        manager.disconnect(websocket)
    except Exception as e:
        logger.error(f"WebSocket error: {e}")
        manager.disconnect(websocket)

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy",
        "whisper_available": WHISPER_AVAILABLE,
        "active_connections": len(manager.active_connections)
    }

@app.get("/model-info")
async def model_info():
    """Get information about the loaded model"""
    model_info = {
        "whisper_available": WHISPER_AVAILABLE,
        "chinese_config_available": CHINESE_CONFIG_AVAILABLE,
        "model_loaded": manager.transcription_service.model is not None,
        "model_name": manager.transcription_service.model_name,
        "language": "Chinese (zh)",
        "optimization": "Chinese language optimized"
    }

    if manager.transcription_service.chinese_config:
        model_info["configuration"] = {
            "transcription_params": manager.transcription_service.chinese_config["transcription_params"],
            "audio_config": manager.transcription_service.chinese_config["audio_config"],
            "chinese_settings": manager.transcription_service.chinese_config["chinese_settings"]
        }

    return model_info

def run_server(host="0.0.0.0", port=8000, ssl=False):
    """Run the server"""
    protocol = "https" if ssl else "http"
    ws_protocol = "wss" if ssl else "ws"

    # Get model information
    model_info = "Mock transcription (install faster-whisper for AI)"
    if WHISPER_AVAILABLE and manager.transcription_service.model is not None:
        model_name = manager.transcription_service.model_name or "unknown"
        model_info = f"Whisper Model: {model_name} (中文模式)"

    print(f"""
🎤 中文语音识别服务器 / Chinese Speech-to-Text Server
====================================================

✓ FastAPI backend with WebSocket support
✓ {model_info}
✓ Real-time Chinese audio processing
✓ Browser microphone integration
✓ Production-ready architecture
✓ {'HTTPS enabled (SSL/TLS)' if ssl else 'HTTP mode'}

Server: {protocol}://{host}:{port}
WebSocket: {ws_protocol}://{host}:{port}/ws
Health: {protocol}://{host}:{port}/health
Language: Chinese (zh)

Press Ctrl+C to stop
""")

    uvicorn_config = {
        "app": app,
        "host": host,
        "port": port,
        "log_level": "info",
        "access_log": True
    }

    if ssl:
        uvicorn_config.update({
            "ssl_keyfile": "key.pem",
            "ssl_certfile": "cert.pem"
        })

    uvicorn.run(**uvicorn_config)

if __name__ == "__main__":
    import argparse

    # Ensure we're in the right directory
    script_dir = os.path.dirname(os.path.abspath(__file__))
    os.chdir(script_dir)

    parser = argparse.ArgumentParser(description='Speech-to-Text Server')
    parser.add_argument('--ssl', action='store_true', help='Enable HTTPS/SSL')
    parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
    parser.add_argument('--port', type=int, default=8000, help='Port to bind to')

    args = parser.parse_args()

    run_server(host=args.host, port=args.port, ssl=args.ssl)