// Copyright 2023 Descript, Inc

/* eslint-disable no-bitwise */
// Written by @anthony

import type { Block, DemuxResult, Segment, TrackInfo } from './types';
import { TrackType } from './types';
import { DescriptError, ErrorCategory } from '@descript/errors';

export class DemuxMKV {
    private clusterTimecode = 0;
    private EBML_ID = 0x1a45dfa3;
    private TRACK_ID = 0x1654ae6b;
    private TRACK_NUMBER_ID = 0xd7;
    private TRACK_TYPE_ID = 0x83;
    private CODEC_ID = 0x86;
    private CLUSTER_ID = 0x1f43b675;
    private SIMPLEBLOCK_ID = 0xa3;
    private VIDEO_SETTINGS_ID = 0xe0;
    private AUDIO_SETTINGS_ID = 0xe1;
    private PIXEL_WIDTH_ID = 0xb0;
    private PIXEL_HEIGHT_ID = 0xba;
    private UNKNOWN_LENGTH = 0x100000000000000;
    private SAMPLE_RATE_ID = 0xb5;
    private CHANNELS_ID = 0x9f;
    private BIT_DEPTH_ID = 0x6264;
    private MAX_AUDIO_CLUSTER_DURATION = 4000;

    private recordingStopped = false;
    private audioOnly = true;

    private buffer: Blob | undefined;
    private readySegments: Segment[] = [];

    private currentSegment: Partial<Segment> = {
        frameCount: 0,
    };
    private previousSegment: Partial<Segment> = {};
    private segmentNumber = 1;

    private trackForDuration = 0;
    private trackDurationType = TrackType.UNKNOWN;
    private tracks: Map<number, TrackInfo> = new Map();
    private blocks: Block[] = [];

    private clusterBytes: Blob | undefined;
    private savedPosition = 0;
    private clusterSlicePoint = 0;
    private firstByte: Blob | undefined = undefined;

    private lastPromise = Promise.resolve();

    public stop(): void {
        this.recordingStopped = true;
    }

    public async demux(blob: Blob): Promise<DemuxResult> {
        return await this.enqueue(() => this.demuxInternal(blob));
    }

    public async segment(blob: Blob): Promise<Segment[]> {
        return await this.enqueue(() => this.segmentInternal(blob));
    }

    private demuxInternal = async (blob: Blob): Promise<DemuxResult> => {
        this.blocks = [];

        const { data } = await this.extractInit(blob);
        this.buffer = data;
        if (data.size) {
            await this.parseData(true);
        }

        return {
            headerInfo: {
                tracks: this.tracks,
            },
            blocks: this.blocks,
        };
    };

    private segmentInternal = async (blob: Blob): Promise<Segment[]> => {
        this.readySegments = [];

        // Handle chrome sending 1 byte because timeslice is too small
        if (blob.size === 1) {
            this.firstByte = blob;
            return [];
        } else if (this.firstByte) {
            blob = new Blob([this.firstByte, blob]);
            this.firstByte = undefined;
        }

        // Split blob into header and data
        const { header, data } = await this.extractInit(blob);
        if (header) {
            this.readySegments.push({
                data: header,
                duration: 0,
                startTimecode: 0,
                isInit: true,
                number: 0,
                tracks: this.tracks,
            });
        }

        // Append to data buffer
        this.buffer = new Blob([this.buffer ?? new Blob(), data]);

        // Read duration from data blocks
        await this.parseData();

        // Flush the buffer into the last segment on stop
        if (this.recordingStopped) {
            this.newSegment(this.buffer);
            this.segmentReady();
            this.recordingStopped = false;
        }

        return this.readySegments;
    };

    private extractInit = async (
        blob: Blob,
    ): Promise<{ header: Blob | undefined; data: Blob }> => {
        const buf = (await this.blobToArrayBuffer(blob)) as ArrayBuffer;
        let position = 0;

        const ebmlId = this.readVint(buf, 0, true);
        if (ebmlId.value !== this.EBML_ID) {
            return { header: undefined, data: blob };
        }
        try {
            let elId;
            while ((elId = this.readVint(buf, position, true).value) !== this.CLUSTER_ID) {
                if (elId === this.TRACK_ID) {
                    position = this.parseTracks(buf, position);
                } else {
                    position = this.skipOverElement(buf, position);
                }
                if (position === buf.byteLength) {
                    break;
                }
            }
        } catch (e) {
            const err = e as Error;
            err.message = 'Error parsing init segment: ' + err.message;
            throw err;
        }
        return { header: blob.slice(0, position), data: blob.slice(position) };
    };

