__mltk_version__ = '0.6.0'

"""autoencoder_example
************************

- Source code: `autoencoder_example.py <https://github.com/siliconlabs/mltk/blob/master/mltk/models/examples/autoencoder_example.py>`_
- Pre-trained model: `autoencoder_example.mltk.zip <https://github.com/siliconlabs/mltk/blob/master/mltk/models/examples/autoencoder_example.mltk.zip>`_


This demonstrates how to build an autoencoder model.
This is based on `Tensorflow: Anomaly detection <https://www.tensorflow.org/tutorials/generative/autoencoder#third_example_anomaly_detection>`_ 

In this example, you will train an autoencoder to detect anomalies on the `ECG5000 <http://www.timeseriesclassification.com/description.php?Dataset=ECG5000>`_ dataset. 
This dataset contains 5,000 `Electrocardiograms <https://en.wikipedia.org/wiki/Electrocardiography>`_, each with 140 data points. You will use a simplified version of the dataset, 
where each example has been labeled either 0 (corresponding to an abnormal rhythm), or 1 (corresponding to a normal rhythm). 
You are interested in identifying the abnormal rhythms.


Commands
--------------

.. code-block:: shell

   # Do a "dry run" test training of the model
   mltk train autoencoder_example-test

   # Train the model
   mltk train autoencoder_example

   # Evaluate the trained model .tflite model
   # Also dump a comparsion of the original image vs the generated autoencoder image
   mltk evaluate autoencoder_example --tflite --dump

   # Profile the model in the MVP hardware accelerator simulator
   mltk profile autoencoder_example --accelerator MVP

   # Profile the model on a physical development board
   mltk profile autoencoder_example --accelerator MVP --device


Model Summary
--------------

.. code-block:: shell
    
    mltk summarize autoencoder_example --tflite
    


Model Diagram
------------------

.. code-block:: shell
   
   mltk view autoencoder_example --tflite

.. raw:: html

    <div class="model-diagram">
        <a href="../../../../_images/models/autoencoder_example.tflite.png" target="_blank">
            <img src="../../../../_images/models/autoencoder_example.tflite.png" />
            <p>Click to enlarge</p>
        </a>
    </div>


"""
from typing import List
import numpy as np
from sklearn.model_selection import train_test_split
import tensorflow as tf


from mltk.core.model import (
    MltkModel,
    TrainMixin,
    DatasetMixin,
    EvaluateAutoEncoderMixin
)
from mltk.utils.path import create_user_dir
from mltk.utils.archive_downloader import download_url



# Instantiate the MltkModel object with the following 'mixins':
# - TrainMixin            - Provides classifier model training operations and settings
# - DatasetMixin          - Provides general dataset operations and settings
# - EvaluateClassifierMixin         - Provides classifier evaluation operations and settings
# @mltk_model # NOTE: This tag is required for this model be discoverable
class MyModel(
    MltkModel, 
    TrainMixin, 
    DatasetMixin, 
    EvaluateAutoEncoderMixin
):
    def load_dataset(
        self, 
        subset: str,  
        classes:List[str]=None,
        **kwargs
    ):
        super().load_dataset(subset) 

        # Download the dataset (if necessary)
        dataset_path = f'{create_user_dir()}/datasets/ecg500.csv'
        download_url(
            'http://storage.googleapis.com/download.tensorflow.org/data/ecg.csv',
            dataset_path
        )

        # Load the dataset into numpy array
        dataset = np.genfromtxt(dataset_path, delimiter=',', dtype=np.float32)

        # The last column contains the labels
        labels = dataset[:, -1]
        data = dataset[:,:-1]

        # Split the data into training and test data
        self.validation_split = 0.2
        train_data, test_data, train_labels, test_labels = train_test_split(
            data, labels, test_size=self.validation_split, random_state=21
        )

        min_val = tf.reduce_min(train_data)
        max_val = tf.reduce_max(train_data)

        train_data = (train_data - min_val) / (max_val - min_val)
        test_data = (test_data - min_val) / (max_val - min_val)

        train_labels_bool = train_labels.astype(bool)
        test_labels_bool = test_labels.astype(bool)

        normal_train_data = train_data[train_labels_bool]
        normal_test_data = test_data[test_labels_bool]

        anomalous_train_data = train_data[~train_labels_bool]
        anomalous_test_data = test_data[~test_labels_bool]

        self._normal_train_count = len(normal_train_data)
        self._normal_test_count = len(normal_test_data)
        self._abnormal_train_count = len(anomalous_train_data)
        self._abnormal_test_count = len(anomalous_test_data)

        # If we're evaluating,
        # then just return the "normal" or "abnormal" samples
        # NOTE: The y value is not required in this case
        if subset == 'evaluation':
            if classes[0] =='normal':
                self.x = normal_test_data
            else:
                self.x = anomalous_test_data
        else:
            # For training, we just use the "normal" data
            # Note that x and y use the same data as the whole point 
            #  of an autoencoder is to reconstruct the input data
            self.x = normal_train_data
            self.y = normal_train_data
            self.validation_data = (test_data, test_data)


    def summarize_dataset(self) -> str: 
        s = f'Train dataset: Found {self._normal_train_count} "normal", {self._abnormal_train_count} "abnormal" samples\n'
        s += f'Validation dataset: Found {self._normal_test_count} "normal", {self._abnormal_test_count} "abnormal" samples'
        return s




my_model = MyModel()


#################################################
# General Settings
# 
my_model.version = 1
my_model.description = 'Autoencoder example to detect anomalies in ECG dataset'

my_model.input_shape = (140,)

#################################################
# Training Settings
my_model.epochs = 20
my_model.batch_size = 512
my_model.optimizer = 'adam'
my_model.metrics = ['mae']
my_model.loss = 'mae'

#################################################
# Training callback Settings

# Generate a training weights .h5 whenever the 
# val_accuracy improves
my_model.checkpoint['monitor'] =  'val_loss'
my_model.checkpoint['mode'] =  'auto'


#################################################
# TF-Lite converter settings
my_model.tflite_converter['optimizations'] = ['DEFAULT']
my_model.tflite_converter['supported_ops'] = ['TFLITE_BUILTINS_INT8']
my_model.tflite_converter['inference_input_type'] = tf.float32
my_model.tflite_converter['inference_output_type'] = tf.float32
 # generate a representative dataset from the validation data
my_model.tflite_converter['representative_dataset'] = 'generate'




#################################################
# Build the ML Model
def my_model_builder(model: MyModel):
    model_input = tf.keras.layers.Input(shape=model.input_shape)
    encoder = tf.keras.Sequential([
        model_input,
        tf.keras.layers.Dense(32, activation="relu"),
        tf.keras.layers.Dense(16, activation="relu"),
        tf.keras.layers.Dense(8, activation="relu")]
    )

    decoder = tf.keras.Sequential([
        tf.keras.layers.Dense(16, activation="relu"),
        tf.keras.layers.Dense(32, activation="relu"),
        tf.keras.layers.Dense(140, activation="sigmoid")
    ])

    autoencoder = tf.keras.models.Model(model_input, decoder(encoder(model_input)))
    autoencoder.compile(
        loss=model.loss, 
        optimizer=model.optimizer, 
        metrics=model.metrics
    )

    return autoencoder

my_model.build_model_function = my_model_builder


