import { Socket } from "socket.io-client"
import { ModalityDataSource } from "../Types/ModalityDataSource"
import { PageType } from "../Types/PageType"
import { DecodedPixels, SocketResponse } from "../Types/SocketResponse"
import { decode } from "@msgpack/msgpack"
import { medianDifference } from "../Components/Visualizations/math"
import { range } from "lodash"
import { TraceData } from "./TimeSeriesData"

var pageId = 0

export type ModalityPageProperties = {
	socketId: string
	startTime: number
	endTime: number
	width: number
	patientId: string
	timeZoneOffsetMs: number
	modalityDataSources: ModalityDataSource[]
	patientIsAdmitted: boolean
	maxReadTime: number | undefined
	getDataQuerySocket: null | ((id?: string) => Socket)
}

type SingleModalityEdges = [number | undefined, number][]
type MultipleModalityEdges = [number | undefined, number, number][]
type DecodedTimeSeriesData = {
	times: (number | undefined)[]
	data: number[] | (number | null)[][]
}

type CachedRender = {
	bitmap: ImageBitmap
	dirty: boolean // If we need to completely redraw using the original data
	edges: SingleModalityEdges | MultipleModalityEdges
}

export type TimingInformation = {
	gapIndexes: number[] | number[][]
	samplingPeriod: number | null
}

export abstract class Page<DataType, PageProperties extends ModalityPageProperties = ModalityPageProperties> {
	public socketId: string = "?"
	public id: number = pageId++
	public index: number = 0
	public startTime: number = 0
	public endTime: number = 0
	public width: number = 0
	public patientId: string = "?"
	public timeZoneOffsetMs: number = 0
	public modalityDataSources: ModalityDataSource[] = []
	public maxReadTime: number | undefined = undefined // The maximum timestamp we are allowed to read from the file. Useful for live review and analysis cache invalidation.
	public needsToReRequestData: boolean = false // if true, the data needs to be re-requested from the backend

	// TODO: Maybe these all belong in their own big object that can be accessed at once, ModalityMetadata or something like that.
	public data = new Map<number, Map<string, DataType>>() // Data Object Id, then modality.
	public renderCache = new Map<number, Map<string, CachedRender>>() // Data Object Id, then modality.
	public timingInformation = new Map<number, Map<string, TimingInformation>>() // Data Object Id, then modality

	protected patientIsAdmitted: boolean = true
	protected getDataQuerySocket: null | ((id?: string) => Socket) = null

	// Data Object ID : ["ABP", "ICP", ...]
	// When working with multiple data sources, the same "modality" can be loaded from multiple sources.
	protected dataObjectsLoaded = new Map<number, Map<string, boolean>>()
	protected dataObjectsLoading = new Map<number, Map<string, boolean>>()

	// UTC timestamps of the latest read end time that we requested and received from the backend.
	// Used in live review to make sure we don't re-request the same intervals of data in partial page loading.
	private lastRequestedPartialLoadEndTime: number | null = null
	private lastReceivedPartialLoadEndTime: number | null = null

	abstract getType(): PageType

	get loaded() {
		return this.modalityDataSources.every(dataSource => (
			this.dataObjectsLoaded.get(dataSource.dataObjectId)?.get(dataSource.modality) === true
			&& this.data.get(dataSource.dataObjectId)?.get(dataSource.modality) !== undefined
			)
		)
	}

	get loading() {
		return this.modalityDataSources.some(dataSource => (
			this.dataObjectsLoading.get(dataSource.dataObjectId)?.get(dataSource.modality) === true)
		)
	}

	public needsToBeLoaded() {
		return !this.loading && (!this.loaded || this.needsToReRequestData)
	}

