import { Socket } from "socket.io-client"
import { ModalityDataSource } from "../Types/ModalityDataSource"
import { PageType } from "../Types/PageType"
import { DecodedPixels, SocketResponse } from "../Types/SocketResponse"
import { decode } from "@msgpack/msgpack"
import { medianDifference } from "../Components/Visualizations/math"
import { range } from "lodash"

var pageId = 0

export type ModalityPageProperties = {
	socketId: string
	startTime: number
	endTime: number
	width: number
	patientId: string
	timeZoneOffsetMs: number
	modalityDataSources: ModalityDataSource[]
	patientIsAdmitted: boolean
	getDataQuerySocket: null | ((id?: string) => Socket)
}

type SingleModalityEdges = [number | undefined, number][]
type MultipleModalityEdges = [number | undefined, number, number][]

type CachedRender = {
	bitmap: ImageBitmap
	dirty: boolean // If we need to completely redraw using the original data
	edges: SingleModalityEdges | MultipleModalityEdges
}

export type TimingInformation = {
	gapIndexes: number[] | number[][]
	samplingPeriod: number | null
}

export abstract class Page<DataType, PageProperties extends ModalityPageProperties=ModalityPageProperties> {
	public socketId: string = "?"
	public id: number = pageId++
	public index: number = 0
	public startTime: number = 0
	public endTime: number = 0
	public width: number = 0
	public patientId: string = "?"
	public timeZoneOffsetMs: number = 0
	public modalityDataSources: ModalityDataSource[] = []

	// TODO: Maybe these all belong in their own big object that can be accessed at once, ModalityMetadata or something like that.
	public data = new Map<number, Map<string, DataType>>() // Data Object Id, then modality.
	public renderCache = new Map<number, Map<string, CachedRender>>() // Data Object Id, then modality.
	public timingInformation = new Map<number, Map<string, TimingInformation>>() // Data Object Id, then modality
	
	public lastReceivedData: number | null = null // the timestamp when new data was fetched
	
	protected patientIsAdmitted: boolean = true
	protected getDataQuerySocket: null | ((id?: string) => Socket) = null

	// Data Object ID : ["ABP", "ICP", ...]
	// When working with multiple data sources, the same "modality" can be loaded from multiple sources.
	protected dataObjectsLoaded = new Map<number, Map<string, boolean>>()
	protected dataObjectsLoading = new Map<number, Map<string, boolean>>()

	abstract getType(): PageType

	get loaded() {
		return this.modalityDataSources.every(dataSource => (
			this.dataObjectsLoaded.get(dataSource.dataObjectId)?.get(dataSource.modality) === true
			&& this.data.get(dataSource.dataObjectId)?.get(dataSource.modality) !== undefined
			)
		)
	}

	get loading() {
		return this.modalityDataSources.some(dataSource => (
			this.dataObjectsLoading.get(dataSource.dataObjectId)?.get(dataSource.modality) === true)
		)
	}

