'''
Script to merge an existing WRF-Chem emission input file ("wrfchemi_*", NetCDF)
with data from the GRETA inventory, replacing existing data with GRETA values
where applicable.

Christoph Knote, MBEES, Uni Augsburg, Germany, 2022
'''

import argparse

parser = argparse.ArgumentParser(description='Emissary')
parser.add_argument('uba_greta_emission_db_path', type=str, help="Path to the GRETA geodatabase")
parser.add_argument('wrfinput_file_path', type=str, help="Path to the wrfinput file (includes vertical grid definition of WRF run)")
parser.add_argument('wrfchemi_file_path', type=str, help="Path to the wrfchemi file (pre-existing emission input file)")
parser.add_argument('output_file_path',  type=str)
parser.add_argument('--grid_mappings_path', type=str, default=None, help="Path where to cache grid associations between input grid and GRETA grid, defaults to uba_greta_emission_db_path directory")
parser.add_argument('--emission_data_tables_path', type=str, default=None, help="Path to attribute table data extracted as .csv from GRETA (using this is much faster than through the osgeo functions).")

args = parser.parse_args()

# Load libraries

import netCDF4
import numpy as np

from osgeo import ogr, osr
import os
import pickle
import datetime

from status import Status # only for status printing
from tools import gridFromWrfInput, netCDF4FromWrfInputAndWrfChemi

# Initialize drivers

fgdb_driver     = ogr.GetDriverByName("OpenFileGDB")
mem_driver      = ogr.GetDriverByName('MEMORY')

# Open GRETA files

fgdb            = fgdb_driver.Open(args.uba_greta_emission_db_path)
area_emissions  = fgdb.GetLayerByName('RASTER_EMI_Teil1')
point_emissions = fgdb.GetLayerByName('F_PRTR_RESULT_Teil1')

# (1) Get GRETA emission data attributes (the actual emission amounts from the attribute tables)

# reading these is actually the most time consuming operation, hence
# we first extracted them using QGIS into .csv files, which we
# read in below and save it into a .pickle (for even faster access)

# area emissions (RASTER_EMI in GRETA parlance)

if args.emission_data_tables_path is None:
    args.emission_data_tables_path = os.path.dirname( args.uba_greta_emission_db_path )
area_emissions_data_pickle = os.path.join( args.emission_data_tables_path, os.path.basename( args.uba_greta_emission_db_path ) + "_AREA_DATA.pickle" )

# re-create pickle from .csv data if it does not exist
if not os.path.exists( area_emissions_data_pickle ):
    area_emissions_data_raw = np.genfromtxt(args.uba_greta_emission_db_path + "_RASTER_EMI_Teil1.csv", delimiter=";", names=True)
    area_emissions_data = {}

    n       = len(area_emissions_data_raw)
    stats   = Status("Area emissions data", n)

    for linenum in range(n):
        stats.iterate()

        line                        = area_emissions_data_raw[linenum]
        oid                         = line['OBJECTID']
        area_emissions_data[oid]    = {}
        
        for name in [ l for l in line.dtype.names if not l == 'OBJECTID' ]:
            area_emissions_data[oid][name] = line[name]

    with open(area_emissions_data_pickle, "wb") as f:
        pickle.dump(area_emissions_data, file=f)

with open(area_emissions_data_pickle, "rb") as f:
    area_emissions_data = pickle.load(f)

# point emissions (F_PRTR_RESULT in GRETA parlance)

if args.emission_data_tables_path is None:
    args.emission_data_tables_path = os.path.dirname( args.uba_greta_emission_db_path )
