# -*- coding: utf-8 -*-
"""
Created on Mon Jul  3 09:22:40 2023

@author: pkiefer
"""
import pytest
import emzed
import os
import numpy as np
import pickle

# from src.targeted_wf import extract_peaks as ep
from src.tadamz import coeluting_peaks as cp

here = os.path.abspath(os.path.dirname(__file__))


@pytest.fixture
def kwargs():
    kwargs = {
        "group_id_column": "peptide_standard_sequence",
        "only_use_ref_peaks": True,
    }
    return kwargs


@pytest.fixture
def table():
    spath = os.path.join(here, "data", "coelution.table")
    t = emzed.io.load_table(spath)
    peptides = {
        "ASDTAMYYCAR",
        "ATEHLSTLSEK",
        "ESDTSYVSLK",
        "FNAVLTNPQGDYDTSTGK",
    }
    return t.filter(t.peptide_standard_sequence.is_in(peptides))


@pytest.fixture
def t1():
    spath = os.path.join(here, "data", "coelution.table")
    t = emzed.io.load_table(spath)
    return t.filter(t.peptide_standard_sequence == "ADQVCINLR")


def test_reference_coelution_0(table, kwargs, regtest):
    t = cp.adapt_rt_by_coelution(table, kwargs)
    print(t, file=regtest)


def test_median_coelution_0(table, kwargs, regtest):
    kwargs["only_use_ref_peaks"] = False
    t = cp.adapt_rt_by_coelution(table, kwargs)
    df = t.to_pandas().to_string()
    print(df, file=regtest)


def test_reference_coelution_1(table, kwargs):
    t1 = table[:0].consolidate()
    t = cp.adapt_rt_by_coelution(t1, kwargs)
    print("t1:", len(t1))
    "cosine_distance" in t.col_names


def test_median_coelution_1(table, kwargs):
    table = table[:0].consolidate()
    kwargs["only_use_ref_peaks"] = False
    t = cp.adapt_rt_by_coelution(table, kwargs)
    "median_cosine_distance" in t.col_names


def test_median_coelution_2(t1, kwargs):
    t1.replace_column(
        "model_chromatogram",
        (t1.standard_type == "Labelled").then_else(None, t1.model_chromatogram),
        object,
    )
    kwargs["only_use_ref_peaks"] = False
    t = cp.adapt_rt_by_coelution(t1, kwargs)
    check = t.filter(t.standard_type == "Labelled", keep_view=True)
    set(check.median_cosine_distance) == {1}


def test_reference_coelution_2(table, kwargs):
    table.drop_columns("is_coelution_reference_peak")
    print(len(table))
    with pytest.raises(AssertionError):
        cp.adapt_rt_by_coelution(table, kwargs)


def test_reference_coelution_3(t1, kwargs):
    t1.replace_column(
        "model_chromatogram",
        (t1.standard_type == "Labelled").then_else(None, t1.model_chromatogram),
        object,
    )
    t = cp.adapt_rt_by_coelution(t1, kwargs)
    set(t.cosine_distance) == {1}


@pytest.fixture
def model1():
    return Model([], [])


