
from wrf import to_np, getvar, smooth2d, latlon_coords, interplevel, get_cartopy
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import cartopy.feature as cfeature
import cartopy.crs as ccrs
from matplotlib.colors import LinearSegmentedColormap, ListedColormap


from tools import total_prec, cartopy_lim_buffered, make_fig_and_ax, add_title_and_description

def slp_map(t, nc, prec_prev, output_path, plevels=np.arange(950, 1050, 5)):
#
    slp        = getvar(nc, "slp")
    mid_clouds = getvar(nc, "mid_cloudfrac")
    prec       = total_prec(nc, prec_prev)
#
    # Smooth the sea level pressure since it tends to be noisy near the
    # mountains
    smooth_slp = smooth2d(slp, 5, cenweight=4)
    # Get the latitude and longitude points
    lats, lons = latlon_coords(slp)
#
    # Get the WRF projection (CRS) object
    crs = get_cartopy(slp)
#
    # Create a figure
    fig, ax, ax_cbarv, ax_cbarh = make_fig_and_ax(plt, crs, add_horiz_cbar=True)
#
    # Draw the contours and filled contours
    cmap  = LinearSegmentedColormap.from_list("", ["white", "silver"])
    cmap2 = matplotlib.cm.viridis_r
#
    con   = ax.contourf(to_np(lons), to_np(lats), to_np(mid_clouds), 
                        levels=[1/10, 1/4, 1/2, 3/4, 1], cmap=cmap, zorder=2,
                        transform=ccrs.PlateCarree())
    con2  = ax.contourf(to_np(lons), to_np(lats), to_np(prec), 
                        levels=[0.5, 1, 2, 5, 10, 25, 50], cmap=cmap2, zorder=2, extend='max',
                        transform=ccrs.PlateCarree())
#
    linewidths = np.ones(plevels.shape)
    linewidths[np.where(plevels == 1015)[0]] = 3
    cs = ax.contour(to_np(lons), to_np(lats), to_np(smooth_slp), levels=plevels, 
                    linewidths=linewidths, colors="black", zorder=4,
                    transform=ccrs.PlateCarree())
    plt.clabel(cs, fmt='%1.0f', fontsize=8)
#
    ax.coastlines(linewidth=0.3, zorder=2)
    countries = cfeature.NaturalEarthFeature(category='cultural', name='admin_0_countries', scale='50m')
    ax.add_feature(countries, facecolor='none', edgecolor='black', linewidth=0.3, zorder=3)
#
    cbar  = plt.colorbar(con, cax=ax_cbarv, ticks=[1/10, 1/4, 1/2, 3/4, 1])
    cbar2 = plt.colorbar(con2, cax=ax_cbarh, orientation="horizontal")
#
    try:
        cbar.ax.set_yticklabels(["1/10", "1/4", "1/2", "3/4", "1"], fontsize=10)
    except:
        print("Plotting tick labels failed")
#
    ax.set_xlim(cartopy_lim_buffered(slp, "x"))
    ax.set_ylim(cartopy_lim_buffered(slp, "y"))

    add_title_and_description(nc, "Sea Level Pressure [hPa], Mid-Level Cloud Cover, Total prec. in the last hour [mm] \n"
                      "WRF {init:s} +{fcsthour:d}", ax)
#
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
    plt.close()

    return prec


def plevel_map(t, nc, output_path, plevel, phi_contour_intervals=np.arange(400, 800, 8)):
#
    def interp_var(nc, variablename, p, plevel):
        var = getvar(nc, variablename)
        return interplevel(var, p, plevel)
#
    p    = getvar(nc, "p", units='hPa')
#
    geop = interp_var(nc, "geopt", p, plevel)
    geop = smooth2d(geop, 5, cenweight=4)
#
    temp = interp_var(nc, "tc", p, plevel)
#
    u, v = interp_var(nc, "uvmet", p, plevel)
#
    # Get the latitude and longitude points
    lats, lons = latlon_coords(temp)
#
    # Get the WRF projection (CRS) object
    crs = get_cartopy(temp)
#
    # Create a figure
    fig, ax, ax_cbarv, ax_cbarh = make_fig_and_ax(plt, crs)