point_emissions_data_pickle = os.path.join( args.emission_data_tables_path, os.path.basename( args.uba_greta_emission_db_path ) + "_POINT_DATA.pickle" )
if not os.path.exists( point_emissions_data_pickle ):
    point_emissions_data_raw = np.genfromtxt(args.uba_greta_emission_db_path + "_F_PRTR_RESULT_Teil1.csv", delimiter=";", names=True)
    point_emissions_data = {}

    n       = len(point_emissions_data_raw)
    stats   = Status("Point emissions data", n)

    for linenum in range(n):
        stats.iterate()

        line                        = point_emissions_data_raw[linenum]
        oid                         = line['OBJECTID']
        point_emissions_data[oid]    = {}
        
        for name in [ l for l in line.dtype.names if not l == 'OBJECTID' ]:
            point_emissions_data[oid][name] = line[name]

    with open(point_emissions_data_pickle, "wb") as f:
        pickle.dump(point_emissions_data, file=f)

with open(point_emissions_data_pickle, "rb") as f:
    point_emissions_data = pickle.load(f)

# (2) Figure out highest emission source

# In WRF-Chem, emissions input can be a 3D (x, y, z) file to add elevated point sources.
# No plume rise, instantaneous dilution to the grid cell, but better than nothing. In
# order to be able to do this, we need to figure out the level (interface) altitudes
# in the already existing wrfinput file.

max_emiss_height = np.max([ feature['emission_height'] for feature in point_emissions ])
point_emissions.ResetReading()

def find_level_in_wrfinput(z, nc, x, y):
    zbot = nc.variables["ZBOTTOM"][0,:,y,x]
    ztop = nc.variables["ZTOP"][0,:,y,x]
#
    level = -1
    for bot, top, lev in zip( zbot, ztop, range(len(zbot)) ):
        if bot <= z and top > z:
            level = lev
            break
#
    return level

# Create output file

grid_shp                = mem_driver.CreateDataSource('memData')
tmp                     = mem_driver.Open('memData',1)

# A collection of polygons describing the output ("raster")
# grid cells, and containing information about the indizes
# of each grid cell in the NetCDF file
output_grid             = gridFromWrfInput(grid_shp, args.wrfinput_file_path)

# An output file, here basically a copy of the wrfchemi input file
# with an additional 3rd dimension for elevated point sources
nc                      = netCDF4FromWrfInputAndWrfChemi(args.wrfinput_file_path, args.wrfchemi_file_path, args.output_file_path, zmax=max_emiss_height)

# just convention: all emission variables in the wrfchemi file start with "E_"
output_variables        = [ x for x in nc.variables if x.startswith("E_") ]

# those need special processing
output_nox_variables    = [ "E_NO", "E_NO2" ]
output_voc_variables    = [ "E_BIGALK", "E_BIGENE", "E_C2H4", "E_C2H5OH", "E_C2H6", "E_CH2O", "E_CH3CHO", "E_CH3COCH3", "E_CH3OH", "E_MEK", "E_TOLUENE", "E_C3H6", "E_C3H8", "E_ISOP", "E_C10H16", "E_C2H2", "E_BENZENE", "E_XYLENE", "E_GLY", "E_MACR", "E_MGLY", "E_MVK", "E_HCOOH" ]
output_pm25_variables   = [ "E_ECI", "E_ECJ", "E_ORGI", "E_ORGJ", "E_PM25I", "E_PM25J", "E_SO4I", "E_SO4J", "E_NO3I", "E_NO3J", "E_NH4I", "E_NH4J", "E_NAI", "E_NAJ", "E_CLI", "E_CLJ" ]

isnox   = lambda spec : spec in output_nox_variables
isvoc   = lambda spec : spec in output_voc_variables
ispm25  = lambda spec : spec in output_pm25_variables

# translate an output species to a GRETA input species
def translate_output_to_input_species(name):
    testname = name
    if isnox(testname):
        testname = "E_NOx" 
    if isvoc(testname):
        testname = "E_NMVOC" 
    if ispm25(testname):
        testname = "E_PM2_5" 
#
    translation_table = {
        "E_NMVOC": "E_SUM_NMVOC",
        "E_CO":    "E_SUM_CO",
        "E_NH3":   "E_SUM_NH3",
        "E_NOx":   "E_SUM_NOX",
        "E_SO2":   "E_SUM_SO2",
        "E_PM_10": "E_SUM_PM10",
        "E_PM2_5": "E_SUM_PM2_5"
    }
    if testname in translation_table:
        return translation_table[testname]
    else:
        return None

