#!/usr/bin/env python3

import fileinput
import os
import sys
import csv
import datetime
import math
import pandas
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.dates as mdates

from multiprocessing import Pool, TimeoutError

import lzma

import encoder
import analyzer
import grapher


DEBUG=True
#DEBUG=False


##############################################################################
##############################################################################
###                                                                        ###
###  GENERIC UTILITY FUNCTIONS                                             ###
###                                                                        ###
##############################################################################
##############################################################################

def _debugprint(dataframe, types=True, data=False, memory=False):
    if not DEBUG:
        return
    else:
        with pandas.option_context('display.precision', 2,
                                   'display.multi_sparse', False,
                                   'display.max_rows', None):
            if types:
                print(dataframe.dtypes)
            if data:
                print(dataframe)
            if memory:
                print("Memory usage: {} MiB".format(dataframe.memory_usage(deep=True).sum() / 1024 / 1024))


##############################################################################
##############################################################################
###                                                                        ###
###  LOADING THE ZONNEDAEL DATASET & GENERATING APDUS                      ###
###                                                                        ###
##############################################################################
##############################################################################

def _compress(apdu):
    if pandas.notnull(apdu):
        return lzma.compress(apdu, format=lzma.FORMAT_RAW,
                             check=lzma.CHECK_NONE,
                             filters=[{"id" : lzma.FILTER_LZMA2}])
    else:
        return None


def _measurements_to_apdus(measurements, apdus):
    datetime = measurements.index[0]
    #print(date)

    for klant, frame in measurements.groupby(level=0, axis='columns'):
#        print("\n\n\n= " + klant.upper() + " ===================================================\n")
        apdu = encoder.generate_apdu(frame[klant])
        apdus.at[datetime, (klant, "normal", "full", "uncompressed", "apdu")] = apdu
        apdus.at[datetime, (klant, "normal", "full", "compressed", "apdu")] = _compress(apdu)

        apdu = encoder.generate_apdu(frame[klant], nullcoding="dates")
        apdus.at[datetime, (klant, "normal", "nulldate", "uncompressed", "apdu")] = apdu
        apdus.at[datetime, (klant, "normal", "nulldate", "compressed", "apdu")] = _compress(apdu)

        apdu = encoder.generate_apdu(frame[klant], nullcoding="all")
        apdus.at[datetime, (klant, "normal", "nullall", "uncompressed", "apdu")] = apdu
        apdus.at[datetime, (klant, "normal", "nullall", "compressed", "apdu")] = _compress(apdu)

        apdu = encoder.generate_apdu(frame[klant], nullcoding="dates", deltacoding="min")
        apdus.at[datetime, (klant, "deltacoded", "min", "uncompressed", "apdu")] = apdu
        apdus.at[datetime, (klant, "deltacoded", "min", "compressed", "apdu")] = _compress(apdu)

        apdu = encoder.generate_apdu(frame[klant], nullcoding="dates", deltacoding=8)
        apdus.at[datetime, (klant, "deltacoded", "8", "uncompressed", "apdu")] = apdu
        apdus.at[datetime, (klant, "deltacoded", "8", "compressed", "apdu")] = _compress(apdu)

        apdu = encoder.generate_apdu(frame[klant], nullcoding="dates", deltacoding=16)
        apdus.at[datetime, (klant, "deltacoded", "16", "uncompressed", "apdu")] = apdu
        apdus.at[datetime, (klant, "deltacoded", "16", "compressed", "apdu")] = _compress(apdu)

        apdu = encoder.generate_apdu(frame[klant], nullcoding="dates", deltacoding=16, compactarray=True)
        apdus.at[datetime, (klant, "deltacoded", "16compact", "uncompressed", "apdu")] = apdu
        apdus.at[datetime, (klant, "deltacoded", "16compact", "compressed", "apdu")] = _compress(apdu)

        apdu = encoder.generate_apdu(frame[klant], nullcoding="dates", deltacoding=32)
        apdus.at[datetime, (klant, "deltacoded", "32", "uncompressed", "apdu")] = apdu
        apdus.at[datetime, (klant, "deltacoded", "32", "compressed", "apdu")] = _compress(apdu)

        apdus.at[datetime, (klant, "power_use", "", "", "")] = frame[klant, "measurements"].sum()

    return apdus


