import jwtDecode from 'jwt-decode';
import { array, mixed, object } from 'yup';

import { objKeysToCamelCase } from 'dtg-shared/utils/object';
import { defineSchema, ValidationSchema } from 'dtg-shared/validation';

import { ApiAuthParsedAccessToken } from '../../modules/auth';
import { ApiCompanyEngagementType } from '../../modules/company';
import { ApiUserRole } from '../../modules/user';
import { TokenStorage } from '../TokenStorage';

import { TokenManagerParsedTokenRef, TokenManagerStorageKeys } from './types';

export class TokenManager<ParsedTokenRef extends TokenManagerParsedTokenRef = TokenManagerParsedTokenRef> {
    private readonly storage: TokenStorage;
    private readonly storageKeys: TokenManagerStorageKeys;

    readonly parsedTokenRef: ParsedTokenRef;

    constructor(
        storage: TokenStorage,
        parsedTokenRef: ParsedTokenRef,
        storageKeys: TokenManagerStorageKeys = { access: 'accessToken', refresh: 'refreshToken' },
    ) {
        this.storage = storage;
        this.storageKeys = storageKeys;
        this.parsedTokenRef = parsedTokenRef;

        this.parsedTokenRef.value = this.parsedAccessToken;
    }

    set(accessToken: string | null, refreshToken: string | null): void {
        if (!accessToken || !refreshToken) {
            this.clear();

            return;
        }

        this._accessToken = accessToken;
        this._refreshToken = refreshToken;

        this.parsedTokenRef.value = this.parsedAccessToken;
    }

    clear(): void {
        this._accessToken = null;
        this._refreshToken = null;

        this.parsedTokenRef.value = null;
    }

    private get _accessToken(): string | null {
        return this.storage.getToken(this.storageKeys.access);
    }

    private set _accessToken(token: string | null) {
        this.storage.setToken(this.storageKeys.access, token);
    }

    private get _refreshToken(): string | null {
        return this.storage.getToken(this.storageKeys.refresh);
    }

    private set _refreshToken(token: string | null) {
        this.storage.setToken(this.storageKeys.refresh, token);
    }

    get accessToken(): string | null {
        return this._accessToken;
    }

    get refreshToken(): string | null {
        return this._refreshToken;
    }

    // Used to invalidate old tokens after the encoded claims have changed
    // It is not meant to validate the entire token
    private get _validationSchema(): ValidationSchema {
        return defineSchema({
            claims: object({
                roles: array()
                    .of(mixed().oneOf(Object.values(ApiUserRole)))
                    .required()
                    .min(1),
                company: object({
                    engagement: array()
                        .of(mixed().oneOf(Object.values(ApiCompanyEngagementType)))
                        .required(),
                }),
                engagement: mixed().oneOf(Object.values(ApiCompanyEngagementType)).required(),
            }),
        });
    }

    get parsedAccessToken(): ApiAuthParsedAccessToken | null {
        const { _accessToken } = this;

        if (!_accessToken) {
            return null;
        }

        return this.parseAccessToken(_accessToken);
    }

    parseAccessToken(accessToken: string): ApiAuthParsedAccessToken | null {
        try {
            const parsedToken = objKeysToCamelCase(jwtDecode(accessToken));

            this._validationSchema.validateSync(parsedToken);

            return parsedToken as ApiAuthParsedAccessToken;
        } catch (error) {
            return null;
        }
    }

    get isAccessTokenExpired(): boolean {
        const { parsedAccessToken } = this;

        if (!parsedAccessToken) {
            return false;
        }

        const expirationTime = Math.ceil(Date.now() / 1000);

        return parsedAccessToken.exp <= expirationTime;
    }
}
