#!/usr/bin/env python
#
# This file is part of PyOP2
#
# PyOP2 is Copyright (c) 2012, Imperial College London and
# others. Please see the AUTHORS file in the main source directory for
# a full list of copyright holders.  All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
#     * Redistributions of source code must retain the above copyright
#       notice, this list of conditions and the following disclaimer.
#     * Redistributions in binary form must reproduce the above copyright
#       notice, this list of conditions and the following disclaimer in the
#       documentation and/or other materials provided with the distribution.
#     * The name of Imperial College London or that of other
#       contributors may not be used to endorse or promote products
#       derived from this software without specific prior written
#       permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS
# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
# OF THE POSSIBILITY OF SUCH DAMAGE.

"""Show a spy plot from a binary PETSc matrix dump or compare two dumps as spy
plots if two input file names are given."""

import matplotlib
import numpy as np
import pylab
from scipy.sparse import csr_matrix

COOKIE = 1211216    # from petscmat.h
IntType = '>i4'     # big-endian, 4 byte integer
ScalarType = '>f8'  # big-endian, 8 byte real floating


# after http://lists.mcs.anl.gov/pipermail/petsc-users/2010-February/005935.html
def readmat(filename):
    with open(filename, 'rb') as fh:
        header = np.fromfile(fh, dtype=IntType, count=4)
        assert header[0] == COOKIE
        M, N, nz = header[1:]
        #
        I = np.empty(M+1, dtype=IntType)
        I[0] = 0
        rownz = np.fromfile(fh, dtype=IntType, count=M)
        np.cumsum(rownz, out=I[1:])
        assert I[-1] == nz
        #
        J = np.fromfile(fh, dtype=IntType, count=nz)
        V = np.fromfile(fh, dtype=ScalarType, count=nz)
    return (M, N), (I, J, V)


def dump2csr(filename):
    (M, N), (I, J, V) = readmat(filename)
    return csr_matrix((V, J, I))


def compare_dump(files, outfile=None, marker='.', markersize=.5):
    """Compare two binary PETSc matrix dumps as spy plots."""

    opts = {'marker': marker, 'markersize': markersize}
    csr1 = dump2csr(files[0])

    if len(files) > 1:
        matplotlib.rc('font', size=4)
        pylab.figure(figsize=(12, 5), dpi=300)
        pylab.subplot(221)
    else:
        matplotlib.rc('font', size=10)
        pylab.figure(figsize=(5, 5), dpi=300)
    pylab.spy(csr1, **opts)
    pylab.title(files[0])

    if len(files) > 1:
        csr2 = dump2csr(files[1])
        pylab.subplot(222)
        pylab.spy(csr2, **opts)
        pylab.title(files[1])

        pylab.subplot(223)
        pylab.spy(csr1 - csr2, **opts)
        pylab.title(files[0] + ' - ' + files[1])

        pylab.subplot(224)
        pylab.plot(csr1.data, label=files[0], **opts)
        pylab.plot(csr2.data, label=files[1], **opts)
        pylab.plot(csr1.data - csr2.data, label='Difference', **opts)
        pylab.legend()
        pylab.title('Nonzero values')

    if outfile:
        pylab.savefig(outfile)
    else:
        pylab.show()


def main():
    import argparse
    parser = argparse.ArgumentParser(description=__doc__, add_help=True)
    parser.add_argument('files', nargs='+', help='Matrix dump files')
    parser.add_argument('--output', '-o',
                        help='Output plot to file instead of showing interactively')
    parser.add_argument('--marker', default='.', choices=['s', 'o', '.', ','],
                        help='Specify marker to use for spyplot')
    parser.add_argument('--markersize', type=float, default=.5,
                        help='Specify marker size to use for spyplot')
    args = parser.parse_args()

    compare_dump(args.files, args.output, marker=args.marker, markersize=args.markersize)


if __name__ == '__main__':
    main()
