#cython: language_level=3

import numpy as np
cimport numpy as np
from libc cimport math

np.import_array()

cdef double[:] _scale(double[:] out, double[:] source, double fromlow, double fromhigh, double tolow, double tohigh, bint log):
    cdef int i = 0
    cdef double v = 0
    cdef double _tolow = min(tolow, tohigh)
    cdef double _tohigh = max(tolow, tohigh)
    cdef double _fromlow = min(fromlow, fromhigh)
    cdef double _fromhigh = max(fromlow, fromhigh)

    cdef double todiff = _tohigh - _tolow
    cdef double fromdiff = _fromhigh - _fromlow

    if log:
        for i in range(len(source)):
            v = ((source[i] - _fromlow) / fromdiff)
            v = math.log(v * (math.e-1) + 1)
            out[i] = v * todiff + tolow

    else:
        for i in range(len(source)):
            out[i] = ((source[i] - _fromlow) / fromdiff) * todiff + tolow

    return out

cdef double[:] _scaleinplace(double[:] out, double fromlow, double fromhigh, double tolow, double tohigh, bint log):
    cdef int i = 0
    cdef double v = 0
    cdef double _tolow = min(tolow, tohigh)
    cdef double _tohigh = max(tolow, tohigh)
    cdef double _fromlow = min(fromlow, fromhigh)
    cdef double _fromhigh = max(fromlow, fromhigh)

    cdef double todiff = _tohigh - _tolow
    cdef double fromdiff = _fromhigh - _fromlow

    if log:
        for i in range(len(out)):
            v = ((out[i] - _fromlow) / fromdiff)
            v = math.log(v * (math.e-1) + 1)
            out[i] = v * todiff + tolow

    else:
        for i in range(len(out)):
            out[i] = ((out[i] - _fromlow) / fromdiff) * todiff + tolow

    return out


cpdef list scale(list source, double fromlow=-1, double fromhigh=1, double tolow=0, double tohigh=1, bint log=False):
    cdef unsigned int length = len(source)
    cdef double[:] out = np.zeros(length, dtype='d')
    cdef double[:] _source = np.array(source, dtype='d')
    return np.array(_scale(out, _source, fromlow, fromhigh, tolow, tohigh, log), dtype='d').tolist()

cdef list _snap_pattern(list source, list pattern):
    cdef int i=0, j=0
    cdef double v=0, t=0

    cdef list out = []
    for i in range(len(source)):
        v = source[i]
        t = pattern[0]
        for j in range(len(pattern)):
            if abs(v - pattern[j]) < abs(v - t):
                t = pattern[j]
        out += [ t ]

    return out

cdef list _snap_mult(list source, double mult):
    cdef int i=0, m=1
    cdef double v = 0
    cdef out = []
    for i in range(len(source)):
        v = source[i]
        if v < mult:
            out += [ mult ]
            continue
        
        while mult * m < v:
            m += 1

        out += [ mult * m ]

    return out

cpdef list snap(list source, double mult=0, object pattern=None):
    if mult <= 0 and pattern is None:
        raise ValueError('Please provide a valid quantization multiple or pattern')

    if mult > 0:
        return _snap_mult(source, mult)

    if pattern is None or (pattern is not None and len(pattern) == 0):
        raise ValueError('Invalid (empty) pattern')

    return _snap_pattern(source, pattern)

cpdef rotate(list l, int offset=0):
    """ Rotate a list by a given offset
    """
    return l[-offset % len(l):] + l[:-offset % len(l)]


