# -*- coding: utf-8 -*-
# BioSTEAM: The Biorefinery Simulation and Techno-Economic Analysis Modules
# Copyright (C) 2020-2023, Yoel Cortes-Pena <yoelcortes@gmail.com>
# 
# This module is under the UIUC open-source license. See 
# github.com/BioSTEAMDevelopmentGroup/biosteam/blob/master/LICENSE.txt
# for license details.
"""
This module contains abstract classes for modeling stage-wise separations/reactions in unit operations.

"""
import thermosteam as tmo
from thermosteam.base.sparse import SparseVector, sum_sparse_vectors
from thermosteam import separations as sep
from numba import njit
import biosteam as bst
import flexsolve as flx
import numpy as np
import pandas as pd
from scipy.optimize import minimize, differential_evolution
from scipy.interpolate import LinearNDInterpolator, RBFInterpolator, PchipInterpolator, Akima1DInterpolator
from math import inf
from typing import Callable, Optional
from copy import copy
from .. import Unit
from .design_tools import MESH
from thermosteam import (
    equilibrium, VariableNode,
)

__all__ = (
    'SinglePhaseStage',
    'ReactivePhaseStage',
    'StageEquilibrium',
    'MultiStageEquilibrium',
    'PhasePartition',
)

# %% Utilities

@njit(cache=True)
def _vle_phi_K(vapor, liquid):
    F_vapor = vapor.sum()
    F_liquid = liquid.sum()
    phi = F_vapor / (F_vapor + F_liquid)
    y = vapor / F_vapor
    x = liquid / F_liquid
    return phi, y / x 

def _get_specification(name, value):
    B = None
    Q = None
    T = None
    F = None
    if name == 'Duty':
        Q = value
    elif name == 'Reflux':
        B = inf if value == 0 else 1 / value
    elif name == 'Boilup':
        B = value
    elif name == 'Temperature':
        T = value
    elif name == 'Flow':
        F = value
    else:
        raise RuntimeError(f"specification '{name}' not implemented for stage")
    return B, Q, T, F


# %% Single phase

class SinglePhaseStage(Unit):
    _N_ins = 2
    _N_outs = 1
    _ins_size_is_fixed = False
    
    @property
    def _energy_variable(self):
        if self.T is None: return 'T'
        else: return None
    
    def _init(self, T=None, P=None, Q=None, phase=None):
        self.T = T
        self.Q = Q
        self.P = P
        self.phase = phase
        
    def _mass_and_energy_balance_specifications(self):
        specs = [('phase', self.phase, '-')]
        if self.T is not None: 
            specs.append(
                ('T', self.T, 'K')
            )
        if self.Q is not None:
            specs.append(
                ('Q', self.Q, 'kJ/hr')
            )
        if self.P is not None:
            specs.append(
                ('P', self.P, 'Pa')
            )
        return self.line, specs
        
    def _run(self):
        outlet = self.outs[0]
        outlet.mix_from(self.ins, energy_balance=False)
        if self.P is not None: outlet.P = self.P
        if self.phase is None: 
            outlet.phase = self.ins[0].phase
        else:
            outlet.phase = self.phase
        if self.T is None:
            if self.Q is None:
                raise RuntimeError('must specify either Q or T')
            else:
                outlet.H = sum([i.H for i in self.ins], self.Q)
        elif self.Q is not None:
            raise RuntimeError('must specify either Q or T; not both')
        else:
            outlet.T = self.T

    def _get_energy_departure_coefficient(self, stream):
        if self.T is None: return (self, -stream.C)
    
    def _create_energy_departure_equations(self):
        if self.T is not None: return []
        # Ll: C1dT1 - Ce2*dT2 - Cr0*dT0 - hv2*L2*dB2 = Q1 - H_out + H_in
        # gl: hV1*L1*dB1 - hv2*L2*dB2 - Ce2*dT2 - Cr0*dT0 = Q1 + H_in - H_out
        outlet = self.outs[0]
        coeff = {self: outlet.C}
        for i in self.ins: i._update_energy_departure_coefficient(coeff)
        return [(coeff, outlet.H - sum([i.H for i in self.ins]))]
        
    def _create_material_balance_equations(self, composition_sensitive):
        outlet = self.outs[0]
        fresh_inlets, process_inlets, equations = self._begin_equations(composition_sensitive)
        ones = np.ones(self.chemicals.size)
        minus_ones = -ones
        zeros = np.zeros(self.chemicals.size)
        
        # Overall flows
        eq_overall = {outlet: ones}
        for i in process_inlets: 
            if i in eq_overall:
                del eq_overall[i]
            else:
                eq_overall[i] = minus_ones
        equations.append(
            (eq_overall, sum([i.mol for i in fresh_inlets], zeros))
        )
        return equations

    def _update_energy_variable(self, departure):
        for i in self.outs: i.T += departure
        
    def _update_nonlinearities(self): pass
    
    @property
    def equation_node_names(self): 
        if self._energy_variable is None:
            return (
                'overall_material_balance_node', 
            )
        else:
            return (
                'overall_material_balance_node', 
                'energy_balance_node',
            )
    
    def initialize_overall_material_balance_node(self):
        self.overall_material_balance_node.set_equations(
            outputs=[i.F_node for i in self.outs],
            inputs=[j for i in self.ins if (j:=i.F_node)],
        )
        
    def initialize_energy_balance_node(self):
        self.energy_balance_node.set_equations(
            inputs=(
                self.T_node, 
                *[i.T_node for i in (*self.ins, *self.outs)],
                *[i.F_node for i in (*self.ins, *self.outs)],
                *[j for i in self.ins if (j:=i.E_node)],
            ),
            outputs=[
                j for i in self.outs if (j:=i.E_node)
            ],
        )
    
    @property
    def T_node(self):
        if hasattr(self, '_T_node'): return self._T_node
        self._T_node = var = VariableNode(f"{self.node_tag}.T", lambda: self.T)
        return var 
        
    @property
    def E_node(self):
        if self._energy_variable is None:
            return None
        else:
            return self.T_node

    def _collect_edge_errors(self):
        equation_name = self.overall_material_balance_node.name
        outs = self.outs
        IDs = self.chemicals.IDs
        results = []
        error = sum([i.mol for i in outs]) - sum([i.mol for i in self.ins])
        for i, outlet in enumerate(outs):
            for j, ID in enumerate(IDs):
                index = (equation_name, outlet.F_node.name, ID)
                results.append((index, error[j]))
        return results # list[tuple[tuple[equation_name, variable_name, chemical_name | '-'], value]]

    def _collect_equation_errors(self):
        equation_name = self.overall_material_balance_node.name
        outs = self.outs
        results = []
        error = np.abs(sum([i.mol for i in outs]) - sum([i.mol for i in self.ins])).sum()
        index = equation_name
        results.append((index, error))
        
        if self._energy_variable is not None:
            equation_name = self.energy_balance_node.name
            error = sum([i.H for i in outs]) - sum([i.H for i in self.ins])
            results.append((equation_name, np.abs(error)))
        
        return results # list[tuple[equation_name, value]]
    

class ReactivePhaseStage(bst.Unit): # Does not include VLE
    _N_outs = _N_ins = 1
    _ins_size_is_fixed = False
    
    @property
    def equation_node_names(self): 
        if self._energy_variable is None:
            return (
                'overall_material_balance_node',
                'reaction_phenomenode',
            )
        else:
            return (
                'overall_material_balance_node', 
                'reaction_phenomenode',
                'energy_balance_node',
            )
    
    @property 
    def _energy_variable(self):
        if self.T is None: return 'T'
    
    def _init(self, reaction, T=None, P=None, Q=0, phase=None):
        self.reaction = reaction
        self.T = T
        self.P = P
        self.Q = Q
        self.phase = phase
        
    _mass_and_energy_balance_specifications = SinglePhaseStage._mass_and_energy_balance_specifications
        
    def _run(self):
        feed = self.ins[0]
        outlet, = self.outs
        outlet.copy_like(feed)
        if self.P is not None: outlet.P = self.P
        if self.phase is not None: outlet.phase = self.phase
        if self.T is None: 
            self.reaction.adiabatic_reaction(outlet, Q=self.Q)
        else:
            self.reaction(outlet)
            outlet.T = self.T
        self.dmol = outlet.mol - feed.mol
        
    def _create_material_balance_equations(self, composition_sensitive=False):
        product, = self.outs
        n = self.chemicals.size
        ones = np.ones(n)
        minus_ones = -ones
        fresh_inlets, process_inlets, equations = self._begin_equations(composition_sensitive)
        # Overall flows
        eq_overall = {}
        predetermined_flow = SparseVector.from_dict(sum_sparse_vectors([i.mol for i in fresh_inlets]), size=n)
        rhs = predetermined_flow + self.dmol
        eq_overall[product] = ones
        for i in process_inlets: eq_overall[i] = minus_ones
        equations.append(
            (eq_overall, rhs)
        )
        return equations
    
    def _update_energy_variable(self, departure):
        self.outs[0].T += departure
    
    def _get_energy_departure_coefficient(self, stream):
        if self.T is not None: return
        return (self, -stream.C)
    
    def _create_energy_departure_equations(self):
        if self.T is not None: return []
        # Ll: C1dT1 - Ce2*dT2 - Cr0*dT0 - hv2*L2*dB2 = Q1 - H_out + H_in
        # gl: hV1*L1*dB1 - hv2*L2*dB2 - Ce2*dT2 - Cr0*dT0 = Q1 + H_in - H_out
        outlet = self.outs[0]
        coeff = {self: outlet.C}
        for i in self.ins: i._update_energy_departure_coefficient(coeff)
        return [(coeff, -self.Hnet)]
    
    def _update_nonlinearities(self):
        f = PhasePartition.dmol_relaxation_factor
        old = self.dmol
        new = self.reaction.conversion(self.ins[0])
        self.dmol = f * old + (1 - f) * new
    
    def initialize_reaction_phenomenode(self):
        self.reaction_phenomenode.set_equations(
            inputs=[j for i in self.ins if (j:=i.F_node)],
            outputs=[self.R_node],
        )
    
    def initialize_overall_material_balance_node(self):
        self.overall_material_balance_node.set_equations(
            inputs=[j for i in self.ins if (j:=i.F_node)] + [self.R_node],
            outputs=[i.F_node for i in self.outs],
        )
            
    def initialize_energy_balance_node(self):
        if self.T is None:
            self.energy_balance_node.set_equations(
                inputs=(
                    self.T_node, 
                    *[i.T_node for i in (*self.ins, *self.outs)],
                    *[i.F_node for i in (*self.ins, *self.outs)],
                    *[j for i in self.ins if (j:=i.E_node)]
                ),
                outputs=[j for i in self.outs if (j:=i.E_node)],
            )
    
    @property
    def R_node(self):
        if hasattr(self, '_R_node'): return self._R_node
        self._R_node = var = VariableNode(f"{self.node_tag}.R", lambda: self.dmol)
        return var 
    
    @property
    def T_node(self):
        if hasattr(self, '_T_node'): return self._T_node
        if self.T is None: 
            var = VariableNode(f"{self.node_tag}.T", lambda: self.T)
        else:
            var = None
        self._T_node = var
        return var 
    
    def get_E_node(self, stream):
        if self.T is None:
            return self.E_node
        else:
            return None
    
    @property
    def E_node(self):
        if hasattr(self, '_E_node'): return self._E_node
        if self._energy_variable is None:
            var = None
        else:
            var = self.T_node
        self._E_node = var
        return var


# %% Two phases

