#!/usr/bin/env python3
# -*- coding: utf-8 -*-

'''Prints differences between two files per variable (and optionally per
dimension). Skips non-numerical variables. Compared variables must have
identical dimensions.

This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with
this program. If not, see <https://www.gnu.org/licenses/>.

Copyright 2019-2025 Pavel Krc <src@pkrc.net>
'''
_, __date__, __author__, _ = __doc__.rsplit(' ', 3)
__version__ = '2.0.1'

import sys
import argparse
import numpy as np
import netCDF4

# comparison results
EQUAL, DIFFERENT, ALMOST, SKIPPED = range(4)

def print_v(v, s):
    print('{}: min={} max={} mean={}'.format(
        s, v.min(), v.max(), v.mean()))

def cmp_v(vn, v1, v2, sl):
    x1 = v1[sl]
    m1 = np.ma.getmask(x1)
    x2 = v2[sl]
    m2 = np.ma.getmask(x2)
    if m1 is not np.ma.nomask or m2 is not np.ma.nomask:
        m1 = np.ma.getmaskarray(x1)
        m2 = np.ma.getmaskarray(x2)
        if not np.array_equiv(m1, m2):
            print()
            print('Inconsistent masks for variable {} (F1: {}/{}, F2: {}/{})'
                    .format(vn, m1.sum(), m1.size, m2.sum(), m2.size))
            return DIFFERENT

    veq = np.ma.asanyarray(x1 == x2)
    numtot = veq.count()
    if numtot == 0:
        print()
        print('Empty (or fully masked) variable {} (total size={})'.format(vn, veq.size))
        return EQUAL

    numeq = veq.sum()
    del veq
    if numeq == numtot:
        if not args.brief:
            print()
            print('Equal variable {}:'.format(vn))
            print_v(x1, 'F1=F2')
        return EQUAL

    d = x2 - x1
    if args.max_diff_coords:
        amn = np.unravel_index(np.argmin(d), d.shape)
        mn = d[amn]
        amx = np.unravel_index(np.argmax(d), d.shape)
        mx = d[amx]
    else:
        mn = d.min()
        mx = d.max()
    mxx = max(-mn, mx)

    status = DIFFERENT
    if args.accept_below_sd is not None:
        sdb = 1. / np.std(x1)
        if mxx*sdb < args.accept_below_sd:
            status = ALMOST
            if args.brief:
                return status

    numdif = numtot - numeq
    print()
    print('{} ({} / {} = {:.2f} %) variable {}:'.format(
        ('Different', 'Almost equal')[status-1],
        numdif, numtot, 100.*numdif/numtot, vn))
    print_v(x1, 'F1')
    print_v(x2, 'F2')

    da = np.abs(d)
    rmse = np.sqrt(np.square(d).mean())
    print('F2-F1: min={} max={} bias={} MAE={} RMSE={} min_abs={}'.format(
        mn, mx, d.mean(), da.mean(), rmse, da.min()))
    if args.max_diff_coords:
        print('Coords* of min F2-F1: {}'.format(amn))
        print('Coords* of max F2-F1: {}'.format(amx))
    if args.percent_diff:
        mab = 1. / np.abs(x1).max()
        if args.accept_below_sd is None:
            sdb = 1. / np.std(x1)
        print('max(abs(F2-F1)) = {0} * max(abs(F1)) = {1} * SD(F1)'.format(mxx*mab, mxx*sdb))
        print('RMSE(F2-F1) = {0} * max(abs(F1)) = {1} * SD(F1)'.format(rmse*mab, rmse*sdb))

    return status

