mgplot.postcovid_plot

Plot the pre-COVID trajectory against the current trend.

  1"""Plot the pre-COVID trajectory against the current trend."""
  2
  3from typing import NotRequired, Unpack, cast
  4
  5from matplotlib.axes import Axes
  6from numpy import arange, polyfit
  7from pandas import DataFrame, Period, PeriodIndex, Series
  8
  9from mgplot.keyword_checking import (
 10    report_kwargs,
 11    validate_kwargs,
 12)
 13from mgplot.line_plot import LineKwargs, line_plot
 14from mgplot.settings import DataT, get_setting
 15from mgplot.utilities import check_clean_timeseries
 16
 17# --- constants
 18ME = "postcovid_plot"
 19
 20# Default regression periods by frequency
 21DEFAULT_PERIODS = {
 22    "Q": {"start": "2014Q4", "end": "2019Q4"},
 23    "M": {"start": "2015-01", "end": "2020-01"},
 24    "D": {"start": "2015-01-01", "end": "2020-01-01"},
 25}
 26
 27
 28class PostcovidKwargs(LineKwargs):
 29    """Keyword arguments for the post-COVID plot."""
 30
 31    start_r: NotRequired[Period]  # start of regression period
 32    end_r: NotRequired[Period]  # end of regression period
 33
 34
 35# --- functions
 36def get_projection(original: Series, to_period: Period) -> Series:
 37    """Create a linear projection based on pre-COVID data.
 38
 39    Assumes the start of the data has been trimmed to the period before COVID.
 40
 41    Args:
 42        original: Series - the original series with a PeriodIndex.
 43        to_period: Period - the period to which the projection should extend.
 44
 45    Returns:
 46        Series: A pandas Series with linear projection values using the same index as original.
 47
 48    Raises:
 49        ValueError: If to_period is not within the original series index range.
 50
 51    """
 52    if to_period not in original.index:
 53        raise ValueError(f"Regression end period {to_period} not found in series index")
 54    y_regress = original[original.index <= to_period].copy()
 55    x_regress = arange(len(y_regress))
 56    m, b = polyfit(x_regress, y_regress, 1)
 57
 58    x_complete = arange(len(original))
 59    return Series((x_complete * m) + b, index=original.index)
 60
 61
 62def postcovid_plot(data: DataT, **kwargs: Unpack[PostcovidKwargs]) -> Axes:
 63    """Plot a series with a PeriodIndex, including a post-COVID projection.
 64
 65    Args:
 66        data: Series - the series to be plotted.
 67        kwargs: PostcovidKwargs - plotting arguments.
 68
 69    Raises:
 70        TypeError if series is not a pandas Series
 71        TypeError if series does not have a PeriodIndex
 72        ValueError if series does not have a D, M or Q frequency
 73        ValueError if regression start is after regression end
 74
 75    """
 76    # --- check the kwargs
 77    report_kwargs(caller=ME, **kwargs)
 78    validate_kwargs(schema=PostcovidKwargs, caller=ME, **kwargs)
 79
 80    # --- check the data
 81    data = check_clean_timeseries(data, ME)
 82    if not isinstance(data, Series):
 83        raise TypeError("The series argument must be a pandas Series")
 84
 85    series_index = PeriodIndex(data.index)
 86    freq_str = series_index.freqstr
 87    if not freq_str or freq_str[0] not in ("Q", "M", "D"):
 88        raise ValueError("The series index must have a D, M or Q frequency")
 89
 90    freq_key = freq_str[0]
 91
 92    # rely on line_plot() to validate kwargs
 93    if "plot_from" in kwargs:
 94        print("Warning: the 'plot_from' argument is ignored in postcovid_plot().")
 95        del kwargs["plot_from"]
 96
 97    # --- plot COVID counterfactual
 98    default_periods = DEFAULT_PERIODS[freq_key]
 99    start_regression = Period(default_periods["start"], freq=freq_str)
