"""Retrieve and process data from WRDS CRSP Daily Stock File"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/01_wrds/02_crspd.ipynb.

# %% ../../nbs/01_wrds/02_crspd.ipynb 3
from __future__ import annotations
from typing import List

import pandas as pd
import numpy as np

import pandasmore as pdm
from . import wrds_api

# %% auto 0
__all__ = ['PROVIDER', 'URL', 'LIBRARY', 'TABLE', 'NAMES_TABLE', 'DELIST_TABLE', 'FREQ', 'MIN_YEAR', 'MAX_YEAR',
           'ENTITY_ID_IN_RAW_DSET', 'ENTITY_ID_IN_CLEAN_DSET', 'TIME_VAR_IN_RAW_DSET', 'TIME_VAR_IN_CLEAN_DSET',
           'list_all_vars', 'default_raw_vars', 'parse_varlist', 'get_raw_data', 'process_raw_data', 'delist_adj_ret',
           'features']

# %% ../../nbs/01_wrds/02_crspd.ipynb 4
PROVIDER = 'Wharton Research Data Services (WRDS)'
URL = 'https://wrds-www.wharton.upenn.edu/pages/get-data/center-research-security-prices-crsp/annual-update/stock-security-files/daily-stock-file/'
LIBRARY = 'crsp'
TABLE = 'dsf'
NAMES_TABLE = 'dsenames'
DELIST_TABLE = 'dsedelist'
FREQ = 'D'
MIN_YEAR = 1925
MAX_YEAR = None
ENTITY_ID_IN_RAW_DSET = 'permno'
ENTITY_ID_IN_CLEAN_DSET = 'permno'
TIME_VAR_IN_RAW_DSET = 'date'
TIME_VAR_IN_CLEAN_DSET = f'{FREQ}date'

# %% ../../nbs/01_wrds/02_crspd.ipynb 5
def list_all_vars() -> pd.DataFrame:
    "Collects names of all available variables from WRDS `{LIBRARY}.{TABLE}` and `{LIBRARY}.{NAMES_TABLE}`"

    try:
        db = wrds_api.Connection()
        msf = db.describe_table(LIBRARY,TABLE).assign(wrds_library=LIBRARY, wrds_table=TABLE)
        mse = db.describe_table(LIBRARY,NAMES_TABLE).assign(wrds_library=LIBRARY, wrds_table=NAMES_TABLE)
        dlst = db.describe_table(LIBRARY,DELIST_TABLE).assign(wrds_library=LIBRARY, wrds_table=DELIST_TABLE)
    finally:
        db.close()

    return pd.concat([msf, mse, dlst])[['name','type','wrds_library','wrds_table']].copy()

# %% ../../nbs/01_wrds/02_crspd.ipynb 9
def default_raw_vars():
    """Defines default variables used in `get_raw_data` if none are specified."""
    
    return ['permno','permco','date',
            'ret', 'retx', 'shrout', 'prc', 
            'shrcd', 'exchcd',
            'cfacpr', 'cfacshr',
            'dlret','dlstcd','dlstdt']  #'siccd','naics','cusip','ncusip'          

# %% ../../nbs/01_wrds/02_crspd.ipynb 11
def parse_varlist(vars: List[str]=None,
                  required_vars = [],
                  ) -> str:
    """Figures out which `vars` come from the `{LIBRARY}.{TABLE}` table and which come from the `{LIBRARY}.{NAMES_TABLE}` table and adds a. and b. prefixes to variable names to feed into an SQL query"""

    # Get all available variables and add suffixes needed for the SQL query
    suffix_mapping = {TABLE: 'a.', NAMES_TABLE: 'b.', DELIST_TABLE: 'c.'}
    all_avail_vars = list_all_vars().drop_duplicates(subset='name',keep='first').copy()
    all_avail_vars['w_prefix'] = all_avail_vars.apply(lambda row: suffix_mapping[row['wrds_table']] + row['name'] , axis=1)

    if vars == '*': return ','.join(list(all_avail_vars['w_prefix']))
    
    # Add required vars to requested vars
    if vars is None: vars = default_raw_vars()
    vars_to_get =  required_vars + [x for x in list(set(vars)) if x not in required_vars]

    # Validate variables to be downloaded (make sure that they are in the target database)
    invalid_vars = [v for v in vars_to_get if v not in list(all_avail_vars.name)]
    if invalid_vars: raise ValueError(f"These vars are not in the database: {invalid_vars}") 

    # Extract information on which variable comes from which wrds table, so we know what prefix to use
    vars_to_get = pd.DataFrame(vars_to_get, columns=['name'])
    get_these = vars_to_get.merge(all_avail_vars, how = 'left', on = 'name')
        
    return ','.join(list(get_these['w_prefix']))

# %% ../../nbs/01_wrds/02_crspd.ipynb 14
def get_raw_data(
        vars: List[str]=None, # If None, downloads `default_raw_vars`; use '*' to get all available variables
        required_vars = ['permno','date'], # Variables that are always downloaded, regardless `vars` argument
        nrows: int=None,  #Number of rows to download. If None, full dataset will be downloaded             
        start_date: str="01/01/1950",  # Start date in MM/DD/YYYY format
        end_date: str=None,            # End date in MM/DD/YYYY format  
        shrcd_exchcd_filters: bool=True, # If true, keep only observations with shrcd in [10,11] and exchcd in [1,2,3]
) -> pd.DataFrame:
    "Downloads `vars` from `start_date` to `end_date` from WRDS {LIBRARY}.{TABLE}, {LIBRARY}.{NAMES_TABLE} and {LIBRARY}.{DELIST_TABLE}." 

    wrds_api.validate_dates([start_date, end_date])
    varlist_string = parse_varlist(vars, required_vars=required_vars)
    sql_string = f"""SELECT {varlist_string}
                        FROM {LIBRARY}.{TABLE} AS a 
                        LEFT JOIN {LIBRARY}.{NAMES_TABLE} AS b
                            ON a.permno=b.permno AND b.namedt<=a.date AND a.date<=b.nameendt 
                        LEFT JOIN {LIBRARY}.{DELIST_TABLE} as c
                            ON a.permno=c.permno AND date_trunc('month', a.date) = date_trunc('month', c.dlstdt)
                            """
    sql_string += "WHERE 1=1 "
    if shrcd_exchcd_filters: sql_string += "AND shrcd IN (10,11) AND exchcd IN (1,2,3) "
    if start_date is not None: sql_string += r"AND date >= %(start_date)s "
    if end_date is not None: sql_string += r"AND date <= %(end_date)s "
    if nrows is not None: sql_string += r" LIMIT %(nrows)s"

    df = wrds_api.download(sql_string,
                             params={'start_date':start_date, 'end_date':end_date, 'nrows':nrows})
    
    return df 

# %% ../../nbs/01_wrds/02_crspd.ipynb 16
def process_raw_data(
        df: pd.DataFrame=None,  # Must contain `permno` and `date` columns         
        clean_kwargs: dict={},  # Params to pass to `pdm.setup_panel` other than `panel_ids`, `time_var`, and `freq`
) -> pd.DataFrame:
    """Applies `pandasmore.setup_panel` to `df`"""

    # Change some variables to categorical
    for col in ['shrcd','exchcd']:
        if col in df.columns:
            df[col] = df[col].astype('Int64').astype('category')

    for col in ['naics','cusip','ncusip']:
        if col in df.columns:
            df[col] = df[col].astype('string').astype('category')

    if 'siccd' in df.columns:
        df['siccd'] = df['siccd'].astype('Int64').astype('string').str.zfill(4).astype('category')

    # Set up panel structure
    df = pdm.setup_panel(df, panel_ids=ENTITY_ID_IN_RAW_DSET, time_var=TIME_VAR_IN_RAW_DSET, freq=FREQ, panel_ids_toint=False, **clean_kwargs)
    return df 

# %% ../../nbs/01_wrds/02_crspd.ipynb 19
def delist_adj_ret(
        df: pd.DataFrame, # Requires `ret`,`exchcd`,`dlret`,`dlstcd`, and `dlstdt` variables
        adj_ret_var: str='ret_adj' # Name of the adjusted return variable created by this function
) -> pd.DataFrame:
    """Adjusts for returns for delisting using Shumway and Warther (1999) and Johnson and Zhao (2007)"""

    df['npdelist'] = (df['dlstcd']==500) | df['dlstcd'].between(520,584)
    df['dlret'] = np.where(df['dlret'].isna() & df['npdelist'] & df['exchcd'].isin([1,2]), -0.35, df['dlret'])
    df['dlret'] = np.where(df['dlret'].isna() & df['npdelist'] & df['exchcd'].isin([3]), -0.55, df['dlret'])
    df['dlret'] = np.where(df['dlret'].notna() & df['dlret'] < -1, -1, df['dlret'])
    df['dlret'] = df['dlret'].fillna(0)

    df[adj_ret_var] = (1 + df.ret) * (1 + df['dlret']) - 1
    df[adj_ret_var] = np.where(df[adj_ret_var].isna() & (df['dlret']!=0), df['dlret'], df[adj_ret_var])
    df = df.drop('npdelist', axis=1) 
    return df

# %% ../../nbs/01_wrds/02_crspd.ipynb 21
def features(
        df: pd.DataFrame,
) -> pd.DataFrame:
    
    out = pd.DataFrame(index=df.index)

    out['ret_adj'] = delist_adj_ret(df, adj_ret_var='ret_adj')[['ret_adj']].copy()
    
    # Note that we are not using trading days below, but calendar days
    #the rrolling method below is not feasible since it creates 30 lags all at once so it blows up the dataset
    #out['lbhret12'] = pdm.rrolling(1+df['ret'], window=30, func='prod', skipna=True) - 1
    #out['retvol12'] = pdm.rrolling(df['ret'], window=30, func='std', skipna=True) 

    return out 
