#!/usr/bin/env python
# -*- coding: utf-8 -*-

from wrf import to_np, getvar, smooth2d, get_basemap, ll_to_xy, latlon_coords, interplevel, get_cartopy, cartopy_xlim, cartopy_ylim
import numpy as np
import netCDF4
import datetime
import matplotlib
from matplotlib import pyplot as plt
import cartopy.feature as cfeature
import cartopy.crs as ccrs
from matplotlib.offsetbox import AnchoredText
import matplotlib.gridspec as gridspec
from matplotlib.colors import LinearSegmentedColormap
import sys

MAP_BUFFER=0.1 # 10 % buffer at edge of maps

t2datetime = lambda x: datetime.datetime.strptime(x, '%Y-%m-%d_%H:%M:%S')

def cartopy_lim_buffered(field, direction, buffer=MAP_BUFFER):
    if direction == "x":
        lims = cartopy_xlim(field)
    elif direction == "y":
        lims = cartopy_ylim(field)
    else:
        raise Exception("Moek!!")
    dlim = lims[1] - lims[0]
    lims[0] += 0.5 * buffer * dlim
    lims[1] -= 0.5 * buffer * dlim
    return lims

def wrf_time(nc):
    first = ''.join([x.decode('utf-8') for x in nc['Times'][0]])
    return t2datetime(first)

def total_prec(nc, nc_prev):
    prec_vars = ['RAINC', 'RAINSH', 'RAINNC', 'SNOWNC', 'GRAUPELNC', 'HAILNC']
    start = t2datetime(nc.START_DATE)
    prec = np.sum([nc[var][0, :, :] for var in prec_vars], axis=0)
#
    if start == wrf_time(nc):
        return prec
    else:
        prec_p = np.sum([nc_prev[var][0, :, :] for var in prec_vars], axis=0)
        return prec-prec_p

def add_title_and_description(nc, title, ax):
    '''
    title: including {init:s} for timestamp of run initialization, and {fcsthour:d} for forecast hour
    '''
    fcstHour = int((wrf_time(nc) - t2datetime(nc.START_DATE)).total_seconds()/3600.)
#    
    ax.set_title(title.format(init=t2datetime(nc.START_DATE).strftime("%d.%m.%Y %H:%M UTC"), 
                              fcsthour=fcstHour), fontdict={ 'fontsize':8, 'horizontalalignment': 'left'}, loc='left')
#
    texttime = AnchoredText(wrf_time(nc).strftime("%H UTC %a"), 
                            loc=2, pad=0.2, prop={'size': 10}, frameon=True)
    ax.add_artist(texttime)
#
    textcopyright = AnchoredText("$\copyright$ MBEES, Faculty of Medicine, University of Augsburg",
                        loc=4, pad=0.1, prop={'size': 6, 'color':'white'}, frameon=False)
    ax.add_artist(textcopyright)

def make_fig_and_ax(plt, crs, add_horiz_cbar=False):
    fig      = plt.figure(figsize=(7,7))
#
    gs       = gridspec.GridSpec(2, 2, width_ratios=[20,1], height_ratios=[20,1], wspace=0.05, hspace=0.05)
#    
    ax       = fig.add_subplot(gs[0,0], projection=crs)
    ax_cbarv = fig.add_subplot(gs[0,1])
    ax_cbarh = None
    if add_horiz_cbar:
        ax_cbarh = fig.add_subplot(gs[1,0])
    return fig, ax, ax_cbarv, ax_cbarh

def slp_map(t, nc, nc_prev, output_path, plevels=np.arange(950, 1050, 5)):
#
    slp        = getvar(nc, "slp")
    mid_clouds = getvar(nc, "mid_cloudfrac")
    prec       = total_prec(nc, nc_prev)
#
    # Smooth the sea level pressure since it tends to be noisy near the
    # mountains
    smooth_slp = smooth2d(slp, 5, cenweight=4)
    # Get the latitude and longitude points
    lats, lons = latlon_coords(slp)
#
    # Get the WRF projection (CRS) object
    crs = get_cartopy(slp)
#
    # Create a figure
    fig, ax, ax_cbarv, ax_cbarh = make_fig_and_ax(plt, crs, add_horiz_cbar=True)
#
    # Draw the contours and filled contours
    cmap  = LinearSegmentedColormap.from_list("", ["white", "silver"])
    cmap2 = matplotlib.cm.viridis_r
#
    con   = ax.contourf(to_np(lons), to_np(lats), to_np(mid_clouds), 
                        levels=[1/10, 1/4, 1/2, 3/4, 1], cmap=cmap, zorder=2,
                        transform=ccrs.PlateCarree())
    con2  = ax.contourf(to_np(lons), to_np(lats), to_np(prec), 
                        levels=[0.5, 1, 2, 5, 10, 25, 50], cmap=cmap2, zorder=2, extend='max',
                        transform=ccrs.PlateCarree())
