from osgeo import ogr, osr

import netCDF4
import numpy as np
import datetime
import wrf

def netCDF4FromWrfInputAndWrfChemi(inFile, chemiFile, outFile, zmax=0.0):
    src_in      = netCDF4.Dataset(inFile, "r", format="NETCDF4")

    src_chemi   = netCDF4.Dataset(chemiFile, "r", format="NETCDF4")

    dst         = netCDF4.Dataset(outFile, "w", format="NETCDF4")

    # figure out required amount of vertical levels
    nztot  = src_in.dimensions['bottom_top'].size
    zbot   = wrf.getvar(src_in, "zstag", units="m")[range(nztot-1),:,:] - wrf.getvar(src_in, "terrain")
    ztop   = wrf.getvar(src_in, "zstag", units="m")[range(1,nztot),:,:] - wrf.getvar(src_in, "terrain")

    levmax = 0
    while np.min(ztop[levmax,:,:]) < zmax and levmax < nztot:
        levmax += 1

    emiss_dimlen_in  = src_chemi.dimensions['emissions_zdim_stag'].size
    emiss_dimlen_out = levmax + 1

    emiss_dimlen_out = max( emiss_dimlen_out, emiss_dimlen_in )

    zbot = zbot[range(emiss_dimlen_out),:,:]
    ztop = ztop[range(emiss_dimlen_out),:,:]

    # copy attributes
    for name in src_chemi.ncattrs():
        dst.setncattr(name, src_chemi.getncattr(name))

    # copy / create dimensions
    for name, dimension in src_chemi.dimensions.items():
        dimlen = len(dimension)
        if name == "emissions_zdim_stag":
            dimlen = emiss_dimlen_out
        if dimension.isunlimited():
            dimlen = None
        dst.createDimension(name, (dimlen))

    for name, var in src_chemi.variables.items():
        x = dst.createVariable(name, var.datatype, var.dimensions)
        dst[name][:] = 0.0 # safeguard
        # emission vertical coordinate length has changed!
        if name.startswith("E_"):
            dst[name][:,range(emiss_dimlen_in),:,:] = src_chemi[name][:,:,:,:]
        else:
            dst[name][:] = src_chemi[name][:]
        dst[name].setncatts(src_chemi[name].__dict__)

    dst.createVariable("ZBOTTOM", src_chemi.variables["XLONG"].datatype, ('Time', 'emissions_zdim_stag', 'south_north', 'west_east'))
    dst.createVariable("ZTOP",    src_chemi.variables["XLONG"].datatype, ('Time', 'emissions_zdim_stag', 'south_north', 'west_east'))

    # has to be done to avoid masking shit
    dst.variables["ZBOTTOM"][:] = 0.0
    dst.variables["ZTOP"][:]    = 0.0

    dst.variables['ZBOTTOM'][0,:,:,:] = zbot
    dst.variables['ZTOP'][0,:,:,:]    = ztop

    dst.createVariable("GRETA_fraction", src_chemi.variables["XLONG"].datatype, ('Time', 'emissions_zdim_stag', 'south_north', 'west_east'))
    dst.variables['GRETA_fraction'][0,:,:,:] = 0.0

    src_in.close()
    src_chemi.close()

    return(dst)

def gridFromWrfInput(grid_shp, fname):
    import pyproj
    import netCDF4
    import numpy as np

    ds = netCDF4.Dataset(fname)

    if not ds.MAP_PROJ == 1:
        raise("Can only work with LCC projection!")

    wrf_proj = pyproj.Proj(proj='lcc', # projection type: Lambert Conformal Conic
                        lat_1=ds.TRUELAT1, lat_2=ds.TRUELAT2, # Cone intersects with the sphere
                        lat_0=ds.MOAD_CEN_LAT, lon_0=ds.STAND_LON, # Center point
                        a=6370000, b=6370000)

    wgs_proj = pyproj.Proj(proj='latlong', datum='WGS84')
    e, n = pyproj.transform(wgs_proj, wrf_proj, ds.CEN_LON, ds.CEN_LAT)

    dx, dy = ds.DX, ds.DY
    nx, ny = len(ds.dimensions['west_east']), len(ds.dimensions['south_north'])
    # Down left corner of the domain
    x0 = -(nx-1) / 2. * dx + e
    y0 = -(ny-1) / 2. * dy + n

    ds.close()

    srs             = osr.SpatialReference()
    srs.ImportFromProj4(wrf_proj.srs)
    grid            = grid_shp.CreateLayer('grid', srs, geom_type=ogr.wkbPolygon)
    grid.CreateField(ogr.FieldDefn("x",     ogr.OFTInteger))
    grid.CreateField(ogr.FieldDefn("y",     ogr.OFTInteger))

    featureDefn     = grid.GetLayerDefn()

    for x in np.arange(nx):
        for y in np.arange(ny):
            ringXleft   = x0 +  x    * dx
            ringXright  = x0 + (x+1) * dx
            ringYbottom = y0 +  y    * dy
            ringYtop    = y0 + (y+1) * dy

            ring = ogr.Geometry(ogr.wkbLinearRing)
            ring.AddPoint(ringXleft, ringYtop)
            ring.AddPoint(ringXright, ringYtop)
            ring.AddPoint(ringXright, ringYbottom)
            ring.AddPoint(ringXleft, ringYbottom)
            ring.AddPoint(ringXleft, ringYtop)
            poly = ogr.Geometry(ogr.wkbPolygon)
            poly.AddGeometry(ring)

            # add new geom to layer
            outFeature = ogr.Feature(featureDefn)
            outFeature.SetField("x", int(x))
            outFeature.SetField("y", int(y))
            outFeature.SetGeometry(poly)
            grid.CreateFeature(outFeature)
            outFeature = None
    
    return grid