	updateProperties(properties: PageProperties) {
		// Clear out data sources that are no longer needed
		const newDataSources = new Set(properties.modalityDataSources.map(dataSource => `${dataSource.dataObjectId}-${dataSource.modality}`))

		this.modalityDataSources.forEach(dataSource => {
			if (!newDataSources.has(`${dataSource.dataObjectId}-${dataSource.modality}`)) {
				this.dataObjectsLoaded.get(dataSource.dataObjectId)?.delete(dataSource.modality)
				this.dataObjectsLoading.get(dataSource.dataObjectId)?.delete(dataSource.modality)
				this.data.get(dataSource.dataObjectId)?.delete(dataSource.modality)
			}
		})

		// Update properties
		Object.assign(this, properties)

		// Update the loading or loaded state for each new data source
		properties.modalityDataSources.forEach(dataSource => {
			// Each data source can be partially loaded, by modality.
			// We only need to request the data for modalities that are unloaded, and that we have not already requested.
			const isModalityLoading = this.dataObjectsLoading.get(dataSource.dataObjectId)?.get(dataSource.modality) === true
			const isModalityLoaded = this.data.get(dataSource.dataObjectId)?.get(dataSource.modality) !== undefined

			if (isModalityLoaded || isModalityLoading) {
				return
			}

			// Initialize the inner map for each data object Id if necessary
			if (!this.dataObjectsLoaded.has(dataSource.dataObjectId)) {
				this.dataObjectsLoaded.set(dataSource.dataObjectId, new Map())
			}

			if (!this.dataObjectsLoading.has(dataSource.dataObjectId)) {
				this.dataObjectsLoading.set(dataSource.dataObjectId, new Map())
			}

			// Init the loaded and loading maps to track partial loaded state.
			const loaded = this.dataObjectsLoaded.get(dataSource.dataObjectId) ?? new Map()
			const loading = this.dataObjectsLoading.get(dataSource.dataObjectId) ?? new Map()

			this.dataObjectsLoaded.set(dataSource.dataObjectId, loaded.set(dataSource.modality, false))
			this.dataObjectsLoading.set(dataSource.dataObjectId, loading.set(dataSource.modality, false))
		})
	}

	socketEventName = () => "render_modalities"

	requestData(batchedDataSources: Map<number, ModalityDataSource[]>) {
		if (this.getDataQuerySocket === null) {
			return
		}

		const socket = this.getDataQuerySocket(this.socketId)

		batchedDataSources.forEach((dataSources, dataObjectId) => {
			socket.emit(
				this.socketEventName(), 
				this.patientId, 
				dataObjectId, 
				this.id, 
				dataSources, 
				this.startTime + this.timeZoneOffsetMs, 
				this.endTime + this.timeZoneOffsetMs, 
				this.width, 
				this.patientIsAdmitted,
			)
		})
	}

	receiveSocketResponse(response: SocketResponse, resolve: (value: Page<DataType> | PromiseLike<Page<DataType>>) => void, reject: (reason?: any) => void) {
		// The pixels are encoded in an Array Buffer and we have to reconstruct the JSON object by decoding it.
		// This operation is very fast.
		const decodedPixels = decode(response.pixels) as DecodedPixels
		const receivedModalities = Object.keys(decodedPixels)

		receivedModalities.forEach(modality => {
			const loading = this.dataObjectsLoading.get(response.data_object_id) ?? new Map()
			const loaded = this.dataObjectsLoaded.get(response.data_object_id) ?? new Map()

			this.dataObjectsLoading.set(response.data_object_id, loading.set(modality, false))
			this.dataObjectsLoaded.set(response.data_object_id, loaded.set(modality, true))
		})

		// Create default for data object
		if (!this.data.has(response.data_object_id)) {
			this.data.set(response.data_object_id, new Map())
		}

		if (!this.timingInformation.has(response.data_object_id)) {
			this.timingInformation.set(response.data_object_id, new Map())
		}

		const dataObjectDataMap = this.data.get(response.data_object_id)
		const timingInformationMap = this.timingInformation.get(response.data_object_id)

		if (!dataObjectDataMap || !timingInformationMap) {
			return
		}

		const loadTimeout = setTimeout(() => {
			// Prevents the Promise from hanging indefinitely. Sometimes messages get dropped, but it's not a problem.
			reject("Page load timed out.")
		}, 3000)

		// Parse the response
		Object.entries(decodedPixels).forEach(([modality, [data, compressedTimes]]) => {
			// Sometimes "data" is undefined
			if (!data) {
				dataObjectDataMap.set(modality, { data: new Float32Array(), times: [] } as unknown as DataType)
				return
			}
			
			// In the case where we are receiving composite data, we have to handle the data length accordingly.
			const dataLength = data.length > 0 && Array.isArray(data[0]) ? data[0].length : data.length
			let [times, gapIndexes] = this.extractTimes(compressedTimes, dataLength, this.timeZoneOffsetMs)

			if (dataLength === 1) {
				gapIndexes = []
			}

			this.handleDataConversionInPlace(data, response.conversion_factor)

			const definedTimes: number[] = times.filter(time => time !== undefined) as number[]
			const samplingPeriod = definedTimes.length > 1 ? medianDifference(definedTimes) : null

			dataObjectDataMap.set(modality, { data, times } as unknown as DataType)
			timingInformationMap.set(modality, { gapIndexes, samplingPeriod })

			this.calculateCompositeGaps(data, modality, timingInformationMap)
		})

		let allDataObjectsLoaded = true

		this.dataObjectsLoaded.forEach((modalityMap, dataObjectId) => {
			modalityMap.forEach((isLoaded, modality) => {
				if (!isLoaded || !this.data.get(dataObjectId)?.has(modality)) {
					allDataObjectsLoaded = false
				}
			})
		})

		if (allDataObjectsLoaded) {
			clearTimeout(loadTimeout)
			this.lastReceivedData = new Date(Date.now()).getTime()
			resolve(this)
		} else {
			clearTimeout(loadTimeout)
			reject("not all data objects loaded!")
		}
	}

