import { Badge, Button } from '@/cove-ui';
import {
  Select,
  SelectContent,
  SelectItem,
  SelectTrigger,
  SelectValue,
} from '@/cove-ui/Select';
import {
  useGQLItemTypesQuery,
  useGQLPoliciesWithModelsQuery,
  useGQLSpotTestModelLazyQuery,
  type GQLFieldType,
  type GQLScalarType,
} from '@/graphql/generated';
import { assertUnreachable } from '@/utils/misc';
import { gql } from '@apollo/client';
import { Input } from 'antd';
import { useEffect, useState } from 'react';
import { useParams } from 'react-router-dom';

import DashboardHeader from '../components/DashboardHeader';
import ComponentLoading from '@/components/common/ComponentLoading';

import { spotTestForbiddenFieldTypes } from '../rules/info/RuleTestModal';

const { TextArea } = Input;

gql`
  query SpotTestModel(
    $modelVersion: Int!
    $modelId: ID!
    $item: SpotTestItemInput!
  ) {
    spotTestModel(modelVersion: $modelVersion, modelId: $modelId, item: $item) {
      ... on ModelExecutionSuccessResult {
        score
      }
      ... on ModelExecutionUnscoreableResult {
        _
      }
      ... on ModelExecutionErrorResult {
        _
      }
    }
  }
`;

export default function ModelSpotTest() {
  const [selectedItemTypeId, setSelectedItemTypeId] = useState<
    string | undefined
  >(undefined);
  const [testData, setTestData] = useState<Record<string, string | string[]>>(
    {},
  );
  const [result, setResult] = useState<number | undefined>(undefined);
  const { policyId, modelId } = useParams<{
    policyId: string;
    modelId: string;
  }>();

  const { data, loading } = useGQLPoliciesWithModelsQuery();

  const {
    data: itemTypesData,
    loading: itemTypesLoading,
    error: itemTypesError,
  } = useGQLItemTypesQuery();

  const policy = data?.myOrg?.policies.find((it) => it.id === policyId);
  const model = data?.myOrg?.models.find((it) => it.id === modelId);

  const [
    spotTestModel,
    { data: spotTestData, loading: spotTestLoading, error: spotTestError },
  ] = useGQLSpotTestModelLazyQuery();

  useEffect(() => {
    if (spotTestData) {
      setResult(
        spotTestData.spotTestModel.__typename === 'ModelExecutionSuccessResult'
          ? spotTestData.spotTestModel.score
          : undefined,
      );
    }
  }, [spotTestData]);

  useEffect(() => {
    if (
      itemTypesData?.myOrg?.itemTypes &&
      itemTypesData.myOrg.itemTypes.length > 0
    ) {
      setSelectedItemTypeId(itemTypesData?.myOrg?.itemTypes?.[0]?.id);
    }
  }, [itemTypesData?.myOrg?.itemTypes]);

  if (loading || itemTypesLoading) {
    return <ComponentLoading />;
  }

  if (itemTypesError || !policy || !model) {
    return <div>Error: could not load model testing form</div>;
  }

  const selectedItemType = itemTypesData?.myOrg?.itemTypes?.find(
    (itemType) => itemType.id === selectedItemTypeId,
  );

  const getPlaceholder = (
    fieldType: GQLFieldType,
    containerValueScalarType: GQLScalarType | undefined,
  ): string => {
    switch (fieldType) {
      case 'AUDIO':
        return 'https://test.com/audio.mp3';
      case 'BOOLEAN':
        return 'false';
      case 'GEOHASH':
        return '9q8yy';
      case 'ID':
      case 'NUMBER':
        return '123';
      case 'IMAGE':
        return 'https://test.com/image.jpg';
      case 'VIDEO':
        return 'https://test.com/video.mp4';
      case 'STRING':
        return 'Some text...';
      case 'URL':
        return 'https://test.com';
      case 'USER_ID':
        return 'user-id';
      case 'ARRAY':
        return `${getPlaceholder(
          containerValueScalarType!,
          undefined,
        )}, ${getPlaceholder(containerValueScalarType!, undefined)}`;
      case 'MAP':
      case 'DATETIME':
      case 'RELATED_ITEM':
      case 'POLICY_ID':
        throw new Error('Unsupported field type');
      default:
        assertUnreachable(fieldType);
    }
  };

  return (
    <div className="flex flex-col items-start w-full">
      <DashboardHeader
        title="Spot Test"
        subtitle={`Test model results for ${policy.name}`}
      />
      <div className="flex flex-col w-full gap-4">
        <div className="flex items-center gap-2 font-semibold whitespace-nowrap w-fit">
          Item Type:
          <Select
            onValueChange={(value) => {
              setSelectedItemTypeId(value);
              setResult(undefined);
            }}
            value={selectedItemTypeId}
          >
            <SelectTrigger>
              <SelectValue placeholder="Select item type" />
            </SelectTrigger>
            <SelectContent>
              {itemTypesData?.myOrg?.itemTypes?.map((itemType) => (
                <SelectItem key={itemType.id} value={itemType.id}>
                  {itemType.name}
                </SelectItem>
              ))}
            </SelectContent>
          </Select>
        </div>
        {selectedItemType?.baseFields
          ?.filter((field) => !spotTestForbiddenFieldTypes.includes(field.type))
          .map((field) => (
            <div key={field.name} className="flex flex-col">
              <label className="pb-2 font-medium w-fit min-w-[180px] text-sm">
                {field.name}
                {field.required && <span className="text-red-500"> *</span>}
              </label>
              <TextArea
                className="!rounded-xl"
                placeholder={getPlaceholder(
                  field.type,
                  field.container?.valueScalarType,
                )}
                onChange={(e) => {
                  const value =
                    field.type === 'ARRAY'
                      ? e.target.value.includes(',')
                        ? e.target.value.split(',').map((item) => item.trim())
                        : [e.target.value]
                      : e.target.value;

                  setTestData({
                    ...testData,
                    [field.name]: value,
                  });
                }}
              />
            </div>
          ))}
      </div>
      <div className="flex flex-row items-end justify-between w-full mt-8">
        <div className="flex flex-col gap-2">
          <div className="font-semibold">Model Score Output</div>
          {result && result >= 0.9 ? (
            <Badge variant="soft-destructive" size="lg">
              {result}
            </Badge>
          ) : (
            <Badge variant="secondary" size="lg">
              {result !== undefined
                ? result
                : spotTestError
                ? 'Error retrieving score'
                : 'N/A'}
            </Badge>
          )}
        </div>
        <Button
          size="lg"
          loading={spotTestLoading}
          disabled={(() => {
            if (Object.keys(testData).length === 0) {
              return true;
            }

            const requiredFields = selectedItemType?.baseFields?.filter(
              (field) => field.required,
            );

            const missingFields = requiredFields?.filter(
              (field) =>
                (!testData[field.name] || testData[field.name].length === 0) &&
                !spotTestForbiddenFieldTypes.includes(field.type),
            );

            return missingFields !== undefined && missingFields.length > 0;
          })()}
          onClick={async () => {
            await spotTestModel({
              variables: {
                modelVersion: model.version,
                modelId: model.id,
                item: {
                  data: testData,
                  itemTypeIdentifier: {
                    id: selectedItemTypeId!,
                    version: selectedItemType!.version,
                    schemaVariant: 'ORIGINAL',
                  },
                },
              },
            });
          }}
        >
          Test Model
        </Button>
      </div>
    </div>
  );
}
