from argparse import ArgumentParser, ArgumentTypeError


def valid_species_tuple(input):
    try:
        return tuple(input.split(","))
    except:
        ArgumentTypeError("Not a valid species tuple: {:s}".format(input))


parser = ArgumentParser(description="Create GeoTIFF from WRF(-Chem) NetCDFs.")
parser.add_argument("input_path", type=str, help="wrfout NetCDF file path")
parser.add_argument(
    "output_path_pattern",
    type=str,
    help="geotiff output file path with {species:s} pattern for naming.",
)
parser.add_argument(
    "-s",
    "--species",
    type=valid_species_tuple,
    action="append",
    required=True,
    help='input/output species tuple, separated by comma ("PM2_5_DRY,PM2.5")',
)

args = parser.parse_args()
# args = parser.parse_args(['-s', 'PM2_5_DRY,PM2.5', 'test.nc', 'moek_{species:s}.tif'])

import netCDF4
import numpy as np
from osgeo import osr, ogr, gdal

nc = netCDF4.Dataset(args.input_path)

lon_0 = nc.STAND_LON
lat_0 = nc.MOAD_CEN_LAT
truelat1 = nc.TRUELAT1
truelat2 = nc.TRUELAT2

# WRF projection string
wrf_proj_string = "+proj=lcc +lat_1={truelat1:} +lat_2={truelat2:} +lat_0={lat_0:} +lon_0={lon_0:} +datum=WGS84 +ellps=sphere +a=6370000 +b=6370000 +units=m +no_defs".format(
    truelat1=truelat1, truelat2=truelat2, lon_0=lon_0, lat_0=lat_0
)

wrf_srs = osr.SpatialReference()
wrf_srs.ImportFromProj4(wrf_proj_string)

# corner points
wgs84_srs = osr.SpatialReference()
wgs84_srs.ImportFromEPSG(4326)

transform = osr.CoordinateTransformation(wgs84_srs, wrf_srs)

xll_wgs84 = nc["XLONG"][0, 0, 0]
yll_wgs84 = nc["XLAT"][0, 0, 0]

point = ogr.CreateGeometryFromWkt("POINT ({y:} {x:})".format(x=xll_wgs84, y=yll_wgs84))
point.Transform(transform)

xmin_wrf = point.GetX()
ymin_wrf = point.GetY()

del point

xur_wgs84 = nc["XLONG"][0, -1, -1]
yur_wgs84 = nc["XLAT"][0, -1, -1]

point = ogr.CreateGeometryFromWkt("POINT ({y:} {x:})".format(x=xur_wgs84, y=yur_wgs84))
point.Transform(transform)

xmax_wrf = point.GetX()
ymax_wrf = point.GetY()

del point

xmin_wgs84 = np.min(nc["XLONG"])
ymin_wgs84 = np.min(nc["XLAT"])
xmax_wgs84 = np.max(nc["XLONG"])
ymax_wgs84 = np.max(nc["XLAT"])

# fields
vars = {name: nc[variable][0, 0, ::-1, :] for variable, name in args.species}

nc.close()

var = vars[args.species[0][1]]

nx = var.shape[0]
ny = var.shape[1]

xres = (xmax_wgs84 - xmin_wgs84) / float(nx)
yres = (ymax_wgs84 - ymin_wgs84) / float(ny)
geotransform_wgs84 = (xmin_wgs84, xres, 0, ymax_wgs84, 0, -yres)

xres = (xmax_wrf - xmin_wrf) / float(nx)
yres = (ymax_wrf - ymin_wrf) / float(ny)
geotransform_wrf = (xmin_wrf, xres, 0, ymax_wrf, 0, -yres)

for name, var in vars.items():
    wrf_ds = gdal.GetDriverByName("MEM").Create("", ny, nx, 1, gdal.GDT_Float32)
    wrf_ds.SetGeoTransform(geotransform_wrf)  # specify coords
    wrf_ds.SetProjection(wrf_srs.ExportToWkt())  # export coords to file
    wrf_ds.GetRasterBand(1).WriteArray(var)  # write r-band to the raster
    #
    wgs84_ds = gdal.GetDriverByName("GTiff").Create(
        args.output_path_pattern.format(species=name), ny, nx, 1, gdal.GDT_Float32
    )
    wgs84_ds.SetGeoTransform(geotransform_wgs84)  # specify coords
    wgs84_ds.SetProjection(wgs84_srs.ExportToWkt())  # export coords to file
    wgs84_ds.GetRasterBand(1).SetNoDataValue(-9999.0)
    wgs84_ds.GetRasterBand(1).Fill(-9999.0)
    #
    gdal.ReprojectImage(
        wrf_ds,
        wgs84_ds,
        wrf_srs.ExportToWkt(),
        wgs84_srs.ExportToWkt(),
        gdal.GRA_Bilinear,
    )
    #
    wgs84_ds.FlushCache()  # write to disk
    wgs84_ds = None  # save, close