	updateProperties(properties: PageProperties) {
		// Clear out data sources that are no longer needed
		const newDataSources = new Set(properties.modalityDataSources.map(dataSource => `${dataSource.dataObjectId}-${dataSource.modality}`))

		this.modalityDataSources.forEach(dataSource => {
			if (!newDataSources.has(`${dataSource.dataObjectId}-${dataSource.modality}`)) {
				this.dataObjectsLoaded.get(dataSource.dataObjectId)?.delete(dataSource.modality)
				this.dataObjectsLoading.get(dataSource.dataObjectId)?.delete(dataSource.modality)
				this.data.get(dataSource.dataObjectId)?.delete(dataSource.modality)
			}
		})

		// Update properties
		Object.assign(this, properties)

		// Update the loading or loaded state for each new data source
		properties.modalityDataSources.forEach(dataSource => {
			// Each data source can be partially loaded, by modality.
			// We only need to request the data for modalities that are unloaded, and that we have not already requested.
			const isModalityLoading = this.dataObjectsLoading.get(dataSource.dataObjectId)?.get(dataSource.modality) === true
			const isModalityLoaded = this.data.get(dataSource.dataObjectId)?.get(dataSource.modality) !== undefined

			if (isModalityLoaded || isModalityLoading) {
				return
			}

			// Initialize the inner map for each data object Id if necessary
			if (!this.dataObjectsLoaded.has(dataSource.dataObjectId)) {
				this.dataObjectsLoaded.set(dataSource.dataObjectId, new Map())
			}

			if (!this.dataObjectsLoading.has(dataSource.dataObjectId)) {
				this.dataObjectsLoading.set(dataSource.dataObjectId, new Map())
			}

			// Init the loaded and loading maps to track partial loaded state.
			const loaded = this.dataObjectsLoaded.get(dataSource.dataObjectId) ?? new Map()
			const loading = this.dataObjectsLoading.get(dataSource.dataObjectId) ?? new Map()

			this.dataObjectsLoaded.set(dataSource.dataObjectId, loaded.set(dataSource.modality, false))
			this.dataObjectsLoading.set(dataSource.dataObjectId, loading.set(dataSource.modality, false))
		})
	}

	socketEventName = () => "render_modalities"

	requestData(batchedDataSources: Map<number, ModalityDataSource[]>, socket: Socket) {
		if (this.getDataQuerySocket === null) {
			return
		}
		
		// Only request data that we haven't already asked for and received
		const readStartTime = this.lastReceivedPartialLoadEndTime 
			? Math.max(this.startTime, this.lastReceivedPartialLoadEndTime) 
			: this.startTime

		// Don't read data that doesn't exist
		const readEndTime = this.maxReadTime 
			? Math.min(this.maxReadTime, this.endTime) 
			: this.endTime

		if (readEndTime <= readStartTime) {
			// Not sure why this happens, but we don't want to request an invalid range of data.
			throw new Error("Data request end time cannot be after the start time.")
		}

		console.log("REQUEST DATA", batchedDataSources)

		batchedDataSources.forEach((dataSources, dataObjectId) => {
			socket.emit(
				this.socketEventName(),
				this.patientId,
				dataObjectId,
				this.id,
				dataSources,
				readStartTime + this.timeZoneOffsetMs,
				readEndTime + this.timeZoneOffsetMs,
				this.width,
				this.patientIsAdmitted
			)
		})

		this.lastRequestedPartialLoadEndTime = readEndTime
		this.needsToReRequestData = false
	}