#
    linewidths = np.ones(plevels.shape)
    linewidths[np.where(plevels == 1015)[0]] = 3
    cs = ax.contour(to_np(lons), to_np(lats), to_np(smooth_slp), levels=plevels, 
                    linewidths=linewidths, colors="black", zorder=4,
                    transform=ccrs.PlateCarree())
    plt.clabel(cs, fmt='%1.0f', fontsize=8)
#
    ax.coastlines(linewidth=0.3, zorder=2)
    countries = cfeature.NaturalEarthFeature(category='cultural', name='admin_0_countries', scale='50m')
    ax.add_feature(countries, facecolor='none', edgecolor='black', linewidth=0.3, zorder=3)
#
    cbar  = plt.colorbar(con, cax=ax_cbarv, ticks=[1/10, 1/4, 1/2, 3/4, 1])
    cbar2 = plt.colorbar(con2, cax=ax_cbarh, orientation="horizontal")
#
    try:
        cbar.ax.set_yticklabels(["1/10", "1/4", "1/2", "3/4", "1"], fontsize=10)
    except:
        print("Plotting tick labels failed")
#
    ax.set_xlim(cartopy_lim_buffered(slp, "x"))
    ax.set_ylim(cartopy_lim_buffered(slp, "y"))

    add_title_and_description(nc, "Sea Level Pressure [hPa], Mid-Level Cloud Cover, Total prec. in the last hour [mm] \n"
                      "WRF {init:s} +{fcsthour:d}", ax)
#
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
    plt.close()


def plevel_map(t, nc, output_path, plevel, phi_contour_intervals=np.arange(400, 800, 8)):
#
    def interp_var(nc, variablename, p, plevel):
        var = getvar(nc, variablename)
        return interplevel(var, p, plevel)
#
    p    = getvar(nc, "p", units='hPa')
#
    geop = interp_var(nc, "geopt", p, plevel)
    geop = smooth2d(geop, 5, cenweight=4)
#
    temp = interp_var(nc, "tc", p, plevel)
#
    u, v = interp_var(nc, "uvmet", p, plevel)
#
    # Get the latitude and longitude points
    lats, lons = latlon_coords(temp)
#
    # Get the WRF projection (CRS) object
    crs = get_cartopy(temp)
#
    # Create a figure
    fig, ax, ax_cbarv, ax_cbarh = make_fig_and_ax(plt, crs)
#
    # Temperature
    lab = np.array([-46, -41, -36, -31, -26, -21, -16, -11, -6, -1, 4])
    con = ax.contourf(to_np(lons), to_np(lats), temp, levels=lab,
                      transform=ccrs.PlateCarree(),
                      cmap='viridis', zorder=0, extend='both')
#
    # Geopotential height
    linewidths = np.ones(phi_contour_intervals.shape)
    linewidths[np.where(phi_contour_intervals == 552)[0]] = 3
    cs = ax.contour(to_np(lons), to_np(lats), geop/98.0, levels=phi_contour_intervals, 
                    linewidths=linewidths, colors="black", 
                    zorder=1, 
                    transform=ccrs.PlateCarree())
    plt.clabel(cs, fmt='%1.0f', fontsize=10)
#
#
    ax.coastlines(linewidth=0.3, zorder=2)
    countries = cfeature.NaturalEarthFeature(category='cultural', name='admin_0_countries', scale='50m')
    ax.add_feature(countries, facecolor='none', edgecolor='black', linewidth=0.3, zorder=3)
#
    nb = 20
    brsb = ax.barbs(to_np(lons[::nb,::nb]), to_np(lats[::nb,::nb]), 
                    to_np(u[::nb,::nb]), to_np(v[::nb,::nb]), 
                    transform=ccrs.PlateCarree(),
                    length=4.5, zorder=4, linewidth=0.3)
#
    ax.set_xlim(cartopy_lim_buffered(temp, "x"))
    ax.set_ylim(cartopy_lim_buffered(temp, "y"))
#
    cbar = plt.colorbar(con, cax=ax_cbarv)
    add_title_and_description(nc, "{:d}".format(plevel) + " hPa level Geopotential [gpdm], Temperature [°C] and Wind  \n"
                     "WRF {init:s} +{fcsthour:d}", ax)
#
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
    plt.close()

def sfc_map(t, nc, get_var_fun, levels, ticks, label, pois, output_path, cmap="viridis", extend='both'):
#
    var  = get_var_fun(nc)
    u, v = getvar(nc, 'uvmet10')