# Coordinate transformations

wgs_srs = osr.SpatialReference()
wgs_srs.ImportFromEPSG(4326)

transform_output_grid_to_emiss   = osr.CoordinateTransformation(output_grid.GetSpatialRef(), area_emissions.GetSpatialRef())
transform_emiss_to_output_grid   = osr.CoordinateTransformation(area_emissions.GetSpatialRef(), output_grid.GetSpatialRef())
transform_output_grid_to_wgs84   = osr.CoordinateTransformation(output_grid.GetSpatialRef(), wgs_srs)

# Unit conversions

def unit_conversion_factor_GRETA_to_WRF(species, area):
    # gases     in GRETA: kt a^-1  in WRF: mol km^-2 hr^-1
    # particles in GRETA: kt a^-1  in WRF: ug m^-2 s^-1
    mws_greta = {
        "E_SUM_NMVOC": 12.01e-3,
        "E_SUM_CO":    28.01e-3,
        "E_SUM_NH3":   17.03e-3,
        "E_SUM_NOX":   46.05e-3,
        "E_SUM_SO2":   64.06e-3
    }

    fac = -1.0
    if species in [ "E_SUM_NMVOC", "E_SUM_CO", "E_SUM_NH3", "E_SUM_NOX", "E_SUM_SO2" ]:
        # for gases: kt -> kg, kg -> mol, mol km^-2,   a^-1 -> h^-1
        fac =      1e6   * 1.0 / mws_greta[species] * 1e6 / area * 1.0 / (24*365)
    elif species in [ "E_SUM_PM10", "E_SUM_PM2_5" ]:
        # for particles: kt -> ug,    ug m^-2,   a^-1 -> s^-1
        fac = 1e15  *  1.0 / area * 1.0 / (365*86400)
    else:
        print("UNKNOWN species {:s}".format(species))
    
    return fac

# "Payload function" to update a grid cell with new emissions
# Here, we know already:
#  - which output grid cell to fill (x, y, z)
#  - which input cell provides data (input_oid)
#  - the size of the output grid cell (for unit conversion)
#  - the fraction of the input grid cell area that 
#    intersects with the output grid cell (to calculate fraction of GRETA emissions to be put in this cell)
#  - the fraction of the output grid cell area 
#    covered by the intersecting input grid cell (to calculate the relative contributions of existing emissions and new GRETA emissions)
#  - whether we have a E_SUM_* input variable (for area data) 
#    or we need to sum up all E_[source sector]_* values (for point data)
#  - whether we replace (fractions of) the output grid cell with new input data (area)
#    or just add up (points)

def add_emissions(output_data, t, z, y, x, input_oid, output_cell_area, input_data, greta_area_fraction, wrf_area_fraction, needToSumInputVariables=False, cumulative=False):
#
    # loop through all possible emission output variables
    for output_variable in output_variables:
        # see if we have input
        input_variable = translate_output_to_input_species(output_variable)
#
        if not input_variable is None:
#            print("Translating {:s} into {:s}".format(input_variable, output_variable))

            speciation_fraction = 1.0
#
            existing_value = output_data[output_variable][t,z,y,x]
#
            group_variables = [ output_variable ]
#
            if isnox(output_variable):
                group_variables = output_nox_variables
#
            if isvoc(output_variable):
                group_variables = output_voc_variables
#
            if ispm25(output_variable):
                group_variables = output_pm25_variables
#
            # always use ground level for speciation (also for elevated point sources...)
            group_total         = np.sum( [ output_data[group_variable][t,0,y,x] for group_variable in group_variables ] )
#
            speciation_fraction = 1.0
            if group_total > 0.0:
                # should be 1.0 if not isnox or isvoc or ispm25 ...
                speciation_fraction = output_data[output_variable][t,0,y,x] / group_total
#
            def get_input_value(input_variable, speciation_fraction, needToSumInputVariables):
                input_variables_raw = [ input_variable ]
                if needToSumInputVariables:
                    input_variables_raw = [ field for field in input_data[input_oid].keys() if field.endswith(input_variable.replace("E_SUM", "")) ]