class StageEquilibrium(Unit):
    _N_ins = 0
    _N_outs = 2
    _ins_size_is_fixed = False
    _outs_size_is_fixed = False
    auxiliary_unit_names = ('partition', 'mixer', 'splitters')
    
    def __init__(self, ID='', ins=None, outs=(), thermo=None, *, 
            phases, partition_data=None, top_split=0, bottom_split=0,
            B=None, Q=None, T=None, top_chemical=None, F=None, P=None, 
            reaction=None,
        ):
        self._N_outs = 2 + int(top_split) + int(bottom_split)
        self.phases = phases
        Unit.__init__(self, ID, ins, outs, thermo)
        mixer = self.auxiliary(
            'mixer', bst.Mixer, ins=self.ins, 
        )
        mixer.outs[0].phases = phases
        partition = self.auxiliary(
            'partition', PhasePartition, ins=mixer-0, phases=phases,
            partition_data=partition_data, top_chemical=top_chemical,
            outs=(
                None if top_split else self.outs[0],
                None if bottom_split else self.outs[1],
            ),
        )
        self.reaction = reaction
        self.top_split = top_split
        self.bottom_split = bottom_split
        self.splitters = []
        if top_split:
            self.auxiliary(
                'splitters', bst.Splitter, 
                partition-0, [self.outs[2], self.outs[0]],
                split=top_split,
            )
        if bottom_split:
            self.auxiliary(
                'splitters', bst.Splitter, 
                partition-1, [self.outs[-1], self.outs[1]],
                split=bottom_split, 
            )
        self.set_specification(B, Q, T, F, P)
    
    def _mass_and_energy_balance_specifications(self):
        specs = []
        if self.phases is not None:
            specs.append(
                ('Phases', self.phases, '-')
            )
        if self.T_specification is not None: 
            specs.append(
                ('T', self.T, 'K')
            )
        if self.B_specification is not None: 
            specs.append(
                ('Vapor to liquid ratio', self.B, 'by mol')
            )
        if self.Q is not None:
            specs.append(
                ('Q', self.Q, 'kJ/hr')
            )
        if self.P is not None:
            specs.append(
                ('P', self.P, 'Pa')
            )
        return self.line, specs
    
    @property
    def composition_sensitive(self):
        return self.phases == ('L', 'l')
    
    @property
    def Q(self):
        return self.partition.Q
    @Q.setter
    def Q(self, Q):
        self.partition.Q = Q
    
    @property
    def B(self):
        return self.partition.B
    @B.setter
    def B(self, B):
        if B is None: breakpoint()
        if B < 0: breakpoint()
        self.partition.B = B
    
    @property
    def B_specification(self):
        return self.partition.B_specification
    @B_specification.setter
    def B_specification(self, B_specification):
        self.partition.B_specification = B_specification
    
    @property
    def T(self):
        return self.partition.T
    @T.setter
    def T(self, T):
        self.partition.T = T
        for i in self.partition.outs: i.T = T
    
    @property
    def P(self):
        return self.partition.P
    @P.setter
    def P(self, P):
        self.partition.P = P
        for i in self.partition.outs: i.P = P
    
    @property
    def T_specification(self):
        return self.partition.T_specification
    @T_specification.setter
    def T_specification(self, T):
        self.partition.T_specification = T
        for i in self.partition.outs: i.T = T
    
    @property
    def K(self):
        return self.partition.K
    @K.setter
    def K(self, K):
        self.partition.K = K
    
    @property
    def reaction(self):
        return self.partition.reaction
    @reaction.setter
    def reaction(self, reaction):
        self.partition.reaction = reaction
    
    def _update_auxiliaries(self):
        for i in self.splitters: i.ins[0].mix_from(i.outs, energy_balance=False)
        self.mixer.outs[0].mix_from(self.ins, energy_balance=False)
    
    def add_feed(self, stream):
        self.ins.append(stream)
        self.mixer.ins.append(
            self.auxlet(
                stream
            )
        )
        
    def set_specification(self, B, Q, T, F, P):
        if B is None and Q is None and T is None and F is None: Q = 0.
        partition = self.partition
        partition.B_specification = partition.B = B
        partition.T_specification = partition.T = T
        partition.P = P
        partition.F_specification = F
        if F is not None:
            raise NotImplementedError('F specification not implemented in BioSTEAM yet')
        if T is not None: 
            for i in partition.outs: i.T = T
        if P is not None: 
            for i in partition.outs: i.P = P
        partition.Q = Q
        if not (B is None and T is None): 
            self._energy_variable = None
        elif self.phases == ('g', 'l'):
            self._energy_variable = 'B'
        else:
            self._energy_variable = 'T'
    
    @property
    def extract(self):
        return self.outs[0]
    @property
    def raffinate(self):
        return self.outs[1]
    @property
    def extract_side_draw(self):
        if self.top_split: return self.outs[2]
    @property
    def raffinate_side_draw(self):
        if self.bottom_split: return self.outs[-1]
    
    @property
    def vapor(self):
        return self.outs[0]
    @property
    def liquid(self):
        return self.outs[1]
    @property
    def vapor_side_draw(self):
        if self.top_split: return self.outs[2]
    @property
    def liquid_side_draw(self):
        if self.bottom_split: return self.outs[-1]
    @property
    def top_side_draw(self):
        if self.top_split: return self.outs[2]
    @property
    def bottom_side_draw(self):
        if self.bottom_split: return self.outs[-1]
    
    def _run(self):
        if self.T_specification is None:
            self.mixer._run()
        else:
            mix = self.mixer.outs[0]
            mix.phase = 'l'
            mix.mol = sum([i.mol for i in self.ins])
            mix.T = self.T_specification
        self.partition._run()
        for i in self.splitters: i._run()
        self._update_separation_factors()
    
    def _get_energy_departure_coefficient(self, stream):
        energy_variable = self._energy_variable
        if energy_variable is None: return None
        if energy_variable == 'B':
            if stream.phase != 'g': return None
            vapor, liquid = self.partition.outs
            if vapor.isempty():
                hV = self.mixture(liquid, 'h', phase='g')
            else:
                hV = vapor.h
            dHdB = hV * liquid.F_mol
            if self.top_split:
                if stream.imol is self.top_side_draw.imol:
                    split = self.top_split
                    return (self, -dHdB * split)
                elif stream.imol is self.outs[0].imol:
                    split = self.top_split
                    return (self, -dHdB * (1 - split))
                else:
                    raise ValueError('stream must be an outlet')
            elif stream.imol is self.outs[0].imol:
                return (self, -dHdB)
            else:
                raise ValueError('stream must be an outlet')
        else:
            return (self, -stream.C)
    
    def correct_mass_balance(self):
        F_out = sum([i.F_mass for i in self.outs])
        F_in = sum([i.F_mass for i in self.ins])
        if F_in == 0: 
            for i in self.outs: i.empty()
        else:
            f = F_in / F_out
            for i in self.outs: i.mol *= f
    
    def _create_energy_departure_equations(self):
        energy_variable = self._energy_variable
        if energy_variable is None: return []
        if energy_variable == 'B':
            vapor, liquid = self.partition.outs
            if vapor.isempty():
                if liquid.isempty(): 
                    raise RuntimeError('empty stage or tray')
                hV = self.mixture('h', liquid, phase='g')
            else:
                hV = vapor.h
            coeff = {self: hV * liquid.F_mol}
        else:
            coeff = {self: sum([i.C for i in self.partition.outs])}
        for i in self.ins: i._update_energy_departure_coefficient(coeff)
        if self.reaction:
            dH = self.Q - self.Hnet
        else:
            dH = self.Q + self.H_in - self.H_out
            
        # if energy_variable == 'T': 
        #     print('--------')
        #     print(self.mass_balance_error())
        #     print(self.node_tag)
        #     print(coeff)
        #     print(dH)
        #     if abs(dH > 1e3): breakpoint()
        return [(coeff, dH)]
    
    def _create_material_balance_equations(self, composition_sensitive):
        partition = self.partition
        chemicals = self.chemicals
        pIDs = partition.IDs
        IDs = chemicals.IDs
        self._update_separation_factors()
        if pIDs != IDs and pIDs:
            partition.IDs = IDs
            S = np.ones(chemicals.size)
            index = [IDs.index(i) for i in pIDs]
            for i, j in zip(index, self.S): S[i] = j
            pIDs = set(pIDs)
            data = self.partition.partition_data
            if data:
                top = data.get('extract_chemicals') or data.get('top_chemicals', ())
                bottom = data.get('raffinate_chemicals') or data.get('bottom_chemicals', ())
                for i in top: S[chemicals.index(i)] = inf
                for i in bottom: S[chemicals.index(i)] = 0
                pIDs.update(top)
                pIDs.update(bottom)
            for index, ID in enumerate(IDs):
                if ID in pIDs: continue
                top = partition.outs[0].mol[index]
                bottom = partition.outs[1].mol[index]
                if top:
                    if bottom:
                        S[index] =  top / bottom
                    else:
                        S[index] =  inf
                else:
                    S[index] =  0
        else:
            S = self.S.copy()
        top_split = self.top_split
        bottom_split = self.bottom_split
        fresh_inlets, process_inlets, equations = self._begin_equations(composition_sensitive)
        top, bottom, *_ = self.outs
        top_side_draw = self.top_side_draw
        bottom_side_draw = self.bottom_side_draw
        N = self.chemicals.size
        ones = np.ones(N)
        minus_ones = -ones
        zeros = np.zeros(N)
        
        # # Overall flows
        eq_overall = {}
        reaction = self.reaction
        if reaction: # Reactive liquid
            predetermined_flow = SparseVector.from_dict(sum_sparse_vectors([i.mol for i in fresh_inlets]), size=N)
            rhs = predetermined_flow + self.partition.dmol
            for i in self.outs: eq_overall[i] = ones
            for i in process_inlets: eq_overall[i] = minus_ones
            equations.append(
                (eq_overall, rhs)
            )
        else:
            for i in self.outs: eq_overall[i] = ones
            for i in process_inlets:
                if i in eq_overall: del eq_overall[i]
                else: eq_overall[i] = minus_ones
            equations.append(
                (eq_overall, sum([i.mol for i in fresh_inlets], zeros))
            )
        
        # Top to bottom flows
        eq_outs = {}
        infmask = ~np.isfinite(S)
        S[infmask] = 1
        if top_split == 1:
            if bottom_split == 1:
                eq_outs[top_side_draw] = -ones
                eq_outs[bottom_side_draw] = S
            else:
                eq_outs[top_side_draw] = -ones * (1 - bottom_split)
                eq_outs[bottom] = S
        elif bottom_split == 1:
            eq_outs[top] = coef = -ones
            eq_outs[bottom_side_draw] = S * (1 - top_split) 
            coef[infmask] = 0
        else:
            eq_outs[top] = coef = -ones * (1 - bottom_split)
            eq_outs[bottom] = S * (1 - top_split) 
            coef[infmask] = 0
        equations.append(
            (eq_outs, zeros)
        )
        # Top split flows
        if top_side_draw:
            if top_split == 1:
                eq_top_split = {
                    top: ones,
                }
            else:
                eq_top_split = {
                    top_side_draw: ones,
                    top: -top_split / (1 - top_split),
                }
            equations.append(
                (eq_top_split, zeros)
            )
        # Bottom split flows
        if bottom_side_draw:
            if bottom_split == 1:
                eq_bottom_split = {
                    bottom: ones,
                }
            else:
                eq_bottom_split = {
                    bottom_side_draw: ones,
                    bottom: -bottom_split / (1 - bottom_split),
                }
            equations.append(
                (eq_bottom_split, zeros)
            )
        return equations
    
    def _update_energy_variable(self, departure):
        phases = self.phases
        energy_variable = self._energy_variable
        if energy_variable == 'B':
            partition = self.partition
            top, bottom = partition.outs
            IDs = partition.IDs
            B = top.imol[IDs].sum() / bottom.imol[IDs].sum()
            self.B = B + departure
        elif phases == ('L', 'l'):
            self.T = T = self.T + departure
            for i in self.outs: i.T = T
        else:
            raise RuntimeError('invalid phases')
            
    def _update_composition_parameters(self):
        partition = self.partition
        data = partition.partition_data
        if data and 'K' in data: return
        partition._run_decoupled_Kgamma()
    
    def _update_net_flow_parameters(self):
        self.partition._run_decoupled_B()
    
    def _update_nonlinearities(self):
        self._update_equilibrium_variables()
        self._update_reaction_conversion()
    
    def _update_equilibrium_variables(self):
        phases = self.phases
        if phases == ('g', 'l'):
            partition = self.partition
            partition._run_decoupled_KTvle()
            T = partition.T
            for i in (*partition.outs, *self.outs): i.T = T
        elif phases == ('L', 'l'):
            pass
            # self.partition._run_lle(single_loop=True)
        else:
            raise NotImplementedError(f'K for phases {phases} is not yet implemented')
        
    def _update_reaction_conversion(self):    
        if self.reaction and self.phases == ('g', 'l'):
            self.partition._run_decoupled_reaction()
        
    def _init_separation_factors(self):
        B_spec = self.B_specification
        if B_spec == 0:
            self.S = 0 * self.K
        elif B_spec == inf:
            self.S = self.K.copy()
            self.S[:] = inf
        elif self.B is None:
            self.S = np.ones(self.chemicals.size)
        else:
            self.S = (self.B * self.K)
        
    def _update_separation_factors(self, f=None):
        if not hasattr(self, 'S'): self._init_separation_factors()
        if self.B is None or self.B == inf or self.B == 0: return
        K = self.K
        if K.size != self.S.size: 
            self._init_separation_factors()
            return
        S = K * self.B
        if self.phases == ('L', 'l'):
            self.S = S
        else:
            if f is None: f = self.partition.S_relaxation_factor
            if f == 0:
                self.S = S
            else:
                self.S = np.exp((1 - f) * np.log(S) + f * np.log(self.S)) if f else S
    
    @property
    def equation_node_names(self): 
        material_balances = (
            'overall_material_balance_node', 
            'separation_material_balance_node',
        )
        if self.phases == ('g', 'l'):
            phenomenode = 'vle_phenomenode'
        else: # Assume LLE
            phenomenode = 'lle_phenomenode'
        if self._energy_variable is None:
            return (
                *material_balances,
                phenomenode,
            )
        else:
            return (
                *material_balances,
                'energy_balance_node',
                phenomenode,
            )
    
    @property
    def phenomenode(self):
        return self.vle_phenomenode if self.phases == ('g', 'l') else self.lle_phenomenode
    
    def initialize_overall_material_balance_node(self):
        self.overall_material_balance_node.set_equations(
            inputs=[j for i in self.ins if (j:=i.F_node)],
            outputs=[i.F_node for i in self.outs],
        )
    
    def initialize_separation_material_balance_node(self):
        self.separation_material_balance_node.set_equations(
            outputs=[i.F_node for i in self.outs],
            inputs=[self.K_node, self.Phi_node],
        )
        
    def initialize_lle_phenomenode(self):
        intermediates = [
            i.F_node for i in self.outs 
            if hasattr(i.sink, 'lle_phenomenode')
        ]
        self.lle_phenomenode.set_equations(
            inputs=[self.T_node, *[i.F_node for i in self.ins]],
            outputs=[self.K_node, self.Phi_node, *intermediates],
            tracked_outputs=[self.K_node, self.Phi_node],
        )
    
    def initialize_vle_phenomenode(self):
        if self.T_specification:
            self.vle_phenomenode.set_equations(
                inputs=[self.T_node, *[i.F_node for i in self.outs]],
                outputs=[self.K_node, self.Phi_node],
            )
        else:
            self.vle_phenomenode.set_equations(
                inputs=[i.F_node for i in self.outs if i.phase == 'l'],
                outputs=[self.T_node, self.K_node],
            )
            
    def initialize_energy_balance_node(self):
        self.energy_balance_node.set_equations(
            inputs=(
                self.T_node, 
                *[i.T_node for i in (*self.ins, *self.outs)],
                *[i.F_node for i in (*self.ins, *self.outs)],
                *[j for i in self.ins if (j:=i.E_node)]
            ),
            outputs=[j for i in self.outs if (j:=i.E_node)],
        )
        
    @property
    def K_node(self):
        if hasattr(self, '_K_node'): return self._K_node
        partition_data = self.partition.partition_data
        if self.B_specification == 0 or self.B_specification == np.inf or (partition_data and 'K' in partition_data):
            var = None
        else:
            var = VariableNode(f"{self.node_tag}.K", lambda: self.K)
        self._K_node = var
        return var
    
    @property
    def T_node(self):
        if hasattr(self, '_T_node'): return self._T_node
        if self.T_specification is not None: 
            var = None
        else:
            var = VariableNode(f"{self.node_tag}.T", lambda: self.T)
        self._T_node = var
        return var 
    
    @property
    def Phi_node(self):
        if hasattr(self, '_Phi_node'): return self._Phi_node
        if self.phases == ('g', 'l'):
            if self.B_specification is not None:
                self._Phi_node = var = None
            else:
                self._Phi_node = var = VariableNode(f"{self.node_tag}.Phi", lambda: self.B)
        else:
            self._Phi_node = var = VariableNode(f"{self.node_tag}.Phi", lambda: self.B)
        return var
    
    def get_E_node(self, stream):
        if self.phases == ('g', 'l') and stream.phase != 'g':
            return None
        else:
            return self.E_node
    
    @property
    def E_node(self):
        if hasattr(self, '_E_node'): return self._E_node
        if self._energy_variable is None:
            var = None
        elif self.phases == ('g', 'l'):
            var = self.Phi_node
        else:
            var = self.T_node
        self._E_node = var
        return var
    
    def _collect_edge_errors(self):
        equation_name = self.overall_material_balance_node.name
        outs = self.outs
        results = []
        error = np.abs(sum([i.mol for i in outs]) - sum([i.mol for i in self.ins])).sum()
        for i, outlet in enumerate(outs):
            index = (equation_name, outlet.F_node.name)
            results.append((index, error))
        return results # list[tuple[tuple[equation_name, variable_name], value]]

    def _collect_equation_errors(self):
        equation_name = self.overall_material_balance_node.name
        outs = self.outs
        results = []
        flows_out = sum([i.mol for i in outs])
        error = np.abs(flows_out - sum([i.mol for i in self.ins])).sum() / flows_out.sum()
        results.append((equation_name, error))
        
        equation_name = self.separation_material_balance_node.name
        S = (self.K * self.B)
        flows_by_phase = {i.phase: 0 for i in outs}
        for i in outs: flows_by_phase[i.phase] += i.mol
        top, bottom = flows_by_phase.values()
        expected = S * bottom
        actual = top
        error = np.abs(expected - actual).sum() / (top + bottom).sum()
        index = equation_name
        results.append((index, error))
        
        ms = bst.MultiStream.sum(self.outs, conserve_phases=True)
        if self.phases == ('g', 'l'):
            equation_name = self.vle_phenomenode.name
            if self.T_specification:
                ms.vle(T=self.T_specification, P=self.P)
                gas = ms.imol['g']
                liq = ms.imol['l']
                B = gas.sum() / liq.sum()
                K = gas / (liq * B)
                expected = np.array([*K, B])
                actual = np.array([*np.log1p(self.K), self.B])
            else:
                bp = ms['l'].bubble_point_at_P()
                expected = np.array([*bp.K, bp.T])
                actual = np.array([*np.log1p(self.K), self.T])
        else:
            equation_name = self.lle_phenomenode.name
            Gamma = self.thermo.Gamma(ms.lle_chemicals)
            IDs = [i.ID for i in ms.lle_chemicals]
            x_liquid = ms.imol['l', IDs]
            x_liquid /= x_liquid.sum()
            x_LIQUID = ms.imol['L', IDs]
            x_LIQUID /= x_LIQUID.sum()
            K = Gamma(x=x_liquid, T=self.T) / Gamma(x=x_LIQUID, T=self.T)
            z = ms.imol[IDs]
            z /= z.sum()
            phi = tmo.equilibrium.phase_fraction(z, K, 0.5)
            if phi == 1: phi = 1 - 1e-16
            expected = np.array([*np.log1p(K), phi / (1. - phi)])
            actual = np.array([*np.log1p(self.K), self.B])
        error = np.abs(expected - actual).sum()
        results.append((equation_name, error))
        
        if self._energy_variable is not None:
            equation_name = self.energy_balance_node.name
            error = (sum([i.H for i in outs]) - sum([i.H for i in self.ins])) / sum([i.C for i in outs])
            results.append((equation_name, np.abs(error)))
            
        return results # list[tuple[equation_name, value]]
    
    