	receiveSocketResponse(response: SocketResponse, resolve: (value: Page<DataType> | PromiseLike<Page<DataType>>) => void, reject: (reason?: any) => void) {
		// The pixels are encoded in an Array Buffer and we have to reconstruct the JSON object by decoding it.
		// This operation is very fast.
		let decodedPixels: DecodedPixels

		try {
			decodedPixels = decode(response.pixels) as DecodedPixels
		} catch {
			reject("Failed to decode data")
			return
		}

		const receivedModalities = Object.keys(decodedPixels)

		receivedModalities.forEach(modality => {
			const loading = this.dataObjectsLoading.get(response.data_object_id) ?? new Map()
			const loaded = this.dataObjectsLoaded.get(response.data_object_id) ?? new Map()

			this.dataObjectsLoading.set(response.data_object_id, loading.set(modality, false))
			this.dataObjectsLoaded.set(response.data_object_id, loaded.set(modality, true))
		})

		// Create default for data object
		if (!this.data.has(response.data_object_id)) {
			this.data.set(response.data_object_id, new Map())
		}

		if (!this.timingInformation.has(response.data_object_id)) {
			this.timingInformation.set(response.data_object_id, new Map())
		}

		const dataObjectDataMap = this.data.get(response.data_object_id)
		const timingInformationMap = this.timingInformation.get(response.data_object_id)

		if (!dataObjectDataMap || !timingInformationMap) {
			return
		}

		const loadTimeout = setTimeout(() => {
			// Prevents the Promise from hanging indefinitely. Sometimes messages get dropped, but it's not a problem.
			reject("Page load timed out.")
		}, 3000)

		// Parse the response
		Object.entries(decodedPixels).forEach(([modality, [data, compressedTimes]]) => {
			const oldData = dataObjectDataMap.get(modality) as unknown as TraceData | undefined

			// Sometimes the new data is undefined. Do nothing.
			if (!data) {
				return
			}

			// In the case where we are receiving composite data, we have to handle the data length accordingly.
			const dataLength = data.length > 0 && Array.isArray(data[0]) ? data[0].length : data.length
			const times = this.extractTimes(compressedTimes, dataLength)

			this.handleDataConversionInPlace(data, response.conversion_factor)

			// Append new data to the existing data
			const newData: DecodedTimeSeriesData = { data, times }
			const traceData = this.combineTimeSeriesData(oldData, newData)

			// Page Timing Information and metadata
			const definedTimes: number[] = traceData.times.filter(time => time !== undefined) as number[]
			const samplingPeriod = definedTimes.length > 1 ? medianDifference(definedTimes) : null

			dataObjectDataMap.set(modality, traceData as unknown as DataType)

			timingInformationMap.set(modality, {
				gapIndexes: this.getGapIndexes(traceData, samplingPeriod),
				samplingPeriod
			})

			this.updateCompositeGaps(data, modality, timingInformationMap)
		})

		let allDataObjectsLoaded = true

		this.dataObjectsLoaded.forEach((modalityMap, dataObjectId) => {
			modalityMap.forEach((isLoaded, modality) => {
				if (!isLoaded || !this.data.get(dataObjectId)?.has(modality)) {
					allDataObjectsLoaded = false
				}
			})
		})

		if (allDataObjectsLoaded) {
			clearTimeout(loadTimeout)
			this.lastReceivedPartialLoadEndTime = this.lastRequestedPartialLoadEndTime
			resolve(this)
		} else {
			clearTimeout(loadTimeout)
			reject("not all data objects loaded!")
		}
	}

	public unload = () => {
		this.lastRequestedPartialLoadEndTime = null
		this.lastReceivedPartialLoadEndTime = null
		this.needsToReRequestData = false
		this.data.clear()
		this.dataObjectsLoaded.clear()
		this.dataObjectsLoading.clear()
		this.clearRenderCache()
	}

	public clearRenderCache = () => {
		this.renderCache.forEach(dataObjectCache => {
			dataObjectCache.forEach(cache => cache.bitmap.close())
			dataObjectCache.clear()
		})
		this.renderCache.clear()
	}

	public updateRenderCache = (dataObjectId: number, key: string, newValue: CachedRender) => {
		if (!this.renderCache.has(dataObjectId)) {
			this.renderCache.set(dataObjectId, new Map())
		}

		const cachedRender = this.renderCache.get(dataObjectId)?.get(key)

		if (cachedRender) {
			cachedRender.bitmap.close()
		}

		this.renderCache.get(dataObjectId)?.set(key, newValue)
	}

