import {useMutation, useQueryClient} from '@tanstack/react-query'
import {addDatasetTag, removeDatasetTag} from '..'
import {getDatasetsListKey} from './cache'
import {Dataset} from '@/types/index'

type UndoError = {
  error: Error
  undo: {
    name: string
    op: 'add' | 'remove'
  }
}

export function useDatasetTag(projectId: string, datasetName: string) {
  const queryClient = useQueryClient()

  const addMutation = useMutation({
    mutationKey: ['addDatasetTag', projectId, datasetName],
    mutationFn: ({tagName}: {tagName: string}) => addDatasetTag(projectId, datasetName, tagName),
    onMutate: async ({tagName}) => {
      await queryClient.cancelQueries({
        queryKey: getDatasetsListKey(projectId),
      })
      const previousDatasets = queryClient.getQueryData<Dataset[]>(getDatasetsListKey(projectId))
      queryClient.setQueryData<Dataset[]>(getDatasetsListKey(projectId), (old) => {
        if (!old) return old
        return old.map((dataset) => {
          if (dataset.name === datasetName) {
            return {
              ...dataset,
              tags: [...(dataset.tags || []), {name: tagName, title: tagName}],
            }
          }
          return dataset
        })
      })
      return {previousDatasets}
    },
    onError: (_err, _variables, context) => {
      queryClient.setQueryData(getDatasetsListKey(projectId), context?.previousDatasets)
    },
    onSettled: () => {
      queryClient.invalidateQueries({
        queryKey: getDatasetsListKey(projectId),
      })
    },
  })

  const removeMutation = useMutation({
    mutationKey: ['removeDatasetTag', projectId, datasetName],
    mutationFn: ({tagName}: {tagName: string}) => removeDatasetTag(projectId, datasetName, tagName),
    onMutate: async ({tagName}) => {
      await queryClient.cancelQueries({
        queryKey: getDatasetsListKey(projectId),
      })
      const previousDatasets = queryClient.getQueryData<Dataset[]>(getDatasetsListKey(projectId))
      queryClient.setQueryData<Dataset[]>(getDatasetsListKey(projectId), (old) => {
        if (!old) return old
        return old.map((dataset) => {
          if (dataset.name === datasetName) {
            return {
              ...dataset,
              tags: (dataset.tags || []).filter((t) => t.name !== tagName),
            }
          }
          return dataset
        })
      })
      return {previousDatasets}
    },
    onError: (_err, _variables, context) => {
      queryClient.setQueryData(getDatasetsListKey(projectId), context?.previousDatasets)
    },
    onSettled: () => {
      queryClient.invalidateQueries({
        queryKey: getDatasetsListKey(projectId),
      })
    },
  })

  return {
    addDatasetTag: addMutation.mutate,
    removeDatasetTag: removeMutation.mutate,
    isLoading: addMutation.isPending || removeMutation.isPending,
    error: (addMutation.error || removeMutation.error) as UndoError | null,
  }
}
