#!/usr/bin/env python

from wrf import getvar, ll_to_xy, to_np, extract_times
import numpy as np


def fT2m(D, xmin, xmax, ymin, ymax):
    return D["T2"][0, xmin:xmax, ymin:ymax] - 273.15


def fTd2m(D, xmin, xmax, ymin, ymax):
    return getvar(D, "td2")[xmin:xmax, ymin:ymax]


def wind10m(D, lat, lon):
    """Returns the wind speed and the wind direction at 10 m"""
    try:
        x, y = ll_to_xy(D, lat, lon)
        w_speed, w_dir = getvar(D, "uvmet10_wspd_wdir")
        return w_speed.values[int(x), int(y)], w_dir.values[int(x), int(y)]
    except:
        return np.NaN, np.NaN


def wind_uv(D, lat, lon):
    """Returns the u and v wind components at 10 m"""
    try:
        x, y = ll_to_xy(D, lat, lon)
        w_u, w_v = getvar(D, "uvmet10")
        return w_u.values[int(x), int(y)], w_v.values[int(x), int(y)]
    except:
        return np.NaN, np.NaN


def Mslp(D, lat, lon):
    try:
        """Returns the mean sea level pressure"""
        x, y = ll_to_xy(D, lat, lon)
        slp = getvar(D, "slp")
        return slp.values[int(x), int(y)]
    except:
        return np.NaN


def cin_cape(D, lat, lon):
    try:
        """Returns the 2D max cin and cape"""
        x, y = ll_to_xy(D, lat, lon)
        var = getvar(D, "cape_2d")
        return (var[1].values[int(x), int(y)], var[0].values[int(x), int(y)])
    except:
        return np.NaN, np.NaN


def vert_wind(D, lat, lon):
    try:
        from wrf import vertcross, CoordPair

        x, y = ll_to_xy(D, lat, lon)
        xy = CoordPair(x=x, y=y)
        u, v = getvar(D, "uvmet")
        p = getvar(D, "p", units="hPa")
        levels = np.arange(400, 1060, 60)
        u_vert = vertcross(
            u,
            p,
            levels=levels,
            start_point=xy,
            end_point=CoordPair(x=x + 1, y=y + 1),
            meta=False,
            missing=np.NaN,
        )
        v_vert = vertcross(
            v,
            p,
            levels=levels,
            start_point=xy,
            end_point=CoordPair(x=x + 1, y=y + 1),
            meta=False,
            missing=np.NaN,
        )
        return list(u_vert[:, 0]), list(v_vert[:, 0])
    except:
        levels = np.arange(400, 1060, 60)
        return [np.NaN for x in range(len(levels))], [
            np.NaN for x in range(len(levels))
        ]


def vert_T(D, lat, lon):
    try:
        from wrf import vertcross, CoordPair

        x, y = ll_to_xy(D, lat, lon)
        xy = CoordPair(x=x, y=y)
        T = getvar(D, "tc")
        p = getvar(D, "p", units="hPa")
        levels = np.arange(400, 1060, 60)
        T_vert = to_np(
            vertcross(
                T,
                p,
                levels=levels,
                start_point=xy,
                end_point=CoordPair(x=x + 1, y=y + 1),
                meta=False,
                missing=np.NaN,
            )
        )
        return T_vert[:, 0][:]
    except:
        levels = np.arange(400, 1060, 60)
        return [np.NaN for x in range(len(levels))], [
            np.NaN for x in range(len(levels))
        ]


def fcfrac(D, lat, lon):
    from wrf import interplevel, CoordPair

    levels = np.arange(400, 1030, 30)
    try:
        x, y = ll_to_xy(D, lat, lon)
        p = getvar(D, "p", units="hPa")
        cf = interplevel(D["CLDFRA"], p, levels)
        return cf.values[:, int(x), int(y)]

    except:
        return [np.NaN for x in range(len(levels))]


def ftot_accum_rain_2d(D, xmin, xmax, ymin, ymax):
    try:
        return np.sum(
            [
                D[var][0, xmin:xmax, ymin:ymax]
                for var in [
                    "RAINC",
                    "RAINSH",
                    "RAINNC",
                    "SNOWNC",
                    "GRAUPELNC",
                    "HAILNC",
                ]
            ],
            axis=0,
        )
    except:
        return np.mgrid[xmin:xmax, ymin:ymax][0] * np.NaN