#
    # Temperature
    lab = np.array([-46, -41, -36, -31, -26, -21, -16, -11, -6, -1, 4])
    con = ax.contourf(to_np(lons), to_np(lats), temp, levels=lab,
                      transform=ccrs.PlateCarree(),
                      cmap='viridis', zorder=0, extend='both')
#
    # Geopotential height
    linewidths = np.ones(phi_contour_intervals.shape)
    linewidths[np.where(phi_contour_intervals == 552)[0]] = 3
    cs = ax.contour(to_np(lons), to_np(lats), geop/98.0, levels=phi_contour_intervals, 
                    linewidths=linewidths, colors="black", 
                    zorder=1, 
                    transform=ccrs.PlateCarree())
    plt.clabel(cs, fmt='%1.0f', fontsize=10)
#
#
    ax.coastlines(linewidth=0.3, zorder=2)
    countries = cfeature.NaturalEarthFeature(category='cultural', name='admin_0_countries', scale='50m')
    ax.add_feature(countries, facecolor='none', edgecolor='black', linewidth=0.3, zorder=3)
#
    nb = 20
    brsb = ax.barbs(to_np(lons[::nb,::nb]), to_np(lats[::nb,::nb]), 
                    to_np(u[::nb,::nb]), to_np(v[::nb,::nb]), 
                    transform=ccrs.PlateCarree(),
                    length=4.5, zorder=4, linewidth=0.3)
#
    ax.set_xlim(cartopy_lim_buffered(temp, "x"))
    ax.set_ylim(cartopy_lim_buffered(temp, "y"))
#
    cbar = plt.colorbar(con, cax=ax_cbarv)
    add_title_and_description(nc, "{:d}".format(plevel) + " hPa level Geopotential [gpdm], Temperature [°C] and Wind  \n"
                     "WRF {init:s} +{fcsthour:d}", ax)
#
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
    plt.close()

# https://stackoverflow.com/questions/37327308/add-alpha-to-an-existing-matplotlib-colormap
def alphafy_colormap(name, mina=0.0, maxa=0.8):
    # Choose colormap
    cmap = matplotlib.cm.get_cmap(name)

    # Get the colormap colors
    my_cmap = cmap(np.arange(cmap.N))

    # Set alpha
    my_cmap[:,-1] = np.linspace(mina, maxa, cmap.N)

    # Create new colormap
    return ListedColormap(my_cmap)


def sfc_map(t, nc, get_var_fun, levels, ticks, label, bg_function, output_path, cmap="YlOrRd", extend='both'):
#
    var  = get_var_fun(nc)
    u, v = getvar(nc, 'uvmet10')
#
    # Get the latitude and longitude points
    lats, lons = latlon_coords(u)
#
    # Get the WRF projection (CRS) object
    crs = get_cartopy(u)
#
    # Create a figure
    fig, ax, ax_cbarv, ax_cbarh = make_fig_and_ax(plt, crs)
#
    # Draw the contours and filled contours
    lab = np.array(levels)
    con = ax.contourf(to_np(lons), to_np(lats), var, levels=lab,
                      transform=ccrs.PlateCarree(),
                      cmap=alphafy_colormap(cmap), zorder=1, extend=extend)
#
    linewidths = np.ones(lab.shape)/2
    linewidths[np.where(lab == 0)[0]] = 1
    cs = ax.contour(to_np(lons), to_np(lats), var, levels=lab[::2], 
                    linewidths=linewidths, colors="black", alpha=0.5,
                    zorder=2, 
                    transform=ccrs.PlateCarree())
    plt.clabel(cs, fmt='%1.0f', fontsize=8)
#
    # Add background (needs to happen after map is already created...)
    bg_function(ax, crs)
#
    nb = 10
    brsb = ax.barbs(to_np(lons[::nb,::nb]), to_np(lats[::nb,::nb]), 
                    to_np(u[::nb,::nb]), to_np(v[::nb,::nb]), 
                    transform=ccrs.PlateCarree(),
                    length=4.5, zorder=4, linewidth=0.5)

    xlim = cartopy_lim_buffered(u, "x")
    ylim = cartopy_lim_buffered(u, "y")
#
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
#
    cbar = plt.colorbar(con, cax=ax_cbarv)
    cbar.set_ticks(ticks)
#
    add_title_and_description(nc, label + " and Wind (10m) \n" + "WRF {init:s} +{fcsthour:d}", ax)
#
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
    plt.close()
