import datetime

from argparse import ArgumentParser, ArgumentTypeError

def valid_date(s):
        try:
                return datetime.datetime.strptime(s, "%Y-%m-%d")
        except ValueError:
                raise ArgumentTypeError("Not a valid date: '{0}'.".format(s))

parser = ArgumentParser()
parser.add_argument('wrf_data_fpath_pattern', type=str,        help='WRF input data path pattern (including strptime pattern for date and {domain:02d} for domain (01, 02, ...))')
parser.add_argument('date',                   type=valid_date, help='Date to process (YYYY-mm-dd)')
parser.add_argument('interval_hours',         type=int,        help='Time interval in hours for plotting')
parser.add_argument('fcst_time_hours',        type=int,        help='Forecast time in hours')
parser.add_argument('plot_fpath_pattern',     type=str,        help='Plot output path pattern (including strptime pattern for date, {plot:s} for plot identifier (e.g., 500hPa...) and {domain:02d} for domain (01, 02)')

args = parser.parse_args()

from plots import plevel_map, sfc_map, slp_map
import pois
import netCDF4
from wrf import getvar
import numpy as np
import cartopy.feature as cfeature
import contextily
import rasterio

times = [ args.date + datetime.timedelta(hours=t) for t in range(0, args.fcst_time_hours, args.interval_hours) ]

# for calculating delta precip
prec_prev = None

def bg_function(ax, wrfcrs):
        wrfcrs_proj4 = wrfcrs.proj4_init
        wrfcrs_proj4 = wrfcrs_proj4.replace('+nadgrids=@null', '')
        rasteriocrs  = rasterio.crs.CRS.from_proj4(wrfcrs_proj4)
        contextily.add_basemap(ax, source=contextily.providers.Stamen.TonerLite, crs=rasteriocrs, zorder=0) 

for t in times:
        input_path_d01   = t.strftime(args.wrf_data_fpath_pattern).format(domain=1)
        input_path_d02   = t.strftime(args.wrf_data_fpath_pattern).format(domain=2)

        try:
                nc_d01      = netCDF4.Dataset(input_path_d01)
                nc_d02      = netCDF4.Dataset(input_path_d02)
        except Exception as e:
                print("Not all data available for step {:s}: {:s}".format(t.strftime("%Y-%m-%d %H"), str(e)))
                continue

        print(t)

        # 500 hPa Europe domain
#        output_path = t.strftime(args.plot_fpath_pattern).format(plot='500hPa', domain=1)
#        plevel_map(t, nc_d01, output_path, 500)

        # Surface Europe domain
        # temperature
#        sfc_map(t, nc_d01, lambda nc: getvar(nc, 'T2') - 273.15, 
#                levels=[-25, -20, -15, -12.5, -10, -7.5, -5, -2.5, 0, 2.5, 5, 7.5, 10, 12.5, 15, 17.5, 20, 22.5, 25, 27.5, 30, 32.5, 35],
#                ticks=[-25, -15, -10, -5, 0, 5, 10, 15, 20, 25, 30, 35],
#                label="Temperature (2m) [°C]", pois=pois.capitals, 
#                output_path=t.strftime(args.plot_fpath_pattern).format(plot='T', domain=1))

        # Ozone
        for domain, nc_d, poilist in zip([1,2],[nc_d01, nc_d02], [pois.capitals, pois.southern_germany]):
                sfc_map(t, nc_d, lambda nc: getvar(nc, 'o3')[0,:,:] * 1e3, 
                        levels=range(0, 160, 10),
                        ticks=range(0, 150, 20),
                        label="Ozone (ground level) [ppbv]", bg_function=bg_function,
                        output_path=t.strftime(args.plot_fpath_pattern).format(plot='O3', domain=domain),
                        extend='max')
                # NO2
                sfc_map(t, nc_d, lambda nc: getvar(nc, 'no2')[0,:,:] * 1e3, 
                        levels=range(0, 110, 10),
                        ticks=range(0, 110, 20),
                        label="NO$_2$ (ground level) [ppbv]", bg_function=bg_function,
                        output_path=t.strftime(args.plot_fpath_pattern).format(plot='NO2', domain=domain),
                        extend='max')
                # PM
                def get_pm(nc, bins=[ 1, 2, 3 ]) :
                        pmspecs = [ "so4", "no3", "smpa", "smpbb", "glysoa_sfc", "biog1_c", "biog1_o", "cl", "co3", "nh4", "na", "cl", "oin", "oc", "bc", "water" ] # ug kg-1
                        alt     = getvar(nc, "ALT")[0,:,:] # inverse density m3 kg-1
                        all     =  [ getvar(nc, spec + "_a{:02d}".format(bin))[0,:,:] / alt for spec in pmspecs for bin in bins ]
                        return np.sum(all, axis=0)

#                sfc_map(t, nc_d, lambda nc: get_pm(nc, bins=[1]), 
#                        levels=range(0, 130, 10),
#                        ticks=range(0, 130, 20),
#                        label="UFP (ground level) [$\mu$g m$^{{-3}}$]", bg_function=bg_function,
#                        output_path=t.strftime(args.plot_fpath_pattern).format(plot='UFP', domain=domain),
#                        extend='max')
                sfc_map(t, nc_d, lambda nc: get_pm(nc, bins=[1,2,3]), 
                        levels=range(0, 130, 10),
                        ticks=range(0, 130, 20),
                        label="PM$_{{2.5}}$ (ground level) [$\mu$g m$^{{-3}}$]", bg_function=bg_function,
                        output_path=t.strftime(args.plot_fpath_pattern).format(plot='PM2_5', domain=domain),
                        extend='max')
                sfc_map(t, nc_d, lambda nc: get_pm(nc, bins=[1,2,3,4]), 
                        levels=range(0, 130, 10),
                        ticks=range(0, 130, 20),
                        label="PM$_{{10}}$ (ground level) [$\mu$g m$^{{-3}}$]", bg_function=bg_function,
                        output_path=t.strftime(args.plot_fpath_pattern).format(plot='PM10', domain=domain),
                        extend='max')

        # SLP Europe
#        output_path = t.strftime(args.plot_fpath_pattern).format(plot='SLP', domain=1)
#        prec_prev = slp_map(t, nc_d01, prec_prev, output_path)

        nc_d01.close()
