package io.scanbot.tools.mediafpsplayer

import android.graphics.Bitmap
import android.media.MediaMetadataRetriever
import io.scanbot.sdk.image.ImageRef
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeoutOrNull
import java.io.File
import kotlin.time.DurationUnit
import kotlin.time.toDuration

class VideoProcessorV2(
    input: File,
    fps: Int,
    @FrameExtractionOption
    private val frameExtractionOption: Int = MediaMetadataRetriever.OPTION_NEXT_SYNC,
    private val extractorListener: FrameExtractorListener? = null,
    private val extractionTimeoutMs: Long = EXTRACTION_TIMEOUT_MS,
    private val extractionFailureMaxRetry: Int = EXTRACTION_FAILURE_MAX_RETRY_DEFAULT,
) : VideoProcessor {

    private val fpsPlayer = MediaMetadataRetriever()

    private var videoDuration: Long

    private val oneFrameTimeStep = (1_000.0 / fps).toLong() // ms per frame
    private val totalFrames: Int
    private val videoOrientation: Int

    private var lastFrameIndex: Int = 0
    private var lastFrameTimestamp = 0L

    private var currentFrameIsBad = false
    private var frameExtractFailureCount = 0

    override val hasMoreFrames: Boolean
        get() {
            val moreFrames = lastFrameIndex < totalFrames
            if (moreFrames.not() && extractorListener != null) {
                extractorListener.onNoMoreFrames(
                    "No more frames to extract: video ended (lastFrameTimestamp=$lastFrameTimestamp, " +
                            "totalDuration=$videoDuration, lastFrameIndex=$lastFrameIndex, totalFrames=$totalFrames)."
                )
            }
            return moreFrames
        }

    init {
        fpsPlayer.setDataSource(input.path)

        videoDuration = fpsPlayer.extractMetadata(MediaMetadataRetriever.METADATA_KEY_DURATION)?.toLong()
            ?.toDuration(DurationUnit.MILLISECONDS)?.inWholeMilliseconds
            ?: throw IllegalStateException("Could not retrieve video duration")

        /** '+1' below to include the initial frame (at 0ms):
         * 1_000ms with FPS=3 -> 333ms per frame -> 0ms, 333ms, 666ms, 999ms = 4 frames
         */
        totalFrames = (videoDuration / oneFrameTimeStep).toInt() + 1

        videoOrientation = fpsPlayer.extractMetadata(MediaMetadataRetriever.METADATA_KEY_VIDEO_ROTATION)?.toInt() ?: 0

        extractorListener?.onFrameExtractorInit(
            filePath = input.absolutePath,
            totalDuration = videoDuration,
            oneFrameTimeStep = oneFrameTimeStep,
            totalFrames = totalFrames,
            videoOrientation = videoOrientation,
        )
    }

    override fun nextFrame(): VideoFrame {
        if (hasMoreFrames.not())
            throw IllegalStateException("Video is ended. Check if more frames are available before calling for the next frame!")

        return runBlocking { extractFrameErrorsWrapped() }
    }

    private suspend fun extractFrameErrorsWrapped(): VideoFrame {
        while (frameExtractFailureCount < extractionFailureMaxRetry) {
            try {
                return extractFrame()
            } catch (e: Exception) { // exception already logged at this point - here we only catch to control the execution flow
                if (currentFrameIsBad.not()) {
                    currentFrameIsBad = true
                } else {
                    frameExtractFailureCount++
                    moveToNextFramePosition()
                    currentFrameIsBad = false
                }
            }
        }
        val message = "No more frames to extract: too many errors (errors $frameExtractFailureCount/$extractionFailureMaxRetry, " +
                    "lastFrameTimestamp=$lastFrameTimestamp, totalDuration=$videoDuration, lastFrameIndex=$lastFrameIndex, " +
                    "estimated totalFrames=$totalFrames)."

        extractorListener?.onNoMoreFrames(message)
        throw FrameExtractionException(message)
    }

    private suspend fun extractFrame(): VideoFrame {
        val imageRef: ImageRef? = withTimeoutOrNull(extractionTimeoutMs) {
            extractorListener?.onFrameExtractionStart(lastFrameIndex, totalFrames)

            var bmp: Bitmap? = null
            try {
                // Must be in Microseconds, not milliseconds!
                bmp = fpsPlayer.getFrameAtTime(lastFrameTimestamp * 1000, frameExtractionOption)
                if (bmp != null) {
                    ImageRef.fromBitmap(bmp)
                } else null
            } catch (e: Exception) {
                extractorListener?.onFrameExtractionException(
                    "Frame extraction failed at ${lastFrameTimestamp}ms with some cause.",
                    e,
                )
                throw FrameExtractionException("Frame extraction failed!", e)
            } finally {
               bmp?.recycle()
            }
        }

        if (imageRef == null) {
            val message = "Frame extraction failed at ${lastFrameTimestamp}ms: got null or timeout."
            extractorListener?.onFrameExtractionException(message)
            throw FrameExtractionException(message)
        }

        extractorListener?.onFrameExtractedDebug(imageRef)
        extractorListener?.onFrameExtractionEnd(lastFrameIndex, totalFrames)

        return VideoFrame(imageRef, videoOrientation, lastFrameIndex, lastFrameTimestamp).apply {
           moveToNextFramePosition()
        }
    }

    private fun moveToNextFramePosition() {
        lastFrameIndex++
        lastFrameTimestamp += oneFrameTimeStep
    }

    override fun close() = fpsPlayer.release()

    companion object {
        const val EXTRACTION_TIMEOUT_MS = 3000L
        const val EXTRACTION_FAILURE_MAX_RETRY_DEFAULT = 2
    }
}

class FrameExtractionException(message: String? = null, cause: Throwable? = null) : Exception(message, cause) {

    init {
        if (message == null && cause == null) {
            throw IllegalArgumentException("Both message and cause are null - give at least something!")
        }
    }
}