    private parseData = async (saveBlocks = false): Promise<void> => {
        if (!this.buffer) {
            throw new DescriptError('No buffer data', ErrorCategory.Recording);
        }
        const data = (await this.blobToArrayBuffer(this.buffer)) as ArrayBuffer;
        let position = this.savedPosition;

        while (position < data.byteLength - 5) {
            const elId = this.readVint(data, position, true);
            if (elId.value === this.CLUSTER_ID) {
                this.newSegment(this.buffer.slice(this.clusterSlicePoint, position));
                this.clusterSlicePoint = position;
                position = this.parseCluster(data, position);
                this.clusterBytes = this.buffer.slice(this.clusterSlicePoint, position);
            } else if (elId.value === this.SIMPLEBLOCK_ID) {
                position = this.parseSimpleBlock(data, position, saveBlocks);
            } else {
                throw new DescriptError(
                    `Invalid element ID: ${elId.value.toString(16)} at ${position}/${
                        data.byteLength
                    }`,
                    ErrorCategory.Recording,
                );
            }
            this.savedPosition = position;
        }
        this.buffer = this.buffer.slice(this.clusterSlicePoint);
        this.savedPosition = this.savedPosition - this.clusterSlicePoint;
        this.clusterSlicePoint = 0;
    };

    private parseCluster(data: ArrayBuffer, position: number): number {
        position = this.readIntoElement(data, position).dataStart; // Read cluster ID

        const timecode = this.readIntoElement(data, position); // Read the timecode
        position = timecode.dataStart;
        this.clusterTimecode = this.readInt(data, position, timecode.length);
        position += timecode.length;

        return position;
    }

    private parseSimpleBlock(data: ArrayBuffer, position: number, saveBlocks = false): number {
        const block = this.readIntoElement(data, position); // Read simple block ID
        position = block.dataStart;
        const blockEnd = position + block.length;

        const blockTrack = this.readVint(data, position);

        // Check if this is a video track
        const trackInfo = this.tracks.get(blockTrack.value);
        if (trackInfo?.type === TrackType.VIDEO) {
            // Increment frame count for video blocks
            this.currentSegment.frameCount = (this.currentSegment.frameCount || 0) + 1;
        }

        if (blockTrack.value !== this.trackForDuration) {
            return blockEnd;
        }

        position += blockTrack.length;
        const blockTimecode = this.readInt(data, position, 2);
        const timecode = blockTimecode + this.clusterTimecode;
        position += 2;

        if (saveBlocks) {
            const blockFlags = data.slice(position, position + 1);
            // Keyframe is bit 0x80 of the flags
            const keyframeMask = new Uint8Array(blockFlags)[0]! & 0x80;
            const keyframe = keyframeMask === 0x80;

            const blockData: Block = {
                trackId: blockTrack.value,
                timecode,
                timecodeBase: 1000,
                keyframe,
                data: data.slice(position + 2, blockEnd),
            };
            this.blocks.push(blockData);
        }

        if (!this.currentSegment.startTimecode) {
            // Save startTimecode and finalize previous segment
            const tcode = blockTimecode + this.clusterTimecode;
            this.currentSegment.startTimecode = tcode;

            // tentatively assign duration here. For screen share sometimes the final segment
            // does not get a duration from the else block, so this covers us in that case
            this.currentSegment.duration =
                blockTimecode + this.clusterTimecode - this.currentSegment.startTimecode;

            if (this.previousSegment.startTimecode) {
                this.previousSegment.duration = tcode - this.previousSegment.startTimecode;
                this.segmentReady();
            }
        } else {
            this.currentSegment.duration =
                blockTimecode + this.clusterTimecode - this.currentSegment.startTimecode;

            // Split audio only clusters if they are too long
            if (
                this.audioOnly &&
                this.currentSegment.duration > this.MAX_AUDIO_CLUSTER_DURATION &&
                !this.recordingStopped &&
                this.buffer &&
                this.clusterBytes
            ) {
                this.newSegment(
                    this.buffer.slice(this.clusterSlicePoint, blockEnd),
                    this.clusterBytes,
                );
                this.clusterSlicePoint = blockEnd;
            }
        }
        return blockEnd;
    }

