mgplot.postcovid_plot
Plot the pre-COVID trajectory against the current trend.
1"""Plot the pre-COVID trajectory against the current trend.""" 2 3from typing import NotRequired, Unpack, cast 4 5from matplotlib.axes import Axes 6from numpy import arange, polyfit 7from pandas import DataFrame, Period, PeriodIndex, Series 8 9from mgplot.keyword_checking import ( 10 report_kwargs, 11 validate_kwargs, 12) 13from mgplot.line_plot import LineKwargs, line_plot 14from mgplot.settings import DataT, get_setting 15from mgplot.utilities import check_clean_timeseries 16 17# --- constants 18ME = "postcovid_plot" 19 20# Default regression periods by frequency 21DEFAULT_PERIODS = { 22 "Q": {"start": "2014Q4", "end": "2019Q4"}, 23 "M": {"start": "2015-01", "end": "2020-01"}, 24 "D": {"start": "2015-01-01", "end": "2020-01-01"}, 25} 26 27 28class PostcovidKwargs(LineKwargs): 29 """Keyword arguments for the post-COVID plot.""" 30 31 start_r: NotRequired[Period] # start of regression period 32 end_r: NotRequired[Period] # end of regression period 33 34 35# --- functions 36def get_projection(original: Series, to_period: Period) -> Series: 37 """Create a linear projection based on pre-COVID data. 38 39 Assumes the start of the data has been trimmed to the period before COVID. 40 41 Args: 42 original: Series - the original series with a PeriodIndex. 43 to_period: Period - the period to which the projection should extend. 44 45 Returns: 46 Series: A pandas Series with linear projection values using the same index as original. 47 48 Raises: 49 ValueError: If to_period is not within the original series index range. 50 51 """ 52 if to_period not in original.index: 53 raise ValueError(f"Regression end period {to_period} not found in series index") 54 y_regress = original[original.index <= to_period].copy() 55 x_regress = arange(len(y_regress)) 56 m, b = polyfit(x_regress, y_regress, 1) 57 58 x_complete = arange(len(original)) 59 return Series((x_complete * m) + b, index=original.index) 60 61 62def postcovid_plot(data: DataT, **kwargs: Unpack[PostcovidKwargs]) -> Axes: 63 """Plot a series with a PeriodIndex, including a post-COVID projection. 64 65 Args: 66 data: Series - the series to be plotted. 67 kwargs: PostcovidKwargs - plotting arguments. 68 69 Raises: 70 TypeError if series is not a pandas Series 71 TypeError if series does not have a PeriodIndex 72 ValueError if series does not have a D, M or Q frequency 73 ValueError if regression start is after regression end 74 75 """ 76 # --- check the kwargs 77 report_kwargs(caller=ME, **kwargs) 78 validate_kwargs(schema=PostcovidKwargs, caller=ME, **kwargs) 79 80 # --- check the data 81 data = check_clean_timeseries(data, ME) 82 if not isinstance(data, Series): 83 raise TypeError("The series argument must be a pandas Series") 84 85 series_index = PeriodIndex(data.index) 86 freq_str = series_index.freqstr 87 if not freq_str or freq_str[0] not in ("Q", "M", "D"): 88 raise ValueError("The series index must have a D, M or Q frequency") 89 90 freq_key = freq_str[0] 91 92 # rely on line_plot() to validate kwargs 93 if "plot_from" in kwargs: 94 print("Warning: the 'plot_from' argument is ignored in postcovid_plot().") 95 del kwargs["plot_from"] 96 97 # --- plot COVID counterfactual 98 default_periods = DEFAULT_PERIODS[freq_key] 99 start_regression = Period(default_periods["start"], freq=freq_str) 100 end_regression = Period(default_periods["end"], freq=freq_str) 101 102 # Override defaults with user-provided periods if specified 103 user_start = kwargs.pop("start_r", None) 104 user_end = kwargs.pop("end_r", None) 105 106 if user_start is not None: 107 start_regression = Period(user_start, freq=freq_str) 108 if user_end is not None: 109 end_regression = Period(user_end, freq=freq_str) 110 111 # Validate regression period 112 if start_regression >= end_regression: 113 raise ValueError("Start period must be before end period") 114 115 if start_regression not in data.index: 116 raise ValueError(f"Regression start period {start_regression} not found in series") 117 if end_regression not in data.index: 118 raise ValueError(f"Regression end period {end_regression} not found in series") 119 120 # --- combine data and projection 121 recent_data = data[data.index >= start_regression].copy() 122 recent_data.name = "Series" 123 projection_data = get_projection(recent_data, end_regression) 124 projection_data.name = "Pre-COVID projection" 125 126 # Create DataFrame with proper column alignment 127 combined_data = DataFrame( 128 { 129 projection_data.name: projection_data, 130 recent_data.name: recent_data, 131 } 132 ) 133 134 # --- activate plot settings 135 kwargs["width"] = kwargs.pop( 136 "width", 137 (get_setting("line_normal"), get_setting("line_wide")), 138 ) # series line is thicker than projection 139 kwargs["style"] = kwargs.pop("style", ("--", "-")) # dashed regression line 140 kwargs["label_series"] = kwargs.pop("label_series", True) 141 kwargs["annotate"] = kwargs.pop("annotate", (False, True)) # annotate series only 142 kwargs["color"] = kwargs.pop("color", ("darkblue", "#dd0000")) 143 144 return line_plot( 145 combined_data, 146 **cast("LineKwargs", kwargs), 147 )
29class PostcovidKwargs(LineKwargs): 30 """Keyword arguments for the post-COVID plot.""" 31 32 start_r: NotRequired[Period] # start of regression period 33 end_r: NotRequired[Period] # end of regression period
Keyword arguments for the post-COVID plot.
37def get_projection(original: Series, to_period: Period) -> Series: 38 """Create a linear projection based on pre-COVID data. 39 40 Assumes the start of the data has been trimmed to the period before COVID. 41 42 Args: 43 original: Series - the original series with a PeriodIndex. 44 to_period: Period - the period to which the projection should extend. 45 46 Returns: 47 Series: A pandas Series with linear projection values using the same index as original. 48 49 Raises: 50 ValueError: If to_period is not within the original series index range. 51 52 """ 53 if to_period not in original.index: 54 raise ValueError(f"Regression end period {to_period} not found in series index") 55 y_regress = original[original.index <= to_period].copy() 56 x_regress = arange(len(y_regress)) 57 m, b = polyfit(x_regress, y_regress, 1) 58 59 x_complete = arange(len(original)) 60 return Series((x_complete * m) + b, index=original.index)
Create a linear projection based on pre-COVID data.
Assumes the start of the data has been trimmed to the period before COVID.
Args: original: Series - the original series with a PeriodIndex. to_period: Period - the period to which the projection should extend.
Returns: Series: A pandas Series with linear projection values using the same index as original.
Raises: ValueError: If to_period is not within the original series index range.
63def postcovid_plot(data: DataT, **kwargs: Unpack[PostcovidKwargs]) -> Axes: 64 """Plot a series with a PeriodIndex, including a post-COVID projection. 65 66 Args: 67 data: Series - the series to be plotted. 68 kwargs: PostcovidKwargs - plotting arguments. 69 70 Raises: 71 TypeError if series is not a pandas Series 72 TypeError if series does not have a PeriodIndex 73 ValueError if series does not have a D, M or Q frequency 74 ValueError if regression start is after regression end 75 76 """ 77 # --- check the kwargs 78 report_kwargs(caller=ME, **kwargs) 79 validate_kwargs(schema=PostcovidKwargs, caller=ME, **kwargs) 80 81 # --- check the data 82 data = check_clean_timeseries(data, ME) 83 if not isinstance(data, Series): 84 raise TypeError("The series argument must be a pandas Series") 85 86 series_index = PeriodIndex(data.index) 87 freq_str = series_index.freqstr 88 if not freq_str or freq_str[0] not in ("Q", "M", "D"): 89 raise ValueError("The series index must have a D, M or Q frequency") 90 91 freq_key = freq_str[0] 92 93 # rely on line_plot() to validate kwargs 94 if "plot_from" in kwargs: 95 print("Warning: the 'plot_from' argument is ignored in postcovid_plot().") 96 del kwargs["plot_from"] 97 98 # --- plot COVID counterfactual 99 default_periods = DEFAULT_PERIODS[freq_key] 100 start_regression = Period(default_periods["start"], freq=freq_str) 101 end_regression = Period(default_periods["end"], freq=freq_str) 102 103 # Override defaults with user-provided periods if specified 104 user_start = kwargs.pop("start_r", None) 105 user_end = kwargs.pop("end_r", None) 106 107 if user_start is not None: 108 start_regression = Period(user_start, freq=freq_str) 109 if user_end is not None: 110 end_regression = Period(user_end, freq=freq_str) 111 112 # Validate regression period 113 if start_regression >= end_regression: 114 raise ValueError("Start period must be before end period") 115 116 if start_regression not in data.index: 117 raise ValueError(f"Regression start period {start_regression} not found in series") 118 if end_regression not in data.index: 119 raise ValueError(f"Regression end period {end_regression} not found in series") 120 121 # --- combine data and projection 122 recent_data = data[data.index >= start_regression].copy() 123 recent_data.name = "Series" 124 projection_data = get_projection(recent_data, end_regression) 125 projection_data.name = "Pre-COVID projection" 126 127 # Create DataFrame with proper column alignment 128 combined_data = DataFrame( 129 { 130 projection_data.name: projection_data, 131 recent_data.name: recent_data, 132 } 133 ) 134 135 # --- activate plot settings 136 kwargs["width"] = kwargs.pop( 137 "width", 138 (get_setting("line_normal"), get_setting("line_wide")), 139 ) # series line is thicker than projection 140 kwargs["style"] = kwargs.pop("style", ("--", "-")) # dashed regression line 141 kwargs["label_series"] = kwargs.pop("label_series", True) 142 kwargs["annotate"] = kwargs.pop("annotate", (False, True)) # annotate series only 143 kwargs["color"] = kwargs.pop("color", ("darkblue", "#dd0000")) 144 145 return line_plot( 146 combined_data, 147 **cast("LineKwargs", kwargs), 148 )
Plot a series with a PeriodIndex, including a post-COVID projection.
Args: data: Series - the series to be plotted. kwargs: PostcovidKwargs - plotting arguments.
Raises: TypeError if series is not a pandas Series TypeError if series does not have a PeriodIndex ValueError if series does not have a D, M or Q frequency ValueError if regression start is after regression end