#!/usr/bin/env python
# Copyright 2019-2025 DADoES, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License in the root directory in the "LICENSE" file or at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import anatools
import argparse
import datetime
import json
try:
    import matplotlib.pyplot as plt
except ImportError:
    raise Exception("matplotlib is required to run this tool. Please install it with 'pip install matplotlib'.")
import os
import sys

parser = argparse.ArgumentParser(
    description="""
Profile a channel on the Rendered.ai Platform.
    profile with no options:            anaprofile
    profile with channelId:             anaprofile --channelId channelId
    profile in specific workspace:      anaprofile --workspaceId workspaceId
    profile with specific instances:    anaprofile --instances instance1,instance2

""",
    formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument('--channelId', default=None)
parser.add_argument('--workspaceId', default=None)
parser.add_argument('--graphId', default=None)
parser.add_argument('--instances', default=None)
parser.add_argument('--environment', default=None)
parser.add_argument('--endpoint', default=None)
parser.add_argument('--email', type=str, default=None)
parser.add_argument('--password', type=str, default=None)
parser.add_argument('--local', default=False, action='store_true')
parser.add_argument('--verbose', default=False, action='store_true')
parser.add_argument('--version', action='store_true')
args = parser.parse_args()
if args.version: print(f'anaprofile {anatools.__version__}'); sys.exit(1)
if args.verbose: verbose = 'debug'
else: verbose = False
interactive = False
if args.email and args.password is None: interactive = True
client = anatools.client(
    email=args.email,
    password=args.password,
    environment=args.environment,
    endpoint=args.endpoint,
    local=args.local,
    interactive=interactive,
    verbose=verbose)

if args.channelId is None:
    print('Please select a channel to profile:')
    val = 0
    channels = client.get_channels()
    for channel in channels:
        print(f"\t{f'[{val}]'.ljust(5)} {channel['name']} in {channel['organization']}")
        val += 1
    resp = input('Enter your choice: ')
    resp = int(resp)
    while resp not in range(0, val):
        resp = input(f'Invalid choice, enter a value between 0 and {val-1}: ')
        resp = int(resp)
    channelId = channels[resp]['channelId']
else: channelId = args.channelId

if args.workspaceId is None:
    print('Please select a workspace to profile the channel in:')
    val = 0
    workspaces = client.get_workspaces()
    for workspace in workspaces:
        print(f"\t{f'[{val}]'.ljust(5)} {workspace['name']}{'' if 'organization' not in workspace else ' in ' + workspace['organization']}")
        val += 1
    resp = input('Enter your choice: ')
    resp = int(resp)
    while resp not in range(0, val):
        resp = input(f'Invalid choice, enter a value between 0 and {val-1}: ')
        resp = int(resp)
    workspaceId = workspaces[resp]['workspaceId']
else: workspaceId = args.workspaceId
client.set_workspace(workspaceId)

if args.graphId is None:
    graphId = None
    print('Using default graph to profile the channel. If you would like to use a different graph, please specify it with the --graphId option.')
else: graphId = args.graphId

if args.instances is None: 
    instances = None
    print('Using default instances to profile the channel. If you would like to use a different instance, please specify them with the --instances option.')
else: instances = args.instances.split(',')

# run profiling jobs
client.interactive = True
jobs = client.profile_channel(graphId=graphId, channelId=channelId, instances=instances, workspaceId=workspaceId)

# generate report
if client.interactive: print("Generating report...", end='', flush=True)
report = "profile.md"
channel = client.get_channels(channelId=channelId)[0]
if instances is None: instances = jobs.keys()
with open(report, 'w') as f:
    f.write(f"# Profiling {channel['name']}\n")
    f.write(f"Report generated on {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}<br/>\n")
    f.write(f"Channel ID: {channel['channelId']}<br/>\n")
    f.write(f"Graph ID: {graphId}<br/>\n")
    f.write(f"Workspace ID: {workspaceId}<br/>\n")
    f.write("Note: Metrics are gathered every 5 seconds for profiling job runs.<br/>\n")
    f.write("\n")

    f.write("## Metrics\n")
    f.write("\n")
    f.write("| Instance Type | Status | Time | Cost | CPU (max/avg) | Memory (max/avg) | Disk IO (max/avg) | Network IO (max/avg) |\n")
    f.write("| --- | --- | --- | --- | --- | --- | --- | --- |\n")
    sorted_instances = sorted(instances, key=lambda x: datetime.datetime.fromisoformat(jobs[x]['metrics']['endTime'].replace('Z',''))-datetime.datetime.fromisoformat(jobs[x]['metrics']['startTime'].replace('Z','')))
    for instance in sorted_instances:
        metrics = json.loads(jobs[instance]['metrics']['metrics'])
        cpu = [metric for metric in metrics if metric['name'] == 'cpuUsage']
        if len(cpu) > 0 and 'values' in cpu[0]: cpudata = f"{max(cpu[0]['values']):.2f}% / {sum(cpu[0]['values']) / len(cpu[0]['values']):.2f}%"
        else: cpudata = "n/a"
        memory = [metric for metric in metrics if metric['name'] == 'memoryUsage']
        if len(memory) > 0 and 'values' in memory[0]: memorydata = f"{max(memory[0]['values']):.2f}% / {sum(memory[0]['values']) / len(memory[0]['values']):.2f}%"
        else: memorydata = "n/a"
        diskio = [metric for metric in metrics if metric['name'] == 'diskIO']
        if len(diskio) > 0 and 'values' in diskio[0]: diskdata = f"{max(diskio[0]['values']):.2f}% / {sum(diskio[0]['values']) / len(diskio[0]['values']):.2f}%"
        else: diskdata = "n/a"
        networkIO = [metric for metric in metrics if metric['name'] == 'networkIO']
        if len(networkIO) > 0 and 'values' in networkIO[0]: networkdata = f"{max(networkIO[0]['values'])/1e6:.3f}Mb/s / {sum(networkIO[0]['values']) / len(networkIO[0]['values'])/1e6:.3f}Mb/s"
        else: networkdata = "n/a"
        f.write(  f"\
| {instance} \
| {jobs[instance]['status']} \
| {datetime.datetime.fromisoformat(jobs[instance]['metrics']['endTime'].replace('Z',''))-datetime.datetime.fromisoformat(jobs[instance]['metrics']['startTime'].replace('Z',''))} \
| {jobs[instance]['metrics']['cost']} \
| {cpudata} \
| {memorydata} \
| {diskdata} \
| {networkdata} |\n")
    f.write("\n")

    if not os.path.exists('metrics'): os.mkdir('metrics')

    # generate summary plots
    plt.figure()
    instance = sorted_instances[0]
    metrics = json.loads(jobs[instance]['metrics']['metrics'])
    f.write("|||||\n| --- | --- | --- | --- |\n|")
    for metric in metrics:
        plt.figure()
        for instance in sorted_instances:
            metrics = json.loads(jobs[instance]['metrics']['metrics'])
            m = [m for m in metrics if m['name'] == metric['name']][0]
            ts = [t-m['timestamp'][0] for t in m['timestamp']]
            if metric['name'] == 'networkIO': 
                m2 = [m1/1e6 for m1 in m['values']]
                plt.plot(ts, m2, label=instance)
            else: plt.plot(ts, m['values'], label=instance)
        if metric['name'] == 'networkIO': plt.title(f"{metric['name']} (Mb/s)")
        else: plt.title(f"{metric['name']} (%)")
        plt.xlabel("Time (s)")
        plt.ylabel(metric['name'])
        plt.legend()
        plt.savefig(f"metrics/summary_{metric['name']}.png")
        plt.close()
        f.write(f" ![{metric['name']}](./metrics/summary_{metric['name']}.png) |")
        

    # generate plots for each metric for each instance
    for instance in sorted_instances:
        f.write(f"\n## {instance} Metrics\n")
        f.write("|||||\n| --- | --- | --- | --- |\n|")
        metrics = json.loads(jobs[instance]['metrics']['metrics'])
        for metric in metrics:
            ts = [t-metric['timestamp'][0] for t in metric['timestamp']]
            plt.figure()
            if metric['name'] == 'networkIO': 
                m1 = [m/1e6 for m in metric['values']]
                plt.plot(ts, m1, label=metric['name'])
            else: plt.plot(ts, metric['values'], label=metric['name'])
            plt.title(f"{metric['name']} for {instance}")
            plt.xlabel("Time (s)")
            plt.ylabel(metric['name'] + " (Mb/s)" if metric['name'] == 'networkIO' else metric['name'] + " (%)")
            plt.savefig(f"metrics/{instance}_{metric['name']}.png")
            plt.close()
            f.write(f" ![{metric['name']} for {instance}](./metrics/{instance}_{metric['name']}.png) |")
    

if client.interactive: print("Generating report...done!\033[K", flush=True)
if client.interactive: 
    tmp = input("Would you like to clean up data? [y/n]: ")
    while tmp.lower().strip() not in ['y', 'n']: tmp = input("Invalid input, must be one of [y/n]: ")
    if tmp.lower() == 'y': 
        for instance in jobs.keys():
            try:
                dataset = client.get_datasets(datasetId=jobs[instance]['datasetId'], workspaceId=workspaceId)[0]
                client.delete_graph(graphId=dataset['graphId'], workspaceId=workspaceId)
                client.delete_dataset(datasetId=jobs[instance]['datasetId'], workspaceId=workspaceId)
            except: pass