package io.scanbot.tools.utils.android

import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.ImageFormat
import android.graphics.Matrix
import android.graphics.Rect
import android.graphics.YuvImage
import android.media.Image
import io.scanbot.tools.utils.kotlin.isEven
import io.scanbot.tools.utils.kotlin.reduceToEven
import java.io.ByteArrayOutputStream

/** Util methods to process camera preview images. */
object ImageUtils {

    /** Converts the image byte array in NV21 format to [android.graphics.Bitmap]. */
    @JvmStatic
    fun convertNV21ToBitmap(nv21Image: ByteArray, imageWidth: Int, imageHeight: Int, imageOrientation: Int = 0): Bitmap {
        val outputStream = ByteArrayOutputStream()
        outputStream.use { os ->
            val yuvImage = YuvImage(nv21Image, ImageFormat.NV21, imageWidth, imageHeight, null)

            yuvImage.compressToJpeg(Rect(0, 0, imageWidth, imageHeight), 100, os)
            val previewByteArray = os.toByteArray()
            var resultBitmap = BitmapFactory.decodeByteArray(previewByteArray, 0, previewByteArray.size)
            if (imageOrientation != 0) {
                val matrix = Matrix()
                matrix.postRotate(imageOrientation.toFloat())
                resultBitmap = Bitmap.createBitmap(resultBitmap, 0, 0, imageWidth, imageHeight, matrix, false)
            }
            return resultBitmap
        }
    }

    /** Converts the image in YUV_420_888 to NV21 format. */
    @JvmStatic
    fun convertYUV420toNV21(image: Image): ByteArray {
        val crop = image.cropRect
        val format = image.format
        val width = crop.width()
        val height = crop.height()
        val planes = image.planes
        val data = ByteArray(width * height * ImageFormat.getBitsPerPixel(format) / 8)
        val rowData = ByteArray(planes[0].rowStride)
        var channelOffset = 0
        var outputStride = 1
        for (i in planes.indices) {
            when (i) {
                0 -> {
                    channelOffset = 0
                    outputStride = 1
                }
                1 -> {
                    channelOffset = width * height + 1
                    outputStride = 2
                }
                2 -> {
                    channelOffset = width * height
                    outputStride = 2
                }
            }
            val buffer = planes[i].buffer
            val rowStride = planes[i].rowStride
            val pixelStride = planes[i].pixelStride
            val shift = if (i == 0) 0 else 1
            val w = width shr shift
            val h = height shr shift
            buffer.position(rowStride * (crop.top shr shift) + pixelStride * (crop.left shr shift))
            for (row in 0 until h) {
                var length: Int
                if (pixelStride == 1 && outputStride == 1) {
                    length = w
                    buffer[data, channelOffset, length]
                    channelOffset += length
                } else {
                    length = (w - 1) * pixelStride + 1
                    buffer[rowData, 0, length]
                    for (col in 0 until w) {
                        data[channelOffset] = rowData[col * pixelStride]
                        channelOffset += outputStride
                    }
                }
                if (row < h - 1) {
                    buffer.position(buffer.position() + rowStride - length)
                }
            }
        }
        return data
    }

    fun convertBitmapToYuv(bitmap: Bitmap, requiredWidth: Int, requiredHeight: Int): ByteArray {
        check(requiredWidth.isEven() && requiredHeight.isEven()) { "Please provide even width and height!" }

        val argb = IntArray(requiredWidth * requiredHeight)
        bitmap.getPixels(argb, 0, requiredWidth, 0, 0, requiredWidth, requiredHeight)

        val yuv = ByteArray((requiredWidth * requiredHeight * 3 / 2.0).toInt())
        encodeYUV420SP(yuv, argb, requiredWidth, requiredHeight)
        bitmap.recycle()
        return yuv
    }

    /** Get 100% guaranteed even sizes of this [Bitmap]. If initially odd - takes width and/or height reduced by 1.
     *
     * @return pair with [Pair.first] as `width` and [Pair.second] as height
     */
    fun Bitmap.getEvenSizes(): Pair<Int, Int> = Pair(this.width.reduceToEven(), this.height.reduceToEven())

    private fun encodeYUV420SP(yuv420sp: ByteArray, argb: IntArray, width: Int, height: Int) {
        val frameSize = width * height
        var yIndex = 0
        var uvIndex = frameSize
        var index = 0
        for (j in 0 until height) {
            for (i in 0 until width) {
                val r = argb[index] and 0xff0000 shr 16
                val g = argb[index] and 0xff00 shr 8
                val b = argb[index] and 0xff shr 0

                val y = (66 * r + 129 * g + 25 * b + 128 shr 8) + 16
                val u = (-38 * r - 74 * g + 112 * b + 128 shr 8) + 128
                val v = (112 * r - 94 * g - 18 * b + 128 shr 8) + 128

                yuv420sp[yIndex++] = (if (y < 0) 0 else if (y > 255) 255 else y).toByte()
                if (j % 2 == 0 && index % 2 == 0) {
                    yuv420sp[uvIndex++] = (if (v < 0) 0 else if (v > 255) 255 else v).toByte()
                    yuv420sp[uvIndex++] = (if (u < 0) 0 else if (u > 255) 255 else u).toByte()
                }
                index++
            }
        }
    }
}