#
    # Get the latitude and longitude points
    lats, lons = latlon_coords(u)
#
    # Get the WRF projection (CRS) object
    crs = get_cartopy(u)
#
    # Create a figure
    fig, ax, ax_cbarv, ax_cbarh = make_fig_and_ax(plt, crs)
#
    # Draw the contours and filled contours
    lab = np.array(levels)
    con = ax.contourf(to_np(lons), to_np(lats), var, levels=lab,
                      transform=ccrs.PlateCarree(),
                      cmap=cmap, zorder=0, extend=extend)
#
    linewidths = np.ones(lab.shape)/2
    linewidths[np.where(lab == 0)[0]] = 1
    cs = ax.contour(to_np(lons), to_np(lats), var, levels=lab[::2], 
                    linewidths=linewidths, colors="black", alpha=0.4,
                    zorder=1, 
                    transform=ccrs.PlateCarree())
    plt.clabel(cs, fmt='%1.0f', fontsize=8)
#
    ax.coastlines(linewidth=0.3, zorder=2)
    countries = cfeature.NaturalEarthFeature(category='cultural', name='admin_0_countries', scale='50m')
    ax.add_feature(countries, facecolor='none', edgecolor='black', linewidth=0.3, zorder=3)
#
    nb = 10
    brsb = ax.barbs(to_np(lons[::nb,::nb]), to_np(lats[::nb,::nb]), 
                    to_np(u[::nb,::nb]), to_np(v[::nb,::nb]), 
                    transform=ccrs.PlateCarree(),
                    length=4.5, zorder=4, linewidth=0.3)
#
    def plot_poi(name, lon, lat, ms, fontsize):
        ax.plot(lon, lat, 'o', zorder=5, color='darkred', ms=ms, transform=ccrs.PlateCarree())
        ax.annotate(name, xy=(lon, lat), xytext=(5,5), textcoords="offset points", 
                    fontsize=fontsize, zorder=6, color='darkred', 
                    xycoords=ccrs.PlateCarree()._as_mpl_transform(ax))
#
    nul = [ plot_poi(**item) for item in pois ]
#
    ax.set_xlim(cartopy_lim_buffered(u, "x"))
    ax.set_ylim(cartopy_lim_buffered(u, "y"))
#
    cbar = plt.colorbar(con, cax=ax_cbarv, extend=extend)
    cbar.set_ticks(ticks)
#
    add_title_and_description(nc, label + " and Wind (10m) \n" + "WRF {init:s} +{fcsthour:d}", ax)
#
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
    plt.close()

# --------------------------------------------------------------------------------------------------------
if __name__ == '__main__':

    import datetime

    from argparse import ArgumentParser, ArgumentTypeError

    def valid_date(s):
        try:
            return datetime.datetime.strptime(s, "%Y-%m-%d")
        except ValueError:
            raise ArgumentTypeError("Not a valid date: '{0}'.".format(s))

    parser = ArgumentParser()
    parser.add_argument('wrf_data_fpath_pattern', type=str,        help='WRF input data path pattern (including strptime pattern for date and {domain:s} for domain (01, 02, ...))')
    parser.add_argument('date',                   type=valid_date, help='Date to process (YYYY-mm-dd)')
    parser.add_argument('interval_hours',         type=int,        help='Time interval in hours for plotting')
    parser.add_argument('fcst_time_hours',        type=int,        help='Forecast time in hours')
    parser.add_argument('plot_fpath_pattern',     type=str,        help='Plot output path pattern (including strptime pattern for date and {plot:s} for plot identifier (e.g., 500hPa_europe...)')

    args = parser.parse_args()