100    end_regression = Period(default_periods["end"], freq=freq_str)
101
102    # Override defaults with user-provided periods if specified
103    user_start = kwargs.pop("start_r", None)
104    user_end = kwargs.pop("end_r", None)
105
106    if user_start is not None:
107        start_regression = Period(user_start, freq=freq_str)
108    if user_end is not None:
109        end_regression = Period(user_end, freq=freq_str)
110
111    # Validate regression period
112    if start_regression >= end_regression:
113        raise ValueError("Start period must be before end period")
114
115    if start_regression not in data.index:
116        raise ValueError(f"Regression start period {start_regression} not found in series")
117    if end_regression not in data.index:
118        raise ValueError(f"Regression end period {end_regression} not found in series")
119
120    # --- combine data and projection
121    recent_data = data[data.index >= start_regression].copy()
122    recent_data.name = "Series"
123    projection_data = get_projection(recent_data, end_regression)
124    projection_data.name = "Pre-COVID projection"
125
126    # Create DataFrame with proper column alignment
127    combined_data = DataFrame(
128        {
129            projection_data.name: projection_data,
130            recent_data.name: recent_data,
131        }
132    )
133
134    # --- activate plot settings
135    kwargs["width"] = kwargs.pop(
136        "width",
137        (get_setting("line_normal"), get_setting("line_wide")),
138    )  # series line is thicker than projection
139    kwargs["style"] = kwargs.pop("style", ("--", "-"))  # dashed regression line
140    kwargs["label_series"] = kwargs.pop("label_series", True)
141    kwargs["annotate"] = kwargs.pop("annotate", (False, True))  # annotate series only
142    kwargs["color"] = kwargs.pop("color", ("darkblue", "#dd0000"))
143
144    return line_plot(
145        combined_data,
146        **cast("LineKwargs", kwargs),
147    )
ME = 'postcovid_plot'
DEFAULT_PERIODS = {'Q': {'start': '2014Q4', 'end': '2019Q4'}, 'M': {'start': '2015-01', 'end': '2020-01'}, 'D': {'start': '2015-01-01', 'end': '2020-01-01'}}
class PostcovidKwargs(mgplot.line_plot.LineKwargs):
29class PostcovidKwargs(LineKwargs):
30    """Keyword arguments for the post-COVID plot."""
31
32    start_r: NotRequired[Period]  # start of regression period
33    end_r: NotRequired[Period]  # end of regression period

Keyword arguments for the post-COVID plot.

start_r: NotRequired[pandas._libs.tslibs.period.Period]
end_r: NotRequired[pandas._libs.tslibs.period.Period]
def get_projection( original: pandas.core.series.Series, to_period: pandas._libs.tslibs.period.Period) -> pandas.core.series.Series:
37def get_projection(original: Series, to_period: Period) -> Series:
38    """Create a linear projection based on pre-COVID data.
39
40    Assumes the start of the data has been trimmed to the period before COVID.
41
42    Args:
43        original: Series - the original series with a PeriodIndex.
44        to_period: Period - the period to which the projection should extend.
45
46    Returns:
47        Series: A pandas Series with linear projection values using the same index as original.
48
49    Raises:
50        ValueError: If to_period is not within the original series index range.
51
52    """
53    if to_period not in original.index:
54        raise ValueError(f"Regression end period {to_period} not found in series index")
55    y_regress = original[original.index <= to_period].copy()
56    x_regress = arange(len(y_regress))
57    m, b = polyfit(x_regress, y_regress, 1)
58
59    x_complete = arange(len(original))
60    return Series((x_complete * m) + b, index=original.index)

Create a linear projection based on pre-COVID data.

Assumes the start of the data has been trimmed to the period before COVID.

Args: original: Series - the original series with a PeriodIndex. to_period: Period - the period to which the projection should extend.

Returns: Series: A pandas Series with linear projection values using the same index as original.

Raises: ValueError: If to_period is not within the original series index range.

