/**
 * @author Ashish Dwivedi <ashish.dwivedi@314ecorp.com>
 * @description docTR Utils
 */

import * as tf from '@tensorflow/tfjs';
import _ from 'lodash';

/** The docTR vocabulary, used for CRNN decoding */
const VOCAB =
	'0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_{|}~°£€¥¢฿àâéèêëîïôùûüçÀÂÉÈÊËÎÏÔÙÛÜÇ';

/**
 * Detection + Recognition normalization constants.
 */
const DET_MEAN = 0.785;
const DET_STD = 0.275;

const REC_MEAN = 0.694;
const REC_STD = 0.298;

export const RECO_CONFIG = {
	crnn_vgg16_bn: {
		value: 'crnn_mobilenet_v2',
		label: 'CRNN (MobileNet V2)',
		height: 32,
		width: 128,
		path: 'models/crnn_mobilenet_v2/model.json',
	},
};
export const DET_CONFIG = {
	db_mobilenet_v2: {
		value: 'db_mobilenet_v2',
		label: 'DB (MobileNet V2)',
		height: 512,
		width: 512,
		path: 'models/db_mobilenet_v2/model.json',
	},
};

// No need now, refactor later
const generateColor = () => {
	const colors = [
		'#FF6B6B',
		'#4ECDC4',
		'#45B7D1',
		'#96CEB4',
		'#FFEEAD',
		'#D4A5A5',
		'#9B59B6',
		'#3498DB',
		'#E67E22',
		'#2ECC71',
	];
	return colors[Math.floor(Math.random() * colors.length)];
};

/**
 * Load both detection (DB-based) and recognition (CRNN-based) models.
 * @returns Promise containing both loaded models
 */
export const loadModels = async (): Promise<{
	detectionModel: tf.GraphModel;
	recognitionModel: tf.GraphModel;
}> => {
	const [detectionModel, recognitionModel] = await Promise.all([
		tf.loadGraphModel(DET_CONFIG.db_mobilenet_v2.path),
		tf.loadGraphModel(RECO_CONFIG.crnn_vgg16_bn.path),
	]);

	return {
		detectionModel,
		recognitionModel,
	};
};

/**
 * Prepare a single image for the DB detection model.
 */
export const getImageTensorForDetectionModel = (imageObject: HTMLImageElement, size: [number, number]) => {
	const tensor = tf.browser.fromPixels(imageObject).resizeNearestNeighbor(size).toFloat();

	const mean = tf.scalar(255 * DET_MEAN);
	const std = tf.scalar(255 * DET_STD);

	return tensor.sub(mean).div(std).expandDims();
};

export const createEnlargedImage = async (image: HTMLImageElement) => {
    const enlargedCanvas = document.createElement('canvas');
    const enlargedWidth = 512;
    const enlargedHeight = 512;
    enlargedCanvas.width = enlargedWidth;
    enlargedCanvas.height = enlargedHeight;
    const offsetX = Math.round((enlargedWidth - image.width) / 2);
    const offsetY = Math.round((enlargedHeight - image.height) / 2);
    const enlargedCtx = enlargedCanvas.getContext('2d');
    if (enlargedCtx) {
        enlargedCtx.fillStyle = '#ffffff';
        enlargedCtx.fillRect(0, 0, enlargedWidth, enlargedHeight);
        enlargedCtx.drawImage(image, offsetX, offsetY);
        let imageData = enlargedCtx.getImageData(0, 0, enlargedWidth, enlargedHeight);
        for (let i = 0; i < imageData.data.length; i += 4) {
            const gray = (imageData.data[i] + imageData.data[i + 1] + imageData.data[i + 2]) / 3;
            const contrastFactor = 1.5; 
            const adjustedGray = Math.min(255, Math.max(0, (gray - 128) * contrastFactor + 128));
            imageData.data[i] = imageData.data[i + 1] = imageData.data[i + 2] = adjustedGray;
        }
        enlargedCtx.putImageData(imageData, 0, 0);
    }
    return await new Promise<HTMLImageElement>((resolve) => {
        const img = new Image();
        img.onload = () => resolve(img);
        img.src = enlargedCanvas.toDataURL();
    });
};
/**
 * Prepare crops for the CRNN recognition model.
 */