#    args = parser.parse_args( [ '/alcc/gpfs2/scratch/mbees/knotechr/archive/WRF/operational_chemistry/wrfout_d{domain:02d}_%Y-%m-%d_%H:%M:%S', '2021-02-17', '1', '72', '/alcc/gpfs2/scratch/mbees/knotechr/WRF/postprocessing/fcst_movie/{plot:s}_%Y%m%d%H.png' ])
    times = [ args.date + datetime.timedelta(hours=t) for t in range(0, args.fcst_time_hours, args.interval_hours) ]

    for t in times:
        tprev = t + datetime.timedelta(hours=-1)

        input_path_d01   = t.strftime(args.wrf_data_fpath_pattern).format(domain=1)
        previnp_path_d01 = tprev.strftime(args.wrf_data_fpath_pattern).format(domain=1)

        try:
            nc_d01      = netCDF4.Dataset(input_path_d01)
            nc_d01_prev = netCDF4.Dataset(previnp_path_d01)
        except Exception as e:
            print("Not all data available for step {:s}: {:s}".format(t.strftime("%Y-%m-%d %H"), str(e)))
            continue
        
        print(t)

        pois_southern_germany = [   { 'name':'Munich',     'lon':11.582,    'lat':48.135,    'ms':5,   'fontsize':9 }, 
                                    { 'name':'Augsburg',   'lon':10.894446, 'lat':48.366512, 'ms':2.5, 'fontsize':8 },
                                    { 'name':'Regensburg', 'lon':12.101624, 'lat':49.013432, 'ms':2.5, 'fontsize':8 },
                                    { 'name':'Ingolstadt', 'lon':11.426311, 'lat':48.761423, 'ms':2.5, 'fontsize':8 },
                                    { 'name':'Innsbruck',  'lon':11.388397, 'lat':47.258320, 'ms':2.5, 'fontsize':8 }
                                    ]

        # 500 hPa Europe domain
        output_path = t.strftime(args.plot_fpath_pattern).format(plot='500hPa_Europe')
        plevel_map(t, nc_d01, output_path, 500)

        # Surface Europe domain
        # temperature
        sfc_map(t, nc_d01, lambda nc: getvar(nc, 'T2') - 273.15, 
                levels=[-25, -20, -15, -12.5, -10, -7.5, -5, -2.5, 0, 2.5, 5, 7.5, 10, 12.5, 15, 17.5, 20, 22.5, 25, 27.5, 30, 32.5, 35],
                ticks=[-25, -15, -10, -5, 0, 5, 10, 15, 20, 25, 30, 35],
                label="Temperature (2m) [°C]", pois={}, 
                output_path=t.strftime(args.plot_fpath_pattern).format(plot='T_Europe'))
        # Ozone
        sfc_map(t, nc_d01, lambda nc: getvar(nc, 'o3')[0,:,:] * 1e3, 
                levels=range(0, 160, 10),
                ticks=range(0, 150, 20),
                label="Ozone (ground level) [ppbv]", pois={}, 
                output_path=t.strftime(args.plot_fpath_pattern).format(plot='O3_Europe'),
                extend='max')
        # NO2
        sfc_map(t, nc_d01, lambda nc: getvar(nc, 'no2')[0,:,:] * 1e3, 
                levels=range(0, 110, 10),
                ticks=range(0, 110, 20),
                label="NO$_2$ (ground level) [ppbv]", pois={}, 
                output_path=t.strftime(args.plot_fpath_pattern).format(plot='NO2_Europe'),
                extend='max')
        # PM
        def get_pm(nc, bins=[ 1, 2, 3 ]) :
            pmspecs = [ "so4", "no3", "smpa", "smpbb", "glysoa_sfc", "biog1_c", "biog1_o", "cl", "co3", "nh4", "na", "cl", "oin", "oc", "bc", "water" ] # ug kg-1
            alt     = getvar(nc, "ALT")[0,:,:] # inverse density m3 kg-1
            all     =  [ getvar(nc, spec + "_a{:02d}".format(bin))[0,:,:] / alt for spec in pmspecs for bin in bins ]
            return np.sum(all, axis=0)

        sfc_map(t, nc_d01, lambda nc: get_pm(nc, bins=[1]), 
                levels=range(0, 130, 10),
                ticks=range(0, 130, 20),
                label="UFP (ground level) [$\mu$g m$^{{-3}}$]", pois={}, 
                output_path=t.strftime(args.plot_fpath_pattern).format(plot='UFP_Europe'),
                extend='max')
        sfc_map(t, nc_d01, lambda nc: get_pm(nc, bins=[1,2,3]), 
                levels=range(0, 130, 10),
                ticks=range(0, 130, 20),
                label="PM$_{{2.5}}$ (ground level) [$\mu$g m$^{{-3}}$]", pois={}, 
                output_path=t.strftime(args.plot_fpath_pattern).format(plot='PM2_5_Europe'),
                extend='max')
        sfc_map(t, nc_d01, lambda nc: get_pm(nc, bins=[1,2,3,4]), 
                levels=range(0, 130, 10),
                ticks=range(0, 130, 20),
                label="PM$_{{10}}$ (ground level) [$\mu$g m$^{{-3}}$]", pois={}, 
                output_path=t.strftime(args.plot_fpath_pattern).format(plot='PM10_Europe'),
                extend='max')
        # SLP Europe
        output_path = t.strftime(args.plot_fpath_pattern).format(plot='SLP_Europe')
        slp_map(t, nc_d01, nc_d01_prev, output_path)
        
        nc_d01.close()
        nc_d01_prev.close()
        