def _apdus_for_period_chunked(dataframe, period, startday=1, numdays=365):
    # Unfortunately building the multi-index beforehand is the only way I've
    # come up with to ensure correct dtyping for all columns that might start
    # out with a None
    mi_tuples = []
    for klant, _ in dataframe.groupby(level=0, axis='columns'):
        mi_tuples += [(klant, "normal", "full", "uncompressed", "apdu"),
                      (klant, "normal", "full", "compressed", "apdu"),
                      (klant, "normal", "nulldate", "uncompressed", "apdu"),
                      (klant, "normal", "nulldate", "compressed", "apdu"),
                      (klant, "normal", "nullall", "uncompressed", "apdu"),
                      (klant, "normal", "nullall", "compressed", "apdu"),
                      (klant, "deltacoded", "min", "uncompressed", "apdu"),
                      (klant, "deltacoded", "min", "compressed", "apdu"),
                      (klant, "deltacoded", "8", "uncompressed", "apdu"),
                      (klant, "deltacoded", "8", "compressed", "apdu"),
                      (klant, "deltacoded", "16", "uncompressed", "apdu"),
                      (klant, "deltacoded", "16", "compressed", "apdu"),
                      (klant, "deltacoded", "16compact", "uncompressed", "apdu"),
                      (klant, "deltacoded", "16compact", "compressed", "apdu"),
                      (klant, "deltacoded", "32", "uncompressed", "apdu"),
                      (klant, "deltacoded", "32", "compressed", "apdu")]

#                      (klant, "power_use", "", "", "")]

    mi = pandas.MultiIndex.from_tuples(mi_tuples)
    apdus = pandas.DataFrame(index=pandas.to_datetime([]), columns=mi, dtype=object)

    # We could have done chunking by simply passing individual days to this
    # method and using 365 chunks. However, this increases the copy-overhead
    # drastically both for dispatch and for consolidation, so let's leave it
    # as-is for now.
    for day in range(startday, min(startday+numdays, 366)):
    #for day in range(1,3): # for basic testing
        day_measurements = dataframe.loc[dataframe.index.dayofyear==day]

        for lowhour, highhour in zip(range(0, 25-period, period), range(period, 25, period)):
            measurements = day_measurements.loc[day_measurements.index.hour >= lowhour]
            measurements = measurements.loc[measurements.index.hour < highhour]
            apdus = _measurements_to_apdus(measurements, apdus)

        if DEBUG and day % 10 == 1:
            _debugprint(apdus, types=False, memory=True)
            print("Day {}".format(day))

    return apdus


def _apdus_for_period(dataframe, period):
    # Split year into reasonably equal chunks:
    numchunks = 16
    numdays = math.ceil(365/numchunks)
    with Pool(processes=numchunks) as pool:
        # Process each chunk individually.
        apdus = pool.starmap(_apdus_for_period_chunked, [(dataframe.copy(), period, startday, numdays) for startday in range(1, 366, numdays)])
        pool.close()
        pool.join()
        # starmap returns an ordered list so we can simply concatenate all results
        apdus = pandas.concat(apdus)
        _debugprint(apdus, data=True)

    return apdus


def _generate_zonnedael_apdus(dataframe, period=24):
    apdus = _apdus_for_period(dataframe, period)

    # Add APDU lengths:
    #
    # Make APDU indicator the top-level index, rather than klant
    apdus = apdus.swaplevel(0, 4, axis='columns')
    apdus = apdus.sort_index(level=0, axis='columns')

    # Add the length of the APDU.
    # Takes the apdu indicator index s.t. power_use is ignored.
    # The resulting dataframe only has four levels in the multiindex.
    apdu_lengths = apdus["apdu"].applymap(
        lambda x: len(x) if pandas.notnull(x) else None
    )
    # Combine and add the top-level of the multi-index again.
    apdu_lengths = pandas.concat({"length" : apdu_lengths}, axis='columns')
    apdus = pandas.concat([apdus, apdu_lengths], axis='columns')

    # Swap the klant to top-level index again.
    apdus = apdus.swaplevel(0, 4, axis='columns')
    apdus = apdus.sort_index(level=0, axis='columns')
    # Done adding APDU lengths

    return apdus


def _read_zonnedael_profiles(csvinput):
    with open(csvinput) as zonnedael_input:
        profiles = pandas.read_csv(zonnedael_input,
                                   sep=";",
                                   header=0,
                                   index_col=0,
                                   parse_dates=[0],
                                   dayfirst=True, # Because of course the default should be month first, right. >.<
                                   na_values="#WAARDE!")
    profiles = profiles[profiles.index.notnull()] # filter superfluous rows without time index
    profiles = profiles.drop(columns=["SOM", "leverende klanten", "niet leverenden"]) # filter uninteresting columns
    profiles.columns = pandas.MultiIndex.from_product([profiles.columns, ["measurements"]])
    # Take cumulative sum of measurements over the year for each customer.
    for klant, profile in profiles.groupby(level=0, axis='columns'):
        profiles[klant, "cumsum"] = profile[klant, "measurements"].cumsum()
    profiles = profiles.sort_index(level=0, axis='columns')
    return profiles