export const getImageTensorForRecognitionModel = (crops: HTMLImageElement[], size: [number, number]) => {
	const list = crops.map((imageObject) => {
		const h = imageObject.height;
		const w = imageObject.width;

		let resize_target: any;
		let padding_target: any;
		const aspect_ratio = size[1] / size[0];

		if (aspect_ratio * h > w) {
			// Height-dominant scaling
			resize_target = [size[0], Math.round((size[0] * w) / h)];
			padding_target = [
				[0, 0],
				[0, size[1] - Math.round((size[0] * w) / h)],
				[0, 0],
			];
		} else {
			// Width-dominant scaling
			resize_target = [Math.round((size[1] * h) / w), size[1]];
			padding_target = [
				[0, size[0] - Math.round((size[1] * h) / w)],
				[0, 0],
				[0, 0],
			];
		}

		return tf.browser
			.fromPixels(imageObject)
			.resizeNearestNeighbor(resize_target)
			.pad(padding_target, 0)
			.toFloat()
			.expandDims();
	});

	const tensor = tf.concat(list);
	const mean = tf.scalar(255 * REC_MEAN);
	const std = tf.scalar(255 * REC_STD);
	return tensor.sub(mean).div(std);
};

export const getRecognitionResult = async (
	box: any,
	enlargedImageElement: HTMLImageElement,
	recognitionModel: tf.GraphModel,
): Promise<string> => {
	const x1 = box.coordinates[0][0] * DET_CONFIG.db_mobilenet_v2.width;
	const y1 = box.coordinates[0][1] * DET_CONFIG.db_mobilenet_v2.height;
	const x2 = box.coordinates[2][0] * DET_CONFIG.db_mobilenet_v2.width;
	const y2 = box.coordinates[2][1] * DET_CONFIG.db_mobilenet_v2.height;
	const regionWidth = x2 - x1;
	const regionHeight = y2 - y1;

	if (regionWidth <= 0 || regionHeight <= 0) {
		return '';
	}

	// Create a new canvas to draw the crop
	const canvasCrop = document.createElement('canvas');
	canvasCrop.width = regionWidth;
	canvasCrop.height = regionHeight;

	const ctxCrop = canvasCrop.getContext('2d');
	if (ctxCrop) {
		ctxCrop.drawImage(enlargedImageElement, x1, y1, regionWidth, regionHeight, 0, 0, regionWidth, regionHeight);
	}

	const htmlImageElementCrop = await new Promise<HTMLImageElement>((resolve) => {
		const img = new Image();
		img.onload = () => resolve(img);
		img.src = canvasCrop.toDataURL();
	});

	if (!htmlImageElementCrop) {
		return '';
	}

	try {
		const words = await extractWordsFromCrop({
			recognitionModel,
			crops: [htmlImageElementCrop],
			size: [RECO_CONFIG.crnn_vgg16_bn.height, RECO_CONFIG.crnn_vgg16_bn.width],
		});
		return words && words.length > 0 ? words[0] : '';
	} catch (error) {
		console.error(error);
		return '';
	}
};

/**
 * Extract recognized text from one or more crops, using the CRNN recognition model.
 */
export const extractWordsFromCrop = async ({
	recognitionModel,
	crops,
	size,
}: {
	recognitionModel: tf.GraphModel | null;
	crops: HTMLImageElement[];
	size: [number, number];
}) => {
	if (!recognitionModel) return;

	const tensor = getImageTensorForRecognitionModel(crops, size);
	const predictions = await recognitionModel.executeAsync(tensor);
	const probabilities = tf.softmax(predictions as any, -1);
	const bestPath = tf.unstack(tf.argMax(probabilities, -1), 0);
	const blank = 126;

	const words: string[] = [];
	for (const sequence of bestPath) {
		let collapsed = '';
		let added = false;
		const values = sequence.dataSync();
		for (const k of values) {
			if (k === blank) {
				added = false;
			} else if (k !== blank && !added) {
				collapsed += VOCAB[k];
				added = true;
			}
		}
		words.push(collapsed);
	}
	return words;
};

/**
 * Takes an HTMLImageElement => runs detection => draws result on the permanent #heatmap canvas.
 */