	public unload = () => {
		this.data.clear()
		this.dataObjectsLoaded.clear()
		this.dataObjectsLoading.clear()
		this.clearRenderCache()
	}

	public unloadWithoutClearingCache = () => {
		this.data.clear()
		
		this.modalityDataSources.forEach(({ dataObjectId, modality }) => {
			this.dataObjectsLoaded.get(dataObjectId)?.set(modality, false)
			this.dataObjectsLoading.get(dataObjectId)?.set(modality, false)
		})
	}

	public clearRenderCache = () => {
		this.renderCache.forEach(dataObjectCache => {
			dataObjectCache.forEach(cache => cache.bitmap.close())
			dataObjectCache.clear()
		})
		this.renderCache.clear()
	}

	public updateRenderCache = (dataObjectId: number, key: string, newValue: CachedRender) => {
		if (!this.renderCache.has(dataObjectId)) {
			this.renderCache.set(dataObjectId, new Map())
		}

		const cachedRender = this.renderCache.get(dataObjectId)?.get(key)

		if (cachedRender) {
			cachedRender.bitmap.close()
		}

		this.renderCache.get(dataObjectId)?.set(key, newValue)
	}

	public load = () => {
		return new Promise<Page<DataType>>((resolve, reject) => {
			if (this.loaded || this.width === 0 || this.getDataQuerySocket === null) {
				resolve(this)
				return
			}

			// Data Object ID: [{ modality: "ABP"}, { modality: "ICP" }, { modality: "PRx", onDemandAnalysis: {...} } ...]
			// Making one socket call with all of the modalities is more efficient than sending each one separately.
			// e.g. reading 7 modalities only requires setting everything up one time instead of 7 times.
			const batchedModalitiesByDataObjectId = new Map<number, ModalityDataSource[]>()

			const dataSourcesToLoad = this.modalityDataSources
				.filter(dataSource => {
					const key = this.getModalityDataSourceKey(dataSource)
					return !this.dataObjectsLoaded.get(dataSource.dataObjectId)?.get(key) && !this.dataObjectsLoading.get(dataSource.dataObjectId)?.get(key)
				})

			const socketEventName = this.socketEventName()

			dataSourcesToLoad.forEach(dataSource => {
				// Add to the list of modalities
				const existingList = batchedModalitiesByDataObjectId.get(dataSource.dataObjectId) ?? []
				batchedModalitiesByDataObjectId.set(dataSource.dataObjectId, [...existingList, dataSource ])

				// Update the loading state
				const existingLoadingModalities = this.dataObjectsLoading.get(dataSource.dataObjectId) ?? new Map()
				this.dataObjectsLoading.set(dataSource.dataObjectId, existingLoadingModalities.set(dataSource.modality, true))
			})

			const socket = this.getDataQuerySocket(this.socketId)

			batchedModalitiesByDataObjectId.forEach((_, dataObjectId) => {
				const listener = (socketResponse: SocketResponse) => {
					if (socketResponse.page_id === this.id && socketResponse.data_object_id === dataObjectId) {
						this.receiveSocketResponse(socketResponse, resolve, reject)
						socket.off(socketEventName, listener)
					}
				}
	
				socket.on(socketEventName, listener)
			})

			this.requestData(batchedModalitiesByDataObjectId)
		})
	}

