
import unittest
import sys
import os
from pathlib import Path
import datetime

from dselib.dse_tools import PathPattern, FileSelector

RUN_ALL = os.environ.get('RUN_ALL', False)

base_pattern =  "/{variable}%Y_%m_%d.v3.nc"

class TestFile_Selection(unittest.TestCase):
    def basePathPattern(self, pattern = None):
        root_dir = "/mnt/ActiveDevelopmentProjects/LargeData/Data"
        if sys.platform == 'win32':
            root_dir = "P:\\ActiveDevelopmentProjects\\LargeData\\Data"
        relative_dir ="/TAMSATv3/unzippedDaily/raw"
        base_pattern =  "/{variable}%Y_%m_%d.v3.nc"
        if pattern is None:
            pattern = base_pattern
        path_pattern =  PathPattern(pattern=pattern, root_dir=root_dir, relative_dir=relative_dir)
        return path_pattern

    def outPathPattern(self,pattern = None):
        root_dir = "/mnt/ActiveDevelopmentProjects/LargeData/Data"
        if sys.platform == 'win32':
            root_dir = "P:\\ActiveDevelopmentProjects\\LargeData\\Data"
        relative_dir ="/TAMSATv3/unzippedDaily/cleaned"
        base_pattern =  "/{variable}%Y_%m_%d.v3.nc"
        if pattern is None:
            pattern = base_pattern
        path_pattern =  PathPattern(pattern=pattern, root_dir=root_dir, relative_dir=relative_dir)
        return path_pattern

    def testPathPattern(self):
        path_pattern = self.basePathPattern()
        self.assertIsInstance(path_pattern, PathPattern)

    def testPathPattern_root_path(self):
        result_dir = "/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw"
        if sys.platform == 'win32':
            result_dir = "P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw"
        path_pattern =  self.basePathPattern()
        self.assertEqual(path_pattern.root_path, result_dir)

    def testPathPattern_root_path_rel_none(self):
        root_dir = "/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw"
        result_dir = "/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw"
        if sys.platform == 'win32':
            result_dir = "P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw"
            root_dir = "P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw"
        path_pattern =  PathPattern(pattern=base_pattern, root_dir=root_dir)
        self.assertEqual(path_pattern.root_path, result_dir)

    def testPathPattern_number_pattern_keys_1(self):
        pattern = "/{variable}%Y_%m_%d.v3.nc"
        path_pattern =  self.basePathPattern(pattern)
        target = 1
        received = path_pattern.get_number_pattern_keys()
        self.assertEqual(received, target)

    def testPathPattern_number_pattern_keys_1_dub(self):
        pattern = "{variable}/%Y/{variable}%Y_%m_%d.v3.nc"
        path_pattern =  self.basePathPattern(pattern)
        target = 1
        received = path_pattern.get_number_pattern_keys()
        self.assertEqual(received, target)

    def testPathPattern_number_pattern_keys_2(self):
        pattern = "/{variable}/{prefix}%Y_%m_%d.v3.nc"
        path_pattern = self.basePathPattern(pattern)
        target = 2
        received = path_pattern.get_number_pattern_keys()
        self.assertEqual(received, target)

    def testPathPattern_number_pattern_keys_3(self):
        pattern = "/{variable}/{prefix}%Y_%m_%d.v3.{suffix}"
        path_pattern =  self.basePathPattern(pattern)
        target = 3
        received = path_pattern.get_number_pattern_keys()
        self.assertEqual(received, target)

    def testPathPattern_number_pattern_keys_0(self):
        pattern = "/airtemp%Y_%m_%d.v3.nc"
        path_pattern =  self.basePathPattern(pattern)
        target = 0
        received = path_pattern.get_number_pattern_keys()
        self.assertEqual(received, target)

    def testPathPattern_get_path_pattern_1var(self):
        path_pattern = self.basePathPattern()
        result = "rfe%Y_%m_%d.v3.nc"
        pattern = path_pattern.get_path_pattern(variable='rfe')
        self.assertEqual(pattern, result)

    def testPathPattern_get_pattern_key_words_3(self):
        pattern = "/{variable}/{prefix}%Y_%m_%d.v3.{suffix}"
        path_pattern =  self.basePathPattern(pattern)
        result = ["variable", "prefix", "suffix"]
        pattern = path_pattern.get_pattern_key_words()
        self.assertEqual(pattern, result)

    # FileSelector test
    def testFileSelector(self):
        path_pattern =  self.basePathPattern()
        myFiles = FileSelector(path_pattern, path_pattern)
        self.assertIsInstance(myFiles, FileSelector)

    def testFileSelector_none(self):
        myFiles = FileSelector()
        self.assertIsInstance(myFiles, FileSelector)

    def testSetTargetPathPattern(self):
        path_pattern =  self.basePathPattern()
        myFiles = FileSelector()
        myFiles.set_target_path_pattern(path_pattern)
        result = "{variable}%Y_%m_%d.v3.nc"
        self.assertEqual(myFiles.target_pattern.pattern, result)


    def testSetInputPathPattern(self):
        myFiles = FileSelector()
        path_pattern =  self.basePathPattern()
        myFiles.set_input_path_pattern(path_pattern)
        result = "{variable}%Y_%m_%d.v3.nc"
        self.assertEqual(myFiles.input_pattern.pattern, result)

    def testAddTargetDataSet(self):
        path_pattern =  self.basePathPattern()
        myFiles = FileSelector(path_pattern, path_pattern)
        myFiles.add_target_set('airtemp', variable='airtemp')
        myFiles.add_target_set('rainfall', variable='rfe')
        sets = myFiles.get_target_sets()
        num_sets = len(sets)
        self.assertEqual(num_sets, 2)
        self.assertTrue('airtemp' in sets)
        self.assertTrue('rainfall' in sets)
        sample = sets['airtemp']
        result = "airtemp%Y_%m_%d.v3.nc"
        self.assertEqual(sample, result)
        sample = sets['rainfall']
        result = "rfe%Y_%m_%d.v3.nc"
        self.assertEqual(sample, result)

    def testAddInputSetToTarget(self):
        in_path_pattern =  self.basePathPattern()
        out_path_pattern =  self.outPathPattern()
        myFiles = FileSelector(out_path_pattern, in_path_pattern)
        myFiles.add_target_set('rainfall', variable='rainfall')
        myFiles.add_input_set_to_target('rainfall', variable='rfe')
        sets = myFiles.get_input_sets()
        num_sets = len(sets)
        self.assertEqual(num_sets, 1)
        self.assertTrue('rainfall' in sets)
        sample = sets['rainfall']
        result = "rfe%Y_%m_%d.v3.nc"
        self.assertEqual(sample[0], result)

    def testGetDailySequenceFiles_vertical(self):
        """
        'vertical' refers to getting a single parameter for a number of days to
        produce an output. for example daily files use the current day + next day to
        build a daily file for the current day (e.g. 2015001 rain requires 2015001 and
        2015002 downloaded rain files).
        """
        in_path_pattern =  self.basePathPattern()
        out_path_pattern =  self.outPathPattern()
        myFiles = FileSelector(out_path_pattern, in_path_pattern)
        myFiles.add_target_set('rainfall', variable='rfe')
        myFiles.add_input_set_to_target('rainfall', variable='rfe')
        start_date = datetime.datetime(2000,1,1)
        end_date = datetime.datetime(2000,1,15)
        files = myFiles.get_sequence_files( start_date, end_date, 'D')
        num_targets = len(files)
        self.assertEqual(num_targets, 15)
        keys = list(files.keys())
        result = "rfe2000_01_01.v3.nc"
        self.assertEqual(keys[0], result)
        file_list = files[keys[0]]
        self.assertEqual(len(file_list), 1)
        result = "/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw/rfe2000_01_01.v3.nc"
        if sys.platform == 'win32':
            result = "P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw\\rfe2000_01_01.v3.nc"
        self.assertEqual(file_list[0], result)

    def testGetDailySequenceFiles_vertical_2(self):
        in_path_pattern =  self.basePathPattern()
        out_path_pattern =  self.outPathPattern()
        myFiles = FileSelector(out_path_pattern, in_path_pattern)
        myFiles.add_target_set('rainfall', variable='rfe')
        myFiles.add_input_set_to_target('rainfall', variable='rfe')
        start_date = datetime.datetime(2000,1,1)
        end_date = datetime.datetime(2000,1,15)
        files = myFiles.get_sequence_files( start_date, end_date, 'D',1)
        num_targets = len(files)
        self.assertEqual(num_targets, 15)
        keys = list(files.keys())
        result = "rfe2000_01_01.v3.nc"
        self.assertEqual(keys[0], result)
        file_list = files[keys[0]]
        self.assertEqual(len(file_list), 2)
        result = "/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw/rfe2000_01_01.v3.nc"
        if sys.platform == 'win32':
            result = "P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw\\rfe2000_01_01.v3.nc"
        self.assertEqual(file_list[0], result)
        result = "/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw/rfe2000_01_02.v3.nc"
        if sys.platform == 'win32':
            result = "P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw\\rfe2000_01_02.v3.nc"
        self.assertEqual(file_list[1], result)

    def testGetDailySequenceFiles_horizontal(self):
        """
        horizontal refers to using multiple parameters to build the output file. For example
        humidity is calculate using airtemp and dewpoint, so to calculate humidity for
        2015001 requires the daily airtemp for 2015001 and the daily dewpoint for 2015001
        """
        pattern = "/{variable}%Y_%m_%d.v3.nc"
        in_path_pattern =  self.basePathPattern(pattern)
        out_path_pattern = self.basePathPattern(pattern)
        myFiles = FileSelector(out_path_pattern, in_path_pattern)
        myFiles.add_target_set('humidity', variable='humidity')
        myFiles.add_input_set_to_target('humidity', variable='airtemp')
        myFiles.add_input_set_to_target('humidity', variable='dewpoint')
        start_date = datetime.datetime(2000,1,1)
        end_date = datetime.datetime(2000,1,15)
        files = myFiles.get_sequence_files( start_date, end_date,'D',0)
        num_targets = len(files)
        self.assertEqual(num_targets, 15)
        keys = list(files.keys())
        result = "humidity2000_01_01.v3.nc"
        self.assertEqual(keys[0], result)
        file_list = files[keys[0]]
        self.assertEqual(len(file_list), 2)
        result = '/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw/airtemp2000_01_01.v3.nc'
        if sys.platform == 'win32':
            result = 'P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw\\airtemp2000_01_01.v3.nc'
        self.assertEqual(file_list[0], result)
        result = '/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw/dewpoint2000_01_01.v3.nc'
        if sys.platform == 'win32':
            result = 'P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw\\dewpoint2000_01_01.v3.nc'
        self.assertEqual(file_list[1], result)

    def testGetMonthlySequenceFiles(self):
        in_path_pattern =  self.basePathPattern()
        out_path_pattern =  self.outPathPattern()
        myFiles = FileSelector(out_path_pattern, in_path_pattern)
        myFiles.add_target_set('rainfall', variable='rfe')
        myFiles.add_input_set_to_target('rainfall', variable='rfe')
        start_date = datetime.datetime(2000,1,1)
        end_date = datetime.datetime(2000,1,15)
        files = myFiles.get_sequence_files( start_date, end_date, 'M')
        num_targets = len(files)
        keys = list(files.keys())
        result = "rfe2000_01_01.v3.nc"
        self.assertEqual(keys[0], result)
        file_list = files[keys[0]]
        self.assertEqual(len(file_list), 31)
        result = "/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw/rfe2000_01_01.v3.nc"
        if sys.platform == 'win32':
            result = "P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw\\rfe2000_01_01.v3.nc"
        self.assertEqual(file_list[0], result)

    def testGetWeelkySequenceFiles(self):
        in_path_pattern =  self.basePathPattern()
        out_path_pattern =  self.outPathPattern()
        myFiles = FileSelector(out_path_pattern, in_path_pattern)
        myFiles.add_target_set('rainfall', variable='rfe')
        myFiles.add_input_set_to_target('rainfall', variable='rfe')
        start_date = datetime.datetime(2000,1,1)
        end_date = datetime.datetime(2000,1,15)
        files = myFiles.get_sequence_files( start_date, end_date, 'W')
        num_targets = len(files)
        keys = list(files.keys())
        result = "rfe2000_01_01.v3.nc"
        self.assertEqual(keys[0], result)
        file_list = files[keys[0]]
        self.assertEqual(len(file_list), 2)
        result = "/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw/rfe2000_01_01.v3.nc"
        if sys.platform == 'win32':
            result = "P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw\\rfe2000_01_01.v3.nc"
        self.assertEqual(file_list[0], result)
        result = "rfe2000_01_03.v3.nc"
        self.assertEqual(keys[1], result)
        file_list = files[keys[1]]
        self.assertEqual(len(file_list), 7)
        result = "/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw/rfe2000_01_03.v3.nc"
        if sys.platform == 'win32':
            result = "P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw\\rfe2000_01_03.v3.nc"
        self.assertEqual(file_list[0], result)

    def testGetWeelkyBeginEndSequenceFiles(self):
        in_path_pattern =  self.basePathPattern()
        out_path_pattern =  self.outPathPattern()
        myFiles = FileSelector(out_path_pattern, in_path_pattern)
        myFiles.add_target_set('rainfall', variable='rfe')
        myFiles.add_input_set_to_target('rainfall', variable='rfe')
        start_date = datetime.datetime(2000,12,20)
        end_date = datetime.datetime(2001,1,3)
        files = myFiles.get_sequence_files( start_date, end_date, 'W')
        num_targets = len(files)
        print("num_targets=", num_targets)
        keys = list(files.keys())
        result = "rfe2000_12_20.v3.nc"
        self.assertEqual(keys[0], result)
        file_list = files[keys[0]]
        self.assertEqual(len(file_list), 5)
        result = "/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw/rfe2000_12_20.v3.nc"
        if sys.platform == 'win32':
            result = "P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw\\rfe2000_12_20.v3.nc"
        self.assertEqual(file_list[0], result)
        result = "rfe2001_01_01.v3.nc"
        self.assertEqual(keys[num_targets - 1], result)
        file_list = files[keys[num_targets - 1]]
        self.assertEqual(len(file_list), 3)
        result = "/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw/rfe2001_01_01.v3.nc"
        if sys.platform == 'win32':
            result = "P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw\\rfe2001_01_01.v3.nc"
        self.assertEqual(file_list[0], result)

    def testGetNGoupSequenceFiles_5(self):
        in_path_pattern =  self.basePathPattern()
        out_path_pattern =  self.outPathPattern()
        myFiles = FileSelector(out_path_pattern, in_path_pattern)
        myFiles.add_target_set('rainfall', variable='rfe')
        myFiles.add_input_set_to_target('rainfall', variable='rfe')
        start_date = datetime.datetime(2000,1,1)
        end_date = datetime.datetime(2000,1,15)
        files = myFiles.get_sequence_files( start_date, end_date, 'N', group_size = 5)
        num_targets = len(files)
        keys = list(files.keys())
        result = "rfe2000_01_01.v3.nc"
        self.assertEqual(keys[0], result)
        file_list = files[keys[0]]
        self.assertEqual(len(file_list), 5)
        result = "/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw/rfe2000_01_01.v3.nc"
        if sys.platform == 'win32':
            result = "P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw\\rfe2000_01_01.v3.nc"
        self.assertEqual(file_list[0], result)

    def testGetQuartelySequenceFiles(self):
        in_path_pattern =  self.basePathPattern()
        out_path_pattern =  self.outPathPattern()
        myFiles = FileSelector(out_path_pattern, in_path_pattern)
        myFiles.add_target_set('rainfall', variable='rfe')
        myFiles.add_input_set_to_target('rainfall', variable='rfe')
        start_date = datetime.datetime(2000,1,1)
        end_date = datetime.datetime(2000,1,15)
        files = myFiles.get_sequence_files( start_date, end_date, 'Q')
        num_targets = len(files)
        keys = list(files.keys())
        result = "rfe2000_01_01.v3.nc"
        self.assertEqual(keys[0], result)
        file_list = files[keys[0]]
        self.assertEqual(len(file_list), 91)
        result = "/mnt/ActiveDevelopmentProjects/LargeData/Data/TAMSATv3/unzippedDaily/raw/rfe2000_01_01.v3.nc"
        if sys.platform == 'win32':
            result = "P:\\ActiveDevelopmentProjects\\LargeData\\Data\\TAMSATv3\\unzippedDaily\\raw\\rfe2000_01_01.v3.nc"
        self.assertEqual(file_list[0], result)

if __name__ == '__main__':
    unittest.main()
