import React, { useEffect, useMemo, useState } from "react";
import * as ReactRouterDom from "react-router-dom";
import { toRelativeUrl } from "@okta/okta-auth-js";
import { useOktaAuth } from "@okta/okta-react";
import { OnAuthRequiredFunction } from "@okta/okta-react/bundles/types/OktaContext";
import OktaError from "./OktaError";
import { useDispatch, useSelector } from "react-redux";
import { setAuth } from "./authSlice";
import { Groups, GroupsList } from "./authentication.constant";
import {
	selectAuthenticatedUser,
	selectUserGroupNames,
	selectUserIdentity,
} from "./auth.selector";
import ForbiddenAccess from "../../common/components/ForbiddenAccess";
import { getGepoFeatureFlags } from "../parameters/ParametersPage.thunk";
import { getMe } from "./auth.thunk";
import { RequiresFeatureFlags } from "../parameters/components/RequiresFeatureFlags";

const ProtectedRoute: React.FC<
	{
		onAuthRequired?: OnAuthRequiredFunction;
		errorComponent?: React.ComponentType<{ error: Error }>;
		element: React.ReactElement | null;
		allowedGroups?: GroupsList[];
	} & ReactRouterDom.RouteProps &
		React.HTMLAttributes<HTMLDivElement>
> = ({ onAuthRequired, errorComponent, element, allowedGroups }) => {
	const dispatch = useDispatch();
	const currentUserGroup = useSelector(selectUserGroupNames);
	const currentUser = useSelector(selectAuthenticatedUser);
	const userIdentity = useSelector(selectUserIdentity);
	const {
		oktaAuth,
		authState,
		_onAuthRequired: authRequiredHandler,
	} = useOktaAuth();
	const [loginError, setLoginError] = useState<Error | null>(null);
	const ErrorReporter = errorComponent || (OktaError as React.ComponentType);
	const isAuthenticated = useMemo(
		() => !!authState?.isAuthenticated,
		[authState]
	);
	const isOktaLoading = useMemo(() => !authState, [authState]);

	// Trigger login flow
	useEffect(() => {
		const handleLogin = async () => {
			const originalUri = toRelativeUrl(
				window.location.href,
				window.location.origin
			);

			oktaAuth.setOriginalUri(originalUri);
			const onAuthRequiredFn = onAuthRequired || authRequiredHandler;
			if (onAuthRequiredFn) {
				await onAuthRequiredFn(oktaAuth);
			} else {
				await oktaAuth.signInWithRedirect();
			}
		};
		if (!isOktaLoading && !isAuthenticated) {
			handleLogin().catch((err) => {
				setLoginError(err);
			});
		}
	}, [
		authState,
		oktaAuth,
		isAuthenticated,
		onAuthRequired,
		authRequiredHandler,
	]);

	// after login flow, fetch extra info + sync store
	useEffect(() => {
		if (isAuthenticated && userIdentity) {
			return;
		}
		const getUserInfos = async () => {
			const user = await oktaAuth.getUser();
			let groups = [];
			if (!Array.isArray(user.groups)) {
				groups[0] = user.groups;
			} else {
				groups = [...user.groups];
			}

			const userGroups = groups
				.map((group) => {
					return Object.keys(Groups)[
						Object.values(Groups).indexOf(group as Groups)
					];
				})
				.filter((g) => g);
			let tokenGroups = authState?.accessToken?.claims
				.groups as Array<String>;
			tokenGroups = tokenGroups.slice().sort();
			if (
				tokenGroups.length != groups.length ||
				groups
					.slice()
					.sort()
					.some((value, index) => value !== tokenGroups[index])
			) {
				oktaAuth.token.getWithoutPrompt().then((res) => {
					oktaAuth.tokenManager.setTokens(res.tokens);
					dispatch(
						setAuth({
							token: authState?.accessToken?.accessToken,
							user,
							groups: userGroups,
						})
					);
					dispatch(getMe());
				});
			} else if (!currentUserGroup.length)
				dispatch(
					setAuth({
						token: authState?.accessToken?.accessToken,
						user,
						groups: userGroups,
					})
				);

			dispatch(getGepoFeatureFlags());
			dispatch(getMe());
		};
		if (isAuthenticated) {
			getUserInfos().catch((err) => {
				setLoginError(err);
			});
		}
	}, [authState, userIdentity]);

	if (loginError) {
		return <ErrorReporter error={loginError} />;
	}

	if (!currentUser || !authState || !authState.isAuthenticated) {
		return null;
	}

	if (allowedGroups && !!currentUserGroup.length) {
		const intactGroups = currentUserGroup.map((group) => "gepo-" + group);
		const isAllowed = intactGroups.some((group: any) =>
			allowedGroups.includes(group)
		);
		if (!isAllowed) return <ForbiddenAccess />;
	}

	return <RequiresFeatureFlags>{element}</RequiresFeatureFlags>;
};

export default ProtectedRoute;
