import unittest

import biogeme.biogeme as bio
from biogeme import models
from biogeme.data.swissmetro import (
    CAR_AV_SP,
    CAR_CO,
    CAR_TT,
    GA,
    PURPOSE,
    SM_AV,
    SM_CO,
    SM_TT,
    TRAIN_AV_SP,
    TRAIN_CO,
    TRAIN_TT,
    read_data,
)
from biogeme.expressions import Beta, Derive, Elem

database = read_data()
# Keep only trip purposes 1 (commuter) and 3 (business)
exclude = ((PURPOSE != 1) * (PURPOSE != 3)) > 0
database.remove(exclude)

ASC_TRAIN = Beta('ASC_TRAIN', -0.701188, None, None, 0)
B_TIME = Beta('B_TIME', -1.27786, None, None, 0)
B_COST = Beta('B_COST', -1.08379, None, None, 0)
ASC_SM = Beta('ASC_SM', 0, None, None, 0)
ASC_CAR = Beta('ASC_CAR', -0.154633, None, None, 0)

SM_COST = SM_CO * (GA == 0)
TRAIN_COST = TRAIN_CO * (GA == 0)

TRAIN_TT_SCALED = TRAIN_TT / 100.0
TRAIN_COST_SCALED = TRAIN_COST / 100
SM_TT_SCALED = SM_TT / 100.0
SM_COST_SCALED = SM_COST / 100.0
CAR_TT_SCALED = CAR_TT / 100.0
CAR_CO_SCALED = CAR_CO / 100.0

V1 = ASC_TRAIN + B_TIME * TRAIN_TT_SCALED + B_COST * TRAIN_COST_SCALED
V2 = ASC_SM + B_TIME * SM_TT_SCALED + B_COST * SM_COST_SCALED
V3 = ASC_CAR + B_TIME * CAR_TT_SCALED + B_COST * CAR_CO_SCALED

# Associate utility functions with the numbering of alternatives
V = {1: V1, 2: V2, 3: V3}

av = {1: TRAIN_AV_SP, 2: SM_AV, 3: CAR_AV_SP}

# The choice model is a logit, with availability conditions
prob1 = Elem({0: 0, 1: models.logit(V, av, 1)}, av[1])

# Elasticities can be computed. We illustrate below two
# formulas. Check in the output file that they produce the same
# result.

# First, the general definition of elasticities. This illustrates the
# use of the Derive expression, and can be used with any model,
# however complicated it is. Note the quotes in the Derive opertor.

genelas1 = Derive(prob1, 'TRAIN_TT') * TRAIN_TT / prob1

# Second, the elasticity of logit models. See Ben-Akiva and Lerman for
# the formula

logitelas1 = TRAIN_AV_SP * (1.0 - prob1) * TRAIN_TT_SCALED * B_TIME

simulate = {
    'P1': prob1,
    'logit elas. 1': logitelas1,
    'generic elas. 1': genelas1,
}


the_betas_values = {
    'ASC_TRAIN': -0.701188,
    'B_TIME': -1.27786,
    'B_COST': -1.08379,
    'ASC_SM': 0,
    'ASC_CAR': -0.154633,
}


class test_01simul(unittest.TestCase):
    def testSimulation(self):
        biogeme = bio.BIOGEME(
            database,
            simulate,
            parameters=None,
            save_iterations=False,
            generate_html=False,
            generate_yaml=False,
        )
        biogeme.model_name = '01logit_simul'
        results = biogeme.simulate(the_betas_values)
        self.assertAlmostEqual(sum(results['P1']), 907.9992101964821, 2)
        self.assertAlmostEqual(sum(results['logit elas. 1']), -12673.838605478186, 2)
        self.assertAlmostEqual(sum(results['generic elas. 1']), -12673.838605478186, 2)


if __name__ == '__main__':
    unittest.main()
