# -*- coding: utf-8 -*-

from pyhrmc.core.structure import Structure
import numpy as np
import pandas as pd
import math
from scipy import integrate
from sklearn.linear_model import LinearRegression
import warnings
from pymatgen.core.periodic_table import Species
import logging

class Interpolator(Structure):
    def __init__(self, struct):
        self.struct = struct
        logger = logging.getLogger(__name__)

        # parameters for atoms from L. M. Peng, Micron, (1999), 30, 625–648, https://doi.org/10.1016/S0968-4328(99)00033-5
        self.atom_dict = {
            "species": [
                "H",
                "He",
                "Li",
                "Be",
                "B",
                "C",
                "N",
                "O",
                "F",
                "Ne",
                "Na",
                "Mg",
                "Al",
                "Si",
                "P",
                "S",
                "Cl",
                "Ar",
                "K",
                "Ca",
                "Sc",
                "Ti",
                "V",
                "Cr",
                "Mn",
                "Fe",
                "Co",
                "Ni",
                "Cu",
                "Zn",
                "Ga",
                "Ge",
                "As",
                "Se",
                "Br",
                "Kr",
                "Rb",
                "Sr",
                "Y",
                "Zr",
                "Nb",
                "Mo",
                "Tc",
                "Ru",
                "Rh",
                "Pd",
                "Ag",
                "Cd",
                "In",
                "Sn",
                "Sb",
                "Te",
                "I",
                "Xe",
                "Cs",
                "Ba",
                "La",
                "Ce",
                "Pr",
                "Nd",
                "Pm",
                "Sm",
                "Eu",
                "Gd",
                "Tb",
                "Dy",
                "Ho",
                "Er",
                "Tm",
                "Yb",
                "Lu",
                "Hf",
                "Ta",
                "W",
                "Re",
                "Os",
                "Ir",
                "Pt",
                "Au",
                "Hg",
                "Tl",
                "Pb",
                "Bi",
                "Po",
                "At",
                "Rn",
                "Fr",
                "Ra",
                "Ac",
                "Th",
                "Pa",
                "U",
                "Np",
                "Pu",
                "Am",
                "Cm",
                "Bk",
                "Cf",
            ],
            "a1": [
                0.0367,
                0.0406,
                0.1198,
                0.118,
                0.1298,
                0.1361,
                0.1372,
                0.1433,
                0.1516,
                0.1575,
                0.3319,
                0.3248,
                0.3582,
                0.3626,
                0.354,
                0.3478,
                0.3398,
                0.3409,
                0.5658,
                0.5474,
                0.5389,
                0.5398,
                0.5412,
                0.5478,
                0.5552,
                0.5627,
                0.573,
                0.5785,
                0.5932,
                0.5996,
                0.6737,
                0.666,
                0.6504,
                0.6343,
                0.6216,
                0.624,
                1.018,
                1.0127,
                0.9722,
                0.9592,
                0.9337,
                0.9334,
                0.9282,
                0.9175,
                0.9046,
                0.7832,
                0.9085,
                0.9056,
                0.966,
                0.9481,
                0.9323,
                0.906,
                0.8912,
                0.8749,
                1.2968,
                1.3111,
                1.2872,
                1.3019,
                1.3354,
                1.3473,
                1.3631,
                1.3702,
                1.3963,
                1.3784,
                1.4257,
                1.4104,
                1.3805,
                1.4391,
                1.4561,
                1.4617,
                1.4381,
                1.4209,
                1.4163,
                1.4109,
                1.3898,
                1.3877,
                1.3783,
                1.3489,
                1.3413,
                1.3592,
                1.4735,
                1.47,
                1.4616,
                1.4255,
                1.417,
                1.4001,
                1.7697,
                1.7704,
                1.7346,
                1.7012,
                1.7266,
                1.7417,
                1.7327,
                1.7491,
                1.7519,
                1.7284,
                1.7298,
                1.6401,
            ],
            "a2": [
                0.1269,
                0.1276,
                0.3952,
                0.4394,
                0.52,
                0.5482,
                0.5344,
                0.5103,
                0.5193,
                0.5041,
                0.9857,
                0.9243,
                0.9754,
                0.9737,
                0.9397,
                0.9158,
                0.8908,
                0.8966,
                2.4151,
                2.2793,
                2.2102,
                2.1568,
                2.1063,
                2.0737,
                2.0073,
                1.9685,
                1.9219,
                1.8679,
                1.8344,
                1.7763,
                1.8457,
                1.7662,
                1.6706,
                1.5698,
                1.4943,
                1.4352,
                2.8882,
                2.9403,
                2.8705,
                2.8531,
                2.8218,
                2.8581,
                2.8832,
                2.8691,
                2.8679,
                2.1753,
                2.9433,
                2.943,
                3.1971,
                3.1011,
                2.9938,
                2.8392,
                2.7442,
                2.6276,
                3.8609,
                3.9542,
                3.8478,
                3.8491,
                3.8652,
                3.8589,
                3.8462,
                3.827,
                3.8201,
                3.7349,
                3.7865,
                3.6934,
                3.5395,
                3.6351,
                3.6041,
                3.5735,
                3.4689,
                3.3984,
                3.3456,
                3.2997,
                3.2373,
                3.2082,
                3.1637,
                3.0674,
                3.0559,
                3.1327,
                3.6368,
                3.6387,
                3.6426,
                3.5214,
                3.4829,
                3.4082,
                4.682,
                4.6854,
                4.5959,
                4.5007,
                4.6123,
                4.7063,
                4.6782,
                4.7933,
                4.8096,
                4.7162,
                4.7198,
                4.3905,
            ],
            "a3": [
                0.236,
                0.1738,
                1.3794,
                1.4273,
                1.3767,
                1.2266,
                1.0862,
                0.937,
                0.822,
                0.733,
                1.4885,
                2.0039,
                2.6393,
                2.7209,
                2.6203,
                2.5066,
                2.3878,
                2.2636,
                2.4655,
                3.1934,
                3.1187,
                2.9961,
                2.8525,
                2.3527,
                2.5678,
                2.4527,
                2.3358,
                2.2229,
                1.808,
                2.024,
                2.6395,
                2.8876,
                2.993,
                3.0362,
                3.0763,
                3.0847,
                3.5125,
                3.992,
                4.1571,
                4.1761,
                3.9092,
                3.8278,
                3.8985,
                3.548,
                3.3997,
                3.2003,
                3.1021,
                3.1814,
                3.6243,
                3.8863,
                4.0829,
                4.1639,
                4.3316,
                4.4112,
                5.6613,
                5.8693,
                6.0581,
                5.8568,
                5.3,
                5.1226,
                4.9578,
                4.8068,
                4.6869,
                4.8549,
                4.4336,
                4.3012,
                4.4117,
                4.0572,
                3.936,
                3.8563,
                4.0824,
                4.2315,
                4.3281,
                4.3698,
                4.3683,
                4.3731,
                4.3549,
                4.2095,
                4.1659,
                4.1897,
                4.395,
                4.5674,
                4.7876,
                4.9058,
                5.0951,
                5.2516,
                6.6701,
                6.8034,
                7.0436,
                7.2504,
                6.7855,
                6.6509,
                6.4465,
                5.9841,
                5.8155,
                5.9338,
                5.7527,
                5.4504,
            ],
            "a4": [
                0.129,
                0.0758,
                1.3889,
                1.0661,
                0.7666,
                0.5971,
                0.4547,
                0.3923,
                0.3081,
                0.2572,
                1.9657,
                1.9507,
                1.9103,
                1.766,
                1.5707,
                1.3884,
                1.2376,
                1.0786,
                3.524,
                3.8824,
                3.4302,
                3.0751,
                2.7967,
                1.9866,
                2.3695,
                2.175,
                2.0177,
                1.894,
                1.3586,
                1.6598,
                1.9413,
                2.051,
                2.0018,
                1.9607,
                1.8642,
                1.7511,
                4.3375,
                5.1441,
                4.6466,
                4.1491,
                2.988,
                2.6302,
                3.1237,
                2.203,
                2.051,
                1.4127,
                1.7092,
                2.1946,
                2.6351,
                2.9142,
                2.9578,
                3.0782,
                2.9328,
                2.8758,
                5.6489,
                7.0998,
                6.5606,
                6.3222,
                6.4296,
                6.2264,
                6.0228,
                5.8402,
                5.6321,
                5.2564,
                5.2698,
                5.1904,
                5.0148,
                4.9059,
                4.7711,
                4.6254,
                4.4599,
                4.1,
                3.7406,
                3.4408,
                3.2472,
                2.9958,
                2.805,
                2.1738,
                2.0022,
                2.2748,
                2.5803,
                2.9103,
                3.1935,
                3.4925,
                3.4639,
                3.425,
                5.5268,
                7.2352,
                7.0521,
                6.6077,
                6.3876,
                5.9929,
                5.8528,
                5.6139,
                5.4119,
                5.2896,
                5.1576,
                5.4542,
            ],
            "b1": [
                0.5608,
                0.314,
                0.5908,
                0.4599,
                0.418,
                0.3731,
                0.3287,
                0.3055,
                0.2888,
                0.2714,
                0.495,
                0.4455,
                0.4529,
                0.4281,
                0.3941,
                0.3652,
                0.3379,
                0.3229,
                0.5061,
                0.47,
                0.4446,
                0.4281,
                0.4132,
                0.4032,
                0.3947,
                0.3859,
                0.3799,
                0.3704,
                0.368,
                0.3602,
                0.3901,
                0.3741,
                0.355,
                0.3374,
                0.3216,
                0.3144,
                0.485,
                0.4721,
                0.4427,
                0.4279,
                0.4074,
                0.3996,
                0.3894,
                0.3776,
                0.3646,
                0.3094,
                0.3541,
                0.3466,
                0.3648,
                0.3518,
                0.3405,
                0.3256,
                0.315,
                0.3046,
                0.4462,
                0.4432,
                0.4283,
                0.4256,
                0.4305,
                0.4266,
                0.4253,
                0.4204,
                0.423,
                0.4105,
                0.4182,
                0.4076,
                0.3938,
                0.4038,
                0.4028,
                0.3986,
                0.3874,
                0.3772,
                0.3714,
                0.3653,
                0.3548,
                0.3502,
                0.344,
                0.3322,
                0.3262,
                0.3265,
                0.35,
                0.3452,
                0.3396,
                0.3274,
                0.3224,
                0.3153,
                0.3969,
                0.3938,
                0.3816,
                0.3704,
                0.3729,
                0.3729,
                0.3683,
                0.3683,
                0.3654,
                0.3577,
                0.3554,
                0.3332,
            ],
            "b2": [
                3.7913,
                2.0952,
                6.1114,
                4.4816,
                3.9214,
                3.2814,
                2.6733,
                2.2683,
                2.0619,
                1.8403,
                4.0855,
                3.5744,
                3.7745,
                3.557,
                3.181,
                2.8915,
                2.619,
                2.4778,
                5.7656,
                5.0494,
                4.5701,
                4.2236,
                3.9256,
                3.7014,
                3.48,
                3.3103,
                3.1572,
                2.9964,
                2.8959,
                2.7665,
                3.0289,
                2.839,
                2.6169,
                2.4,
                2.2338,
                2.1394,
                5.1504,
                4.9802,
                4.5552,
                4.3237,
                4.0572,
                3.9476,
                3.8131,
                3.6434,
                3.4888,
                2.5857,
                3.3319,
                3.2215,
                3.4534,
                3.2474,
                3.0545,
                2.817,
                2.6616,
                2.4959,
                4.3194,
                4.3226,
                4.0782,
                4.0288,
                4.0368,
                3.9844,
                3.9386,
                3.8747,
                3.8621,
                3.7014,
                3.7931,
                3.6339,
                3.4131,
                3.5707,
                3.5468,
                3.4988,
                3.3304,
                3.2109,
                3.1289,
                3.0517,
                2.9382,
                2.8796,
                2.7975,
                2.6601,
                2.6076,
                2.6458,
                3.0849,
                3.034,
                2.9739,
                2.8031,
                2.7289,
                2.6261,
                3.995,
                3.9313,
                3.7454,
                3.5712,
                3.6159,
                3.6337,
                3.5375,
                3.5651,
                3.5215,
                3.3776,
                3.3269,
                2.983,
            ],
            "b3": [
                13.5557,
                7.1369,
                36.7672,
                21.5831,
                16.6634,
                13.0456,
                10.3165,
                8.2625,
                7.2628,
                6.3059,
                31.5107,
                24.2702,
                23.3862,
                19.3905,
                15.6579,
                13.0522,
                11.0684,
                9.7408,
                31.9169,
                29.6928,
                26.4597,
                24.1928,
                22.3625,
                20.0893,
                19.5862,
                18.7003,
                17.8168,
                16.8507,
                15.6333,
                15.5278,
                19.0448,
                17.2911,
                15.11,
                13.1385,
                11.7075,
                10.606,
                26.1762,
                26.8565,
                24.0646,
                22.151,
                19.4778,
                18.6153,
                18.2715,
                16.4538,
                15.5106,
                9.8481,
                14.7723,
                14.7749,
                18.3358,
                17.4348,
                16.2843,
                14.6041,
                13.6958,
                12.5908,
                24.1136,
                24.3857,
                22.4363,
                22.002,
                22.1273,
                21.8395,
                21.633,
                21.384,
                21.4604,
                20.2877,
                21.3185,
                20.3594,
                18.4989,
                20.0727,
                19.9521,
                19.905,
                18.8077,
                17.772,
                16.8719,
                15.963,
                14.85,
                14.1535,
                13.364,
                12.0128,
                11.5043,
                11.6919,
                14.9338,
                14.8562,
                14.8388,
                13.8305,
                13.4904,
                12.9003,
                23.0274,
                22.4168,
                21.133,
                19.8632,
                19.6433,
                19.6579,
                18.8556,
                18.8727,
                18.512,
                17.7168,
                17.3659,
                14.6892,
            ],
            "b4": [
                37.7229,
                19.4462,
                117.0314,
                66.115,
                51.7511,
                41.0202,
                32.7631,
                25.6645,
                22.0262,
                19.164,
                118.494,
                80.7304,
                80.5019,
                64.3334,
                49.5239,
                40.1848,
                33.5378,
                28.9354,
                151.259,
                109.5608,
                98.1283,
                90.6685,
                84.4689,
                87.923,
                74.9201,
                71.6638,
                68.4867,
                65.2843,
                70.68,
                60.283,
                74.1674,
                62.5149,
                50.8113,
                42.4344,
                36.7639,
                32.4403,
                155.1291,
                116.0307,
                99.1688,
                89.7694,
                85.5042,
                82.5062,
                74.4844,
                74.459,
                71.3737,
                37.4035,
                70.5441,
                62.3544,
                75.8859,
                66.0384,
                56.5387,
                48.088,
                43.2296,
                38.5666,
                168.3481,
                128.5976,
                109.8403,
                107.6524,
                117.0334,
                114.8891,
                112.8589,
                110.9527,
                110.3872,
                96.0364,
                107.633,
                103.6978,
                86.9074,
                100.4842,
                98.6682,
                97.9782,
                84.5078,
                76.0916,
                70.3632,
                65.7545,
                60.6048,
                57.3692,
                54.3356,
                50.3882,
                48.4954,
                47.9347,
                65.0072,
                60.2174,
                56.7362,
                50.1382,
                46.2094,
                42.3929,
                154.8886,
                120.9717,
                102.1159,
                89.3562,
                95.272,
                95.0805,
                91.5521,
                100.8062,
                98.8221,
                86.4936,
                84.3294,
                73.0723,
            ],
        }
        # parameters for ions from  L. M. Peng, Acta Cryst. (1998). A54, 481-485, https://doi.org/10.1107/S0108767398001901
        self.ion_dict = {
            "species": [
                "H",
                "Li",
                "Be",
                "O",
                "O",
                "F",
                "Na",
                "Mg",
                "Al",
                "Si",
                "Cl",
                "K",
                "Ca",
                "Sc",
                "Ti",
                "Ti",
                "Ti",
                "V",
                "V",
                "V",
                "Cr",
                "Cr",
                "Cr",
                "Mn",
                "Mn",
                "Mn",
                "Fe",
                "Fe",
                "Co",
                "Co",
                "Ni",
                "Ni",
                "Cu",
                "Cu",
                "Zn",
                "Ga",
                "Ge",
                "Br",
                "Rb",
                "Sr",
                "Y",
                "Zr",
                "Nb",
                "Nb",
                "Mo",
                "Mo",
                "Mo",
                "Ru",
                "Ru",
                "Rh",
                "Rh",
                "Pd",
                "Pd",
                "Ag",
                "Ag",
                "Cd",
                "In",
                "Sn",
                "Sn",
                "Sb",
                "Sb",
                "I",
                "Cs",
                "Ba",
                "La",
                "Ce",
                "Ce",
                "Pr",
                "Pr",
                "Nd",
                "Pm",
                "Sm",
                "Eu",
                "Eu",
                "Gd",
                "Tb",
                "Dy",
                "Ho",
                "Er",
                "Tm",
                "Yb",
                "Yb",
                "Lu",
                "Hf",
                "Ta",
                "W",
                "Os",
                "Ir",
                "Ir",
                "Pt",
                "Pt",
                "Au",
                "Au",
                "Hg",
                "Hg",
                "Tl",
                "Tl",
                "Pb",
                "Pb",
                "Bi",
                "Bi",
                "Ra",
                "Ac",
                "U",
                "U",
                "U",
            ],
            "Z": [
                1,
                3,
                4,
                8,
                8,
                9,
                11,
                12,
                13,
                14,
                17,
                19,
                20,
                21,
                22,
                22,
                22,
                23,
                23,
                23,
                24,
                24,
                24,
                25,
                25,
                25,
                26,
                26,
                27,
                27,
                28,
                28,
                29,
                29,
                30,
                31,
                32,
                35,
                37,
                38,
                39,
                40,
                41,
                41,
                42,
                42,
                42,
                44,
                44,
                45,
                45,
                46,
                46,
                47,
                47,
                48,
                49,
                50,
                50,
                51,
                51,
                53,
                55,
                56,
                57,
                58,
                85,
                59,
                59,
                60,
                61,
                62,
                63,
                63,
                64,
                65,
                66,
                67,
                68,
                69,
                70,
                70,
                71,
                72,
                73,
                74,
                76,
                77,
                77,
                78,
                78,
                79,
                79,
                80,
                80,
                81,
                81,
                82,
                82,
                83,
                83,
                88,
                89,
                92,
                92,
                92,
            ],
            "oxidation_state": [
                -1,
                1,
                2,
                -1,
                -2,
                -1,
                1,
                2,
                3,
                4,
                -1,
                1,
                2,
                3,
                2,
                3,
                4,
                2,
                3,
                5,
                2,
                3,
                4,
                2,
                3,
                4,
                2,
                3,
                2,
                3,
                2,
                3,
                1,
                2,
                2,
                3,
                4,
                -1,
                1,
                2,
                3,
                4,
                3,
                5,
                3,
                5,
                6,
                3,
                4,
                3,
                4,
                2,
                4,
                1,
                2,
                2,
                3,
                2,
                4,
                3,
                5,
                -1,
                1,
                2,
                3,
                3,
                4,
                3,
                4,
                3,
                3,
                3,
                2,
                3,
                3,
                3,
                3,
                3,
                3,
                3,
                2,
                3,
                3,
                4,
                5,
                6,
                4,
                3,
                4,
                2,
                4,
                1,
                3,
                1,
                2,
                1,
                3,
                2,
                4,
                3,
                5,
                2,
                3,
                3,
                4,
                6,
            ],
            "a1": [
                0.14,
                0.0046,
                0.0034,
                0.205,
                0.0421,
                0.134,
                0.0256,
                0.021,
                0.0192,
                0.192,
                0.265,
                0.199,
                0.164,
                0.163,
                0.399,
                0.364,
                0.116,
                0.317,
                0.341,
                0.0367,
                0.237,
                0.393,
                0.132,
                0.0576,
                0.116,
                0.381,
                0.307,
                0.198,
                0.213,
                0.331,
                0.338,
                0.347,
                0.312,
                0.224,
                0.252,
                0.391,
                0.346,
                0.125,
                0.368,
                0.346,
                0.465,
                0.234,
                0.377,
                0.0828,
                0.401,
                0.479,
                0.203,
                0.428,
                0.282,
                0.352,
                0.397,
                0.935,
                0.348,
                0.503,
                0.431,
                0.425,
                0.417,
                0.797,
                0.261,
                0.552,
                0.377,
                0.901,
                0.587,
                0.733,
                0.493,
                0.56,
                0.483,
                0.663,
                0.521,
                0.501,
                0.496,
                0.518,
                0.613,
                0.496,
                0.49,
                0.503,
                0.503,
                0.456,
                0.522,
                0.475,
                0.508,
                0.498,
                0.483,
                0.522,
                0.569,
                0.181,
                0.586,
                0.692,
                0.653,
                0.872,
                0.55,
                0.811,
                0.722,
                0.796,
                0.773,
                0.82,
                0.836,
                0.755,
                0.583,
                0.708,
                0.654,
                0.911,
                0.915,
                1.14,
                1.09,
                0.687,
            ],
            "a2": [
                0.649,
                0.0165,
                0.0103,
                0.628,
                0.21,
                0.391,
                0.0919,
                0.0672,
                0.0579,
                0.289,
                0.596,
                0.396,
                0.327,
                0.307,
                1.04,
                0.919,
                0.256,
                0.939,
                0.805,
                0.124,
                0.634,
                1.05,
                0.292,
                0.21,
                0.523,
                1.83,
                0.838,
                0.387,
                0.488,
                0.487,
                0.982,
                0.877,
                0.812,
                0.544,
                0.6,
                0.947,
                0.83,
                0.563,
                0.884,
                0.804,
                0.923,
                0.642,
                0.749,
                0.271,
                0.756,
                0.846,
                0.567,
                0.773,
                0.653,
                0.723,
                0.725,
                3.11,
                0.64,
                0.94,
                0.756,
                0.745,
                0.755,
                2.13,
                0.642,
                1.14,
                0.588,
                2.8,
                1.4,
                2.05,
                1.1,
                1.35,
                1.09,
                1.73,
                1.19,
                1.18,
                1.2,
                1.24,
                1.53,
                1.21,
                1.19,
                1.22,
                1.24,
                1.17,
                1.28,
                1.2,
                1.37,
                1.22,
                1.21,
                1.22,
                1.26,
                0.873,
                1.31,
                1.37,
                1.29,
                1.68,
                1.21,
                1.57,
                1.39,
                1.56,
                1.49,
                1.57,
                1.43,
                1.44,
                1.14,
                1.35,
                1.18,
                1.65,
                1.64,
                2.48,
                2.32,
                1.14,
            ],
            "a3": [
                1.37,
                0.0435,
                0.0233,
                1.17,
                0.852,
                0.814,
                0.297,
                0.198,
                0.163,
                0.1,
                1.6,
                0.928,
                0.743,
                0.716,
                1.21,
                1.35,
                0.565,
                1.49,
                0.942,
                0.244,
                1.23,
                1.62,
                0.703,
                0.604,
                0.881,
                -1.33,
                1.11,
                0.889,
                0.998,
                0.729,
                1.32,
                0.79,
                1.11,
                0.97,
                0.917,
                0.69,
                0.599,
                1.43,
                1.14,
                0.998,
                2.41,
                0.747,
                1.29,
                0.654,
                1.38,
                15.6,
                0.646,
                1.55,
                1.14,
                1.5,
                1.51,
                24.6,
                1.22,
                2.17,
                1.72,
                1.73,
                1.59,
                2.15,
                1.53,
                1.87,
                1.22,
                5.61,
                1.87,
                23,
                1.5,
                1.59,
                1.34,
                2.35,
                1.33,
                1.45,
                1.47,
                1.43,
                1.84,
                1.45,
                1.42,
                1.42,
                1.44,
                1.43,
                1.46,
                1.42,
                1.76,
                1.39,
                1.41,
                1.37,
                0.979,
                1.18,
                1.63,
                1.8,
                1.5,
                2.63,
                1.62,
                2.63,
                1.94,
                2.72,
                2.45,
                2.78,
                0.394,
                2.48,
                1.6,
                2.28,
                1.25,
                2.53,
                2.26,
                3.61,
                12,
                1.83,
            ],
            "a4": [
                0.337,
                0.0649,
                0.0325,
                1.03,
                1.82,
                0.928,
                0.514,
                0.368,
                0.284,
                -0.0728,
                2.69,
                1.45,
                1.16,
                0.88,
                -0.0797,
                -0.933,
                0.772,
                -1.31,
                0.0783,
                0.723,
                0.713,
                -1.15,
                0.692,
                1.32,
                0.589,
                0.995,
                0.28,
                0.709,
                0.828,
                0.608,
                -3.56,
                0.0538,
                0.794,
                0.727,
                0.663,
                0.0709,
                0.949,
                3.52,
                2.26,
                1.89,
                -2.31,
                1.47,
                1.61,
                1.24,
                1.58,
                -15.2,
                1.16,
                1.46,
                1.53,
                1.63,
                1.19,
                -43.6,
                1.45,
                1.99,
                1.78,
                1.74,
                1.36,
                -1.64,
                1.36,
                1.36,
                1.18,
                -8.69,
                3.48,
                -152,
                2.7,
                2.63,
                2.45,
                0.351,
                2.36,
                2.53,
                2.43,
                2.4,
                2.46,
                2.36,
                2.3,
                2.24,
                2.17,
                2.15,
                2.05,
                2.05,
                2.23,
                1.97,
                1.94,
                1.68,
                1.29,
                1.48,
                1.71,
                1.97,
                1.74,
                1.93,
                1.95,
                2.68,
                1.94,
                2.76,
                2.23,
                2.82,
                2.51,
                2.45,
                2.06,
                2.18,
                1.66,
                3.62,
                3.18,
                1.13,
                -9.11,
                2.53,
            ],
            "a5": [
                0.787,
                0.027,
                0.012,
                0.29,
                1.17,
                0.347,
                0.199,
                0.174,
                0.114,
                0.0012,
                1.23,
                0.45,
                0.307,
                0.139,
                0.352,
                0.589,
                0.132,
                1.47,
                0.156,
                0.435,
                0.0859,
                0.407,
                0.0959,
                0.659,
                0.214,
                0.0618,
                0.277,
                0.117,
                0.23,
                0.131,
                3.62,
                0.192,
                0.257,
                0.182,
                0.161,
                0.0653,
                -0.0217,
                3.22,
                0.881,
                0.609,
                2.48,
                0.377,
                0.481,
                0.829,
                0.497,
                1.6,
                0.171,
                0.486,
                0.418,
                0.499,
                0.251,
                21.2,
                0.427,
                0.726,
                0.526,
                0.487,
                0.451,
                2.72,
                0.177,
                0.414,
                0.244,
                12.6,
                1.67,
                134,
                1.08,
                0.706,
                0.797,
                1.59,
                0.69,
                0.92,
                0.943,
                0.781,
                0.714,
                0.774,
                0.795,
                0.71,
                0.643,
                0.692,
                0.508,
                0.584,
                0.584,
                0.559,
                0.522,
                0.312,
                0.551,
                0.562,
                0.54,
                0.804,
                0.683,
                0.475,
                0.61,
                0.998,
                0.699,
                1.18,
                0.57,
                1.31,
                1.5,
                1.03,
                0.662,
                0.797,
                0.778,
                1.58,
                1.25,
                0.9,
                2.15,
                0.957,
            ],
            "b1": [
                0.984,
                0.0358,
                0.0267,
                0.397,
                0.0609,
                0.228,
                0.0397,
                0.0331,
                0.0306,
                0.359,
                0.252,
                0.192,
                0.157,
                0.157,
                0.376,
                0.364,
                0.108,
                0.269,
                0.321,
                0.033,
                0.177,
                0.359,
                0.109,
                0.0398,
                0.0117,
                0.354,
                0.23,
                0.154,
                0.148,
                0.267,
                0.237,
                0.26,
                0.201,
                0.145,
                0.161,
                0.264,
                0.232,
                0.053,
                0.187,
                0.176,
                0.24,
                0.113,
                0.184,
                0.0369,
                0.191,
                0.241,
                0.0971,
                0.191,
                0.125,
                0.151,
                0.177,
                0.393,
                0.151,
                0.199,
                0.175,
                0.168,
                0.164,
                0.317,
                0.0957,
                0.212,
                0.151,
                0.312,
                0.2,
                0.258,
                0.167,
                0.19,
                0.165,
                0.226,
                0.177,
                0.162,
                0.156,
                0.163,
                0.19,
                0.152,
                0.148,
                0.15,
                0.148,
                0.129,
                0.15,
                1.32,
                0.136,
                0.138,
                0.131,
                0.145,
                0.161,
                0.0118,
                0.155,
                0.182,
                0.174,
                0.223,
                0.142,
                0.201,
                0.184,
                0.194,
                0.191,
                0.197,
                0.208,
                0.181,
                0.144,
                0.17,
                0.162,
                0.204,
                0.205,
                0.25,
                0.243,
                0.154,
            ],
            "b2": [
                8.67,
                0.239,
                0.162,
                2.64,
                0.559,
                1.47,
                0.287,
                0.222,
                0.198,
                1.96,
                1.56,
                1.1,
                0.894,
                0.899,
                2.74,
                2.67,
                0.655,
                2.09,
                2.23,
                0.222,
                1.35,
                2.57,
                0.695,
                0.284,
                0.876,
                2.72,
                1.62,
                0.893,
                0.939,
                1.41,
                1.67,
                1.71,
                1.31,
                0.933,
                1.01,
                1.65,
                1.45,
                0.469,
                1.12,
                1.04,
                1.43,
                0.736,
                1.02,
                0.261,
                0.106,
                1.46,
                0.647,
                1.09,
                0.753,
                0.878,
                1.01,
                4.06,
                0.832,
                1.19,
                0.979,
                0.944,
                0.96,
                2.51,
                0.625,
                1.42,
                0.812,
                2.59,
                1.38,
                1.96,
                1.11,
                1.3,
                1.1,
                1.61,
                1.17,
                1.08,
                1.05,
                1.08,
                1.27,
                1.01,
                0.974,
                0.982,
                0.97,
                0.869,
                0.964,
                0.864,
                0.922,
                0.881,
                0.845,
                0.896,
                0.972,
                0.442,
                0.938,
                1.04,
                0.992,
                1.35,
                0.833,
                1.18,
                1.06,
                1.14,
                1.12,
                1.16,
                1.2,
                1.05,
                0.796,
                0.981,
                0.905,
                1.26,
                1.28,
                1.84,
                1.75,
                0.861,
            ],
            "b3": [
                38.9,
                0.879,
                0.531,
                8.8,
                2.96,
                4.68,
                1.18,
                0.838,
                0.713,
                9.34,
                6.21,
                3.91,
                3.15,
                3.06,
                8.1,
                8.18,
                2.38,
                7.22,
                5.99,
                0.824,
                4.3,
                8.68,
                0.239,
                1.29,
                3.06,
                3.47,
                4.87,
                2.62,
                2.78,
                2.89,
                5.73,
                4.75,
                3.8,
                2.69,
                2.76,
                4.82,
                4.09,
                2.15,
                3.98,
                3.59,
                6.45,
                2.54,
                3.8,
                0.957,
                3.84,
                6.79,
                2.28,
                3.82,
                2.85,
                3.28,
                3.62,
                43.1,
                2.85,
                4.05,
                3.3,
                3.14,
                3.08,
                9.04,
                2.51,
                4.21,
                2.4,
                14.1,
                4.12,
                11.8,
                3.11,
                3.93,
                3.02,
                6.33,
                3.28,
                3.06,
                3.07,
                3.11,
                4.18,
                2.95,
                2.81,
                2.86,
                2.88,
                2.61,
                2.93,
                2.6,
                3.12,
                2.63,
                2.57,
                2.74,
                2.76,
                1.52,
                3.19,
                3.47,
                3.14,
                4.99,
                2.81,
                4.25,
                3.58,
                4.21,
                4,
                4.23,
                2.57,
                3.75,
                2.58,
                3.44,
                2.68,
                4.03,
                3.92,
                7.39,
                7.79,
                2.58,
            ],
            "b4": [
                111,
                2.64,
                1.48,
                27.1,
                1.15,
                13.2,
                3.75,
                2.48,
                2.04,
                11.1,
                17.8,
                9.75,
                7.67,
                7.05,
                14.2,
                11.8,
                5.51,
                15.2,
                13.4,
                2.8,
                12.2,
                11,
                5.65,
                4.23,
                6.44,
                5.47,
                10.7,
                6.65,
                7.31,
                6.45,
                11.4,
                7.51,
                10.5,
                7.11,
                7.08,
                10.7,
                13.2,
                11.1,
                10.9,
                9.32,
                9.97,
                6.72,
                9.44,
                3.94,
                9.38,
                7.13,
                5.61,
                9.08,
                7.01,
                8.16,
                8.56,
                54,
                6.59,
                11.3,
                8.24,
                7.84,
                7.03,
                24.2,
                6.31,
                12.5,
                5.27,
                34.4,
                13,
                14.4,
                9.61,
                10.7,
                8.85,
                11,
                8.94,
                8.8,
                8.56,
                8.52,
                10.7,
                8.18,
                7.78,
                7.77,
                7.73,
                7.24,
                7.72,
                7.09,
                8.72,
                6.99,
                6.88,
                6.91,
                5.4,
                4.35,
                7.84,
                8.51,
                7.22,
                13.6,
                7.21,
                12.1,
                8.56,
                12.4,
                10.8,
                12.7,
                4.86,
                10.6,
                6.22,
                9.41,
                5.14,
                12.6,
                11.3,
                18,
                8.31,
                7.7,
            ],
            "b5": [
                166,
                7.09,
                3.88,
                91.8,
                37.7,
                36,
                10.8,
                6.75,
                5.25,
                13.4,
                47.8,
                23.4,
                17.7,
                16.1,
                23.2,
                14.9,
                12.3,
                17.6,
                16.9,
                6.7,
                39,
                15.8,
                14.7,
                14.5,
                14.3,
                16.1,
                19.2,
                18,
                20.7,
                15.8,
                12.1,
                13,
                28.2,
                19.4,
                19,
                15.2,
                29.5,
                38.9,
                26.6,
                21.4,
                12.2,
                14.7,
                25.7,
                9.44,
                24.6,
                10.4,
                12.4,
                21.7,
                17.5,
                20.7,
                18.9,
                69.8,
                15.6,
                32.4,
                21.4,
                20.4,
                16.1,
                26.4,
                15.9,
                29,
                11.9,
                39.5,
                31.8,
                14.9,
                21.2,
                23.8,
                18.8,
                16.9,
                19.3,
                19.6,
                19.2,
                19.1,
                26.2,
                18.5,
                17.7,
                17.7,
                17.6,
                16.7,
                17.8,
                16.6,
                23.7,
                16.3,
                16.2,
                16.1,
                10.9,
                9.42,
                19.3,
                21.2,
                17.2,
                33,
                17.7,
                34.4,
                20.4,
                36.2,
                27.6,
                35.7,
                13.5,
                27.9,
                14.8,
                23.7,
                11.2,
                30,
                25.1,
                22.7,
                16.5,
                15.9,
            ],
        }
        # create dfs of neutral and ionic species separately
        self.atom_df = pd.DataFrame(data=self.atom_dict)
        self.ion_df = pd.DataFrame(data=self.ion_dict)

    # method to select appropriate Gaussian parameters for the species of interest
    def select_terms(self, species, OS=int):
        # finding index in dict that corresponds to the row for the species of interest, returns as an array so must index to 0 when called
        if OS == 0:
            df = self.atom_df
            a_col, b_col = 1, 5
            idx = df[df["species"] == species].index.values
        else:
            df = self.ion_df
            a_col, b_col = 3, 8
            idx = df[
                (df["species"] == species) & (df["oxidation_state"] == OS)
            ].index.values

        # first term is the index number for the species and the second is the column
        a_terms = []
        b_terms = []
        for i in range(df.loc[:, "b1":].shape[1]):
            a_terms.append(df.iat[idx[0], a_col])
            b_terms.append(df.iat[idx[0], b_col])
            a_col += 1
            b_col += 1

        return (a_terms, b_terms)

    def electron_properties(self, keV):
        c = 2.9979 * 10**8
        E = (keV * 1000) * (1.6 * 10 ** (-19))
        m0 = 9.109 * 10**-31

        v = c * math.sqrt(1 - (1 / (1 + (E / (m0 * c**2))) ** 2))
        g = 1 / math.sqrt(1 - (v / c) ** 2)

        # l = (6.626 * 10**-34) / (2 * (9.11 * 10**-31) * E)**0.5
        l_relativistic = 12.2639 / (
            math.sqrt(keV * 1000 + 0.97845 * 10**-6 * (keV * 1000) ** 2)
        )
        return g, l_relativistic

    # method to calculate TCS after the appropriate row of values have been selected from database
    def calc_TCS(self, a_terms, b_terms, qmin, qmax, oxi_state):

        a1 = a_terms[0]
        a2 = a_terms[1]
        a3 = a_terms[2]
        a4 = a_terms[3]

        b1 = b_terms[0]
        b2 = b_terms[1]
        b3 = b_terms[2]
        b4 = b_terms[3]

        if len(a_terms) == 5:
            a5 = a_terms[4]
            b5 = b_terms[4]

        if len(a_terms) < 5:
            ff_total = (lambda q, a1=a1, a2=a2, a3=a3, a4=a4,
                         b1=b1, b2=b2, b3=b3, b4=b4: 
            a1*np.exp(-b1*(q/(4*np.pi))**2) +
            a2*np.exp(-b2*(q/(4*np.pi))**2) +
            a3*np.exp(-b3*(q/(4*np.pi))**2) +
            a4*np.exp(-b4*(q/(4*np.pi))**2)
        )

        else:
            ff_total = (lambda q, a1=a1, a2=a2, a3=a3, a4=a4, a5=a5,
                         b1=b1, b2=b2, b3=b3, b4=b4, b5=b5: 
            a1*np.exp(-b1*(q/(4*np.pi))**2) +
            a2*np.exp(-b2*(q/(4*np.pi))**2) +
            a3*np.exp(-b3*(q/(4*np.pi))**2) +
            a4*np.exp(-b4*(q/(4*np.pi))**2) +
            a5*np.exp(-b5*(q/(4*np.pi))**2) +
            0.023934 * (oxi_state /float((q)/(4*np.pi)**2))
        )

        # integrand, err = integrate.quad(ff_total, qmin, qmax)
        # ff_eff = (1 / (qmax - qmin)) * integrand

        # TCS = integrate.quad(DCS, mintheta, 180)
        # # in angstroms^2
        # TCS_corrected = TCS[0] / 100 * math.pi

        # empirical corrective factor, so that values are coparable to NIST electron scattering cross section database
        # https://srdata.nist.gov/srd64/

        return ff_total # TCS_corrected

    def interpolated_TCS(self, el, partial_charge, keV, qmin, qmax):
        # get electron wavelength and relativistic correction
        g, l = self.electron_properties(keV)
        # partial_charges = self.partial_charges(struct)

        #convert qmin to theta (in radians)
        # mintheta = math.asin( (qmin * l ) / ( 4 * math.pi ) )
        # mintheta = math.degrees(mintheta)

        # maxtheta =  math.asin( (qmax * l ) / ( 4 * math.pi ) )
        # maxtheta = math.degrees(maxtheta)

        # for el, charge in partial_charges.items():
        # get TCS for neutral atom
        a_terms_neutral, b_terms_neutral = self.select_terms(el, OS=0)
        TCS1 = self.calc_TCS(a_terms_neutral, b_terms_neutral, qmin, qmax, 0)

        if partial_charge == 0:
            interpolated_TCS = TCS1
        else:
            try:
                # select df rows of ions of same Z and different e- count
                selected_rows = self.ion_df[self.ion_df["species"] == el]
                # reset to start index from 0 in this subset
                selected_rows = selected_rows.reset_index(drop=True)

                Z = selected_rows.iat[0, 1]
                # make list of relative e- change, initialize with 1 because there is always the atomic info (e- count/Z = 1 for atom)
                rel_e_changes = [1]
                # make list of the TCS that corresponds to xs list
                # linearize by taking natural log
                TCSs = [TCS1]
                ln_TCSs = [math.log(TCS1)]

                # append the available OS from the ion_dict column 1 and all rows
                for i in range(selected_rows.shape[0]):
                    oxi_state = selected_rows.iat[i, 2]
                    e_count = Z - oxi_state
                    rel_e_changes.append(e_count / Z)
                    a_terms_ion, b_terms_ion = self.select_terms(el, OS=oxi_state)
                    TCS_ion = self.calc_TCS(a_terms_ion, b_terms_ion, qmin, qmax, oxi_state)
                    TCSs.append(TCS_ion)
                    ln_TCSs.append(math.log(TCS_ion))

                x = np.array(rel_e_changes).reshape((-1, 1))
                y = np.array(ln_TCSs)

                # apply linear regression model
                reg = LinearRegression().fit(x, y)
                m = reg.coef_
                b = reg.intercept_
                rel_e_ion = (Z - partial_charge) / Z
                interpolated_ln_TCS = m * rel_e_ion + b
                interpolated_TCS = math.exp(interpolated_ln_TCS)
            
            except:
                logging.warning(
                    f"No tabulated data for ionic cross-sections of {el} are available. " 
                    "Applying cross section of neutral atom."
                )
                interpolated_TCS = TCS1

        return interpolated_TCS

    def partial_charges(self):
        el_oxi_list = []

        for i in range(len(self.struct)):
            el = self.struct.species[i].symbol
            # el = struct.sites[i].species_string
            if not any(el == type[0] for type in el_oxi_list):
                el_oxi_list.append(
                    (el, self.struct.sites[i].oxi_state, self.struct.species[i].X)
                )

        chi1 = el_oxi_list[0][2]
        chi2 = el_oxi_list[1][2]
        ionic_char = 1 - math.exp(-(((chi1 - chi2) / 2) ** 2))

        partial_charges = {}
        for i in range(len(el_oxi_list)):
            el = el_oxi_list[i][0]
            oxi_state = el_oxi_list[i][1]
            charge = oxi_state * ionic_char
            partial_charges[el] = charge
        return partial_charges

    def database_ionic_radii(self):
        el_radii_list = []
        warnings.simplefilter("ignore")
        for idx in range(len(self.struct)):
            el = self.struct.species[idx].symbol
            if not any(el == type[0] for type in el_radii_list):
                for OS in range(-2, 4):
                    ionic_radius = Species(el, OS).ionic_radius
                    if ionic_radius:
                        el_radii_list.append((el, OS, ionic_radius, idx))
        return el_radii_list

    # make list of all types of atoms, then in get_radius just pull the relevant one
    def interpolated_ionic_radii(self, partial_charges, database_radii):

        # partial charges to calculate radii for
        charges = partial_charges

        el_oxi_list = []
        for idx in range(len(self.struct)):
            el = self.struct.species[idx].symbol
            if not any(el == type[0] for type in el_oxi_list):
                el_oxi_list.append((el, charges[el], idx))

        # database ionic radii of elements in structure
        el_radii_list = database_radii

        interpolated_radii = []
        for el_type in el_oxi_list:
            OS_list = [0]
            radii_list = [self.struct.sites[el_type[2]].specie.atomic_radius]

            for radius in el_radii_list:
                if radius[0] == el_type[0]:
                    OS_list.append(radius[1])
                    radii_list.append(radius[2])

            x = np.array(OS_list).reshape((-1, 1))
            y = np.array(radii_list)

            # apply linear regression model
            reg = LinearRegression().fit(x, y)
            m = reg.coef_
            b = reg.intercept_

            partial_charge = el_type[1]

            interpolated_radius = m * partial_charge + b
            interpolated_radii.append((el_type[0], interpolated_radius[0]))

        return interpolated_radii