	public load = () => {
		return new Promise<Page<DataType>>((resolve, reject) => {
			if (!this.needsToBeLoaded() || this.width === 0 || this.getDataQuerySocket === null) {
				resolve(this)
				return
			}

			// Data Object ID: [{ modality: "ABP"}, { modality: "ICP" }, { modality: "PRx", onDemandAnalysis: {...} } ...]
			// Making one socket call with all of the modalities is more efficient than sending each one separately.
			// e.g. reading 7 modalities only requires setting everything up one time instead of 7 times.
			const batchedModalitiesByDataObjectId = new Map<number, ModalityDataSource[]>()

			const dataSourcesToLoad = this.modalityDataSources
				.filter(dataSource => {
					const key = this.getModalityDataSourceKey(dataSource)
					return !this.dataObjectsLoaded.get(dataSource.dataObjectId)?.get(key) && !this.dataObjectsLoading.get(dataSource.dataObjectId)?.get(key)
				})

			const socketEventName = this.socketEventName()

			dataSourcesToLoad.forEach(dataSource => {
				// Add to the list of modalities
				const existingList = batchedModalitiesByDataObjectId.get(dataSource.dataObjectId) ?? []
				batchedModalitiesByDataObjectId.set(dataSource.dataObjectId, [...existingList, dataSource ])

				// Update the loading state
				const existingLoadingModalities = this.dataObjectsLoading.get(dataSource.dataObjectId) ?? new Map()
				this.dataObjectsLoading.set(dataSource.dataObjectId, existingLoadingModalities.set(dataSource.modality, true))
			})

			const socket = this.getDataQuerySocket(this.socketId)

			const allListeners: Array<(socketResponse: SocketResponse) => void> = []

			batchedModalitiesByDataObjectId.forEach((_, dataObjectId) => {
				const listener = (socketResponse: SocketResponse) => {
					if (socketResponse.page_id === this.id && socketResponse.data_object_id === dataObjectId) {
						this.receiveSocketResponse(socketResponse, resolve, reject)
						socket.off(socketEventName, listener)
					}
				}

				allListeners.push(listener)
				socket.on(socketEventName, listener)
			})

			const cancelLoad = () => {
				allListeners.forEach(listener => socket.off(socketEventName, listener))
				this.dataObjectsLoading.forEach(modalityMap => {
					modalityMap.forEach((_, modality) => {
						modalityMap.set(modality, false)
					})
				})
				resolve(this)
			}

			socket.on("disconnect", reject)

			try {
				this.requestData(batchedModalitiesByDataObjectId, socket)
			} catch {
				cancelLoad()
			}
		})
	}

	private getModalityDataSourceKey = (dataSource: ModalityDataSource) => `${dataSource.modality}-${dataSource.dataObjectId}`

	// Timestamps are compressed into continuous chunks of timestamps to save bits.
	// This function just de-compresses the timestamps back into a list of the actual timestamp values.
	private extractTimes(compressedTimes: number[][], dataLength: number): Array<number | undefined> {
		const times = new Array<number | undefined>(dataLength).fill(0)
		let timesOffset = 0

		for (const [startTimeMicroseconds, endTimeMicroseconds, numberOfPoints] of compressedTimes) {
			if (startTimeMicroseconds === undefined || endTimeMicroseconds === undefined || numberOfPoints === undefined) {
				continue
			}
			
			const start = startTimeMicroseconds / 1e3 - this.timeZoneOffsetMs
			const end = endTimeMicroseconds / 1e3 - this.timeZoneOffsetMs

			if (numberOfPoints === 1) {
				times[timesOffset++] = end
				continue
			}

			const step = (end - start) / (numberOfPoints - 1)

			for (let i = 0; i < numberOfPoints; i++) {
				times[timesOffset++] = start + i * step
			}
		}

		return times
	}

	// For data compression purposes, data can come over the network as integers, even if the data is floating point.
	// This offers a significant performance gain because we can use fewer bits to represent approximately the same number.
	// The conversion factor is the number that we multiply by to get the floating point number out of the integer.
	private handleDataConversionInPlace(data: number[] | (number | null)[][], conversionFactor: number) {
		if (conversionFactor === 1) {
			return
		}

		switch(this.getType()) {
			case PageType.EEG:
				const eegData = data as number[][]

				for (let i = 0; i < eegData.length; i++) {
					for (let j = 0; j < eegData[i].length; j++) {
						eegData[i][j] *= conversionFactor
					}
				}

				break
			default:
				const timeSeriesData = data as number[]

				for (let i = 0; i < data.length; i++) {
					timeSeriesData[i] *= conversionFactor
				}
		}
	}

