import itertools

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.dates as mdates

import pandas
import analyzer


def _colorblind_colors(numcolors=9):
    return itertools.cycle(['#377eb8', '#ff7f00', '#4daf4a',
                            '#f781bf', '#a65628', '#984ea3',
                            '#999999', '#e41a1c', '#dede00',
                            '#000000'][:numcolors])


def _graphlabel(encoding, tweak, compression):
    label = ""
    if encoding == "normal":
        if tweak == "full":
            label = "Baseline"
        elif tweak == "nullall":
            label = "NULL Coding"
    elif encoding == "deltacoded":
        if tweak == "min":
            label = "Min-length Delta Coding"
        elif tweak == "16compact":
            label = "16-bit Compact Delta Coding"
        else:
            label = tweak + "-bit Delta Coding"

    label += ", " + compression
    return label


"""
Assumes the APDUs passed in have already been combined into a single
set by an earlier call to analyzer.collapse_dataframe.
"""
def scatterplot_total_set(apdus,
                          encodings=["normal", "deltacoded"],
                          compressions=["uncompressed", "compressed"]):
    for encoding in encodings:
        encoding_types = apdus[encoding].columns.get_level_values(0).unique()
        for encoding_type in encoding_types:
            fig = plt.figure(num="combined {} encoding, type {}".format(encoding, encoding_type))
            x = apdus[[("power_use", "", "", "")]]
            colors = _colorblind_colors(len(compressions))
            for compression in compressions:
                y = apdus[[(encoding, encoding_type, compression, "length")]]
                plt.scatter(x, y, label="{}, {}, {}".format(encoding, encoding_type, compression), color=next(colors), alpha=0.3)

            plt.legend()
    plt.show()


def scatterplot_customer_set(data,
                             klantnum,
                             encodings=[("normal", "full"),
                                        ("normal", "nullall"),
                                        ("deltacoded", "min"),
                                        ("deltacoded", "16"),
                                        ("deltacoded", "16compact")],
                             compressions=["uncompressed", "compressed"],
                             combined=False):
    # Scatterplot the compressed APDUs.
    fig = plt.figure(num=klantnum + " year")
    x = data[[("power_use", "", "", "")]]

    # Use colourblind-proof cycle
    colors = _colorblind_colors(len(encodings) if not combined else (len(encodings) * len(compressions)))
    for compression in ["uncompressed", "compressed"]:
        ax = plt
        for encoding, tweak in encodings:
            y = data[[(encoding, tweak, compression, "length")]]
            ax.scatter(x, y, label=_graphlabel(encoding, tweak, compression), color=next(colors), alpha=0.3)
        leg = ax.legend()
        leg.set_draggable(True)
        ax.ylabel("Data length (bytes)")
        ax.xlabel("Daily power consumption (Wh)")

        if not combined:
            plt.show()

    if combined:
        plt.show()


def lineplot_customer_year(data,
                           klant,
                           encodings=[("normal", "full", "compressed"),
                                      ("normal", "nullall", "compressed"),
                                      ("deltacoded", "16compact", "compressed"),
                                      ("deltacoded", "16", "compressed"),
                                      ("deltacoded", "min", "compressed"),
                                      ("deltacoded", "min", "uncompressed"),
                                      ("deltacoded", "16compact", "uncompressed")]):
    legendplots = []
    fig, ax1 = plt.subplots(num=klant + " entire year APDU length & power use")
    ax1.set_ylabel("Data length")
    ax1.set_xlabel("Date")
    x = data.index
    first = True

    # Use colourblind-proof cycle
    colors = _colorblind_colors(len(encodings)+1)
    firstcolor = next(colors)
    for ytype in encodings:
        y = data[[ytype + ("length",)]]
        if not first:
            ax1 = ax1.twinx()
        else:
            first = False
        #ax1.set_title(ytype)

        (legplot,) = ax1.plot(x, y, label=_graphlabel(*ytype[:3]), linewidth=0.5, color=next(colors))
        legendplots.append(legplot)
        if (len(encodings) > 1):
            ax1.yaxis.set_ticks([])

    ax2 = ax1.twinx()
    y = data[[("power_use", "", "", "")]]
    color = firstcolor
    ax2.set_ylabel("Power use (Wh)", color=color)
    (legplot,) = ax2.plot(x, y, label="Power use (Wh)", linestyle="dashed", linewidth=2, color=color)
    legendplots.append(legplot)
    ax2.tick_params(axis="y", labelcolor=color)

    fig.tight_layout()
    leg = fig.legend(
        handles=legendplots,
        labels=[x.get_label() for x in legendplots]
    )
    leg.set_draggable(True)
    plt.show()


def lineplot_power_consumption_day(profiles,
                                   klant,
                                   date):
    xformatter = mdates.DateFormatter('%H:%M')
    fig, ax = plt.subplots()
    day = profiles[klant].loc[profiles.index.date==date]
    x = day.index
    y = day["measurements"]
    ax.xaxis.set_major_formatter(xformatter)
    ax.plot(x, y)
    plt.show()