	private getModalityDataSourceKey = (dataSource: ModalityDataSource) => `${dataSource.modality}-${dataSource.dataObjectId}`

	// Timestamps are compressed into continuous chunks of timestamps to save bits.
	// This function just de-compresses the timestamps back into a list of the actual timestamp values.
	private extractTimes(compressedTimes: number[][], dataLength: number, timeZoneOffsetMs: number): [Array<number | undefined>, Array<number>] {
		const times = new Array<number | undefined>(dataLength).fill(0)
		let timesOffset = 0
		const gapIndexes = []
	
		for (const [startTime, endTime, numberOfPoints] of compressedTimes) {
			if (startTime === undefined || endTime === undefined || numberOfPoints === undefined) {
				continue
			}

			const start = startTime / 1e3 - timeZoneOffsetMs
			const end = endTime / 1e3 - timeZoneOffsetMs
			const step = (end - start) / (numberOfPoints - 1)
	
			for (let i = 0; i < numberOfPoints; i++) {
				times[timesOffset++] = start + i * step
			}

			gapIndexes.push(timesOffset)
		}

		return [times, gapIndexes]
	}

	// For data compression purposes, data can come over the network as integers, even if the data is floating point.
	// This offers a significant performance gain because we can use fewer bits to represent approximately the same number.
	// The conversion factor is the number that we multiply by to get the floating point number out of the integer.
	private handleDataConversionInPlace(data: Float32Array | Float32Array[], conversionFactor: number) {
		if (conversionFactor === 1) {
			return
		}

		switch(this.getType()) {
			case PageType.EEG:
				const eegData = data as Float32Array[]
				
				for (let i = 0; i < eegData.length; i++) {
					for (let j = 0; j < eegData[i].length; j++) {
						eegData[i][j] *= conversionFactor
					}
				}
				
				break
			default:
				const timeSeriesData = data as Float32Array

				for (let i = 0; i < data.length; i++) {
					timeSeriesData[i] *= conversionFactor
				}
		}
	}

	// Composite data can have null values, which can be misinterpreted as values instead of gaps.
	// Instead, those null values should be interpreted as gaps, so we include them in the timing information.
	private calculateCompositeGaps(data: Float32Array | Float32Array[], modality: string, timingInformationMap: Map<string, TimingInformation>) {
		if (data.length === 0) {
			return
		}

		const isComposite = Array.isArray(data[0])

		if (!isComposite || !timingInformationMap.has(modality)) {
			return
		}

		const compositeData = data as Float32Array[]
		const gapIndexes: number[] = (timingInformationMap.get(modality)?.gapIndexes ?? []) as number[]
		const compositeGapIndexes: number[][] = range(0, compositeData.length).map(() => [...gapIndexes]) // one for each composite index

		for (let column = 0; column < compositeData.length; column++) {
			for (let row = 0; row < compositeData[column].length; row++) {
				if (compositeData[column][row] === null) {
					compositeGapIndexes[column].push(row)
				}
			}
		}

		const previous = timingInformationMap.get(modality) as TimingInformation
		timingInformationMap.set(modality, {...previous, gapIndexes: compositeGapIndexes})
	}

	public setSocketAccessor(getDataQuerySocket: null | ((id?: string) => Socket)) {
		this.getDataQuerySocket = getDataQuerySocket
	}
}