#
                return np.sum( [ speciation_fraction * input_data[input_oid][ivar] for ivar in input_variables_raw ] )
#
            input_value = get_input_value( input_variable, speciation_fraction, needToSumInputVariables)
#
            if input_variable == "E_SUM_PM10":
                # PM 10 is actually the difference between PM 10 and PM 2.5 ...
                pm25_value = get_input_value( "E_SUM_PM2_5", speciation_fraction, needToSumInputVariables)
                input_value -= pm25_value
#            
            input_value    = unit_conversion_factor_GRETA_to_WRF(input_variable, output_cell_area) * input_value * greta_area_fraction
#
            if cumulative:
                # point sources are not (fractionally) replacing existing data, but add up
                output_data[output_variable][t,z,y,x] += input_value
            else:
                # area sources replace existing values (fractionally)
                output_data[output_variable][t,z,y,x]  = (1.0 - wrf_area_fraction) * existing_value + wrf_area_fraction * input_value
#                if output_variable == "E_PM25I":
#                    print( existing_value, unit_conversion_factor_GRETA_to_WRF(input_variable, output_cell_area), wrf_area_fraction, input_value )
#        else:
#            print("Cannot translate output species: {:s}".format(output_variable))

    output_data['GRETA_fraction'][t, z, y, x] += wrf_area_fraction

    return output_data

# (*) Grid mappings

# Here we determine the corresponding input and output grid cells.
# Can be cached as long as grids do not change.

# area emissions (RASTER_EMI in GRETA parlance)

if args.grid_mappings_path is None:
    args.grid_mappings_path = os.path.dirname( args.wrfchemi_file_path )
area_grid_mapping_pickle = os.path.join( args.grid_mappings_path, os.path.basename( args.uba_greta_emission_db_path ) + "_" + os.path.basename( args.wrfchemi_file_path ) + "_AREA_grid_mapping.pickle" )

if not os.path.exists( area_grid_mapping_pickle ):

    n       = output_grid.GetFeatureCount()
    stats   = Status("Area grid mapping", n)

    area_gridcell_mapping = {}

    for output_gridcell in output_grid:
        stats.iterate()
    #
        output_gridcell_geometry = output_gridcell.geometry().Clone()
        output_gridcell_geometry.Transform(transform_output_grid_to_emiss) # output_gridcell now in Emiss coords

        output_gridcell_area = output_gridcell_geometry.GetArea()
    #
        # only look at those features that intersect with geom
        area_emissions.SetSpatialFilter(output_gridcell_geometry)
    #
        y = output_gridcell.GetField("y"); x = output_gridcell.GetField("x")
    #
        output_fid = output_gridcell.GetFID()
        area_gridcell_mapping[output_fid] = { "x": x, "y": y, "z": 0, "area": output_gridcell_area, "input": {} }

        for emiss_gridcell in area_emissions:
            emiss_gridcell_geometry = emiss_gridcell.geometry()

            intersection        = output_gridcell_geometry.Intersection(emiss_gridcell_geometry)
            intersection_area   = intersection.GetArea()

            # (1) fraction of the GRETA grid cell in this WRF grid cell
            #     --> emitted amounts are to be reduced to this fraction of the total
            greta_area_fraction = intersection_area / emiss_gridcell_geometry.GetArea()
            # (2) fraction of the WRF grid cell covered by the intersecting part of the emissions grid cell
            #     --> this determines the relative contributions of existing emissions and the new stuff coming here
            wrf_area_fraction   = intersection_area / output_gridcell_area
    #
            emiss_oid = emiss_gridcell.GetFID()
            area_gridcell_mapping[output_fid]["input"][emiss_oid] = { "greta_area_fraction": greta_area_fraction, "wrf_area_fraction": wrf_area_fraction }
    #    
        area_emissions.ResetReading()

    output_grid.ResetReading()

    with open(area_grid_mapping_pickle, "wb") as f:
        pickle.dump(area_gridcell_mapping, file=f)