    private parseVideoSettings(
        data: ArrayBuffer,
        position: number,
    ): { position: number; videoSettings: { width: number; height: number } } {
        const videoSettings = this.readIntoElement(data, position);
        position = videoSettings.dataStart;
        const videoSettingsEnd = position + videoSettings.length;
        let width = 0,
            height = 0;
        while (position < videoSettingsEnd) {
            const elId = this.readVint(data, position, true);
            if (elId.value === this.PIXEL_WIDTH_ID) {
                const pixelWidth = this.readIntoElement(data, position, elId);
                position = pixelWidth.dataStart;
                width = this.readInt(data, position, pixelWidth.length);
                position += pixelWidth.length;
            } else if (elId.value === this.PIXEL_HEIGHT_ID) {
                const pixelHeight = this.readIntoElement(data, position, elId);
                position = pixelHeight.dataStart;
                height = this.readInt(data, position, pixelHeight.length);
                position += pixelHeight.length;
            } else {
                position = this.skipOverElement(data, position);
            }
        }
        return {
            position: videoSettingsEnd,
            videoSettings: { width, height },
        };
    }

    private parseAudioSettings(
        data: ArrayBuffer,
        position: number,
    ): {
        position: number;
        audioSettings: { sampleRate: number; channels: number; bitDepth: number };
    } {
        const audioSettings = this.readIntoElement(data, position);
        position = audioSettings.dataStart;
        const audioSettingsEnd = position + audioSettings.length;
        let sampleRate = 0,
            channels = 0,
            bitDepth = 0;
        while (position < audioSettingsEnd) {
            const elId = this.readVint(data, position, true);
            if (elId.value === this.SAMPLE_RATE_ID) {
                const sampleRateId = this.readIntoElement(data, position, elId);
                position = sampleRateId.dataStart;
                sampleRate = this.readFloat(data, position, sampleRateId.length);
                position += sampleRateId.length;
            } else if (elId.value === this.CHANNELS_ID) {
                const channelsId = this.readIntoElement(data, position, elId);
                position = channelsId.dataStart;
                channels = this.readInt(data, position, channelsId.length);
                position += channelsId.length;
            } else if (elId.value === this.BIT_DEPTH_ID) {
                const bitDepthId = this.readIntoElement(data, position, elId);
                position = bitDepthId.dataStart;
                bitDepth = this.readInt(data, position, bitDepthId.length);
                position += bitDepthId.length;
            } else {
                position = this.skipOverElement(data, position);
            }
        }
        return {
            position: audioSettingsEnd,
            audioSettings: { sampleRate, channels, bitDepth },
        };
    }

    private parseTracks(data: ArrayBuffer, position: number): number {
        const tracksId = this.readVint(data, position, true);
        position += tracksId.length;

        const tracksLength = this.readVint(data, position);
        position += tracksLength.length;

        const tracksEnd = position + tracksLength.value;
        while (position < tracksEnd) {
            position = this.saveTrackData(data, position);
        }
        return tracksEnd;
    }

    private saveTrackData = (data: ArrayBuffer, position: number): number => {
        const trackEntryId = this.readIntoElement(data, position);
        position = trackEntryId.dataStart;
        const trackEnds = position + trackEntryId.length;
        let trackNumber, trackType, videoSettings, audioSettings;
        let codec = 'UNKNOWN';

        while (position < trackEnds) {
            const elId = this.readVint(data, position, true);
            if (elId.value === this.TRACK_NUMBER_ID) {
                const trackNumId = this.readIntoElement(data, position, elId);
                position = trackNumId.dataStart;
                trackNumber = this.readInt(data, position, trackNumId.length);
                position += trackNumId.length;
            } else if (elId.value === this.TRACK_TYPE_ID) {
                const trackTypeId = this.readIntoElement(data, position, elId);
                position = trackTypeId.dataStart;
                trackType = this.readInt(data, position, trackTypeId.length);
                position += trackTypeId.length;
            } else if (elId.value === this.CODEC_ID) {
                const codecId = this.readIntoElement(data, position, elId);
                position = codecId.dataStart;
                const codecRaw = new Uint8Array(
                    data.slice(position, position + codecId.length),
                );
                codec = this.uint8ArrToString(codecRaw);
                position += codecId.length;
            } else if (elId.value === this.VIDEO_SETTINGS_ID) {
                ({ position, videoSettings } = this.parseVideoSettings(data, position));
            } else if (elId.value === this.AUDIO_SETTINGS_ID) {
                ({ position, audioSettings } = this.parseAudioSettings(data, position));
            } else {
                position = this.skipOverElement(data, position);
            }
        }

        if (trackNumber === undefined || trackType === undefined) {
            //TODO: handle warning cases and emit warning and try to bail out gracefully with the data up until this point
            // console.warn('Missing track EBML info');
        } else {
            this.tracks.set(trackNumber, {
                type: trackType,
                codec,
                ...videoSettings,
                ...audioSettings,
            });

            if (trackType === TrackType.VIDEO) {
                this.audioOnly = false;
            }

            // Use first video track, else first audio track for duration
            if (!this.trackForDuration) {
                this.trackForDuration = trackNumber;
                this.trackDurationType = trackType;
            } else if (
                trackType === TrackType.VIDEO &&
                (this.trackDurationType === TrackType.UNKNOWN ||
                    this.trackDurationType === TrackType.AUDIO)
            ) {
                this.trackForDuration = trackNumber;
                this.trackDurationType = TrackType.VIDEO;
            }
        }

        return trackEnds;
    };