@pytest.fixture
def model2():
    rts = np.array(
        [
            344.64755722,
            344.68434651,
            344.72113581,
            344.75792511,
            344.79471441,
            344.8315037,
            344.868293,
            344.9050823,
            344.9418716,
            344.97866089,
            345.01545019,
            345.05223949,
            345.08902879,
            345.12581808,
            345.16260738,
            345.19939668,
            345.23618598,
            345.27297528,
            345.30976457,
            345.34655387,
            345.38334317,
            345.42013247,
            345.45692176,
            345.49371106,
            345.53050036,
            345.56728966,
            345.60407895,
            345.64086825,
            345.67765755,
            345.71444685,
            345.75123614,
            345.78802544,
            345.82481474,
            345.86160404,
            345.89839334,
            345.93518263,
            345.97197193,
            346.00876123,
            346.04555053,
            346.08233982,
            346.11912912,
            346.15591842,
            346.19270772,
            346.22949701,
            346.26628631,
            346.30307561,
            346.33986491,
            346.37665421,
            346.4134435,
            346.4502328,
            346.4870221,
            346.5238114,
            346.56060069,
            346.59738999,
            346.63417929,
            346.67096859,
            346.70775788,
            346.74454718,
            346.78133648,
            346.81812578,
            346.85491507,
            346.89170437,
            346.92849367,
            346.96528297,
            347.00207227,
            347.03886156,
            347.07565086,
            347.11244016,
            347.14922946,
            347.18601875,
            347.22280805,
            347.25959735,
            347.29638665,
            347.33317594,
            347.36996524,
            347.40675454,
            347.44354384,
            347.48033313,
            347.51712243,
            347.55391173,
            347.59070103,
            347.62749033,
            347.66427962,
            347.70106892,
            347.73785822,
            347.77464752,
            347.81143681,
            347.84822611,
            347.88501541,
            347.92180471,
            347.958594,
            347.9953833,
            348.0321726,
            348.0689619,
            348.10575119,
            348.14254049,
            348.17932979,
            348.21611909,
            348.25290839,
            348.28969768,
            348.32648698,
            348.36327628,
            348.40006558,
            348.43685487,
            348.47364417,
            348.51043347,
            348.54722277,
            348.58401206,
            348.62080136,
            348.65759066,
            348.69437996,
            348.73116926,
            348.76795855,
            348.80474785,
            348.84153715,
            348.87832645,
            348.91511574,
            348.95190504,
            348.98869434,
            349.02548364,
            349.06227293,
            349.09906223,
            349.13585153,
            349.17264083,
            349.20943012,
            349.24621942,
            349.28300872,
            349.31979802,
            349.35658732,
            349.39337661,
            349.43016591,
            349.46695521,
            349.50374451,
            349.5405338,
            349.5773231,
            349.6141124,
            349.6509017,
            349.68769099,
            349.72448029,
            349.76126959,
            349.79805889,
            349.83484818,
            349.87163748,
            349.90842678,
            349.94521608,
            349.98200538,
            350.01879467,
            350.05558397,
            350.09237327,
            350.12916257,
            350.16595186,
            350.20274116,
            350.23953046,
            350.27631976,
            350.31310905,
            350.34989835,
            350.38668765,
            350.42347695,
            350.46026625,
            350.49705554,
            350.53384484,
            350.57063414,
            350.60742344,
            350.64421273,
            350.68100203,
            350.71779133,
            350.75458063,
            350.79136992,
            350.82815922,
            350.86494852,
            350.90173782,
            350.93852711,
            350.97531641,
            351.01210571,
            351.04889501,
            351.08568431,
            351.1224736,
            351.1592629,
            351.1960522,
            351.2328415,
            351.26963079,
            351.30642009,
            351.34320939,
            351.37999869,
            351.41678798,
            351.45357728,
            351.49036658,
            351.52715588,
            351.56394517,
            351.60073447,
            351.63752377,
            351.67431307,
            351.71110237,
            351.74789166,
            351.78468096,
            351.82147026,
            351.85825956,
            351.89504885,
            351.93183815,
            351.96862745,
            352.00541675,
            352.04220604,
            352.07899534,
            352.11578464,
            352.15257394,
            352.18936324,
            352.22615253,
            352.26294183,
            352.29973113,
            352.33652043,
            352.37330972,
            352.41009902,
            352.44688832,
            352.48367762,
            352.52046691,
            352.55725621,
            352.59404551,
            352.63083481,
            352.6676241,
            352.7044134,
            352.7412027,
            352.777992,
            352.8147813,
            352.85157059,
            352.88835989,
            352.92514919,
            352.96193849,
            352.99872778,
            353.03551708,
            353.07230638,
            353.10909568,
            353.14588497,
            353.18267427,
            353.21946357,
            353.25625287,
            353.29304216,
            353.32983146,
            353.36662076,
            353.40341006,
            353.44019936,
            353.47698865,
            353.51377795,
            353.55056725,
            353.58735655,
            353.62414584,
            353.66093514,
            353.69772444,
            353.73451374,
            353.77130303,
            353.80809233,
            353.84488163,
            353.88167093,
            353.91846023,
            353.95524952,
            353.99203882,
            354.02882812,
            354.06561742,
            354.10240671,
            354.13919601,
            354.17598531,
            354.21277461,
            354.2495639,
            354.2863532,
            354.3231425,
            354.3599318,
            354.39672109,
            354.43351039,
            354.47029969,
            354.50708899,
            354.54387829,
            354.58066758,
            354.61745688,
            354.65424618,
            354.69103548,
            354.72782477,
            354.76461407,
            354.80140337,
            354.83819267,
            354.87498196,
            354.91177126,
            354.94856056,
            354.98534986,
            355.02213915,
            355.05892845,
            355.09571775,
            355.13250705,
            355.16929635,
            355.20608564,
            355.24287494,
            355.27966424,
            355.31645354,
            355.35324283,
            355.39003213,
            355.42682143,
            355.46361073,
            355.50040002,
            355.53718932,
            355.57397862,
            355.61076792,
            355.64755722,
        ]
    )
    ints = np.array(
        [
            2.81303535,
            2.98343519,
            3.16415647,
            3.35582428,
            3.55910158,
            3.77469143,
            4.00333947,
            4.24583647,
            4.50302106,
            4.77578263,
            5.06506438,
            5.37186658,
            5.69725002,
            6.04233968,
            6.40832857,
            6.79648184,
            7.20814116,
            7.64472932,
            8.1077551,
            8.59881846,
            9.11961605,
            9.67194699,
            10.25771907,
            10.87895524,
            11.53780055,
            12.23652947,
            12.97755365,
            13.7634301,
            14.59686992,
            15.48074749,
            16.4181102,
            17.41218877,
            18.46640811,
            19.58439892,
            20.77000977,
            22.02732009,
            23.36065368,
            24.77459316,
            26.2739951,
            27.86400602,
            29.5500793,
            31.33799301,
            33.23386859,
            35.24419064,
            37.3758277,
            39.63605405,
            42.03257265,
            44.5735392,
            47.26758731,
            50.12385493,
            53.15201184,
            56.3622885,
            59.7655059,
            63.37310677,
            67.1971878,
            71.25053307,
            75.54664845,
            80.09979695,
            84.92503498,
            90.03824918,
            95.45619381,
            101.19652838,
            107.27785521,
            113.71975661,
            120.54283127,
            127.76872925,
            135.42018513,
            143.52104845,
            152.09631067,
            161.1721277,
            170.77583685,
            180.93596683,
            191.68223933,
            203.04556047,
            215.05800012,
            227.75275673,
            241.16410538,
            255.32732585,
            270.27860789,
            286.0549298,
            302.6939067,
            320.23360405,
            338.71231197,
            358.16827528,
            378.63937427,
            400.16275066,
            422.7743734,
            446.50853895,
            471.39730072,
            497.46982335,
            524.75165772,
            553.26393424,
            583.02247334,
            614.03681436,
            646.30916658,
            679.83328974,
            714.59331523,
            750.56252386,
            787.70210166,
            825.95990079,
            865.26923925,
            905.54777937,
            946.69653132,
            988.5990339,
            1031.12076889,
            1074.10886858,
            1117.39217626,
            1160.78171684,
            1204.0716286,
            1247.04059674,
            1289.45381477,
            1331.06548086,
            1371.62181395,
            1410.86454864,
            1448.53484149,
            1484.37749434,
            1518.14537628,
            1549.60390557,
            1578.53543947,
            1604.74341405,
            1628.05607959,
            1648.32968987,
            1665.4510255,
            1679.33916038,
            1689.94641562,
            1697.25848258,
            1701.29373498,
            1702.10178515,
            1699.76137032,
            1694.37767825,
            1686.07923767,
            1675.01450648,
            1661.34829054,
            1645.25811891,
            1626.9306883,
            1606.55847339,
            1584.33657997,
            1560.45989857,
            1535.12059644,
            1508.50596787,
            1480.79664752,
            1452.16517778,
            1422.77491155,
            1392.7792239,
            1362.32100147,
            1331.53237543,
            1300.53466328,
            1269.4384851,
            1238.3440219,
            1207.3413858,
            1176.51107515,
            1145.92449079,
            1115.64449262,
            1085.72597935,
            1056.21647687,
            1027.15672368,
            998.58124409,
            970.51890217,
            942.99343129,
            916.02393557,
            889.625361,
            863.8089348,
            838.58257285,
            813.95125526,
            789.91737091,
            766.48103217,
            743.64036115,
            721.39174909,
            699.73009061,
            678.64899456,
            658.14097326,
            638.19761177,
            618.80971897,
            599.9674619,
            581.66048509,
            563.87801602,
            546.60895825,
            529.8419734,
            513.56555297,
            497.76808126,
            482.43789011,
            467.56330651,
            453.1326937,
            439.13448662,
            425.55722222,
            412.38956527,
            399.62033018,
            387.2384993,
            375.2332381,
            363.59390754,
            352.31007418,
            341.371518,
            330.76823851,
            320.49045915,
            310.52863031,
            300.87343116,
            291.51577037,
            282.44678595,
            273.65784427,
            265.14053844,
            256.88668611,
            248.88832676,
            241.13771861,
            233.6273352,
            226.34986164,
            219.2981907,
            212.46541866,
            205.84484108,
            199.42994838,
            193.21442143,
            187.19212708,
            181.35711358,
            175.70360611,
            170.22600224,
            164.91886743,
            159.77693063,
            154.79507983,
            149.96835772,
            145.29195744,
            140.76121835,
            136.37162189,
            132.11878756,
            127.99846891,
            124.00654966,
            120.13903994,
            116.39207254,
            112.7618993,
            109.24488758,
            105.83751684,
            102.53637524,
            99.33815641,
            96.23965626,
            93.2377699,
            90.32948861,
            87.51189694,
            84.78216986,
            82.13757,
            79.57544496,
            77.09322472,
            74.68841913,
            72.3586154,
            70.10147578,
            67.91473521,
            65.7961991,
            63.74374113,
            61.75530117,
            59.82888321,
            57.96255339,
            56.15443807,
            54.40272196,
            52.70564631,
            51.06150718,
            49.46865368,
            47.9254864,
            46.43045576,
            44.98206047,
            43.57884603,
            42.21940329,
            40.90236701,
            39.62641452,
            38.3902644,
            37.19267515,
            36.03244402,
            34.90840573,
            33.81943134,
            32.76442714,
            31.74233349,
            30.75212383,
            29.79280359,
            28.86340923,
            27.96300726,
            27.09069329,
            26.24559116,
            25.42685201,
            24.63365349,
            23.86519887,
            23.12071628,
            22.39945793,
            21.70069936,
            21.02373868,
            20.36789593,
            19.73251234,
            19.11694969,
            18.52058966,
            17.94283324,
            17.38310009,
            16.84082797,
            16.31547219,
            15.80650504,
            15.31341527,
            14.83570759,
            14.37290216,
            13.9245341,
            13.49015302,
            13.06932261,
            12.66162015,
            12.2666361,
            11.88397373,
            11.51324864,
            11.15408845,
        ]
    )
    return Model(rts, ints)


def test__fit_peak_0():
    rts = []
    assert cp._fit_peak(rts, None) is None


def test__fit_peak_1(model1):
    rts = []
    assert cp._fit_peak(rts, model1) is None


def test__fit_peak_2(model2):
    rts, ints_expected = model2.graph()

    ints_is = cp._fit_peak(rts, model2)
    deltas = np.abs(ints_is - ints_expected)
    assert np.max(deltas) <= 1e-10


class Model:
    def __init__(self, rts, ints):
        self.rts = rts
        self.ints = ints

    def graph(self):
        return self.rts, self.ints


@pytest.fixture
def empty_model():
    model = Model([], [])
    # path = os.path.join(here, "data", "failed_peak_shape_model.pickle")
    # with open(path, "rb") as fp:
    #     model = pickle.load(fp)
    return model


def test__rts(empty_model):
    is_ = cp._rts([empty_model])
    assert is_.size == 0
