# %%
"""
SetSpeedTrainSim over a simple, hypothetical corridor
"""

import time
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import altrios as alt

sns.set_theme()

SHOW_PLOTS = alt.utils.show_plots()

SAVE_INTERVAL = 1

# Build the train config
rail_vehicle_loaded = alt.RailVehicle.from_file(
    alt.resources_root() / "rolling_stock/Manifest_Loaded.yaml"
)
rail_vehicle_empty = alt.RailVehicle.from_file(
    alt.resources_root() / "rolling_stock/Manifest_Empty.yaml"
)

# https://docs.rs/altrios-core/latest/altrios_core/train/struct.TrainConfig.html
train_config = alt.TrainConfig(
    rail_vehicles=[rail_vehicle_loaded, rail_vehicle_empty],
    n_cars_by_type={
        "Manifest_Loaded": 50,
        "Manifest_Empty": 50,
    },
    train_length_meters=None,
    train_mass_kilograms=None,
)

# Build the locomotive consist model
bel: alt.Locomotive = alt.Locomotive.default_battery_electric_loco()
hel: alt.Locomotive = alt.Locomotive.default_hybrid_electric_loco()
# construct a vector of one BEL and several conventional locomotives
loco_vec = [bel] + [alt.Locomotive.default()] * 7 + [hel.copy()]
# instantiate consist
loco_con = alt.Consist(
    loco_vec,
    SAVE_INTERVAL,
)

# Instantiate the intermediate `TrainSimBuilder`
tsb = alt.TrainSimBuilder(
    train_id="0",
    train_config=train_config,
    loco_con=loco_con,
)

# Load the network and link path through the network.
network = alt.Network.from_file(
    alt.resources_root() / "networks/simple_corridor_network.yaml"
)
# This data in this file were generated by running
# ```python
# [lp.link_idx.idx for lp in sim0.path_tpc.link_points]
# ```
# in sim_manager_demo.py.
link_path = alt.LinkPath.from_csv_file(
    alt.resources_root() / "demo_data/link_points_idx_simple_corridor.csv"
)


# load the prescribed speed trace that the train will follow
speed_trace = alt.SpeedTrace.from_csv_file(
    alt.resources_root() / "demo_data/speed_trace_simple_corridor.csv"
)

train_sim: alt.SetSpeedTrainSim = tsb.make_set_speed_train_sim(
    network=network,
    link_path=link_path,
    speed_trace=speed_trace,
    save_interval=SAVE_INTERVAL,
)

train_sim.set_save_interval(1)
t0 = time.perf_counter()
train_sim.walk()
t1 = time.perf_counter()
print(f"Time to simulate: {t1 - t0:.5g}")

ts_dict = train_sim.to_pydict()

# pull out solved locomotive for plotting convenience
loco0: alt.Locomotive = ts_dict["loco_con"]["loco_vec"][0]

fig, ax = plt.subplots(4, 1, sharex=True, figsize=((8, 6)))
ax[0].plot(
    np.array(ts_dict["history"]["time_seconds"]) / 3_600,
    np.array(ts_dict["history"]["pwr_whl_out_watts"]) / 1e6,
    label="tract pwr",
)
ax[0].set_ylabel("Power [MW]")
# to accommodate the legend
ax[0].set_xlim(
    [
        ax[0].get_xlim()[0],
        ax[0].get_xlim()[0] + (ax[1].get_xlim()[1] - ax[1].get_xlim()[0]) * 1.2,
    ]
)
ax[0].legend()

ax[1].plot(
    np.array(ts_dict["history"]["time_seconds"]) / 3_600,
    ts_dict["history"]["res_aero_newtons"],
    label="aero",
)
ax[1].plot(
    np.array(ts_dict["history"]["time_seconds"]) / 3_600,
    ts_dict["history"]["res_rolling_newtons"],
    label="rolling",
)
ax[1].plot(
    np.array(ts_dict["history"]["time_seconds"]) / 3_600,
    ts_dict["history"]["res_curve_newtons"],
    label="curve",
)
ax[1].plot(
    np.array(ts_dict["history"]["time_seconds"]) / 3_600,
    ts_dict["history"]["res_bearing_newtons"],
    label="bearing",
)
ax[1].plot(
    np.array(ts_dict["history"]["time_seconds"]) / 3_600,
    ts_dict["history"]["res_grade_newtons"],
    label="grade",
)
ax[1].set_ylabel("Force [N]")
ax[1].legend(loc="right")

ax[-1].plot(
    np.array(ts_dict["history"]["time_seconds"]) / 3_600,
    ts_dict["speed_trace"]["speed_meters_per_second"],
)
ax[-1].set_xlabel("Time [hr]")
ax[-1].set_ylabel("Speed [m/s]")

ax[2].plot(
    np.array(ts_dict["history"]["time_seconds"]) / 3_600,
    np.array(next(iter(loco0["loco_type"].values()))["res"]["history"]["soc"]),
)

ax[2].set_ylabel("SOC")

plt.suptitle("Set Speed Train Sim Demo")
plt.tight_layout()

print("SHOW_PLOTS: ", SHOW_PLOTS)
if SHOW_PLOTS:
    plt.show()

# %%