    private newSegment = (data: Blob, clusterBytes?: Blob): void => {
        // Keep the cluster bytes if we have them
        if (this.currentSegment.data) {
            this.currentSegment.data = new Blob([this.currentSegment.data, data]);
        } else {
            this.currentSegment.data = data;
        }

        this.previousSegment = this.currentSegment;
        this.currentSegment = {
            isInit: false,
            frameCount: 0,
        };

        if (clusterBytes) {
            this.currentSegment.data = clusterBytes;
        }
    };

    private segmentReady = (): void => {
        this.previousSegment.number = this.segmentNumber;
        this.readySegments.push(this.previousSegment as Segment);
        this.segmentNumber += 1;
    };

    private skipOverElement = (buf: ArrayBuffer, position: number): number => {
        const elId = this.readVint(buf, position, true);
        position += elId.length;
        const elLength = this.readVint(buf, position);
        if (elLength.value === this.UNKNOWN_LENGTH) {
            position += elLength.length;
        } else {
            position += elLength.length + elLength.value;
        }
        return position;
    };

    private readIntoElement = (
        buf: ArrayBuffer,
        position: number,
        elId = this.readVint(buf, position, true),
    ): { dataStart: number; length: number } => {
        position += elId.length;
        const elLength = this.readVint(buf, position);
        position += elLength.length;
        return {
            dataStart: position,
            length: elLength.value,
        };
    };

    private readFloat = (buf: ArrayBuffer, offset: number, length: number): number => {
        const dv = new DataView(buf, offset, length);
        if (length === 4) {
            // Read as 32-bit float (Float32), big-endian
            return dv.getFloat32(0, false);
        } else if (length === 8) {
            // Read as 64-bit float (Float64), big-endian
            return dv.getFloat64(0, false);
        } else {
            throw new DescriptError(`Invalid float length: ${length}`, ErrorCategory.Recording);
        }
    };

    private readInt = (buf: ArrayBuffer, offset: number, length: number): number => {
        return this.uint8ArrToDec(new Uint8Array(buf.slice(offset, offset + length)));
    };

    private readVint = (
        buf: ArrayBuffer,
        offset: number,
        keepMarker = false,
    ): { value: number; length: number } => {
        const dv = new DataView(buf, offset);

        // Get the number of bytes for this VINT
        let bytes = 1;
        let leading = dv.getUint8(0);
        while (!(leading & 0x80)) {
            bytes++;
            leading = (leading << 1) & 0xff;
        }

        // Get rid of the 1 bit length marker
        if (!keepMarker) {
            leading &= 0x7f;
        }
        leading >>>= bytes - 1;

        // Add the variable part of the VINT
        const value = new Uint8Array(bytes);
        value[0] = leading;
        for (let i = 1; i < bytes; i++) {
            value[i] = dv.getUint8(i);
        }
        return { value: this.uint8ArrToDec(value), length: bytes };
    };

    private blobToArrayBuffer(blob: Blob): Promise<ArrayBuffer | string> {
        return new Promise((resolve, reject) => {
            const reader = new FileReader();
            reader.onload = function (event) {
                if (event.target?.result) {
                    resolve(event.target.result);
                }
            };
            reader.onerror = function (error) {
                reject(error);
            };
            reader.readAsArrayBuffer(blob);
        });
    }

    private uint8ArrToDec(uint8Array: Uint8Array): number {
        if (!uint8Array.length) {
            throw new DescriptError(
                'Tried to convert empty array to decimal',
                ErrorCategory.Recording,
            );
        }
        return uint8Array.reduce((acc, val) => acc * 256 + val, 0);
    }

    private uint8ArrToString(uint8Array: Uint8Array): string {
        const decoder = new TextDecoder();
        return decoder.decode(uint8Array);
    }

    private enqueue<T>(task: () => Promise<T>): Promise<T> {
        let taskResolve: (value: T | PromiseLike<T>) => void = () => {
            // noop
        };
        let taskReject: () => void = () => {
            // noop
        };

        const taskPromise = new Promise<T>((resolve, reject) => {
            taskResolve = resolve;
            taskReject = reject;
        });

        this.lastPromise = this.lastPromise.then(
            () => task().then(taskResolve).catch(taskReject),
            taskReject,
        );

        return taskPromise;
    }
}