def fdrain(DS, xmin, xmax, ymin, ymax):
    tot_accum_rain = np.array(
        [ftot_accum_rain_2d(D, xmin, xmax, ymin, ymax) for D in DS]
    )

    # weird construct to remove negative rain rates
    # because timeseries have restarts in them
    # i.e. we search for spots where rainrate accumulation starts

    starts = np.array([t2datetime(D.START_DATE) for D in DS])
    dt = np.array([x.total_seconds() for x in (starts[1:] - starts[:-1])])
    drain = np.zeros_like(tot_accum_rain)
    drain[1:] = tot_accum_rain[1:] - tot_accum_rain[:-1]
    drain[1:] = np.where(
        dt > 0, tot_accum_rain[1:].T, (tot_accum_rain[1:] - tot_accum_rain[:-1]).T
    ).T
    return drain


def ftot_accum_snow_2d(D, xmin, xmax, ymin, ymax):
    try:
        return np.sum(
            [D[var][0, xmin:xmax, ymin:ymax] for var in ["SNOWNC", "GRAUPELNC"]], axis=0
        )
    except:
        return np.mgrid[xmin:xmax, ymin:ymax][0] * np.NaN


def fdsnow(DS, xmin, xmax, ymin, ymax):
    tot_accum_snow = np.array(
        [ftot_accum_snow_2d(D, xmin, xmax, ymin, ymax) for D in DS]
    )

    starts = np.array([t2datetime(D.START_DATE) for D in DS])
    dt = np.array([x.total_seconds() for x in (starts[1:] - starts[:-1])])
    drain = np.zeros_like(tot_accum_snow)
    drain[1:] = tot_accum_snow[1:] - tot_accum_snow[:-1]
    drain[1:] = np.where(
        dt > 0, tot_accum_snow[1:].T, (tot_accum_snow[1:] - tot_accum_snow[:-1]).T
    ).T
    return drain


def t2datetime(x):
    import datetime

    return datetime.datetime.strptime(x, "%Y-%m-%d_%H:%M:%S")


def ftimes(D):
    """Convert first WRF timestamp description into datetimes string"""
    return t2datetime("".join([x.decode("utf-8") for x in D["Times"][0]]))


def corresponding_pixels(D, location_lon, location_lat, deg_round_loc):
    import numpy as np

    xlon = D["XLONG"][0]
    xlat = D["XLAT"][0]

    around_loc = (np.abs(xlat - location_lat) < deg_round_loc) & (
        np.abs(xlon - location_lon) < deg_round_loc
    )

    xmin, xmax = np.nonzero(around_loc)[0][[0, -1]]
    ymin, ymax = np.nonzero(around_loc)[1][[0, -1]]

    xlon = xlon[xmin:xmax, ymin:ymax]
    xlat = xlat[xmin:xmax, ymin:ymax]

    # print("Using {} pixels for plot".format(np.size(xlon)))
    return (xmin, xmax), (ymin, ymax), xlon, xlat


def haversine(lat1, lon1, lat2, lon2):
    """
    Compute distance between lat lon points
    """
    import numpy as np

    R = 6378137  # Earth radius
    dLat = np.deg2rad(lat2) - np.deg2rad(lat1)
    dLon = np.deg2rad(lon2) - np.deg2rad(lon1)
    a = np.sin(dLat / 2) * np.sin(dLat / 2) + np.cos(np.deg2rad(lat1)) * np.cos(
        np.deg2rad(lat2)
    ) * np.sin(dLon / 2.0) * np.sin(dLon / 2.0)
    c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1.0 - a))
    return R * c


def compute_average_resolution(D, xmin, xmax, ymin, ymax):
    import numpy as np

    lon = D["XLONG"][0]
    lat = D["XLAT"][0]

    dy_avg = np.average(
        haversine(
            lat[xmin:xmax, ymin : ymax - 1],
            lon[xmin:xmax, ymin : ymax - 1],
            lat[xmin:xmax, ymin + 1 : ymax],
            lon[xmin:xmax, ymin + 1 : ymax],
        )
    )
    dx_avg = np.average(
        haversine(
            lat[xmin : xmax - 1, ymin:ymax],
            lon[xmin : xmax - 1, ymin:ymax],
            lat[xmin + 1 : xmax, ymin:ymax],
            lon[xmin + 1 : xmax, ymin:ymax],
        )
    )

    return (dy_avg + dx_avg) / 2


def datetime_to_float(d):
    import datetime

    return d.timestamp()


# --------------------------------------------------------------------------------


