import unittest

import biogeme.biogeme as bio
from biogeme import models
from biogeme.data.swissmetro import (
    CAR_AV_SP,
    CAR_CO_SCALED,
    CAR_TT_SCALED,
    CHOICE,
    GA,
    PURPOSE,
    SM_AV,
    SM_CO,
    SM_COST_SCALED,
    SM_TT_SCALED,
    TRAIN_AV_SP,
    TRAIN_CO,
    TRAIN_COST_SCALED,
    TRAIN_TT_SCALED,
    read_data,
)
from biogeme.expressions import Beta
from biogeme.nests import NestsForNestedLogit, OneNestForNestedLogit

database = read_data()
# Keep only trip purposes 1 (commuter) and 3 (business)
exclude = ((PURPOSE != 1) * (PURPOSE != 3)) > 0
database.remove(exclude)


ASC_CAR = Beta('ASC_CAR', 0, None, None, 0)
ASC_TRAIN = Beta('ASC_TRAIN', 0, None, None, 0)
ASC_SM = Beta('ASC_SM', 0, None, None, 1)
B_TIME = Beta('B_TIME', 0, None, None, 0)
B_COST = Beta('B_COST', 0, None, None, 0)

MU = Beta('MU', 2.05, 1, None, 0)

# Additional parameter designed to estimate the bias due to choice
# based sampling
SB_TRAIN = Beta('SB_TRAIN', 0, None, None, 0)

correction = {1: SB_TRAIN, 2: 0, 3: 0}


SM_COST = SM_CO * (GA == 0)
TRAIN_COST = TRAIN_CO * (GA == 0)


V1 = ASC_TRAIN + B_TIME * TRAIN_TT_SCALED + B_COST * TRAIN_COST_SCALED
V2 = ASC_SM + B_TIME * SM_TT_SCALED + B_COST * SM_COST_SCALED
V3 = ASC_CAR + B_TIME * CAR_TT_SCALED + B_COST * CAR_CO_SCALED

# Associate utility functions with the numbering of alternatives
V = {1: V1, 2: V2, 3: V3}


av = {1: TRAIN_AV_SP, 2: SM_AV, 3: CAR_AV_SP}

# Definition of nests:
# 1: nests parameter
# 2: list of alternatives
existing = OneNestForNestedLogit(
    nest_param=MU, list_of_alternatives=[1, 3], name='existing'
)
nests = NestsForNestedLogit(tuple_of_nests=(existing,), choice_set=[1, 2, 3])

# The choice model is a nested logit, with corrections for endogenous sampling
Gi = models.get_mev_for_nested(V, av, nests)
logprob = models.logmev_endogenous_sampling(V, Gi, av, correction, CHOICE)


class test_14(unittest.TestCase):
    def testEstimation(self):
        biogeme = bio.BIOGEME(
            database,
            logprob,
            save_iterations=False,
            generate_html=False,
            generate_yaml=False,
        )
        results = biogeme.estimate()
        self.assertAlmostEqual(results.final_log_likelihood, -5169.641515468776, 2)


if __name__ == '__main__':
    unittest.main()