def compare(vn):
    v1 = f1.variables[vn]
    try:
        v2 = f2.variables[vn]
    except KeyError:
        sys.stderr.write('F2 is missing variable {}.\n'.format(vn))
        var_bins[DIFFERENT].append(vn)
        return

    if v1.dtype.kind not in ['f', 'i']:
        var_bins[SKIPPED].append(vn)
        return
    if not v1.shape:
        var_bins[SKIPPED].append(vn) #ignore scalars
        return

    if dims:
        dd = [(d, i) for i, d in enumerate(v1.dimensions) if d in dims]
        if not dd:
            # No sliced dims, just compare full var
            var_bins[cmp_v(vn, v1, v2, slice(None))].append(vn)
            return
        if len(dd) == len(v1.dimensions):
            var_bins[SKIPPED].append(vn) #would become scalar - ignore
            return

        bases = [dims[d][0] for d, i in dd]
        shp = tuple(dims[d][1]-dims[d][0]+1 for d, i in dd)
        # Iterate over shape of sliced dims
        for idx in np.ndindex(shp):
            sl = [slice(None) for i in range(len(v1.dimensions))]
            dshape = []
            # For each sliced dim, replace full slice by index
            for (d, i), ii, b in zip(dd, idx, bases):
                sl[i] = ii + b
                dshape.append('{}={}'.format(d, ii))
            slice_name = '{} ({})'.format(vn, ', '.join(dshape))
            var_bins[cmp_v(slice_name, v1, v2, tuple(sl))].append(slice_name)
    else:
        # Compare full var
        var_bins[cmp_v(vn, v1, v2, slice(None))].append(vn)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__.rsplit('\n', 15)[0],
            epilog=f'author: {__author__}, {__date__}')
    parser.add_argument('F1', help='NetCDF file 1 (base)')
    parser.add_argument('F2', help='NetCDF file 2 (changed)')
    parser.add_argument('-v', '--variables', nargs='+', default=[],
            help='variables to process (default: all from F1)',
            metavar='VAR')
    parser.add_argument('-d', '--dimensions', nargs='+', default=[],
            help='slice output by these dimensions', metavar='DIM')
    parser.add_argument('--dmin', nargs='+', default=[], type=int,
            help='compare only where DIM >= DMIN', metavar='DMIN')
    parser.add_argument('--dmax', nargs='+', default=[], type=int,
            help='compare only where DIM <= DMAX', metavar='DMAX')
    parser.add_argument('-a', '--attributes', nargs='*', default=None,
            help='also compare listed attributes (default: all from F1)',
            metavar='ARG')
    parser.add_argument('-c', '--max-diff-coords', action='store_true',
            help='print coordinates of maximum differences')
    parser.add_argument('-p', '--percent-diff', action='store_true',
            help='print differences relative to SD and to max abs value')
    parser.add_argument('-s', '--accept-below-sd', type=float, nargs='?',
            help='accept records with max abs diff below X*SD (default X=1e-6)',
            metavar='X', const=1e-6)
    parser.add_argument('-b', '--brief', action='store_true',
            help='print only different variables or slices where F1==F2')

    args = parser.parse_args()
    if args.dmin and len(args.dmin) != len(args.dimensions):
        'Number of DMIN and DIM arguments must be identical.'
    if args.dmax and len(args.dmax) != len(args.dimensions):
        'Number of DMAX and DIM arguments must be identical.'
    dims = dict(zip(args.dimensions, zip(
                        args.dmin or [0]*len(args.dimensions),
                        args.dmax or [-1]*len(args.dimensions),
                        )))

    print('F1 =', args.F1)
    f1 = netCDF4.Dataset(args.F1)
    try:
        # Load sliced dims
        for dn, (d0, d1) in list(dims.items()):
            if d1 < 0:
                dims[dn] = d0, len(f1.dimensions[dn])+d1

        print('F2 =', args.F2)
        f2 = netCDF4.Dataset(args.F2)
        try:
            # Variables
            var_bins = tuple([] for _ in range(4))
            for vn in args.variables or f1.variables.keys():
                compare(vn)

            # Attributes
            equal_attrs = []
            different_attrs = []
            if args.attributes is not None:
                print()
                for atname in args.attributes or f1.ncattrs():
                    try:
                        a1 = f1.getncattr(atname)
                        a2 = f2.getncattr(atname)
                    except AttributeError:
                        sys.stderr.write('Missing attribute {}, skipping.\n'.format(vn))
                        different_attrs.append(atname)
                    else:
                        if a1 == a2:
                            equal_attrs.append(atname)
                            if not args.brief:
                                print('Equal attribute {}: F1=F2={}'.format(atname, a1))
                        else:
                            different_attrs.append(atname)
                            print('Different attribute {}: F1={}, F2={}'.format(atname, a1, a2))
        finally:
            f2.close()
    finally:
        f1.close()

    if var_bins[EQUAL]:
        print()
        print('Equal variables or slices: {}.'.format(';  '.join(var_bins[EQUAL])))
    if var_bins[ALMOST]:
        print()
        print('Variables or slices with max abs diff below {}*SD: {}.'.format(
            args.accept_below_sd, ';  '.join(var_bins[ALMOST])))
    if var_bins[DIFFERENT]:
        print()
        print('Different variables or slices: {}.'.format(';  '.join(var_bins[DIFFERENT])))
    if var_bins[SKIPPED]:
        print()
        print('Skipped variables or slices: {}.'.format(';  '.join(var_bins[SKIPPED])))

    if args.max_diff_coords:
        print()
        print('* Coordinates do not include already sliced dimensions where applicable.')

    if equal_attrs:
        print()
        print('Equal attributes: {}.'.format(';  '.join(equal_attrs)))
    if different_attrs:
        print()
        print('Different attributes: {}.'.format(';  '.join(different_attrs)))

    if var_bins[DIFFERENT]:
        sys.exit(2)
    if different_attrs:
        sys.exit(3)
