from datetime import datetime, timedelta
from pathlib import Path
from time import sleep
from typing import Dict, List, Set, Tuple

from cestel_helpers.aliases import Element
from lxml import etree

from abstract.postprocessing import PostprocessingModule
from consts import GAP_REMOVE, GAP_TIME, TS_FORMAT_STRING
from exceptions import StopModule


def calculate_time_gap(ts1: datetime, ts2: datetime, length: float, speed: float) -> float:
    return (ts2 - ts1 - timedelta(seconds=length / speed)).total_seconds()


def calculate_distance_gap(ts1: datetime, ts2: datetime, length: float, speed: float) -> float:
    return (ts2 - ts1).total_seconds() * speed - length


def sort_vehicles(vehicles: List[Element]) -> List[etree.Element]:
    """ Sort vehicles by wim/ts child element.
    :param vehicles: List of elements to sort.
    :returns: Sorted list of elements."""
    return sorted(
        vehicles,
        key=lambda x: datetime.strptime(
            x.xpath('./wim/ts/text()')[0],
            TS_FORMAT_STRING
        )
    )


def update_vehicles(vehicles: List[Element], max_time_gap: float) -> Set[Element]:
    """ Updates vehicles in the list and returns a set of duplicates.
    :param vehicles: List of xml objects.
    :param max_time_gap: Maximum allowed time gap between two vehicles.
    :returns Set: A set of duplicate vehicles.
    """
    duplicates = set()  # Set of duplicated vehicles.
    gaps: Dict[str, Tuple[datetime, float]] = {}  # Holds timestamp of last axle and speed of vehicle.
    for vehicle in vehicles:
        lane = vehicle.find('wim/lane').text
        ts = datetime.strptime(vehicle.find('wim/ts').text, TS_FORMAT_STRING)
        v = float(vehicle.find('wim/v').text)
        if lane in gaps:  # If it's first occurrence of the lane, just save the information for next vehicle.
            gap_seconds = calculate_time_gap(gaps[lane][0], ts, gaps[lane][1], v)
            if gap_seconds < max_time_gap:  # Don't save it, if vehicles are more than 20 seconds apart.
                gap = calculate_distance_gap(gaps[lane][0], ts, gaps[lane][1], v)
                if gap < 0:  # Negative gap indicates duplicated vehicle which we want to skip if remove_duplicates is True.
                    duplicates.add(vehicle)
                    continue
                node = etree.Element('integration')
                el = etree.Element('gap')
                el.text = str(round(gap, 4))
                node.append(el)
                vehicle.append(node)
        gaps[lane] = (ts, float(vehicle.find('wim/whlbse').text))  # Save timestamp of vehicle and its length.
    return duplicates


class GapModule(PostprocessingModule):
    def __init__(self, args):
        self.remove_duplicates = False
        self.time_gap = 20.0
        PostprocessingModule.__init__(self, args, mandatory_keys=tuple(), optional_keys=(GAP_REMOVE, GAP_TIME))

    def add_gap(self, file_path: Path) -> Tuple[List[Element], Set[Element]]:
        """ Reads XML, sorts vehicle elements by their wim/ts subkey, and writes sorted XML.
        :param file_path: XML file name.
        """
        # Parse the XML with whitespace removal for clean output
        if not file_path.is_file():
            raise StopModule(f'File "{file_path.name}" does not exist.')
        tree = etree.parse(file_path)
        root = tree.getroot()

        # Find all vehicle elements
        vehicles = root.xpath('/swd/site/vehicles/vehicle')

        vehicles_sorted = sort_vehicles(vehicles)

        if len(vehicles) != len(vehicles_sorted):
            raise StopModule('Failed to sort vehicles!')

        # Remove all vehicles from their parents
        for vehicle in vehicles:
            parent = vehicle.getparent()
            if parent is not None:
                parent.remove(vehicle)

        # Update vehicles.
        duplicates = update_vehicles(vehicles_sorted, self.time_gap)
        if len(duplicates) > 0:
            self.logger.error(f'{len(duplicates)} vehicles with negative gap found. Removing {"enabled" if self.remove_duplicates else "disabled"}.')
        for vehicle in duplicates:
            self.logger.info(f'Vehicle {vehicle.find("wim/ts").text} from event {vehicle.find("wim/ets").text} has negative gap.')

        return vehicles_sorted, duplicates

    def run(self) -> None:
        self.alive = True
        self.end = False
        self.logger.debug(f'Started using SWD location {self.swd_dir}.')
        try:
            # Find the first day that needs to be updated.
            for day in range(self.lookback_days, 0, -1):
                date = datetime.now().date() - timedelta(days=day)
                self.logger.debug(f'Checking if {date}.xml needs to be updated ...')
                # If any match tag already exists, the file has been handled already.
                if self.swd_to_process(self.swd_dir / f'{date}.xml'):
                    self.logger.info(f'Updating starting with {date}.xml.')
                    break
            else:  # If we're not checking past files, use today's date.
                date = datetime.now().date()
                self.logger.debug(f'No valid SWD files found within {self.lookback_days} days.')

            try:
                while True:
                    self.logger.debug('Checking if date changed ...')
                    if self.end:
                        raise StopModule('Shutdown flag detected.')
                    if date != (datetime.now() - timedelta(seconds=self.grace_period)).date():  # Run after grace period.
                        self.logger.debug(f'Processing {date}.')
                        xml_path = self.swd_dir / f'{date}.xml'
                        if self.swd_to_process(xml_path):
                            vehicles, duplicates = self.add_gap(xml_path)
                            if vehicles:
                                self.logger.info(f'Overwriting {xml_path.name} ...')
                                self.save_swd_file(xml_path, vehicles, duplicates if self.remove_duplicates else set())
                                self.logger.info(f'{xml_path.name} written.')
                        date = date + timedelta(days=1)
                    else:
                        sleep(10)
            except StopModule:
                raise
            except:
                self.logger.exception('Unexpected exception occurred:')
                raise StopModule('Fatal error!')
        except StopModule as e:
            self.log_stop_module(e)