export const getHeatMapFromImage = async ({
	detectionModel,
	imageObject,
	size,
}: {
	detectionModel: tf.GraphModel | null;
	imageObject: HTMLImageElement;
	size: [number, number];
}) => {
	const tensor = getImageTensorForDetectionModel(imageObject, size);
	let prediction = await detectionModel?.execute(tensor);
	// @ts-ignore
	prediction = tf.squeeze(prediction, 0);
	if (Array.isArray(prediction)) {
		prediction = prediction[0];
	}
	const heatmapContainer = document.getElementById('heatmap') as HTMLCanvasElement;
	if (heatmapContainer) {
		heatmapContainer.width = imageObject.width;
		heatmapContainer.height = imageObject.height;
		await tf.browser.toPixels(prediction as any, heatmapContainer);
	}
	return prediction;
};

/**
 * Instead of openCV.js we are using normal js to get the bounding boxes
 */
const clamp = (number: number, size: number): number => {
	return Math.max(0, Math.min(number, size));
};

const iou = (box1: any, box2: any): number => {
	const x1 = Math.max(box1[0], box2[0]);
	const y1 = Math.max(box1[1], box2[1]);
	const x2 = Math.min(box1[2], box2[2]);
	const y2 = Math.min(box1[3], box2[3]);

	const intersection = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);

	const box1Area = (box1[2] - box1[0]) * (box1[3] - box1[1]);
	const box2Area = (box2[2] - box2[0]) * (box2[3] - box2[1]);

	const union = box1Area + box2Area - intersection;
	return intersection / union;
};

const nonMaximumSuppression = (boxes: any, iouThreshold: number): any[] => {
	if (boxes.length === 0) return [];
	boxes.sort((a: any, b: any) => a[4] - b[4]);
	const keep: any[] = [];
	while (boxes.length > 0) {
		const currentBox = boxes.shift();
		keep.push(currentBox);

		boxes = boxes.filter((box: any) => {
			return (
				iou([currentBox[0], currentBox[1], currentBox[2], currentBox[3]], [box[0], box[1], box[2], box[3]]) <=
				iouThreshold
			);
		});
	}
	return keep;
};