def load_zonnedael():
    FILEBASE = "./meterdata"
    MEASUREMENTS = "measurements"
    PICKLEBASE = "zonnedael.pickle"

    apdus_per_period = {}
    csvinput = FILEBASE + "/" + MEASUREMENTS + "/" + "Zonnedael - slimme meter dataset - 2013 - Levering.csv"
    profiles = _read_zonnedael_profiles(csvinput)

    periods = [24]
# For future work on smaller chunks
#    periods = [24, 12, 8, 6, 4, 3, 2, 1]
    missing_periods = []
    for period in periods:
        try:
            apdus_per_period[period] = pandas.read_pickle(FILEBASE + "/" + PICKLEBASE + "." + str(period))
        except FileNotFoundError:
            missing_periods.append(period)

    for period in missing_periods:
        apdus = _generate_zonnedael_apdus(profiles, period=period)
        apdus_per_period[period] = apdus
        apdus_per_period[period].to_pickle(FILEBASE + "/" + PICKLEBASE + "." + str(period))

    return profiles, apdus_per_period


##############################################################################
##############################################################################
###                                                                        ###
###  ANALYSIS FOR THE ZONNEDAEL DATASET                                    ###
###                                                                        ###
##############################################################################
##############################################################################


"""
For every period in the dataset, combine the data of all customers and check
for correlation between power use and length.
"""
def analyze_zonnedael_total_set(apdus_per_period):
    for period in apdus_per_period:
        print("\n\n\n==================== {}-hourly periods: =================================\n\n".format(period))
        apdus = apdus_per_period[period]
        _debugprint(apdus)
        collapsed_apdus = analyzer.collapse_dataframe(apdus)
        _debugprint(collapsed_apdus)

        print(analyzer.correlate(collapsed_apdus))
        grapher.scatterplot_total_set(collapsed_apdus)


def analyze_zonnedael_per_customer(profiles, apdus_per_period,
                                   customers=[17, 46]):
    apdus = apdus_per_period[24]
    if (len(customers) > 0):
        klanten = ["Klant {}".format(x) for x in customers]
    else:
        klanten = None

    for klant, data in apdus.groupby(level=0, axis='columns'):
        if klanten and (klant not in klanten):
            continue

        print("\n\n\n= {} ===================================================\n".format(klant.upper()))
        data = data[klant]
        _debugprint(data)
        correlations = analyzer.correlate(data)
        print("Correlations between power use and APDU length")
        print(correlations)

        print(data.describe())
        print(data.mean())
        print(data.median())

        for ytype in [("deltacoded", "min", "uncompressed"),
                      ("deltacoded", "min", "compressed"),
                      ("deltacoded", "16compact", "compressed"),
                      ("deltacoded", "16", "compressed"),
                      ("normal", "full", "compressed"),
                      ("normal", "nullall", "compressed")]:
            print("Absence indices in encoding {}".format(ytype))
            absencedates = analyzer.find_absences(data[ytype + ("length",)])
            print(absencedates)

        # Overview of the customer
        grapher.overviewplot(data, profiles, klant)

        # Scatterplot the compressed APDUs separately.
        grapher.scatterplot_customer_set(data, klant)
        # Combined into one plot.
        grapher.scatterplot_customer_set(data, klant, combined=True)

        # Lineplot the APDU lengths and power use for the entire year.
        grapher.lineplot_customer_year(data, klant)
        # And separately:
        for ytype in [("deltacoded", "min", "uncompressed"),
                      ("deltacoded", "min", "compressed"),
                      ("deltacoded", "16compact", "compressed"),
                      ("deltacoded", "16", "compressed"),
                      ("normal", "full", "compressed"),
                      ("normal", "nullall", "compressed")]:
            grapher.lineplot_customer_year(data, klant, [ytype])


        # Show the five least and most compressible days as line plots
        data_sorted_length = data.sort_values(by=[("deltacoded","min","compressed","length")])
        indices_small = data_sorted_length.head(5).index
        for date in indices_small:
            grapher.lineplot_power_consumption_day(profiles, klant, date)

        indices_large = data_sorted_length.tail(5).index
        for date in indices_large:
            grapher.lineplot_power_consumption_day(profiles, klant, date)

        # WIP, for future work.
        grapher.weekdayplots(data, klant)
        grapher.day_distribution_plot(data, klant)


if __name__ == "__main__":
    profiles, apdus = load_zonnedael()
    #analyze_zonnedael_total_set(apdus)
    analyze_zonnedael_per_customer(profiles, apdus)