	// Composite data can have null values, which can be misinterpreted as values instead of gaps.
	// Instead, those null values should be interpreted as gaps, so we include them in the timing information.
	private updateCompositeGaps(data: number[] | (number | null)[][], modality: string, timingInformationMap: Map<string, TimingInformation>) {
		if (data.length === 0) {
			return
		}

		const isComposite = Array.isArray(data[0])

		if (!isComposite || !timingInformationMap.has(modality)) {
			return
		}

		const compositeData = data as (number | null)[][]
		const gapIndexes: number[] = (timingInformationMap.get(modality)?.gapIndexes ?? []) as number[]
		const compositeGapIndexes: number[][] = range(0, compositeData.length).map(() => [...gapIndexes]) // one for each composite index

		for (let column = 0; column < compositeData.length; column++) {
			for (let row = 0; row < compositeData[column].length; row++) {
				if (compositeData[column][row] === null) {
					compositeGapIndexes[column].push(row)
				}
			}
		}

		const previous = timingInformationMap.get(modality) as TimingInformation
		timingInformationMap.set(modality, { ...previous, gapIndexes: compositeGapIndexes })
	}

	private combineTimeSeriesData(oldData: TraceData | undefined, newData: DecodedTimeSeriesData | undefined): TraceData {
		if (oldData === undefined) {
			oldData = { data: new Float32Array(), times: [] }
		}

		if (newData === undefined || newData.data.length === 0) {
			return oldData
		}

		const isComposite = Array.isArray(newData.data[0])

		if (oldData.data.length === 0) {
			if (isComposite) {
				return {
					times: newData.times,
					data: (newData.data as number[][]).map(column => this.mergeData(new Float32Array(), column as (number | null)[])),
				}
			}

			return {
				times: newData.times,
				data: this.mergeData(new Float32Array(), newData.data as (number | null)[]),
			}
		}

		const data: Float32Array | Float32Array[] = isComposite
			? (oldData.data as Float32Array[]).map((oldColumn, i) => this.mergeData(oldColumn, newData.data[i] as (number | null)[]))
			: this.mergeData(oldData.data as Float32Array, newData.data as (number | null)[])

		return { times: [...oldData.times, ...newData.times], data } as TraceData
	}

	private mergeData(oldColumn: Float32Array | undefined, newColumn: (number | null)[] | undefined): Float32Array {
		if (!oldColumn || !newColumn) {
			return new Float32Array()
		}

		const combinedArray = new Float32Array(oldColumn.length + newColumn.length)

		// Copy old data
		combinedArray.set(oldColumn)

		// Map new data, replacing null with NaN
		for (let j = 0; j < newColumn.length; j++) {
			combinedArray[oldColumn.length + j] = newColumn[j] === null ? NaN : newColumn[j]!
		}

		return combinedArray
	}

	private getGapIndexes(traceData: TraceData, initialSamplingPeriod: number | null): number[] {
		const times = traceData.times

		if (times.length === 0) {
			return []
		}

		const gapIndexes: number[] = []
		
		let samplingPeriod = initialSamplingPeriod
		let previousTime: number | undefined = undefined

		for (let i = 0; i < times.length; i++) {
			const time = times[i]

			if (i > 0) {
				previousTime = times[i - 1]
			}

			if (time === undefined) {
				gapIndexes.push(i)
				continue
			}

			if (samplingPeriod === null && previousTime) {
				samplingPeriod = time - previousTime
			}

			if (samplingPeriod && previousTime && time - previousTime >= 2 * samplingPeriod) {
				gapIndexes.push(i)
				
				// Set the new sample period. We use the next time because we need to skip over the gap.
				if (times[i+1]) {
					samplingPeriod = times[i+1]! - time // ! lets Typescript know we are sure the value is defined.
				}
			}
		}

		return gapIndexes
	}

	public setSocketAccessor(getDataQuerySocket: null | ((id?: string) => Socket)) {
		this.getDataQuerySocket = getDataQuerySocket
	}

	// The maximum read time is the farthest that the page is allowed to read.
	// This is to prevent reading and caching analytics for data that doesn't exist yet.
	public updateMaxReadTime(newMaxReadTimeMilliseconds: number) {
		this.maxReadTime = newMaxReadTimeMilliseconds

		if (this.lastReceivedPartialLoadEndTime !== null && this.lastReceivedPartialLoadEndTime < this.endTime) {
			this.needsToReRequestData = true
		}
	}
}