export const transformBoundingBox = (contour: any, size: [number, number], score: number): any => {
	let offset = (contour.width * contour.height * 1.8) / (2 * (contour.width + contour.height));
	const p1 = clamp(contour.x - offset, size[1]) - 1;
	const p2 = clamp(p1 + contour.width + 2 * offset, size[1]) - 1;
	const p3 = clamp(contour.y - offset, size[0]) - 1;
	const p4 = clamp(p3 + contour.height + 2 * offset, size[0]) - 1;
	return [p1 / size[1], p3 / size[0], p2 / size[1], p4 / size[0], score];
};
function calculateOtsuThresholdOnProbabilities(grayValues: number[]): number {
	// Scale probabilities to 0-255
	const scaledValues = grayValues.map((value) => Math.floor(value * 255));
	const histogram = new Array(256).fill(0);
	for (let gray of scaledValues) {
		histogram[gray]++;
	}
	const totalPixels = scaledValues.length;
	let sum = 0;
	for (let i = 0; i < 256; i++) {
		sum += i * histogram[i];
	}
	let maxVariance = 0;
	let threshold = 0;
	for (let t = 1; t < 255; t++) {
		let weightBackground = 0;
		let sumBackground = 0;
		for (let i = 0; i < t; i++) {
			weightBackground += histogram[i];
			sumBackground += i * histogram[i];
		}
		if (weightBackground == 0) continue;
		let meanBackground = sumBackground / weightBackground;
		let weightForeground = totalPixels - weightBackground;
		if (weightForeground == 0) break;
		let sumForeground = sum - sumBackground;
		let meanForeground = sumForeground / weightForeground;
		let variance = weightBackground * weightForeground * Math.pow(meanBackground - meanForeground, 2);
		if (variance > maxVariance) {
			maxVariance = variance;
			threshold = t;
		}
	}
	// Convert threshold back to probability scale
	const thresholdProbability = threshold / 255;
	return thresholdProbability;
}
export const extractBoundingBoxesFromHeatmap = async (predictionTensor: tf.Tensor, size: [number, number]) => {
	const [height, width] = size;
	const data = await predictionTensor.data();
	const grayValues = new Array(height * width);
	for (let i = 0; i < height * width; i++) {
		grayValues[i] = data[i];
	}
	// Calculate Otsu's threshold on the probability values
	const threshold = calculateOtsuThresholdOnProbabilities(grayValues);
	// Binarize the data
	const binaryData = new Uint8Array(height * width);
	for (let i = 0; i < height * width; i++) {
		binaryData[i] = grayValues[i] > threshold ? 255 : 0;
	}
	// Find connected components
	const regions = findConnectedComponents(binaryData, width, height);
	// Process regions and return bounding boxes (rest of the function remains the same)
	const boundingBoxes = regions
		.filter((r) => r.width > 2 && r.height > 2)
		.map((r) => transformBoundingBox({ x: r.x, y: r.y, width: r.width, height: r.height }, size, 1));
	const nmsBoxes = nonMaximumSuppression(boundingBoxes, 0.5);
	const annotationShape = _.map(nmsBoxes, (box: any, id: number) => {
		const [xmin, ymin, xmax, ymax] = box;
		return {
			id,
			config: {
				stroke: generateColor(),
			},
			coordinates: [
				[xmin, ymin],
				[xmax, ymin],
				[xmax, ymax],
				[xmin, ymax],
			],
			_topLeft: { x: Math.min(xmin, xmax), y: Math.min(ymin, ymax) },
		};
	});
	annotationShape.sort((a, b) => a._topLeft.y - b._topLeft.y);
	const rowGroups: Array<typeof annotationShape> = [];
	let currentRow: typeof annotationShape = [];
	const Y_THRESHOLD = 0.03;
	annotationShape.forEach((shape) => {
		if (currentRow.length === 0) {
			currentRow.push(shape);
		} else {
			const last = currentRow[currentRow.length - 1];
			const deltaY = Math.abs(shape._topLeft.y - last._topLeft.y);
			if (deltaY < Y_THRESHOLD) {
				currentRow.push(shape);
			} else {
				rowGroups.push(currentRow);
				currentRow = [shape];
			}
		}
	});
	if (currentRow.length) rowGroups.push(currentRow);
	rowGroups.forEach((row) => {
		row.sort((a, b) => a._topLeft.x - b._topLeft.x);
	});
	const sortedBoundingBoxes = rowGroups.flat();
	return sortedBoundingBoxes;
};


interface Region {
	x: number;
	y: number;
	width: number;
	height: number;
}

/**
 * Simple 8-connected flood-fill to find distinct "blobs" in binaryData => bounding rects.
 */
const findConnectedComponents = (binaryData: Uint8Array, width: number, height: number): Region[] => {
	const visited = new Set<number>();
	const output: Region[] = [];

	for (let y = 0; y < height; y++) {
		for (let x = 0; x < width; x++) {
			const idx = y * width + x;

			if (binaryData[idx] === 255 && !visited.has(idx)) {
				const region = { minX: x, maxX: x, minY: y, maxY: y };

				const queue: [number, number][] = [[x, y]];
				visited.add(idx);

				while (queue.length > 0) {
					const [px, py] = queue.shift()!;
					const neighbors = [
						[px - 1, py],
						[px + 1, py],
						[px, py - 1],
						[px, py + 1],
						[px - 1, py - 1],
						[px + 1, py - 1],
						[px - 1, py + 1],
						[px + 1, py + 1],
					];

					for (const [nx, ny] of neighbors) {
						if (nx >= 0 && nx < width && ny >= 0 && ny < height) {
							const nIdx = ny * width + nx;
							if (binaryData[nIdx] === 255 && !visited.has(nIdx)) {
								queue.push([nx, ny]);
								visited.add(nIdx);

								region.minX = Math.min(region.minX, nx);
								region.maxX = Math.max(region.maxX, nx);
								region.minY = Math.min(region.minY, ny);
								region.maxY = Math.max(region.maxY, ny);
							}
						}
					}
				}
				output.push({
					x: region.minX,
					y: region.minY,
					width: region.maxX - region.minX + 1,
					height: region.maxY - region.minY + 1,
				});
			}
		}
	}
	return output;
};