with open(area_grid_mapping_pickle, "rb") as f:
    area_gridcell_mapping = pickle.load(f)

# point emissions (F_PRTR_RESULT in GRETA parlance)

if args.grid_mappings_path is None:
    args.grid_mappings_path = os.path.dirname( args.wrfchemi_file_path )
point_grid_mapping_pickle = os.path.join( args.grid_mappings_path, os.path.basename( args.uba_greta_emission_db_path ) + "_" + os.path.basename( args.wrfchemi_file_path ) + "_POINT_grid_mapping.pickle" )

if not os.path.exists( point_grid_mapping_pickle ):

    n       = point_emissions.GetFeatureCount()
    stats   = Status("Point grid mapping", n)

    point_gridcell_mapping = {}

    for point in point_emissions:
        stats.iterate()
#
        geom = point.geometry().Clone()
        geom.Transform(transform_emiss_to_output_grid) # point now in grid coords
#
        # find the feature that contains the point (should be only one, the target output_grid output_gridcell)
        cellCount=0
        for output_gridcell in output_grid:
            output_gridcell_geometry = output_gridcell.geometry()
            if geom.Within(output_gridcell_geometry):
#
                t = 0; z = 0; y = output_gridcell.GetField("y"); x = output_gridcell.GetField("x")
                z = find_level_in_wrfinput( point['emission_height'], nc, x, y )
#
                output_fid = output_gridcell.GetFID()
                emiss_oid = point.GetFID()
                point_gridcell_mapping[output_fid] = { "x": x, "y": y, "z": z, "area": output_gridcell_geometry.GetArea(), 
                                                        "input": { emiss_oid: { "greta_area_fraction": 1.0, "wrf_area_fraction": 1.0 }} }
#
                cellCount += 1
#
        if cellCount  > 1:
            print("Something fishy with feature {:s}, multiple assignments!".format(point['plant_name']))
            exit
#
        output_grid.ResetReading()

    with open(point_grid_mapping_pickle, "wb") as f:
        pickle.dump(point_gridcell_mapping, file=f)

with open(point_grid_mapping_pickle, "rb") as f:
    point_gridcell_mapping = pickle.load(f)

# (*) Update emission data

# As soon as processing for a input / output grid combination was run once,
# everything done up to this point should be cached already.

# reading in data from the output NetCDF into a dictionary saves computation time
output_data = { name: nc[name][:] for name in nc.variables if name.startswith("E_") }

# diagnostic output, to see which fraction of each grid cell was replaced with GRETA data
output_data['GRETA_fraction'] = np.zeros_like(nc['E_NO2'][:])

n       = len(area_gridcell_mapping)
stats   = Status("Area sources", n)

for output_cellmapping in area_gridcell_mapping.values():
    stats.iterate()

    t = 0; z = 0; y = output_cellmapping["y"]; x = output_cellmapping["x"]
#
    for emiss_oid, input_relations in output_cellmapping["input"].items():
        output_data = add_emissions(output_data, t, z, y, x, emiss_oid, output_cellmapping["area"], area_emissions_data, input_relations["greta_area_fraction"], input_relations["wrf_area_fraction"])

# point sources, elevated

n       = point_emissions.GetFeatureCount()
stats   = Status("Point sources", n)

for output_cellmapping in point_gridcell_mapping.values():
    stats.iterate()

    t = 0; z = output_cellmapping["z"]; y = output_cellmapping["y"]; x = output_cellmapping["x"]
#
    for emiss_oid, input_relations in output_cellmapping["input"].items():
        output_data = add_emissions(output_data, t, z, y, x, emiss_oid, output_cellmapping["area"], point_emissions_data, input_relations["greta_area_fraction"], input_relations["wrf_area_fraction"], needToSumInputVariables=True, cumulative=True)

# putting data back into the NetCDF
for name in nc.variables:
    if name.startswith("E_"):
        nc[name][:] = output_data[name][:]

nc['GRETA_fraction'][:] = output_data['GRETA_fraction'][:]

nc.close()
