"""
This is the script from this tutorial:
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

Then, it's modified to make the training loop using Jax's grad
and optimizer
"""

import torch
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
#from torch.utils.tensorboard import SummaryWriter
#from datetime import datetime

# NOTE: add these lines to make it run on TPUs!
import torchax

torchax.enable_globally()

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST(
    './data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST(
    './data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(
    training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(
    validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
           'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))

import matplotlib.pyplot as plt
import numpy as np


# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
  if one_channel:
    img = img.mean(dim=0)
  img = img / 2 + 0.5  # unnormalize
  npimg = img.numpy()
  if one_channel:
    plt.imshow(npimg, cmap="Greys")
  else:
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


#torchax.env.config.debug_print_each_op = True
#torchax.env.config.debug_mixed_tensor = True
dataiter = iter(training_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(classes[labels[j]] for j in range(4)))

import torch.nn as nn
import torch.nn.functional as F


# PyTorch models inherit from torch.nn.Module
class GarmentClassifier(nn.Module):

  def __init__(self):
    super(GarmentClassifier, self).__init__()
    self.fc1 = nn.Linear(28 * 28, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = x.view(-1, 28 * 28)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x


model = GarmentClassifier().to('jax')

loss_fn = torch.nn.CrossEntropyLoss()

# NB: Loss functions expect data in batches, so we're creating batches of 4
# Represents the model's confidence in each of the 10 classes for a given input
dummy_outputs = torch.rand(4, 10, device='jax')
# Represents the correct class among the 10 being tested
dummy_labels = torch.tensor([1, 5, 3, 7], device='jax')

print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


def train_one_epoch(epoch_index, tb_writer=None):
  running_loss = 0.
  last_loss = 0.

  # Here, we use enumerate(training_loader) instead of
  # iter(training_loader) so that we can track the batch
  # index and do some intra-epoch reporting
  for i, data in enumerate(training_loader):
    # Every data instance is an input + label pair
    # NEW: Move model to XLA device
    inputs, labels = data
    inputs = inputs.to('jax')
    labels = labels.to('jax')

    # Zero your gradients for every batch!
    optimizer.zero_grad()

    # Make predictions for this batch

    outputs = model(inputs)

    # Compute the loss and its gradients
    loss = loss_fn(outputs, labels)
    loss.backward()

    # Adjust learning weights
    optimizer.step()

    # Gather data and report
    running_loss += loss.item()
    if i % 1000 == 999:
      last_loss = running_loss / 1000  # loss per batch
      print('  batch {} loss: {}'.format(i + 1, last_loss))
      tb_x = epoch_index * len(training_loader) + i + 1
      #tb_writer.add_scalar('Loss/train', last_loss, tb_x)
      running_loss = 0.

  return last_loss


# Initializing in a separate cell so we can easily add more epochs to the same run
#timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
#writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0
EPOCHS = 2
best_vloss = 1_000_000.

for epoch in range(EPOCHS):
  print('EPOCH {}:'.format(epoch_number + 1))

  # Make sure gradient tracking is on, and do a pass over the data
  model.train(True)

  avg_loss = train_one_epoch(epoch_number)

  running_vloss = 0.0
  # Set the model to evaluation mode, disabling dropout and using population
  # statistics for batch normalization.
  model.eval()

  # Disable gradient computation and reduce memory consumption.
  with torch.no_grad():
    for i, vdata in enumerate(validation_loader):
      vinputs, vlabels = vdata
      vinputs = vinputs.to('jax')
      vlabels = vlabels.to('jax')
      voutputs = model(vinputs)  # call model's forward
      vloss = loss_fn(voutputs, vlabels)
      running_vloss += vloss

  avg_vloss = running_vloss / (i + 1)
  print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

  # Log the running loss averaged per batch
  # for both training and validation

  # # Track best performance, and save the model's state
  # if avg_vloss < best_vloss:
  #     best_vloss = avg_vloss
  #     model_path = 'model_{}_{}'.format(timestamp, epoch_number)
  #     torch.save(model.state_dict(), model_path)

  epoch_number += 1