# %%

class PhasePartition(Unit):
    _N_ins = 1
    _N_outs = 2
    strict_infeasibility_check = False
    dmol_relaxation_factor = 0
    S_relaxation_factor = 0
    B_relaxation_factor = 0
    K_relaxation_factor = 0
    T_relaxation_factor = 0
    F_relaxation_factor = 0
    gamma_y_relaxation_factor = 0
    fgas_relaxation_factor = 0
    
    def _init(self, phases, partition_data, top_chemical=None, reaction=None):
        self.partition_data = partition_data
        self.phases = phases
        self.top_chemical = top_chemical
        self.reaction = reaction
        self.gamma_y = None
        self.fgas = None
        self.IDs = None
        self.K = None
        self.B = None
        self.T = None
        self.P = None
        self.Q = 0.
        self.B_specification = self.T_specification = None
        self.B_fallback = 1
        self.dmol = SparseVector.from_size(self.chemicals.size)
        for i, j in zip(self.outs, self.phases): i.phase = j 
        
    @property
    def x(self):
        try:
            IDs = self.IDs
            x = self.outs[1].imol[IDs]
            xsum = x.sum()
            if xsum:
                return x / xsum
            else:
                return x
        except:
            return None
        
    @property
    def y(self):
        try:
            return self.x * self.K
        except:
            return None
        
    def _get_mixture(self, linked=True):
        if linked:
            try:
                ms = self._linked_multistream 
            except:
                outs = self.outs
                for i, j in zip(self.outs, self.phases): i.phase = j 
                self._linked_multistream = ms = tmo.MultiStream.from_streams(outs)
            if self.T_specification is not None: ms.T = self.T_specification
            return ms
        else:
            try:
                ms = self._unlinked_multistream
                ms.copy_like(self._get_mixture())
            except:
                self._unlinked_multistream = ms = self._get_mixture().copy()
            if self.T_specification is not None: ms.T = self.T_specification
            return ms
    
    def _get_arrays(self):
        if self.gamma_y is None:
            return {'K': self.K}
        else:
            return {'K': self.K, 'gamma_y': self.gamma_y}
    
    def _set_arrays(self, IDs, **kwargs):
        IDs_last = self.IDs
        IDs = tuple(IDs)
        if IDs_last and IDs_last != IDs:
            if len(IDs_last) > len(IDs):
                index = [IDs_last.index(i) for i in IDs]
                for name, array in kwargs.items():
                    last = getattr(self, name)
                    f = getattr(PhasePartition, name + '_relaxation_factor')
                    g = 1. - f 
                    for i, j in enumerate(index):
                        last[j] = g * array[i] + f * last[j]
            else:
                self.IDs = IDs
                index = [IDs.index(i) for i in IDs_last]
                for name, array in kwargs.items():
                    last = getattr(self, name)
                    new = array.copy()
                    setattr(self, name, new)
                    for i, j in enumerate(index): new[i] = last[j]
                    f = getattr(PhasePartition, name + '_relaxation_factor')
                    g = 1. - f 
                    for i, j in enumerate(index):
                        new[j] = g * array[i] + f * new[j]
        else:
            for i, j in kwargs.items(): setattr(self, i, j)
            self.IDs = IDs
    
    def _get_activity_model(self):
        chemicals = self.chemicals
        index = chemicals.get_lle_indices(sum([i.mol for i in self.ins]).nonzero_keys())
        chemicals = chemicals.tuple
        lle_chemicals = [chemicals[i] for i in index]
        return self.thermo.Gamma(lle_chemicals), [i.ID for i in lle_chemicals], index
    
    def _get_fugacity_models(self):
        chemicals = self.chemicals
        index = chemicals.get_vle_indices(sum([i.mol for i in self.ins]).nonzero_keys())
        chemicals = chemicals.tuple
        vle_chemicals = [chemicals[i] for i in index]
        return (equilibrium.GasFugacities(vle_chemicals, thermo=self.thermo),
                equilibrium.LiquidFugacities(vle_chemicals, thermo=self.thermo), 
                [i.ID for i in vle_chemicals],
                index)
    
    def _update_fgas(self, P=None):
        F_gas, F_liq, IDs, index = self._get_fugacity_models()
        top, bottom = self.outs
        T = self.T
        y = top.mol[index]
        y_sum = y.sum()
        if y_sum: 
            y /= y_sum
        else:
            y = np.ones(y.size) / y.size
        if P is None: P = self.P
        self.fgas = F_gas.coefficient(y, T, P)
    
    def _run_decoupled_Kfgas(self, P=None):
        top, bottom = self.outs
        F_gas, F_liq, IDs, index = self._get_fugacity_models()
        if P is None: P = self.P
        T = self.T
        x = bottom.mol[index]
        x_sum = x.sum()
        if x_sum:
            x /= x_sum
        else:
            x = np.ones(x.size) / x.size
        
        try:
            fgas = self.fgas
            init = fgas is None or fgas.size != len(index)
        except:
            init = True
        if init:
            y = top.mol[index]
            y_sum = y.sum()
            if y_sum: 
                y /= y_sum
            else:
                y = np.ones(y.size) / y.size
            self.fgas = fgas = F_gas.coefficient(y, T, P)
        
        fliq = F_liq.coefficient(x, T, P)
        K = fliq / fgas 
        y = K * x
        y /= y.sum()
        fgas = F_gas.coefficient(y, T, P)
        K = fliq / fgas
        good = (x != 0) | (y != 0)
        if not good.all():
            index, = np.where(good)
            IDs = [IDs[i] for i in index]
            fgas = [fgas[i] for i in index]
            K = [K[i] for i in index]
        self._set_arrays(IDs, fgas=fgas, K=K)
    
    def _run_decoupled_Kgamma(self, P=None): # Psuedo-equilibrium
        top, bottom = self.outs
        f_gamma, IDs, index = self._get_activity_model()
        T = self.T
        x = bottom.mol[index]
        x_sum = x.sum()
        if x_sum:
            x /= x_sum
        else:
            x = np.ones(x.size) / x.size
        gamma_x = f_gamma(x, T)
        gamma_y = self.gamma_y
        try:
            init_gamma = gamma_y is None or gamma_y.size != len(index)
        except:
            init_gamma = True
        if init_gamma:
            y = top.mol[index]
            y_sum = y.sum()
            if y_sum: 
                y /= y_sum
            else:
                y = np.ones(y.size) / y.size
            self.gamma_y = gamma_y = f_gamma(y, T)
        K = gamma_x / gamma_y 
        y = K * x
        y /= y.sum()
        gamma_y = f_gamma(y, T)
        K = gamma_x / gamma_y
        good = (x != 0) | (y != 0)
        if not good.all():
            index, = np.where(good)
            IDs = [IDs[i] for i in index]
            gamma_y = [gamma_y[i] for i in index]
            K = [K[i] for i in index]
        self._set_arrays(IDs, gamma_y=gamma_y, K=K)
        
    def _run_decoupled_B(self, stacklevel=1): # Flash Rashford-Rice
        ms = self.feed.copy()
        ms.phases = self.phases
        top, bottom = ms
        data = self.partition_data
        try:
            if data and 'K' in data:
                phi = sep.partition(
                    ms, top, bottom, self.IDs, data['K'], 0.5, 
                    data.get('extract_chemicals') or data.get('top_chemicals'),
                    data.get('raffinate_chemicals') or data.get('bottom_chemicals'),
                    self.strict_infeasibility_check, stacklevel+1
                )
            else:
                phi = sep.partition(
                    ms, top, bottom, self.IDs, self.K, 0.5, 
                    None, None, self.strict_infeasibility_check,
                    stacklevel+1
                )
        except: 
            return
        if phi <= 0 or phi >= 1: return
        self.B = phi / (1 - phi)
        # TODO: set S using relaxation factor and use different separation factors for lle and vle
    
    def _run_decoupled_KTvle(self, P=None, 
                             T_relaxation_factor=None,
                             K_relaxation_factor=None): # Bubble point
        top, bottom = self.outs
        if P is not None: top.P = bottom.P = P
        if self.T_specification:
            self._run_vle(update=False)
            for i in self.outs: i.T = self.T_specification
        else:
            if bottom.isempty():
                if top.isempty(): return
                p = top.dew_point_at_P(P)
            else:
                p = bottom.bubble_point_at_P(P)
            # TODO: Note that solution decomposition method is bubble point
            x = p.x
            x[x == 0] = 1.
            K_new = p.y / p.x
            IDs = p.IDs
            top.imol[IDs] = p.y * top.imol[IDs].sum()
            f = self.T_relaxation_factor if T_relaxation_factor is None else T_relaxation_factor
            if self.T:
                self.T = f * self.T + (1 - f) * p.T
            else:
                self.T = p.T
            self._equilibrium_point = p
            self._set_arrays(IDs, K=K_new)
    
    def _run_decoupled_reaction(self, P=None, relaxation_factor=None):
        top, bottom = self.outs
        f = self.dmol_relaxation_factor if relaxation_factor is None else relaxation_factor
        old = self.dmol
        new = self.reaction.conversion(bottom)
        self.dmol = f * old + (1 - f) * new
    
    def _run_lle(self, P=None, update=True, top_chemical=None, single_loop=False):
        if top_chemical is None: top_chemical = self.top_chemical
        else: self.top_chemical = top_chemical
        ms = self._get_mixture(update)
        eq = ms.lle
        data = self.partition_data
        if data and 'K' in data:
            ms.phases = self.phases
            top, bottom = ms
            IDs = data['IDs']
            K = data['K']
            phi = sep.partition(
                ms, top, bottom, IDs, K, 0.5, 
                data.get('extract_chemicals') or data.get('top_chemicals'),
                data.get('raffinate_chemicals') or data.get('bottom_chemicals'),
                self.strict_infeasibility_check, 1
            )
            if phi == 1:
                self.B = np.inf
            else:
                self.B = phi / (1 - phi)
            self.K = K
            self.T = ms.T
            self.IDs = IDs
        else:
            if update:
                eq(T=ms.T, P=P or self.P, top_chemical=top_chemical, update=update, single_loop=single_loop)
                lle_chemicals, K_new, gamma_y, phi = eq._lle_chemicals, eq._K, eq._gamma_y, eq._phi
            else:
                lle_chemicals, K_new, gamma_y, phi = eq(T=ms.T, P=P or self.P, top_chemical=top_chemical, update=update)
            if phi == 1 or phi is None:
                self.B = np.inf
                self.T = ms.T
                return
            else:
                self.B = phi / (1 - phi)
            self.T = ms.T
            IDs = tuple([i.ID for i in lle_chemicals])
            self._set_arrays(IDs, K=K_new, gamma_y=gamma_y)
    
    def _run_vle(self, P=None, update=True):
        ms = self._get_mixture(update)
        B = self.B_specification
        T = self.T_specification
        Q = self.Q
        kwargs = {'P': P or self.P or ms.P}
        if self.reaction:
            kwargs['liquid_conversion'] = self.reaction.conversion_handle(self.outs[1])
        if T is None:
            if B is None: 
                if self.reaction: Q += ms.Hf
                kwargs['H'] = ms.H + Q
            elif B == np.inf:
                kwargs['V'] = 1.
                # B = V / (1 - V)
                # B(1 - V) = V
                # B - BV - V = 0
                # -V(1 + B) + B = 0
            else:
                kwargs['V'] = B / (1 + B)
        else:
            kwargs['T'] = T
        ms.vle(**kwargs)
        index = ms.vle._index
        if self.reaction: self.dmol = ms.mol - self.feed.mol
        IDs = ms.chemicals.IDs
        IDs = tuple([IDs[i] for i in index])
        L_mol = ms.imol['l', IDs]
        L_total = L_mol.sum()
        if L_total: 
            x_mol = L_mol / L_total
            x_mol[x_mol == 0] = 1e-9
        else:
            x_mol = 1
        V_mol = ms.imol['g', IDs]
        V_total = V_mol.sum()
        if V_total: 
            y_mol = V_mol / V_total
            K_new = y_mol / x_mol
        else:
            K_new = np.ones(len(index)) * 1e-16
        if B is None: 
            if not L_total:
                self.B = inf
            else:
                self.B = V_total / L_total
        self.T = ms.T
        self._set_arrays(IDs, K=K_new)
        # TODO: Add option to set S and T using relaxation factor
    
    def _simulation_error(self):
        cache = self.T, self.B, copy(self.K), copy(self.dmol), copy(self.S), copy(self.gamma_y)
        um = getattr(self, '_unlinked_multistream', None)
        m = getattr(self, '_linked_multistream', None)
        if um is not None: self._unlinked_multistream = copy(um)
        if m is not None: self._linked_multistream = copy(m)
        error = super()._simulation_error()
        self.T, self.B, self.K, self.dmol, self.S, self.gamma_y = cache
        if um is not None: self._unlinked_multistream = um
        if m is not None: self._linked_multistream = um
        return error
    
    def _run(self):
        mixture = self._get_mixture()
        mixture.copy_like(self.feed)
        if self.phases == ('g', 'l'):
            self._run_vle()
        else:
            self._run_lle()
            
    def lle_gibbs(self):
        gamma, IDs, index = self._get_activity_model()
        f = gamma.f
        args = gamma.args
        mol_L, mol = [i.mol for i in self.outs]
        return tmo.equilibrium.lle.lle_objective_function(mol_L.to_array(), mol.to_array(), self.T, f, args)