def postcovid_plot( data: ~DataT, **kwargs: Unpack[PostcovidKwargs]) -> matplotlib.axes._axes.Axes:
 63def postcovid_plot(data: DataT, **kwargs: Unpack[PostcovidKwargs]) -> Axes:
 64    """Plot a series with a PeriodIndex, including a post-COVID projection.
 65
 66    Args:
 67        data: Series - the series to be plotted.
 68        kwargs: PostcovidKwargs - plotting arguments.
 69
 70    Raises:
 71        TypeError if series is not a pandas Series
 72        TypeError if series does not have a PeriodIndex
 73        ValueError if series does not have a D, M or Q frequency
 74        ValueError if regression start is after regression end
 75
 76    """
 77    # --- check the kwargs
 78    report_kwargs(caller=ME, **kwargs)
 79    validate_kwargs(schema=PostcovidKwargs, caller=ME, **kwargs)
 80
 81    # --- check the data
 82    data = check_clean_timeseries(data, ME)
 83    if not isinstance(data, Series):
 84        raise TypeError("The series argument must be a pandas Series")
 85
 86    series_index = PeriodIndex(data.index)
 87    freq_str = series_index.freqstr
 88    if not freq_str or freq_str[0] not in ("Q", "M", "D"):
 89        raise ValueError("The series index must have a D, M or Q frequency")
 90
 91    freq_key = freq_str[0]
 92
 93    # rely on line_plot() to validate kwargs
 94    if "plot_from" in kwargs:
 95        print("Warning: the 'plot_from' argument is ignored in postcovid_plot().")
 96        del kwargs["plot_from"]
 97
 98    # --- plot COVID counterfactual
 99    default_periods = DEFAULT_PERIODS[freq_key]
100    start_regression = Period(default_periods["start"], freq=freq_str)
101    end_regression = Period(default_periods["end"], freq=freq_str)
102
103    # Override defaults with user-provided periods if specified
104    user_start = kwargs.pop("start_r", None)
105    user_end = kwargs.pop("end_r", None)
106
107    if user_start is not None:
108        start_regression = Period(user_start, freq=freq_str)
109    if user_end is not None:
110        end_regression = Period(user_end, freq=freq_str)
111
112    # Validate regression period
113    if start_regression >= end_regression:
114        raise ValueError("Start period must be before end period")
115
116    if start_regression not in data.index:
117        raise ValueError(f"Regression start period {start_regression} not found in series")
118    if end_regression not in data.index:
119        raise ValueError(f"Regression end period {end_regression} not found in series")
120
121    # --- combine data and projection
122    recent_data = data[data.index >= start_regression].copy()
123    recent_data.name = "Series"
124    projection_data = get_projection(recent_data, end_regression)
125    projection_data.name = "Pre-COVID projection"
126
127    # Create DataFrame with proper column alignment
128    combined_data = DataFrame(
129        {
130            projection_data.name: projection_data,
131            recent_data.name: recent_data,
132        }
133    )
134
135    # --- activate plot settings
136    kwargs["width"] = kwargs.pop(
137        "width",
138        (get_setting("line_normal"), get_setting("line_wide")),
139    )  # series line is thicker than projection
140    kwargs["style"] = kwargs.pop("style", ("--", "-"))  # dashed regression line
141    kwargs["label_series"] = kwargs.pop("label_series", True)
142    kwargs["annotate"] = kwargs.pop("annotate", (False, True))  # annotate series only
143    kwargs["color"] = kwargs.pop("color", ("darkblue", "#dd0000"))
144
145    return line_plot(
146        combined_data,
147        **cast("LineKwargs", kwargs),
148    )

Plot a series with a PeriodIndex, including a post-COVID projection.

Args: data: Series - the series to be plotted. kwargs: PostcovidKwargs - plotting arguments.

Raises: TypeError if series is not a pandas Series TypeError if series does not have a PeriodIndex ValueError if series does not have a D, M or Q frequency ValueError if regression start is after regression end