import numpy as np

"""
This module is supposed to contain production functions and related calculations
"""

class CobbDouglas:
    """
    The Cobb-Douglas production function
    """

    def __init__(self, beta, alpha):
        """Creates a Cobb-Douglas production function with multiple input factors.

        Args:
            beta (float): scale parameter
            alpha (list of float): list of elasticity parameters for each input factor
                                    (sum of alphas should be 1)
        """
        self.beta = beta 
        self.alphas = alpha

        if not ((isinstance(alpha, float) or isinstance(alpha, int))):
            assert abs(np.sum(self.alphas) - 1) < 0.01, "alpha_1 + alpha_2 + ... + alpha_N must sum up to 1"
            self.alpha = self.alphas[0]
        else:
            self.alphas = [alpha, 1-alpha]
            self.alpha = alpha 
        
    def utility(self, args):
        # compute utility level, see 'output'
        return self.output(args)
    
    def output(self, args, min_input=0.001):
        """ 
        Output function given multiple input factors.

        Args:
            args (list of float): list of input factors (e.g., [K, L, X1, X2, ...])
            min_input (float): minimum allowed input quantity for improved numerical stability
        Returns:
            float: production output
        """
        
        if len(args) != len(self.alphas):
            raise ValueError("The number of inputs must match the number of elasticity parameters")
        
        # Generalized output formula
        output_value = self.beta
        for input_value, alpha in zip(args, self.alphas):
            input_value = np.clip(input_value, min_input, np.inf)
            # print("input value", input_value, "alpha", alpha)
            output_value *= input_value ** alpha
        
        return output_value
    
    def optimal_endowment(self, Y, r, w):
        """computes the optimal factor endowment of a Cobb-Douglas production 
        function, given the factor prices r for factor 1/capital (K) and w for factor2/labor (w).

        Args:
            Y (float): output level 
            r (float): factor price 1 (capital)
            w (float): factor price 2 (labor)

        Returns:
            K, L (tuple): the optimal point 
        """
        
        a, b  = self.alpha, (1-self.alpha)
        Y0 = self.beta 
        L = (Y / Y0) * ((b/a)**a) * ((r/w)**a)
        K = (Y / Y0 ) * ((b/a)**(a-1))*((r/w)**(a-1))
        return K, L

    def costs(self, K, L, r, w):
        """computes the costs given capital and labor endowments,
        as well as corresponding factor prices
        
        Args:
            L (float): labor (factor 2) usage
            K (flaot): capital (factor 1) usage
            w (float): cost of factor 2 (labor) 
            r (float): cost of factor 1 (capital)
        """
        return L*w + K*r 


    def mrts(self, r, w):
        """marginal rate of technological substitution

        Args:
            r (float): factor price 1 (capital)
            w (float): factor price 2 (labor)

        Returns:
            mrts (float): MRTS value
        """
        return w/r

    def indifference(self, L, Y):
        """the indifference curve (see isoquant)
        """
        return self.isoquant(L, Y)
    
    def isoquant(self, K, Y):
        """the isoquant curve is the curve with same output Y given .
        This function returns the required capital (first factor), given the level of usage of the other factor (L).
        
        Args:
            L (float): capital (factor 1) used 
            Y (float): desired output level (fixed)
        
        Returns:
            K: cpaital (factor 1) required
        
        Y = b K**a L**(1-a)
        ((Y/b)*(K^(-a))**1/(1-a)
        """
        #return ((Y/self.beta)*L**(self.alpha-1))**(1/self.alpha)
        return ((Y/self.beta)*K**(-self.alpha))**(1/(1-self.alpha))

    def isocost(self, L, Y, r, w):
        """the isocost curve is the tangent of the isoquant curve, indicating 
        lines of equal production costs. 

        Args:
            L (float): labor usage
            Y (flaot): level of output (fixed)
            r (float): cost of factor capital
            w (float): cost of factor labor

        Returns:
            K: capital (factor 1) level at which costs are equal
        """
        mrts = self.mrts(r, w)
        L_opt, K_opt  = self.optimal_endowment(Y, r, w)
        return K_opt - mrts * (L - L_opt)
    
    def expansion_path(self, L, r, w):
        """gives the point on the expansion path line.

        Args:
            L (float): labor (factor 2) employed
            r (float): cost of factor 1 (capital)
            w (float): cost of factor 2 (labor)

        Returns:
            K: expansion path capital 
        """
        dL, dK = self.optimal_endowment(1, r, w)
        slope = dK/dL 
        return slope * L 