def overviewplot(data,
                 profiles,
                 klant):
    # Scatterplot the compressed APDUs.
    fig = plt.figure(num=klant + " year")
    x = data[[("power_use", "", "", "")]]

    # Normal encoding using minimum-sized integers
    for compression, row in zip(["uncompressed", "compressed"], [0, 1]):
        ax = plt.subplot2grid((10, 3), (row*5, 0), rowspan=5)
        for nulldate in ["full", "nulldate", "nullall"]:
            y = data[[("normal", nulldate, compression, "length")]]
            ax.scatter(x, y, label="normal, " + nulldate + ", " + compression, alpha=0.3)
        ax.legend()

    for compression, row in zip(["uncompressed", "compressed"], [0, 1]):
        ax = plt.subplot2grid((10, 3), (row*5, 1), rowspan=5)
#            for deltacoding in ["min", "8", "16", "32"]:
# FIXME extend to all encodings again
        for deltacoding in ["min", "16", "16compact"]:
            y = data[[("deltacoded", deltacoding, compression, "length")]]
            ax.scatter(x, y, label="deltacoded, " + deltacoding + ", " + compression, alpha=0.3)
        ax.legend()

    # Add the five least and most compressible days as line plots
    data_sorted_length = data.sort_values(by=[("deltacoded","min","compressed","length")])
    indices_small = data_sorted_length.head(5).index
    indices_large = data_sorted_length.tail(5).index

    xformatter = mdates.DateFormatter('%H:%M')
    for date, row in zip(indices_small, range(0,5)):
        ax = plt.subplot2grid((10,3), (row,2))
        day = profiles[klant].loc[profiles.index.date==date]
        x = day.index
        y = day["measurements"]
        ax.xaxis.set_major_formatter(xformatter)
        ax.plot(x, y)

    for date, row in zip(indices_large, range(0,5)):
        ax = plt.subplot2grid((10,3), (row+5,2))
        day = profiles[klant].loc[profiles.index.date==date]
        x = day.index
        y = day["measurements"]
        ax.xaxis.set_major_formatter(xformatter)
        ax.plot(x, y)

    plt.show()


"""
WIP, nothing of value here yet.
"""
def weekdayplots(data, klant):
    # TODO: boxplots.
    #
    # Lineplot the APDU lengths and power use for weeks overlapping
    weekdays = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
    fig, axes = plt.subplots(3, 3, num=klant + " weekly APDU length & power use")
    x = weekdays

    colors = _colorblind_colors(2)
    lencol = next(colors)
    usecol = next(colors)
    for week in range(2, 53):
        if (week-2) % 6 == 0:
            ax1 = axes[(week-2) // 18, ((week-2) // 6) % 3]
            ax1.set_title(str(week))
            ax1.set_xlabel("Day of week")
            ax1.set_ylabel("Data length", color=lencol)
            ax2 = ax1.twinx()
            ax2.set_ylabel("Power use (Wh)", color=usecol)

        weekdata = data.loc[data.index.isocalendar().week==week]
        y = weekdata[[("deltacoded", "min", "compressed", "length")]]
        ax1.plot(x, y, color=lencol)
        ax1.tick_params(axis="y", labelcolor=lencol)

        y = weekdata[[("power_use", "", "", "")]]
        ax2.plot(x, y, color=usecol)
        ax2.tick_params(axis="y", labelcolor=usecol)
    plt.show()


"""
WIP, nothing of value here yet.
"""
def day_distribution_plot(data, klant):
    # TODO: boxplots, other compressions
    #
    # Scatterplot the distribution of days on a yearly and monthly basis.
    fig = plt.figure(num=klant + " day-distribution")

    colors = _colorblind_colors(7)
    weekdays = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
    for day in range(0,7):
        daily = data.loc[data.index.dayofweek==day]

        x = daily[[("power_use", "", "", "")]]
        y = daily[[("deltacoded", "min", "compressed", "length")]]
        plt.scatter(x, y, label=weekdays[day], alpha=0.3, color=next(colors))

    plt.legend()
    plt.show()

    fig, axes = plt.subplots(3, 4, num=klant + " monthly day-distribution")
    for month in range(0,12):
        ax = axes[month // 4, month % 4]
        ax.set_title(str(month))
        monthly = data.loc[data.index.month==month+1]
        colors = _colorblind_colors(7)
        for day in range(0,7):
            daily = monthly.loc[monthly.index.dayofweek==day]
            #print(daily)

            x = daily[[("power_use", "", "", "")]]
            y = daily[[("deltacoded", "min", "compressed", "length")]]
            ax.scatter(x, y, label=weekdays[day], alpha=0.3, color=next(colors))

    plt.show()