def plot_var(
    filenames, outfile, deg_round_loc=0.2, location_lon=10.8978, location_lat=48.3705
):
    """
    Plot meteogram for given location (location_lon,location_lat)
    Script uses all pixels around location in a radius of ''deg_round_loc''
    to attain model spread.
    """
    import matplotlib as mpl

    mpl.use("Agg")
    import matplotlib.pyplot as plt
    import numpy as np
    import netCDF4
    import datetime

    from matplotlib.colors import LinearSegmentedColormap
    from matplotlib.font_manager import FontProperties

    def draw_vlines(ax, labels=False):
        for t in time[1:]:
            if t.hour == 0.0 and t.minute == 0.0:
                ax.vlines(
                    datetime_to_float(t),
                    *ax.get_ylim(),
                    color="black",
                    linestyle=":",
                    lw=1
                )
        for t in time:
            if t.hour == 1.0 and t.minute == 0.0:
                if labels:
                    ycoo = ax.get_ylim()
                    ax.text(
                        datetime_to_float(t + datetime.timedelta(hours=11)),
                        -ycoo[1] / 3,
                        t.strftime("%A %-d %b"),
                        ha="center",
                    )

    # -----------------reading input files:-------------------------

    def get_only_valid_file(f):
        try:
            nc = netCDF4.Dataset(f)
            ts = extract_times(nc, None)
            if len(ts) == 0:
                nc = None
        except:
            nc = None
        return nc

    DS = [get_only_valid_file(f) for f in sorted(filenames)]
    DS = [x for x in DS if not x is None]

    (xmin, xmax), (ymin, ymax), xlon, xlat = corresponding_pixels(
        DS[0], location_lon, location_lat, deg_round_loc
    )

    time = np.array([ftimes(D) for D in DS])

    extent = (xmin, xmax, ymin, ymax)

    # -----------------reading data from input files-----------------------
    T2m = [fT2m(D, *extent) for D in DS]
    Tavg, Tmin, Tmax = [x(T2m, axis=(1, 2)) for x in [np.average, np.min, np.max]]

    Td2m = [fTd2m(D, *extent) for D in DS]
    Tdavg, Tdmin, Tdmax = [x(Td2m, axis=(1, 2)) for x in [np.average, np.min, np.max]]

    wind_speed = [wind10m(D, location_lat, location_lon)[0] for D in DS]
    wind_u = [wind_uv(D, location_lat, location_lon)[0] for D in DS]
    wind_v = [wind_uv(D, location_lat, location_lon)[1] for D in DS]
    mslp = [Mslp(D, location_lat, location_lon) for D in DS]

    drain = fdrain(DS, *extent)
    ravg = np.average(drain, axis=(1, 2))

    dsnow = fdsnow(DS, *extent)
    snow = np.average(dsnow, axis=(1, 2))

    conv = [cin_cape(D, location_lat, location_lon) for D in DS]
    cape = [x[1] for x in conv]
    cin = [x[0] for x in conv]

    w_vert = [vert_wind(D, location_lat, location_lon) for D in DS]
    u_vert = [x[0] for x in w_vert]
    v_vert = [x[1] for x in w_vert]

    T_vert = [vert_T(D, location_lat, location_lon) for D in DS]

    cf_vert = [list(fcfrac(D, location_lat, location_lon)) for D in DS]

    # figure settings:
    fig = plt.figure(figsize=(8.27 * 1.25, 11.69 * 1.25), num=1)
    plt.clf()

    mplotkw = {"ls": "-", "marker": "."}

    ax = plt.subplot2grid((8, 20), (0, 19), rowspan=3)
    ax0 = plt.subplot2grid((8, 20), (0, 0), colspan=19, rowspan=3)
    ax1 = plt.subplot2grid((8, 20), (3, 0), colspan=19, rowspan=2)
    ax2 = plt.subplot2grid((8, 20), (5, 0), colspan=19)
    ax3 = plt.subplot2grid((8, 20), (6, 0), colspan=19)
    ax4 = plt.subplot2grid((8, 20), (7, 0), colspan=19)

    # for the xticks:
    hours = np.array(list(map(lambda x: x.hour, time)))
    ind = list(np.where(hours % 6 == 0)[0])

    # 0.--------------vertical plot------------------------

    ax0.grid(axis="y")

    # (vertical) cloud fraction
    levels = np.arange(400, 1030, 30)
    cf_vert = np.matrix(cf_vert).transpose()
    cmap = LinearSegmentedColormap.from_list("", ["steelblue", "white"])
    time_mod = list(map(lambda x: datetime_to_float(x), time))
    cs = ax0.contourf(
        time_mod, levels, cf_vert, levels=[x / 8 for x in range(10)], cmap=cmap
    )
    ax0.contourf(time_mod, levels, cf_vert, levels=[x / 8 for x in range(9)], cmap=cmap)
    cbar = fig.colorbar(cs, cax=ax, orientation="vertical")
    cbar.set_ticks([x / 8 for x in range(0, 9)])
    try:
        cbar.ax.set_yticklabels(["{}/8".format(x) for x in range(0, 9)])
    except:
        print("Plotting tick labels failed")
    ax.set_ylabel("cloud cover")

    # (vertical) temperature contours
    levels = np.arange(400, 1060, 60)
    T_vert = np.matrix(T_vert).transpose()
    cs = ax0.contour(time_mod, levels, T_vert, cmap="plasma")
    ax0.clabel(cs, fmt="%1.0f", fontsize=10)

    # (vertical) wind barbs:
    for i, j in enumerate(time_mod[:]):
        if i % (len(time) // 24) == 0:
            for k, l in enumerate(levels):
                ax0.barbs(
                    j,
                    l,
                    u_vert[i][k],
                    v_vert[i][k],
                    length=5,
                    pivot="middle",
                    linewidth=0.5,
                )

    ax0.set_yticks(np.arange(400, 1050, 50))
    ax0.set_ylim(980, 400)
    ax0.set_xlim(
        datetime_to_float(time[0] - datetime.timedelta(hours=3 * len(time) / (24 * 3))),
        datetime_to_float(
            time[-1] + datetime.timedelta(hours=3 * len(time) / (24 * 3))
        ),
    )
    ax0.set_ylabel("pressure [hPa]")
    try:
        ax0.set_xticklabels([], minor=False)
        ax0.set_xticklabels([], minor=True)
    except:
        print("Plotting tick labels failed")
    ax0.set_xticks([time_mod[i] for i in ind])
    draw_vlines(ax0)

    # 1.--------------temperature plot------------------------

    ax1.fill_between(time_mod, Tdmin, Tdmax, alpha=0.3, color="lightblue")
    ax1.fill_between(time_mod, Tmin, Tmax, alpha=0.3, color="grey")
    ax1.grid(axis="y")
    ax1.plot(time_mod, Tdavg, color="blue", label="Td(2m)", **mplotkw)
    ax1.plot(time_mod, Tavg, color="black", label="T(2m)", **mplotkw)

    actual_day = time[0].day
    actual_month = time[0].month
    day = 0

    #   min and max temperature for the current day
    for i, j in enumerate(time):

        if time[i - 1].month != actual_month and (i != 0):
            day += 1
            actual_month = time[i - 1].month

        elif (
            (day != time[i - 1].day - actual_day)
            and (i != 0)
            and (time[i - 1].day - actual_day >= 0)
        ):
            day += 1

        try:
            if Tavg[i] == min(Tavg[day * 24 : (day + 1) * 24]):
                ax1.text(
                    datetime_to_float(time[i]),
                    Tavg[i] + 2,
                    str(np.round(Tavg[i], 1)),
                    color="blue",
                    ha="center",
                )
            elif Tavg[i] == max(Tavg[day * 24 : (day + 1) * 24]):
                ax1.text(
                    datetime_to_float(time[i]),
                    Tavg[i] + 1.5,
                    str(np.round(Tavg[i], 1)),
                    color="red",
                    ha="center",
                )
        except:
            pass

    ax1.set_ylabel(r"temp. [$^\circ$C]")
    ax1.legend(loc="best", fontsize="small", ncol=3)
    ax1.set_ylim(min(min(Tmin), min(Tdmin)) - 3, max(max(Tmax), max(Tdmax)) + 7.5)
    ax1.set_xlim(
        datetime_to_float(time[0] - datetime.timedelta(hours=3 * len(time) / (24 * 3))),
        datetime_to_float(
            time[-1] + datetime.timedelta(hours=3 * len(time) / (24 * 3))
        ),
    )
    try:
        ax1.xaxis.set_ticklabels([], minor=True)
        ax1.xaxis.set_ticklabels([], minor=False)
    except:
        print("Plotting tick labels failed")
    ax1.set_xticks([time_mod[i] for i in ind])
    draw_vlines(ax1)

    # 2.--------------10m wind plot------------------------

    ax2.grid(axis="y")
    ax2.plot(time_mod, wind_speed, color="black", **mplotkw)
    ax2.set_xlim(
        datetime_to_float(time[0] - datetime.timedelta(hours=3 * len(time) / (24 * 3))),
        datetime_to_float(
            time[-1] + datetime.timedelta(hours=3 * len(time) / (24 * 3))
        ),
    )
    for i, j in enumerate(time_mod):
        if i % (len(time) // 24) == 0:
            ax2.text(
                j,
                wind_speed[i] - 4.9,
                str(int(round(wind_speed[i]))),
                color="black",
                ha="center",
            )
            ax2.barbs(
                j,
                max(15, max(wind_speed) + 5),
                wind_u[i],
                wind_v[i],
                length=5,
                pivot="middle",
                linewidth=0.5,
            )
    ax2.set_ylabel(r"10 m ws [ms$^{-1}$]")
    ax2.set_ylim(-4.9, max(20, max(wind_speed) + 8))
    try:
        ax2.xaxis.set_ticklabels([], minor=True)
        ax2.xaxis.set_ticklabels([], minor=False)
    except:
        print("Plotting tick labels failed")
    ax2.set_xticks([time_mod[i] for i in ind])
    draw_vlines(ax2)

    # 3.--------------mslp plot------------------------

    ax3.grid(axis="y")
    ax3.plot(time_mod, mslp, color="red", **mplotkw)
    ax3.set_ylabel(r"MSLP [hPa]")
    try:
        ax3.xaxis.set_ticklabels([], minor=True)
        ax3.xaxis.set_ticklabels([], minor=False)
    except:
        print("Plotting tick labels failed")
    ax3.set_xlim(
        datetime_to_float(time[0] - datetime.timedelta(hours=3 * len(time) / (24 * 3))),
        datetime_to_float(
            time[-1] + datetime.timedelta(hours=3 * len(time) / (24 * 3))
        ),
    )
    ax3.set_ylim(min(mslp) - 2, max(mslp) + 2)
    ax3.ticklabel_format(useOffset=False, axis="y")
    ax3.set_xticks([time_mod[i] for i in ind])
    draw_vlines(ax3)

    # 4.--------------prec. plot------------------------
    # rain:

    ax4.grid(axis="y")
    ax4.bar(time_mod, ravg, color="c", label="tot. prec.", width=3600)
    ax4.set_ylabel(r"total prec. [mm/h]")
    ax4.set_ylim(0, max(max(ravg), 3) + 1)
    ax4.set_xlim(
        datetime_to_float(time[0] - datetime.timedelta(hours=3 * len(time) / (24 * 3))),
        datetime_to_float(
            time[-1] + datetime.timedelta(hours=3 * len(time) / (24 * 3))
        ),
    )
    ax4.set_xticks([time_mod[i] for i in ind], minor=True)
    ax4.set_yticks(
        [x for x in range(0, max(int(max(ravg)), 3) + 2, max(int(max(ravg)), 3) // 3)]
    )
    try:
        ax4.xaxis.set_ticklabels(
            ["00", "06", "12", "18"] * (len(time) // 24), minor=True
        )
    except:
        print("Plotting tick labels failed")
    draw_vlines(ax4, labels=True)

    # print a red (magenta) thunderstorm symbol if cape > 500 (> 1500) Jkg-1 and cin < 100 Jkg-1
    prop = FontProperties()
    prop.set_file("/usr/share/fonts/truetype/freefont/FreeMono.ttf")
    for i, j in enumerate(time_mod):
        if (cin[i] < 100 or np.isnan(cin[i])) and (cape[i] >= 1500):
            ax4.text(
                datetime_to_float(time[i]),
                ax4.get_ylim()[1] * 0.62,
                "\u2608",
                fontproperties=prop,
                color="magenta",
                ha="center",
                fontsize="large",
            )
        elif (cin[i] < 100 or np.isnan(cin[i])) and (cape[i] >= 500):
            ax4.text(
                datetime_to_float(time[i]),
                ax4.get_ylim()[1] * 0.62,
                "\u2608",
                fontproperties=prop,
                color="red",
                ha="center",
                fontsize="large",
            )

    # snow:
    ax5 = ax4.twinx()
    ax5.bar(time_mod, snow, color="magenta", label="snow", width=3600)
    ax5.set_ylabel(r"snow [cm/h]")
    ax5.set_ylim(0, max(max(ravg), 3) + 1)
    ax5.set_xlim(
        datetime_to_float(time[0] - datetime.timedelta(hours=3 * len(time) / (24 * 3))),
        datetime_to_float(
            time[-1] + datetime.timedelta(hours=3 * len(time) / (24 * 3))
        ),
    )
    ax5.set_yticks(
        [
            x
            for x in range(max(int(max(ravg)), 3) + 2)
            if x % (max(int(max(ravg)), 3) / 3) == 0
        ]
    )
    ax5.set_xticks([time_mod[i] for i in ind])
    try:
        ax5.yaxis.set_ticklabels([])
        ax5.xaxis.set_ticklabels([])
    except:
        print("Plotting tick labels failed")
    lines, labels = ax4.get_legend_handles_labels()
    lines2, labels2 = ax5.get_legend_handles_labels()
    ax4.legend(lines + lines2, labels + labels2, loc="best", fontsize="small", ncol=2)

    # ---------------additional-----------------

    date = datetime.datetime.utcnow().strftime("%H:%M, %B %d, %Y, UTC")
    ax0.text(
        1,
        1.01,
        "{0:} \n experimental WRF setup \n $\Delta x = {1:.3f}$ km".format(
            date, compute_average_resolution(DS[0], *extent) / 1e3
        ),
        fontsize="x-small",
        transform=ax0.transAxes,
        horizontalalignment="right",
    )

    plt.figtext(
        0.483,
        0.948,
        r"Christoph Knote (MBEES, Uni Augsburg), Fabian Jakub, Matjaz Puh (Meteorological Institute, LMU Munich)",
        fontsize="x-small",
        alpha=0.4,
    )
    ax0.text(
        0.05,
        1.15,
        "MIM Forecast",
        transform=ax0.transAxes,
        fontsize="xx-large",
        horizontalalignment="left",
    )
    ax0.text(
        0.05,
        1.05,
        " Meteogram @{0:.5f},{1:.5f} \n using {2:} pts for output statistics".format(
            location_lon, location_lat, np.size(xlon)
        ),
        transform=ax0.transAxes,
        fontsize="small",
        horizontalalignment="left",
    )

    # plt.ticklabel_format(useoffset=False, axis='y')
    plt.savefig(outfile, bbox_inches="tight")


if __name__ == "__main__":
    import datetime
    from argparse import ArgumentParser, ArgumentTypeError

    def valid_date(s):
        try:
            return datetime.datetime.strptime(s, "%Y-%m-%d")
        except ValueError:
            raise ArgumentTypeError("Not a valid date: '{0}'.".format(s))

    parser = ArgumentParser()
    parser.add_argument(
        "wrf_data_fpath_pattern",
        type=str,
        help="WRF input data path pattern (including strptime pattern for date and {domain:s} for domain (01, 02, ...))",
    )
    parser.add_argument(
        "start_date", type=valid_date, help="First date to process (YYYY-mm-dd)"
    )
    parser.add_argument(
        "end_date", type=valid_date, help="Last date to process (YYYY-mm-dd)"
    )
    parser.add_argument(
        "interval_hours", type=int, help="Time interval in hours for plotting"
    )
    parser.add_argument("domain", type=int, help="Which domain to process")
    parser.add_argument(
        "plot_fpath_pattern",
        type=str,
        help="Plot output path pattern (including strptime pattern for date and {plot:s} for plot identifier (e.g., 500hPa_europe...)",
    )
    parser.add_argument(
        "--spatial_patch_deg",
        type=float,
        default=0.1,
        help="Size of the patch (in degrees) to use for uncertainty calculations",
    )

    args = parser.parse_args()
    # args = parser.parse_args( [ "/alcc/gpfs2/scratch/mbees/knotechr/archive/WRF/operational_chemistry/wrfout_d{domain:02d}_%Y-%m-%d_%H:%M:%S", "2021-03-11", "2021-03-14", "3", "1", "{plot:s}_%Y%m%d%H.png" ])

    import numpy as np

    filenames = [
        d.strftime(args.wrf_data_fpath_pattern).format(domain=args.domain)
        for d in np.arange(
            args.start_date,
            args.end_date,
            datetime.timedelta(hours=args.interval_hours),
            dtype=datetime.datetime,
        )
    ]

    try:
        plot_var(
            filenames,
            args.start_date.strftime(args.plot_fpath_pattern.format(plot="met")),
            deg_round_loc=args.spatial_patch_deg,
        )
    except:
        import matplotlib.font_manager as fm

        print("Rebuilding font cache")
        fm._rebuild()
        plot_var(
            filenames,
            args.start_date.strftime(args.plot_fpath_pattern.format(plot="met")),
            deg_round_loc=args.spatial_patch_deg,
        )
