/**
 * @author Ashish Dwivedi <ashish.dwivedi@314ecorp.com>
 * @description Store for loading models
 */

import { create } from 'zustand';
import * as tf from '@tensorflow/tfjs';
import { loadModels } from 'utils/docTrUtils';

/** docTR Model Config type */
type ModelStore = {
  detectionModel: tf.GraphModel | null;
  recognitionModel: tf.GraphModel | null;
  modelsLoaded: boolean;
  isLoading: boolean;
  loadModels: () => Promise<void>;
};

export const useModelStore = create<ModelStore>((set, get) => ({
  detectionModel: null,
  recognitionModel: null,
  modelsLoaded: false,
  isLoading: false,

  loadModels: async () => {
    const { isLoading, modelsLoaded } = get();

    if (modelsLoaded || isLoading) return;

    set({ isLoading: true });
    try {
      const { detectionModel, recognitionModel } = await loadModels();

      set({
        detectionModel,
        recognitionModel,
        modelsLoaded: true,
        isLoading: false,
      });
    } catch (error) {
      console.error("Error loading models", error);
      set({
        detectionModel: null,
        recognitionModel: null,
        modelsLoaded: false,
        isLoading: false,
      });
    }
  },
}));