class MultiStageEquilibrium(Unit):
    """
    Create a MultiStageEquilibrium object that models counter-current 
    equilibrium stages.
    
    Parameters
    ----------
    N_stages : int
        Number of stages.
    feed_stages : tuple[int]
        Respective stage where feeds enter. Defaults to (0, -1).
    partition_data : {'IDs': tuple[str], 'K': 1d array}, optional
        IDs of chemicals in equilibrium and partition coefficients (molar 
        composition ratio of the extract over the raffinate or vapor over liquid). If given,
        The mixer-settlers will be modeled with these constants. Otherwise,
        partition coefficients are computed based on temperature and composition.
    top_chemical : str
        Name of main chemical in the solvent.
        
    Examples
    --------
    Simulate 2-stage extraction of methanol from water using octanol:
    
    >>> import biosteam as bst
    >>> bst.settings.set_thermo(['Water', 'Methanol', 'Octanol'], cache=True)
    >>> feed = bst.Stream('feed', Water=500, Methanol=50)
    >>> solvent = bst.Stream('solvent', Octanol=500)
    >>> MSE = bst.MultiStageEquilibrium(N_stages=2, ins=[feed, solvent], phases=('L', 'l'))
    >>> MSE.simulate()
    >>> extract, raffinate = MSE.outs
    >>> extract.imol['Methanol'] / feed.imol['Methanol'] # Recovery
    0.83
    >>> extract.imol['Octanol'] / solvent.imol['Octanol'] # Solvent stays in extract
    0.99
    >>> raffinate.imol['Water'] / feed.imol['Water'] # Carrier remains in raffinate
    0.82
    
    Simulate 10-stage extraction with user defined partition coefficients:
    
    >>> import biosteam as bst
    >>> bst.settings.set_thermo(['Water', 'Methanol', 'Octanol'], cache=True)
    >>> import numpy as np
    >>> feed = bst.Stream('feed', Water=5000, Methanol=500)
    >>> solvent = bst.Stream('solvent', Octanol=5000)
    >>> MSE = bst.MultiStageEquilibrium(N_stages=10, ins=[feed, solvent], phases=('L', 'l'),
    ...     partition_data={
    ...         'K': np.array([1.451e-01, 1.380e+00, 2.958e+03]),
    ...         'IDs': ('Water', 'Methanol', 'Octanol'),
    ...         'phi': 0.5899728891780545, # Initial phase fraction guess. This is optional.
    ...     }
    ... )
    >>> extract, raffinate = MSE.outs
    >>> MSE.simulate()
    >>> extract.imol['Methanol'] / feed.imol['Methanol'] # Recovery
    0.99
    >>> extract.imol['Octanol'] / solvent.imol['Octanol'] # Solvent stays in extract
    0.99
    >>> raffinate.imol['Water'] / feed.imol['Water'] # Carrier remains in raffinate
    0.82
    
    Because octanol and water do not mix well, it may be a good idea to assume
    that these solvents do not mix at all:
        
    >>> import biosteam as bst
    >>> bst.settings.set_thermo(['Water', 'Methanol', 'Octanol'], cache=True)
    >>> import numpy as np
    >>> feed = bst.Stream('feed', Water=5000, Methanol=500)
    >>> solvent = bst.Stream('solvent', Octanol=5000)
    >>> MSE = bst.MultiStageEquilibrium(N_stages=20, ins=[feed, solvent], phases=('L', 'l'),
    ...     partition_data={
    ...         'K': np.array([1.38]),
    ...         'IDs': ('Methanol',),
    ...         'raffinate_chemicals': ('Water',),
    ...         'extract_chemicals': ('Octanol',),
    ...     }
    ... )
    >>> MSE.simulate()
    >>> extract, raffinate = MSE.outs
    >>> extract.imol['Methanol'] / feed.imol['Methanol'] # Recovery
    0.99
    >>> extract.imol['Octanol'] / solvent.imol['Octanol'] # Solvent stays in extract
    1.0
    >>> raffinate.imol['Water'] / feed.imol['Water'] # Carrier remains in raffinate
    1.0
       
    Simulate with a feed at the 4th stage:
    
    >>> import biosteam as bst
    >>> bst.settings.set_thermo(['Water', 'Methanol', 'Octanol'], cache=True)
    >>> import numpy as np
    >>> feed = bst.Stream('feed', Water=5000, Methanol=500)
    >>> solvent = bst.Stream('solvent', Octanol=5000)
    >>> dilute_feed = bst.Stream('dilute_feed', Water=100, Methanol=2)
    >>> MSE = bst.MultiStageEquilibrium(N_stages=5, ins=[feed, dilute_feed, solvent], 
    ...     feed_stages=[0, 3, -1],
    ...     phases=('L', 'l'),
    ...     partition_data={
    ...         'K': np.array([1.38]),
    ...         'IDs': ('Methanol',),
    ...         'raffinate_chemicals': ('Water',),
    ...         'extract_chemicals': ('Octanol',),
    ...     }
    ... )
    >>> MSE.simulate()
    >>> extract, raffinate = MSE.outs
    >>> extract.imol['Methanol'] / (feed.imol['Methanol'] + dilute_feed.imol['Methanol']) # Recovery
    0.93
    
    Simulate with a 60% extract side draw at the 2nd stage:
    
    >>> import biosteam as bst
    >>> bst.settings.set_thermo(['Water', 'Methanol', 'Octanol'], cache=True)
    >>> import numpy as np
    >>> feed = bst.Stream('feed', Water=5000, Methanol=500)
    >>> solvent = bst.Stream('solvent', Octanol=5000)
    >>> MSE = bst.MultiStageEquilibrium(N_stages=5, ins=[feed, solvent],                         
    ...     top_side_draws={1: 0.6},
    ...     phases=('L', 'l'),
    ...     partition_data={
    ...         'K': np.array([1.38]),
    ...         'IDs': ('Methanol',),
    ...         'raffinate_chemicals': ('Water',),
    ...         'extract_chemicals': ('Octanol',),
    ...     }
    ... )
    >>> MSE.simulate()
    >>> extract, raffinate, extract_side_draw, *raffinate_side_draws = MSE.outs
    >>> (extract.imol['Methanol'] + extract_side_draw.imol['Methanol']) / feed.imol['Methanol'] # Recovery
    0.92
    
    Simulate stripping column with 2 stages
    
    >>> import biosteam as bst
    >>> bst.settings.set_thermo(['AceticAcid', 'EthylAcetate', 'Water', 'MTBE'], cache=True)
    >>> feed = bst.Stream('feed', Water=75, AceticAcid=5, MTBE=20, T=320)
    >>> steam = bst.Stream('steam', Water=100, phase='g', T=390)
    >>> MSE = bst.MultiStageEquilibrium(N_stages=2, ins=[feed, steam], feed_stages=[0, -1],
    ...     outs=['vapor', 'liquid'],
    ...     phases=('g', 'l'),
    ... )
    >>> MSE.simulate()
    >>> vapor, liquid = MSE.outs
    >>> vapor.imol['MTBE'] / feed.imol['MTBE']
    0.99
    >>> vapor.imol['Water'] / (feed.imol['Water'] + steam.imol['Water'])
    0.42
    >>> vapor.imol['AceticAcid'] / feed.imol['AceticAcid']
    0.74
    
    Simulate distillation column with 5 stages, a 0.673 reflux ratio, 
    2.57 boilup ratio, and feed at stage 2:
    
    >>> import biosteam as bst
    >>> bst.settings.set_thermo(['Water', 'Ethanol'], cache=True)
    >>> feed = bst.Stream('feed', Ethanol=80, Water=100, T=80.215 + 273.15)
    >>> MSE = bst.MultiStageEquilibrium(N_stages=5, ins=[feed], feed_stages=[2],
    ...     outs=['vapor', 'liquid'],
    ...     stage_specifications={0: ('Reflux', 0.673), -1: ('Boilup', 2.57)},
    ...     phases=('g', 'l'),
    ... )
    >>> MSE.simulate()
    >>> vapor, liquid = MSE.outs
    >>> vapor.imol['Ethanol'] / feed.imol['Ethanol']
    0.96
    >>> vapor.imol['Ethanol'] / vapor.F_mol
    0.69
    
    Simulate the same distillation column with a full condenser, 5 stages, a 0.673 reflux ratio, 
    2.57 boilup ratio, and feed at stage 2:
    
    >>> import biosteam as bst
    >>> bst.settings.set_thermo(['Water', 'Ethanol'], cache=True)
    >>> feed = bst.Stream('feed', Ethanol=80, Water=100, T=80.215 + 273.15)
    >>> MSE = bst.MultiStageEquilibrium(N_stages=5, ins=[feed], feed_stages=[2],
    ...     outs=['vapor', 'liquid', 'distillate'],
    ...     stage_specifications={0: ('Reflux', float('inf')), -1: ('Boilup', 2.57)},
    ...     bottom_side_draws={0: 0.673 / (1 + 0.673)},
    ...     max_attempts=10,
    ... )
    >>> MSE.simulate()
    >>> vapor, liquid, distillate = MSE.outs
    >>> distillate.imol['Ethanol'] / feed.imol['Ethanol']
    0.81
    >>> distillate.imol['Ethanol'] / distillate.F_mol
    0.70
    
    """
    _N_ins = 2
    _N_outs = 2
    inside_maxiter = 100
    default_max_attempts = 5
    default_maxiter = 40
    default_fallback = None
    default_S_tolerance = 1e-6
    default_relative_S_tolerance = 1e-6
    default_algorithm = 'phenomena'
    detault_inside_out = False
    default_inner_loop_algorthm = None
    decomposition_algorithms = {
        'phenomena', 'sequential modular',
    }
    available_algorithms = {
        *decomposition_algorithms, 
        'simultaneous correction',
    }
    default_methods = {
        'phenomena': 'wegstein',
        'simultaneous correction': NotImplemented,
    }
    method_options = {
        'fixed-point': {},
        'wegstein': {'lb': 1, 'ub': 4, 'exp': 0.5}
    }
    auxiliary_unit_names = (
        'stages',
    )
    _side_draw_names = ('top_side_draws', 'bottom_side_draws')
    
    # Inside-out surrogate model
    TSurrogate = RBFInterpolator
    KSurrogate = Akima1DInterpolator 
    hSurrogate = Akima1DInterpolator 
    T_surrogate_options = {}
    K_surrogate_options = {'method': 'makima', 'extrapolate': True}
    h_surrogate_options = {'method': 'makima', 'extrapolate': True}
    
    def __init_subclass__(cls, *args, **kwargs):
        super().__init_subclass__(cls, *args, **kwargs)
        if '_side_draw_names' in cls.__dict__:
            top, bottom = cls._side_draw_names
            setattr(
                cls, top, 
                property(
                    lambda self: self.top_side_draws,
                    lambda self, value: setattr(self, 'top_side_draws', value)
                )
            )
            setattr(
                cls, bottom, 
                property(
                    lambda self: self.bottom_side_draws,
                    lambda self, value: setattr(self, 'bottom_side_draws', value)
                )
            )
    
    def __init__(self,  ID='', ins=None, outs=(), thermo=None, stages=None, **kwargs):
        if stages is None:
            if 'feed_stages' in kwargs: self._N_ins = len(kwargs['feed_stages'])
            top_side_draws, bottom_side_draws = self._side_draw_names
            N_outs = 2
            if top_side_draws in kwargs: N_outs += len(kwargs[top_side_draws]) 
            if bottom_side_draws in kwargs: N_outs += len(kwargs[bottom_side_draws]) 
            self._N_outs = N_outs
            Unit.__init__(self, ID, ins, outs, thermo, **kwargs)
        else:
            ins = []
            outs = []
            top_side_draws_outs = []
            bottom_side_draws_outs = []
            stages_set = set(stages)
            top_side_draws = {}
            bottom_side_draws = {}
            feed_stages = []
            first_stage = stages[0]
            phases = first_stage.phases
            stage_specifications = {}
            stage_reactions = {}
            self._load_thermo(thermo or first_stage.thermo)
            for n, stage in enumerate(stages):
                for s in stage.ins:
                    if s.source not in stages_set: 
                        sp = s.proxy()
                        sp._source = s._source
                        ins.append(sp)
                        feed_stages.append(n)
                top, bottom, *other = stage.outs
                if stage.top_split:
                    s = other[0]
                    sp = s.proxy()
                    sp._sink = s._sink
                    top_side_draws_outs.append(sp)
                    top_side_draws[n] = stage.top_split
                if stage.bottom_split:
                    s = other[-1]
                    sp = s.proxy()
                    sp._sink = s._sink
                    bottom_side_draws_outs.append(sp)
                    bottom_side_draws[n] = stage.bottom_split
                if top.sink not in stages_set: 
                    sp = top.proxy()
                    sp._sink = top._sink
                    outs.append(sp)
                if bottom.sink not in stages_set: 
                    sp = bottom.proxy()
                    sp._sink = bottom._sink
                    outs.append(sp)
                if stage.B_specification is not None: 
                    stage_specifications[n] = ('Boilup', stage.B_specification)
                elif stage.T_specification is not None:
                    stage_specifications[n] = ('Temperature', stage.T_specification)
                elif stage.Q != 0:
                    stage_specifications[n] = ('Duty', stage.Q)
                elif stage.F_specification is not None:
                    stage_specifications[n] = ('Flow', stage.F)
                if stage.reaction is not None:
                    stage_reactions[n] = stage.reaction
            outs = [*outs, *top_side_draws_outs, *bottom_side_draws_outs]
            self._N_ins = len(ins)
            self._N_outs = len(outs)
            Unit.__init__(self, ID, ins, outs, thermo, 
                stage_specifications=stage_specifications,
                stage_reactions=stage_reactions,
                feed_stages=feed_stages,
                bottom_side_draws=bottom_side_draws,
                top_side_draws=top_side_draws,
                stages=stages,
                phases=phases,
                **kwargs
            )
    
    def _init(self,
            N_stages=None, 
            stages=None,
            top_side_draws=None,
            bottom_side_draws=None, 
            feed_stages=None, 
            phases=None, 
            P=101325, 
            T=None,
            stage_specifications=None, 
            stage_reactions=None,
            partition_data=None, 
            top_chemical=None, 
            use_cache=None,
            collapsed_init=False,
            algorithm=None,
            method=None,
            maxiter=None,
            max_attempts=None,
            inside_out=None,
            vle_decomposition=None,
        ):
        # For VLE look for best published algorithm (don't try simple methods that fail often)
        if N_stages is None: N_stages = len(stages)
        if phases is None: phases = ('g', 'l')
        if feed_stages is None: feed_stages = (0, -1)
        if stage_specifications is None: stage_specifications = {}
        elif not isinstance(stage_specifications, dict): stage_specifications = dict(stage_specifications)
        if T is not None: 
            for i in range(N_stages):
                if i in stage_specifications: continue
                stage_specifications[i] = ('Temperature', T)
        if stage_reactions is None: stage_reactions = {}
        elif not isinstance(stage_reactions, dict): stage_reactions = dict(stage_reactions)
        if top_side_draws is None: top_side_draws = {}
        elif not isinstance(top_side_draws, dict): top_side_draws = dict(top_side_draws)
        if bottom_side_draws is None: bottom_side_draws = {}
        elif not isinstance(bottom_side_draws, dict): bottom_side_draws = dict(bottom_side_draws)
        if partition_data is None: partition_data = {}
        self.multi_stream = tmo.MultiStream(None, P=P, phases=phases, thermo=self.thermo)
        self.N_stages = N_stages
        self.P = P
        self.T = T
        self.phases = phases = self.multi_stream.phases # Corrected order
        self._has_vle = 'g' in phases
        self._has_lle = 'L' in phases
        self._top_split = top_splits = np.zeros(N_stages)
        self._bottom_split = bottom_splits = np.zeros(N_stages)
        if stages is None:
            top_mark = 2 + len(top_side_draws)
            tsd_iter = iter(self.outs[2:top_mark])
            bsd_iter = iter(self.outs[top_mark:])
            last_stage = None
            self.stages = stages = []
            for i in range(N_stages):
                if last_stage is None:
                    feed = ()
                else:
                    feed = last_stage-1
                outs = []
                if i == 0:
                    outs.append(
                        self-0, # extract or vapor
                    )
                else:
                    outs.append(None)
                if i == N_stages - 1: 
                    outs.append(
                        self-1 # raffinate or liquid
                    )
                else:
                    outs.append(None)
                if i in top_side_draws:
                    outs.append(next(tsd_iter))
                    top_split = top_side_draws[i]
                    top_splits[i] = top_split 
                else: 
                    top_split = 0
                if i in bottom_side_draws:
                    try:
                        outs.append(next(bsd_iter))
                    except:
                        breakpoint()
                    bottom_split = bottom_side_draws[i]
                    bottom_splits[i] = bottom_split
                else: 
                    bottom_split = 0
                
                new_stage = self.auxiliary(
                    'stages', StageEquilibrium, phases=phases,
                    ins=feed,
                    outs=outs,
                    partition_data=partition_data,
                    top_split=top_split,
                    bottom_split=bottom_split,
                )
                if last_stage:
                    last_stage.add_feed(new_stage-0)
                last_stage = new_stage
            for feed, stage in zip(self.ins, feed_stages):
                stages[stage].add_feed(self.auxlet(feed))  
            #: dict[int, tuple(str, float)] Specifications for VLE by stage
            self.stage_specifications = stage_specifications
            for i, (name, value) in stage_specifications.items():
                B, Q, T, F = _get_specification(name, value)
                stages[i].set_specification(B=B, Q=Q, T=T, P=P, F=F)
            self.stage_reactions = stage_reactions
            for i, reaction in stage_reactions.items():
                stages[i].reaction = reaction
        else:
            self.stage_specifications = stage_specifications
            self.stage_reactions = stage_reactions
            self.stages = stages
            top_splits = np.zeros(N_stages)
            bottom_splits = top_splits.copy()
            for i, j in top_side_draws.items(): top_splits[i] = j
            for i, j in bottom_side_draws.items(): bottom_splits[i] = j
        self._asplit_left = 1 - top_splits
        self._bsplit_left = 1 - bottom_splits
        self._asplit_1 = top_splits - 1
        self._bsplit_1 = bottom_splits - 1
        self.partitions = [i.partition for i in stages]
        self.top_chemical = top_chemical
        self.partition_data = partition_data
        self.feed_stages = feed_stages
        self.top_side_draws = top_side_draws
        self.bottom_side_draws = bottom_side_draws
            
        #: [int] Maximum number of iterations.
        self.maxiter = self.default_maxiter if maxiter is None else maxiter
        
        #: [int] Maximum number of attempts.
        self.max_attempts = self.default_max_attempts if max_attempts is None else max_attempts
        
        #: tuple[str, str] Fallback algorithm and method.
        self.fallback = self.default_fallback

        #: [float] Separation factor tolerance
        self.S_tolerance = self.default_S_tolerance

        #: [float] Relative separation factor tolerance
        self.relative_S_tolerance = self.default_relative_S_tolerance
        
        self.use_cache = True if use_cache else False
        
        self.collapsed_init = collapsed_init
        
        self.algorithm = self.default_algorithm if algorithm is None else algorithm
        
        self.method = self.default_methods[self.algorithm] if method is None else method
        
        self.inside_out = self.detault_inside_out if inside_out is None else inside_out
        
        self.vle_decomposition = vle_decomposition
    
    @property
    def composition_sensitive(self):
        return self._has_lle
    
    @property
    def aggregated_stages(self):
        if not (any([i.B_specification or i.T_specification for i in self.partitions]) or self.top_side_draws or self.bottom_side_draws):
            self.aggregated = True
            self.use_cache = True
            return [self]
        else:
            self.aggregated = False
            N_stages = self.N_stages
            stage_specifications = [(i if i >= 0 else N_stages + i) for i in self.stage_specifications]
            top_side_draws = [(i if i >= 0 else N_stages + i) for i in self.top_side_draws]
            bottom_side_draws = [(i if i >= 0 else N_stages + i) for i in self.bottom_side_draws]
            singles = set([*stage_specifications, *top_side_draws, *bottom_side_draws])
            aggregated = []
            stages = []
            for i, stage in enumerate(self.stages):
                if i in singles:
                    N_aggregated = len(stages)
                    if N_aggregated == 1:
                        aggregated.append(stages[0])
                    elif N_aggregated > 1:
                        last_stage = MultiStageEquilibrium(
                            None, stages=stages, P=self.P, use_cache=True,
                            method=self.method, maxiter=self.maxiter, 
                            algorithm=self.algorithm,
                            top_chemical=self.top_chemical, 
                            collapsed_init=False,
                            inside_out=self.inside_out,
                        )
                        last_stage._N_chemicals = self._N_chemicals
                        last_stage._system = self._system
                        last_stage.aggregated = True
                        last_stage.parent = self
                        aggregated.append(last_stage)
                    aggregated.append(stage)
                    stages = []
                else:
                    stages.append(stage)
            if stages: 
                last_stage = MultiStageEquilibrium(
                    None, stages=stages, P=self.P, use_cache=True,
                    method=self.method, maxiter=self.maxiter, 
                    algorithm=self.algorithm,
                    top_chemical=self.top_chemical, 
                    collapsed_init=False,
                    inside_out=self.inside_out,
                )
                last_stage.parent = self
                last_stage._N_chemicals = self._N_chemicals
                last_stage._system = self._system
                last_stage.aggregated = True
                aggregated.append(last_stage)
            return aggregated
    

    # %% Decoupled phenomena equation oriented simulation
    
    def _get_energy_departure_coefficient(self, stream):
        assert self.aggregated
        if self._has_vle:
            vapor, liquid = self.outs
            if stream.imol is vapor.imol:
                if vapor.isempty():
                    with liquid.temporary_phase('g'): coeff = liquid.H
                else:
                    coeff = -vapor.h * liquid.F_mol
        else:
            coeff = -stream.C
        return (self, coeff)
    
    def _create_energy_departure_equations(self):
        # Ll: C1dT1 - Ce2*dT2 - Cr0*dT0 - hv2*L2*dB2 = Q1 - H_out + H_in
        # gl: hV1*L1*dB1 - hv2*L2*dB2 - Ce2*dT2 - Cr0*dT0 = Q1 + H_in - H_out
        phases = self.phases
        if phases == ('g', 'l'):
            vapor, liquid = self.outs
            coeff = {}
            if vapor.isempty():
                with liquid.temporary_phase('g'): coeff[self] = liquid.H
            else:
                coeff[self] = vapor.h * liquid.F_mol
        elif phases == ('L', 'l'):
            coeff = {self: sum([i.C for i in self.outs])}
        else:
            raise RuntimeError('invalid phases')
        for i in self.ins: i._update_energy_departure_coefficient(coeff)
        if self.stage_reactions:
            return [(coeff, sum([i.Q for i in self.stages]) - self.Hnet)]
        else:
            return [(coeff, self.H_in - self.H_out + sum([(i.Hnet if i.Q is None else i.Q) for i in self.stages]))]
    
    def _create_material_balance_equations(self, composition_sensitive):
        top, bottom = self.outs
        try:
            B = self.B
            K = self.K
        except:
            if bottom.isempty():
                self.B = B = np.inf
                self.K = K = 1e16 * np.ones(self.chemicals.size)
            elif top.isempty():
                self.K = K = np.zeros(self.chemicals.size)
                self.B = B = 0
            else:
                top_mol = top.mol.to_array()
                bottom_mol = bottom.mol.to_array()
                F_top = top_mol.sum()
                F_bottom = bottom_mol.sum()
                y = top_mol / F_top
                x = bottom_mol / F_bottom
                x[x <= 0] = 1e-16
                self.K = K = y / x
                self.B = B = F_top / F_bottom
        
        fresh_inlets, process_inlets, equations = self._begin_equations(composition_sensitive)
        top, bottom, *_ = self.outs
        ones = np.ones(self.chemicals.size)
        minus_ones = -ones
        zeros = np.zeros(self.chemicals.size)
        
        # Overall flows
        eq_overall = {}
        for i in self.outs: eq_overall[i] = ones
        for i in process_inlets: 
            if i in eq_overall:
                del eq_overall[i]
            else:
                eq_overall[i] = minus_ones
        if self.stage_reactions:
            partitions = self.partitions
            flows = [i.mol for i in fresh_inlets] + [partitions[i].dmol for i in self.stage_reactions]
            equations.append(
                (eq_overall, sum(flows, zeros))
            )
        else:
            equations.append(
                (eq_overall, sum([i.mol for i in fresh_inlets], zeros))
            )
        
        # Top to bottom flows
        eq_outs = {}
        if B == np.inf:
            eq_outs[bottom] = ones
        elif B == 0:
            eq_outs[top] = ones
        else:
            eq_outs[top] = ones
            eq_outs[bottom] = -(K * B)
        equations.append(
            (eq_outs, zeros)
        )
        return equations
    
    def _update_auxiliaries(self):
        for i in self.stages: i._update_auxiliaries()
    
    def _update_composition_parameters(self):
        for i in self.partitions: 
            if 'K' in i.partition_data: continue
            i._run_decoupled_Kgamma()
    
    def _update_net_flow_parameters(self):
        for i in self.partitions: i._run_decoupled_B()
    
    def _update_nonlinearities(self):
        if self._has_vle:
            for i in self.stages: i._update_nonlinearities()
        elif self._has_lle:
            pass
            # self.update_pseudo_lle()
    
    def _update_energy_variable(self, departure):
        phases = self.phases
        if phases == ('g', 'l'):
            if not hasattr(self, 'B'):
                top, bottom = self.outs
                if bottom.isempty():
                    self.B = np.inf
                    self.K = 1e16 * np.ones(self.chemicals.size)
                elif top.isempty():
                    self.K = np.zeros(self.chemicals.size)
                    self.B = 0
                else:
                    top_mol = top.mol.to_array()
                    bottom_mol = bottom.mol.to_array()
                    F_top = top_mol.sum()
                    F_bottom = bottom_mol.sum()
                    y = top_mol / F_top
                    x = bottom_mol / F_bottom
                    x[x <= 0] = 1e-16
                    self.K = y / x
                    self.B = F_top / F_bottom
            self.B += departure
        elif phases == ('L', 'l'):
            for i in self.outs: i.T += departure
        else:
            raise RuntimeError('invalid phases')
    
    @property
    def outlet_stages(self):
        if hasattr(self, 'parent'): return self.parent.outlet_stages
        try:
            return self._outlet_stages
        except:
            outlet_stages = {}
            for i in self.stages:
                for s in i.outs:
                    outlet_stages[s] = i
                    while hasattr(s, 'port'):
                        s = s.port.get_stream()
                        outlet_stages[s] = i
            self._outlet_stages = outlet_stages
            return outlet_stages
    
    def correct_overall_mass_balance(self):
        outmol = sum([i.mol for i in self.outs])
        inmol = sum([i.mol for i in self.ins])
        stage_reactions = self.stage_reactions
        if stage_reactions:
            partitions = self.partitions
            inmol += sum([partitions[i].dmol for i in stage_reactions])
        try:
            factor = inmol / outmol
        except:
            pass
        else:
            for i in self.outs: i.mol *= factor
    
    def material_errors(self):
        errors = []
        stages = self.stages
        IDs = self.multi_stream.chemicals.IDs
        for stage in stages:
            errors.append(
                sum([i.mol for i in stage.ins],
                    -sum([i.mol for i in stage.outs], -stage.partition.dmol))
            )
        return pd.DataFrame(errors, columns=IDs)
    
    def _feed_flows_and_conversion(self):
        feed_flows = self.feed_flows.copy()
        partition = self.partitions
        index = self._update_index
        for i in self.stage_reactions: 
            p = partition[i]
            dmol = p.dmol
            for n, j in enumerate(index): feed_flows[i, n] += dmol[j]
        return feed_flows
    
    def set_flow_rates(self, bottom_flows, update_B=True):
        stages = self.stages
        N_stages = self.N_stages
        range_stages = range(N_stages)
        feed_flows = self.feed_flows
        index = self._update_index
        if self.stage_reactions:
            feed_flows = self._feed_flows_and_conversion()
        f = PhasePartition.F_relaxation_factor
        if f and self.bottom_flows is not None:
            bottom_flows = f * self.bottom_flows + (1 - f) * bottom_flows
        self.bottom_flows = bottom_flows
        top_flows = MESH.top_flows_mass_balance(
            bottom_flows, feed_flows, self._asplit_left, self._bsplit_left, 
            self.N_stages
        )
        for i in range_stages:
            stage = stages[i]
            partition = stage.partition
            s_top, s_bot = partition.outs
            t = top_flows[i]
            mask = t < 0
            bulk_t = t.sum()
            if mask.any():
                t[mask] = 0
                t *= bulk_t / t.sum()
            b = bottom_flows[i]
            mask = b < 0
            bulk_b = b.sum()
            if mask.any():
                b[mask] = 0
                b *= bulk_b / b.sum()
            s_top.mol[index] = t
            s_bot.mol[index] = b
            for i in stage.splitters: i._run()
            if update_B and stage.B_specification is None:
                stage.B = bulk_t / bulk_b
        
    def default_vle_decomposition(self):
        K = np.mean([i.K for i in self.stages], axis=0)
        mol = self.feed_flows.sum(axis=0)
        z = mol / mol.sum()
        if equilibrium.stable_phase(K, z):
            self.vle_decomposition = 'sum rates'
        else:
            self.vle_decomposition = 'bubble point'
        
    def _run(self):
        if all([i.isempty() for i in self.ins]): 
            for i in self.outs: i.empty()
            return
        try:
            separation_factors = self.hot_start()
            algorithm = self.algorithm
            method = self.method
            if algorithm in self.decomposition_algorithms:
                options = dict(
                    maxiter=self.maxiter,
                    xtol=self.S_tolerance,
                    rtol=self.relative_S_tolerance,
                    **self.method_options[method],
                )
                if method == 'fixed-point':
                    solver = flx.fixed_point
                elif method == 'wegstein':
                    solver = flx.wegstein
                else:
                    raise ValueError(f'invalid method {method!r}')
                f = self._inside_out_iter if self.inside_out else self._iter
                self.attempt = 0
                self.bottom_flows = None
                last = self.max_attempts - 1
                if self.vle_decomposition is None:
                    self.default_vle_decomposition()
                for n in range(self.max_attempts):
                    self.attempt = n
                    self.iter = 0
                    try:
                        separation_factors = solver(f, separation_factors, **options)
                    except:
                        if n != last:
                            for i in self.stages: i._run()
                            for i in reversed(self.stages): i._run()
                            separation_factors = np.array([i.S for i in self._S_stages])
                    else:
                        break
                if self.iter == self.maxiter and self.fallback and self.fallback[0] != self.algorithm:
                    original = self.algorithm, self.method, self.maxiter, self.max_attempts
                    self.algorithm, self.method, self.maxiter = self.fallback
                    self.max_attempts = 1
                    try:
                        self._run()
                    finally:
                        self.algorithm, self.method, self.maxiter, self.max_attempts = original
            elif algorithm == 'simultaneous correction':
                raise NotImplementedError(f'{algorithm!r} not implemented in BioSTEAM (yet)')
            else:
                raise RuntimeError(
                    f'invalid algorithm {algorithm!r}, only {self.available_algorithms} are allowed'
                )
            # self.correct_overall_mass_balance()
        except Exception as e:
            if self.use_cache:
                self.use_cache = False
                try:
                    self._run()
                finally:
                    self.use_cache = True
            else:
                raise e
    
    def _split_objective(self, splits):
        self.iter += 1
        S = ((1 - splits) / splits).reshape([self._NS_stages, self._N_chemicals])
        self.update_flow_rates(S)
        P = self.P
        stages = self.stages
        if self._has_vle: 
            for i in stages:
                mixer = i.mixer
                partition = i.partition
                mixer.outs[0].mix_from(
                    mixer.ins, energy_balance=False,
                )
                mixer.outs[0].P = P
                partition._run_decoupled_KTvle(P=P)
                partition._run_decoupled_B()
                T = partition.T
                for i in (partition.outs + i.outs): i.T = T
            energy_error = lambda stage: abs(
                (sum([i.H for i in stage.outs]) - sum([i.H for i in stage.ins], stage.Q)) / sum([i.C for i in stage.outs])
            )
            total_energy_error = sum([energy_error(stage) for stage in stages if stage.Q is not None])
        else:
            for i in range(5):
                dTs = self.update_energy_balance_temperatures()
                if np.abs(dTs).sum() < 1e-12: break
            for i in self.stages: 
                mixer = i.mixer
                partition = i.partition
                mixer.outs[0].mix_from(
                    mixer.ins, energy_balance=False,
                )
                partition._run_lle(update=False, P=P)
        S_new = np.array([(i.K * i.B) for i in self._S_stages]).flatten()
        splits_new = 1 / (S_new + 1)
        diff = splits_new - splits
        total_split_error = np.abs(diff).sum()
        if self._has_vle:
            total_error = total_split_error + total_energy_error
        else:
            total_error = total_split_error
        err = np.sqrt(total_error / self.N_stages)
        return err
    
    def _hot_start_phase_ratios_iter(self, 
            top_flow_rates, *args
        ):
        bottom_flow_rates = MESH.hot_start_bottom_flow_rates(
            top_flow_rates, *args
        )
        top_flow_rates = MESH.hot_start_top_flow_rates(
            bottom_flow_rates, *args
        )
        return top_flow_rates
        
    def hot_start_phase_ratios(self):
        stages = self.stages
        stage_index = []
        phase_ratios = []
        for i in list(self.stage_specifications):
            B = stages[i].partition.B_specification
            if B is None: continue 
            stage_index.append(i)
            phase_ratios.append(B)
        stage_index = np.array(stage_index, dtype=int)
        phase_ratios = np.array(phase_ratios, dtype=float)
        feeds = self.ins
        feed_stages = self.feed_stages
        top_feed_flows = 0 * self.feed_flows
        bottom_feed_flows = top_feed_flows.copy()
        top_flow_rates = top_feed_flows.copy()
        index = self._update_index
        for feed, stage in zip(feeds, feed_stages):
            if len(feed.phases) > 1 and 'g' in feed.phases:
                top_feed_flows[stage, :] += feed['g'].mol[index]
            elif feed.phase != 'g':
                continue
            else:
                top_feed_flows[stage, :] += feed.mol[index]
        for feed, stage in zip(feeds, feed_stages):
            if len(feed.phases) > 1 and 'g' not in feed.phases:
                bottom_feed_flows[stage, :] += feed['l'].mol[index]
            elif feed.phase == 'g': 
                continue
            else:
                bottom_feed_flows[stage, :] += feed.mol[index]
        feed_flows, asplit_1, bsplit_1, N_stages = self._iter_args
        args = (
            phase_ratios, np.array(stage_index), top_feed_flows,
            bottom_feed_flows, asplit_1, bsplit_1, N_stages
        )
        top_flow_rates = flx.wegstein(
            self._hot_start_phase_ratios_iter,
            top_flow_rates, args=args, xtol=self.relative_S_tolerance,
            checkiter=False,
        )
        bottom_flow_rates = MESH.hot_start_bottom_flow_rates(
            top_flow_rates, *args
        )
        bf = bottom_flow_rates.sum(axis=1)
        bf[bf == 0] = 1e-32
        return top_flow_rates.sum(axis=1) / bf
    
    def hot_start_collapsed_stages(self,
            all_stages, feed_stages, stage_specifications,
            top_side_draws, bottom_side_draws,
        ):
        last = 0
        for i in sorted(all_stages):
            if i == last + 1: continue
            all_stages.add(i)
        N_stages = len(all_stages)
        stage_map = {j: i for i, j in enumerate(sorted(all_stages))}
        feed_stages = [stage_map[i] for i in feed_stages]
        stage_specifications = {stage_map[i]: j for i, j in stage_specifications.items()}
        top_side_draws = {stage_map[i]: j for i, j in top_side_draws.items()}
        bottom_side_draws = {stage_map[i]: j for i, j in bottom_side_draws.items()}
        self.collapsed = collapsed = MultiStageEquilibrium(
            '.collapsed', 
            ins=[i.copy() for i in self.ins],
            outs=[i.copy() for i in self.outs],
            N_stages=N_stages,
            feed_stages=feed_stages,
            stage_specifications=stage_specifications,
            phases=self.multi_stream.phases,
            top_side_draws=top_side_draws,
            bottom_side_draws=bottom_side_draws,  
            P=self.P, 
            partition_data=self.partition_data,
            top_chemical=self.top_chemical, 
            use_cache=self.use_cache,
            thermo=self.thermo
        )
        collapsed._run()
        collapsed_stages = collapsed.stages
        partitions = self.partitions
        stages = self.stages
        for i in range(self.N_stages):
            if i in all_stages:
                collapsed_partition = collapsed_stages[stage_map[i]].partition
                partition = partitions[i]
                partition.T = collapsed_partition.T
                partition.B = collapsed_partition.B
                T = collapsed_partition.T
                for i in partition.outs + stages[i].outs: i.T = T 
                partition.K = collapsed_partition.K
                partition.gamma_y = collapsed_partition.gamma_y
                partition.fgas = collapsed_partition.fgas
        self.interpolate_missing_variables()
                
    def hot_start(self):
        ms = self.multi_stream
        feeds = self.ins
        feed_stages = self.feed_stages
        stages = self.stages
        partitions = self.partitions
        N_stages = self.N_stages
        chemicals = self.chemicals
        top_phase, bottom_phase = ms.phases
        eq = 'vle' if top_phase == 'g' else 'lle'
        ms.mix_from(feeds)
        ms.P = self.P
        if eq == 'lle':
            self.top_chemical = top_chemical = self.top_chemical or feeds[1].main_chemical
            for i in partitions: i.top_chemical = top_chemical
        data = self.partition_data
        if data:
            top_chemicals = data.get('extract_chemicals') or data.get('vapor_chemicals', [])
            bottom_chemicals = data.get('raffinate_chemicals') or data.get('liquid_chemicals', [])
            for i in chemicals.light_chemicals:
                i = i.ID
                if i in top_chemicals or i in bottom_chemicals: continue
                top_chemicals.append(i)
            for i in chemicals.heavy_chemicals:
                i = i.ID
                if i in top_chemicals or i in bottom_chemicals: continue
                bottom_chemicals.append(i)
        else:
            top_chemicals = [i.ID for i in chemicals.light_chemicals]
            bottom_chemicals = [i.ID for i in chemicals.heavy_chemicals]
        if eq == 'lle':
            IDs = data['IDs'] if 'IDs' in data else [i.ID for i in ms.lle_chemicals]
        else:
            IDs = data['IDs'] if 'IDs' in data else [i.ID for i in ms.vle_chemicals]
        if self.stage_reactions:
            nonzero = set()
            for rxn in self.stage_reactions.values():
                nonzero.update(rxn.stoichiometry.nonzero_keys())
            all_IDs = set(IDs)
            for i in nonzero:
                ID = chemicals.IDs[i]
                if ID not in all_IDs:
                    IDs.append(ID)
        self._IDs = IDs = tuple(IDs)
        self._N_chemicals = N_chemicals = len(IDs)
        self._S_stages = [i for i in stages if i.B_specification != 0]
        self._NS_stages = len(self._S_stages)
        self._update_index = index = ms.chemicals.get_index(IDs)
        self.feed_flows = feed_flows = np.zeros([N_stages, N_chemicals])
        self.feed_enthalpies = feed_enthalpies = np.zeros(N_stages)
        for feed, stage in zip(feeds, feed_stages):
            feed_flows[stage, :] += feed.mol[index]
            feed_enthalpies[stage] += feed.H
        self.total_feed_flows = feed_flows.sum(axis=1)
        self._iter_args = (feed_flows, self._asplit_1, self._bsplit_1, self.N_stages)
        feed_stages = [(i if i >= 0 else N_stages + i) for i in self.feed_stages]
        stage_specifications = {(i if i >= 0 else N_stages + i): j for i, j in self.stage_specifications.items()}
        top_side_draws = {(i if i >= 0 else N_stages + i): j for i, j in self.top_side_draws.items()}
        bottom_side_draws = {(i if i >= 0 else N_stages + i): j for i, j in self.bottom_side_draws.items()}
        self.key_stages = key_stages = set([*feed_stages, *stage_specifications, *top_side_draws, *bottom_side_draws])
        if (self.use_cache 
            and all([i.IDs == IDs for i in partitions])): # Use last set of data
            pass
        elif self.collapsed_init and len(key_stages) != self.N_stages:
            self.hot_start_collapsed_stages(
                key_stages, feed_stages, stage_specifications,
                top_side_draws, bottom_side_draws,
            )
        else:
            if data and 'K' in data: 
                top, bottom = ms
                K = data['K']
                phi = data.get('phi', 0.5)
                if K.ndim == 2:
                    data['phi'] = phi = sep.partition(
                        ms, top, bottom, IDs, K.mean(axis=0), phi,
                        top_chemicals, bottom_chemicals
                    )
                    B = inf if phi == 1 else phi / (1 - phi)
                    T = ms.T
                    for i, Ki in zip(partitions, K): 
                        if i.B_specification is None: i.B = B
                        i.T = T
                        i.K = Ki
                else:
                    data['phi'] = phi = sep.partition(ms, top, bottom, IDs, K, phi,
                                                      top_chemicals, bottom_chemicals)
                    B = inf if phi == 1 else phi / (1 - phi)
                    T = ms.T
                    for i in partitions: 
                        if i.B_specification is None: i.B = B
                        i.T = T
                        i.K = K
            elif eq == 'lle':
                lle = ms.lle
                T = ms.T
                lle(T, top_chemical=top_chemical)
                K = lle._K
                phi = lle._phi
                B = inf if phi == 1 else phi / (1 - phi)
                y = ms.imol['L', IDs]
                y /= y.sum()
                f_gamma = self.thermo.Gamma([chemicals[i] for i in IDs])
                gamma_y = f_gamma(y, T)
                for i in partitions: 
                    i.B = B
                    i.T = T
                    i.K = K
                    i.gamma_y = gamma_y
                    for j in i.outs: j.T = T
            else:
                P = self.P
                if self.stage_specifications:
                    dp = ms.dew_point_at_P(P=P, IDs=IDs)
                    T_bot = dp.T
                    bp = ms.bubble_point_at_P(P=P, IDs=IDs)
                    T_top = bp.T
                    dT_stage = (T_bot - T_top) / N_stages
                    phase_ratios = self.hot_start_phase_ratios()
                    z = bp.z
                    z[z == 0] = 1.
                    x = dp.x
                    x[x == 0] = 1.
                    K_dew = dp.z / dp.x
                    K_bubble = bp.y / bp.z
                    dK_stage = (K_bubble - K_dew) / N_stages
                    for i, B in enumerate(phase_ratios):
                        partition = partitions[i]
                        if partition.B_specification is None: partition.B = B
                        partition.T = T = T_top + i * dT_stage
                        partition.K = K_dew + i * dK_stage
                        for s in partition.outs: s.T = T
                else:
                    vle = ms.vle
                    vle(H=ms.H, P=P)
                    L_mol = ms.imol['l', IDs]
                    L_mol_net = L_mol.sum()
                    if L_mol_net: x_mol = L_mol / L_mol.sum()
                    else: x_mol = np.ones(N_chemicals, float) / N_chemicals
                    V_mol = ms.imol['g', IDs]
                    y_mol = V_mol / V_mol.sum()
                    K = y_mol / x_mol
                    phi = ms.V
                    B = phi / (1 - phi)
                    T = ms.T
                    for partition in partitions:
                        partition.T = T
                        partition.B = B
                        for i in partition.outs: i.T = T
                        partition.K = K
                        partition.fgas = P * y_mol
                        for s in partition.outs: s.empty()
            N_chemicals = len(index)
        if top_chemicals:
            top_side_draws = self.top_side_draws
            n = len(top_chemicals)
            b = np.ones([N_stages, n])
            c = self._asplit_1[1:]
            d = np.zeros([N_stages, n])
            for feed, stage in zip(feeds, feed_stages):
                d[stage] += feed.imol[top_chemicals]
            top_flow_rates = MESH.solve_RBDMA(b, c, d)
            for partition, flows in zip(partitions, top_flow_rates):
                partition.outs[0].imol[top_chemicals] = flows
        if bottom_chemicals:
            bottom_side_draws = self.bottom_side_draws
            a = self._bsplit_1[:-1]
            n = len(bottom_chemicals)
            b = np.ones([N_stages, n])
            d = np.zeros([N_stages, n])
            for feed, stage in zip(feeds, feed_stages):
                d[stage] += feed.imol[bottom_chemicals]
            bottom_flow_rates = MESH.solve_LBDMA(a, b, d)
            for partition, b in zip(partitions, bottom_flow_rates):
                partition.outs[1].imol[bottom_chemicals] = b
        if top_chemicals or bottom_chemicals:
            for i in stages:
                for s in i.splitters: s._run()
        for i in partitions: i.IDs = IDs
        self.interpolate_missing_variables()
        for i in self.stages: i._init_separation_factors()
        return np.array([i.S for i in self._S_stages])
    
    def get_energy_balance_temperature_departures(self):
        partitions = self.partitions
        if all([i.T_specification is None for i in partitions]):
            N_stages = self.N_stages
            Cl = np.zeros(N_stages)
            Cv = Cl.copy()
            Hv = Cl.copy()
            Hl = Cl.copy()
            for i, j in enumerate(partitions):
                top, bottom = j.outs
                Hl[i] = bottom.H
                Hv[i] = top.H
                Cl[i] = bottom.C
                Cv[i] = top.C
            dTs = MESH.temperature_departures(
                Cv, Cl, Hv, Hl, self._asplit_left, self._bsplit_left,
                N_stages, self.feed_enthalpies
            )
        else:
            start = 0
            Cl = np.zeros(N_stages)
            Cv = Cl.copy()
            Hv = Cl.copy()
            Hl = Cl.copy()
            dT = Cl.copy()
            for i, p in enumerate(partitions):
                if p.T_specification is None:
                    top, bottom = p.outs
                    Hl[i] = bottom.H
                    Hv[i] = top.H
                    Cl[i] = bottom.C
                    Cv[i] = top.C
                else:
                    end = i + 1
                    index = slice(start, end)
                    dT[index] = MESH.temperature_departures(
                        Cv[index], Cl[index], Hv[index], Hl[index], 
                        self._asplit_left[index], 
                        self._bsplit_left[index],
                        end - start, self.feed_enthalpies[index],
                    )
                    start = end
        return dTs
    
    def get_energy_balance_phase_ratio_departures(self):
        # ENERGY BALANCE
        # hV1*L1*dB1 - hv2*L2*dB2 = Q1 + H_in - H_out
        partitions = self.partitions
        N_stages = self.N_stages
        L = np.zeros(N_stages)
        V = L.copy()
        hv = L.copy()
        hl = L.copy()
        specification_index = []
        missing = []
        for i, j in enumerate(partitions):
            top, bottom = j.outs
            Li = bottom.F_mol
            Vi = top.F_mol
            L[i] = Li
            V[i] = Vi
            if Vi == 0:
                if Li == 0:  
                    hv[i] = None
                    hl[i] = None
                    if j.B_specification is not None or j.T_specification is not None:
                        specification_index.append(i)
                    missing.append(i)
                    continue
                bottom.phase = 'g'
                hv[i] = bottom.h
                bottom.phase = 'l'
            else:
                hv[i] = top.h
            if Li == 0:
                top.phase = 'l'
                hl[i] = top.h
                top.phase = 'g'
            else:
                hl[i] = bottom.h
            if j.B_specification is not None or j.T_specification is not None: 
                specification_index.append(i)
        if missing:
            neighbors = MESH.get_neighbors(missing=missing, size=N_stages)
            hv = MESH.fillmissing(neighbors, hv)
            hl = MESH.fillmissing(neighbors, hl)
        feed_enthalpies = self.feed_enthalpies
        if self.stage_reactions:
            feed_enthalpies = feed_enthalpies.copy()
            for i in self.stage_reactions:
                partition = partitions[i]
                feed_enthalpies[i] += partition.ins[0].Hf - sum([i.Hf for i in partition.outs])
        specification_index = np.array(specification_index, dtype=int)
        dB = MESH.phase_ratio_departures(
            L, V, hl, hv, 
            self._asplit_1, 
            self._asplit_left,
            self._bsplit_left,
            N_stages,
            specification_index,
            feed_enthalpies,
        )
        return dB
        
    def update_energy_balance_phase_ratio_departures(self):
        dBs = self.get_energy_balance_phase_ratio_departures()
        partitions = self.partitions
        # Bs = np.array([i.B for i in self.partitions])
        # Bs_new = dBs + Bs
        # index = np.argmin(Bs_new)
        # Bs_min = Bs_new[index]
        # if Bs_min < 0:
        #     f = - 0.99 * Bs[index] / dBs[index]
        #     print(f)
        #     breakpoint()
        #     dBs = f * dBs
        f = 1
        for i, dB in zip(partitions, dBs): 
            if dB >= 0: continue
            f = min(-0.99 * i.B / dB, f)
        dBs *= f
        for i, dB in zip(partitions, dBs):
            if i.B_specification is None: 
                i.B += (1 - i.B_relaxation_factor) * dB
        if getattr(self, 'tracking', False):
            self._collect_variables('energy')
    
    def update_energy_balance_temperatures(self):
        dTs = self.get_energy_balance_temperature_departures()
        partitions = self.partitions
        for p, dT in zip(partitions, dTs):
            if p.T_specification is None: 
                dT = (1 - p.T_relaxation_factor) * dT
                p.T += dT
                for i in p.outs: i.T += dT
        if getattr(self, 'tracking', False):
            self._collect_variables('energy')
        return dTs
       
    def run_mass_balance(self):
        S = np.array([i.S for i in self.stages])
        feed_flows, *args = self._iter_args
        if self.stage_reactions:
            feed_flows = self._feed_flows_and_conversion()
        return MESH.bottom_flow_rates(S, feed_flows, *args)
       
    def update_mass_balance(self):
        self.set_flow_rates(self.run_mass_balance())
        
    def interpolate_missing_variables(self):
        stages = self.stages
        lle = self._has_lle and 'K' not in self.partition_data
        partitions = [i.partition for i in stages]
        Bs = []
        Ks = []
        Ts = []
        if lle: gamma_y = []
        N_stages = self.N_stages
        index = []
        N_chemicals = self._N_chemicals
        for i in range(N_stages):
            partition = partitions[i]
            B = partition.B
            T = partition.T
            K = partition.K
            if B is None or K is None or K.size != N_chemicals: continue
            index.append(i)
            Bs.append(B)
            Ks.append(K)
            Ts.append(T)
            if lle: gamma_y.append(partition.gamma_y)
        N_ok = len(index)
        if len(index) != N_stages:
            if N_ok > 1:
                neighbors = MESH.get_neighbors(index, size=N_stages)
                Bs = MESH.fillmissing(neighbors, MESH.expand(Bs, index, N_stages))
                Ts = MESH.fillmissing(neighbors, MESH.expand(Ts, index, N_stages))
                N_chemicals = self._N_chemicals
                all_Ks = np.zeros([N_stages, N_chemicals])
                if lle: all_gamma_y = all_Ks.copy()
                for i in range(N_chemicals):
                    all_Ks[:, i] = MESH.fillmissing(
                        neighbors, 
                        MESH.expand([stage[i] for stage in Ks], index, N_stages)
                    )
                    if not lle: continue
                    all_gamma_y[:, i] = MESH.fillmissing(
                        neighbors, 
                        MESH.expand([stage[i] for stage in gamma_y], index, N_stages)
                    )
                if lle: gamma_y = all_gamma_y
                Ks = all_Ks
            elif N_ok == 1:
                Bs = np.array(N_stages * Bs)
                Ks = np.array(N_stages * Ks)
                Ts = np.array(N_stages * Ts)
                if lle: gamma_y = np.array(N_stages * gamma_y)
            elif N_ok == 0:
                raise RuntimeError('no phase equilibrium')
            for i, stage in enumerate(stages): 
                partition = stage.partition
                T = Ts[i]
                partition.T = T 
                for j in partition.outs: j.T = T
                if partition.B_specification is None: partition.B = Bs[i]
                partition.K = Ks[i]
                if lle: partition.gamma_y = gamma_y[i]
    
    def _lnS_objective(self, lnS):
        self.iter += 1
        S_original = np.exp(lnS)
        S = S_original.reshape([self._NS_stages, self._N_chemicals])
        self.update_flow_rates(S)
        stages = self.stages
        P = self.P
        if self._has_vle: 
            for i in stages:
                mixer = i.mixer
                partition = i.partition
                mixer.outs[0].mix_from(
                    mixer.ins, energy_balance=False,
                )
                mixer.outs[0].P = P
                partition._run_decoupled_KTvle(P=P)
                if (partition.T_specification is None
                    and partition.B_specification is None):
                    partition._run_decoupled_B()
                T = partition.T
                for i in (partition.outs + i.outs): i.T = T
            energy_error = lambda stage: abs(
                (sum([i.H for i in stage.outs]) - sum([i.H for i in stage.ins], stage.Q)) / sum([i.C for i in stage.outs])
            )
            total_energy_error = sum([energy_error(stage) for stage in stages if stage.Q is not None])
        else:
            for i in range(5):
                dTs = self.update_energy_balance_temperatures()
                if np.abs(dTs).sum() < 1e-12: break
            for i in self.stages: 
                mixer = i.mixer
                partition = i.partition
                mixer.outs[0].mix_from(
                    mixer.ins, energy_balance=False,
                )
                partition._run_lle(update=False, P=P)
        S_new = np.array([(i.K * i.B) for i in self._S_stages]).flatten()
        splits_new = 1 / (S_new + 1)
        splits = 1 / (S_original + 1)
        diff = splits_new - splits
        total_split_error = np.abs(diff).sum()
        if self._has_vle:
            total_error = total_split_error + total_energy_error
        else:
            total_error = total_split_error
        total_error /= self.N_stages
        err = np.sqrt(total_error)
        return err
    
    def update_bubble_point(self):
        stages = self.stages
        P = self.P
        if self.stage_reactions:
            self.update_liquid_holdup() # Finds liquid volume at each stage
            for stage in stages:
                partition = stage.partition
                partition._run_decoupled_KTvle(P=P)
                T = partition.T
                for i in (partition.outs + stage.outs): i.T = T
                if partition.reaction: 
                    partition._run_decoupled_reaction(P=P)
        else:
            for i in stages:
                partition = i.partition
                partition._run_decoupled_KTvle(P=P)
                T = partition.T
                for i in (partition.outs + i.outs): i.T = T
        if getattr(self, 'tracking', False):
            self._collect_variables('vle_phenomena')
    
    def lle_gibbs(self):
        return sum([i.partition.lle_gibbs() for i in self.stages])
    
    def update_pseudo_lle(self, separation_factors=None):
        stages = self.stages
        P = self.P
        if 'K' in self.partition_data:
            if separation_factors is not None: 
                self.update_flow_rates(separation_factors, update_B=False)
                for i in stages: i._update_separation_factors()
                if getattr(self, 'tracking', False):
                    self._collect_variables('material')
        else:
            stages = self._S_stages
            if separation_factors is None: separation_factors = np.array([i.S for i in stages])
            def psuedo_equilibrium(separation_factors):
                self.update_flow_rates(separation_factors, update_B=False)
                for n, i in enumerate(stages): 
                    i.partition._run_decoupled_Kgamma(P=P)
                    i._update_separation_factors()
                return np.array([i.S for i in stages])
            separation_factors = flx.fixed_point(
                psuedo_equilibrium, separation_factors, 
                xtol=self.S_tolerance,
                rtol=self.relative_S_tolerance,
                checkiter=False,
                checkconvergence=False,
            )
            self.update_flow_rates(separation_factors, update_B=True)
            for i in stages: i._update_separation_factors()
        for i in stages: 
            mixer = i.mixer
            partition = i.partition
            mixer.outs[0].mix_from(mixer.ins, energy_balance=False)
            partition._run_decoupled_B()
            i._update_separation_factors()
        if getattr(self, 'tracking', False):
            self._collect_variables('lle_phenomena')
    
    def update_flow_rates(self, separation_factors, update_B=True):
        if separation_factors.min() < 0:
            S = separation_factors
            S_old = np.array([i.S for i in self._S_stages])
            # S * x + (1 - x) * S_old = 0
            # S * x + S_old - x * S_old = 0
            # S_old + x * (S - S_old) = 0
            # x = S_old / (S_old - S)
            x = S_old / (S_old - S)
            x = 0.1 * (x[(x < 1) & (x > 0)]).min()
            separation_factors = S * x + (1 - x) * S_old
        for stage, S in zip(self._S_stages, separation_factors): stage.S = S
        flows = self.run_mass_balance()
        self.set_flow_rates(flows, update_B)
        if getattr(self, 'tracking', False):
            self._collect_variables('material')
    
    def update_pseudo_vle(self, separation_factors):
        P = self.P
        if 'K' in self.partition_data:
            if separation_factors is not None: 
                self.update_flow_rates(separation_factors)
                if getattr(self, 'tracking', False):
                    self._collect_variables('material')
        else:
            if separation_factors is None: separation_factors = np.array([i.S for i in self._S_stages])
            self.update_flow_rates(separation_factors, update_B=True)
            for i in self.stages: 
                i.partition._run_decoupled_Kfgas(P=P)
                i._update_separation_factors()
        if getattr(self, 'tracking', False):
            raise NotImplementedError('tracking sum rates decomposition not implemented')
    
    def _phenomena_iter(self, separation_factors):
        if self._has_vle:
            decomp = self.vle_decomposition
            if decomp == 'bubble point':
                self.update_flow_rates(separation_factors, update_B=True)
                self.update_bubble_point()
                self.update_energy_balance_phase_ratio_departures()
                for i in self.stages: i._update_separation_factors()
            elif decomp == 'sum rates':
                self.update_pseudo_vle(separation_factors)
                self.update_energy_balance_temperatures()
            else:
                raise NotImplementedError(f'{decomp!r} decomposition not implemented')
        elif self._has_lle:
            self.update_pseudo_lle(separation_factors)
            self.update_energy_balance_temperatures()
        else:
            raise RuntimeError('unknown equilibrium phenomena')
        if not hasattr(self, 'deltas'): self.deltas = []
        x = lambda stream: np.array([*stream.mol, stream.H])
        self.deltas.append(
            [x(i.outs[0]) - sum([x(i) for i in i.ins if i.phase=='g']) for i in self.partitions]
        )
        return np.array([i.S for i in self._S_stages])
    
    def _inside_out_iter(self, separation_factors):
        self.iter += 1
        self.inside_iter = 0
        separation_factors = self._iter(separation_factors)
        if self._update_surrogate_models():
            return flx.fixed_point(
                self._surrogate_iter, 
                separation_factors, 
                maxiter=self.inside_maxiter,
                xtol=self.S_tolerance,
                rtol=self.relative_S_tolerance, 
                checkconvergence=False,
                checkiter=False,
            )
        else:
            return separation_factors
    
    def _update_surrogate_models(self):
        partitions = self.partitions
        mixture = self.mixture
        range_stages = range(self.N_stages)
        P = self.P
        Ts = np.array([i.T for i in partitions])
        for i in range(self.N_stages-1):
            if Ts[i+1] - Ts[i] < 0: return False
        xs = np.array([i.x for i in partitions])
        ys = np.array([i.y for i in partitions])
        Ks = np.array([i.K for i in partitions])
        hVs = np.array([
            mixture.H(mol=ys[i], phase='g', T=Ts[i], P=P)
            for i in range_stages
        ])
        hLs = np.array([
            mixture.H(mol=xs[i], phase='l', T=Ts[i], P=P)
            for i in range_stages
        ])
        self.Tlb = Ts.min()
        self.Tub = Ts.max()
        self.T_surrogate = self.TSurrogate(xs[:, :-1], Ts, **self.T_surrogate_options)
        self.K_surrogate = self.KSurrogate(Ts, Ks, **self.K_surrogate_options)
        self.hV_surrogate = self.hSurrogate(Ts, hVs, **self.h_surrogate_options)
        self.hL_surrogate = self.hSurrogate(Ts, hLs, **self.h_surrogate_options)
        # import matplotlib.pyplot as plt
        # xs = np.linspace(self.Tlb, self.Tub, num=100)
        # fig, ax = plt.subplots()
        # ax.plot(Ts, hVs, "o", label="data")
        # ax.plot(xs, self.hV_surrogate(xs), label='surrogate')
        # ax.legend()
        # fig.show()
        return True
    
    def _sequential_iter(self, separation_factors):
        self.update_flow_rates(separation_factors)
        for i in self.stages: i._run()
        for i in reversed(self.stages): i._run()
        return np.array([i.S for i in self._S_stages])
    
    def _phenomena_surrogate_iter(self, separation_factors):
        feed_flows, *args = self._iter_args
        stages = self.stages
        S_stages = self._S_stages
        for i, j in zip(S_stages, separation_factors): i.S = j
        separation_factors = np.array([i.S for i in stages])
        bottom_flows = MESH.bottom_flow_rates(separation_factors, feed_flows, *args)
        N_stages = self.N_stages
        range_stages = range(N_stages)
        top_flows = MESH.top_flows_mass_balance(
            bottom_flows, feed_flows, self._asplit_left, self._bsplit_left, N_stages
        )
        for i in range_stages:
            t = top_flows[i]
            mask = t < 0
            if mask.any():
                bulk_t = t.sum()
                t[mask] = 0
                dummy = t.sum()
                if dummy: t *= bulk_t / dummy
            b = bottom_flows[i]
            mask = b < 0
            if mask.any():
                bulk_b = b.sum()
                b[mask] = 0
                dummy = b.sum()
                if dummy: b *= bulk_b / b.sum()
        V = top_flows.sum(axis=1)
        L = bottom_flows.sum(axis=1, keepdims=True)
        L[L == 0] = 1
        xs = bottom_flows / L
        L = L[:, 0]
        Ts = self.T_surrogate(xs[:, :-1])
        Ts[Ts < self.Tlb] = self.Tlb
        Ts[Ts > self.Tub] = self.Tub
        for i, j in enumerate(stages):
            if j.T_specification: Ts[i] = j.T
        Ks = self.K_surrogate(Ts)
        hV = self.hV_surrogate(Ts)
        hL = self.hL_surrogate(Ts)
        specification_index = [
            i for i, j in enumerate(self.stages)
            if j.B_specification is not None or j.T_specification is not None
        ]
        feed_enthalpies = self.feed_enthalpies
        specification_index = np.array(specification_index, dtype=int)
        dB = MESH.phase_ratio_departures(
            L, V, hL, hV, 
            self._asplit_1, 
            self._asplit_left,
            self._bsplit_left,
            N_stages,
            specification_index,
            feed_enthalpies,
        )
        Bs = np.array([i.B for i in self.stages])
        Bs_spec = Bs[specification_index]
        Bs += dB * 0.1
        Bs[specification_index] = Bs_spec
        for i, j in zip(stages, Bs): i.B = j
        mask = Bs != 0
        S = Bs[mask][:, None] * Ks[mask]
        for i, j in zip(S_stages, S): i.S = S
        print('-------')
        print(self.iter, self.inside_iter)
        print('-------')
        print(Ts)
        print(hL)
        print(hV)
        # breakpoint()
        return S
    
    def _surrogate_iter(self, separation_factors=None):
        self.inside_iter += 1
        algorithm = self.algorithm
        if algorithm == 'phenomena':
            separation_factors = self._phenomena_surrogate_iter(separation_factors)
        else:
            raise RuntimeError(f'invalid algorithm {algorithm!r} for surrogate model')
        return separation_factors
    
    def _iter(self, separation_factors=None):
        self.iter += 1
        algorithm = self.algorithm
        if algorithm == 'phenomena':
            separation_factors = self._phenomena_iter(separation_factors)
        elif algorithm == 'sequential modular':
            separation_factors = self._sequential_iter(separation_factors)
        else:
            raise RuntimeError(f'invalid algorithm {algorithm!r}')
        return separation_factors
        # if self.inside_out and self._has_vle:
        #     raise NotImplementedError('inside-out algorithm not implemented in BioSTEAM (yet)')
        #     self.update_mass_balance()
        #     N_stages = self.N_stages
        #     N_chemicals = self._N_chemicals
        #     T = np.zeros(N_stages)
        #     B = np.zeros(N_stages)
        #     K = np.zeros([N_stages, N_chemicals])
        #     hv = T.copy()
        #     hl = T.copy()
        #     specification_index = []
        #     for i, j in enumerate(self.partitions):
        #         top, bottom = j.outs
        #         T[i] = j.T
        #         B[i] = j.B
        #         K[i] = j.K
        #         if bottom.isempty():
        #             top.phase = 'l'
        #             hl[i] = top.h
        #             top.phase = 'g'
        #         else:
        #             hl[i] = bottom.h
        #         if top.isempty():
        #             bottom.phase = 'g'
        #             hv[i] = bottom.h
        #             bottom.phase = 'l'
        #         else:
        #             hv[i] = top.h
        #         if j.B_specification is not None or j.T_specification: specification_index.append(i)
        #     S, safe = MESH.bottoms_stripping_factors_safe(B, K)
        #     top_flows = MESH.estimate_top_flow_rates(
        #         S, 
        #         self.feed_flows,
        #         self._asplit_1,
        #         self._bsplit_1,
        #         N_stages,
        #         safe,
        #     )
        #     if safe:
        #         S_index = None
        #         lnS = np.exp(S).flatten()
        #     else:
        #         S_index = [
        #             i for i, j in enumerate(self.partitions)
        #             if not (j.B_specification == 0 or j.B_specification == np.inf)
        #         ]
        #         S_init = S[S_index]
        #         lnS = np.exp(S_init).flatten()
        #     result = MESH.solve_inside_loop(
        #         lnS, S, S_index, top_flows, K, B, T, hv, hl, self.feed_flows,
        #         self._asplit_1, self._bsplit_1, 
        #         self._asplit_left, self._bsplit_left,
        #         N_stages, np.array(specification_index, int),
        #         N_chemicals,
        #         self.feed_enthalpies
        #     )
        #     if safe:
        #         S = np.exp(result.x).reshape([N_stages, N_chemicals])
        #     else:
        #         S[S_index] = np.exp(result.x).reshape([len(S_index), N_chemicals])
