import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Set, Tuple

from cestel_helpers.aliases import Element
from lxml import etree

import config
from consts import TS_FORMAT_STRING
from exceptions import StopModule
from generic_module import Module


def calculate_time_gap(ts1: datetime, ts2: datetime, length: float, speed: float) -> float:
    return (ts2 - ts1 - timedelta(seconds=length / speed)).total_seconds()


def calculate_distance_gap(ts1: datetime, ts2: datetime, length: float, speed: float) -> float:
    return (ts2 - ts1).total_seconds() * speed - length


def sort_vehicles(vehicles: List[Element]) -> List[etree.Element]:
    """ Sort vehicles by wim/ts child element.
    :param vehicles: List of elements to sort.
    :returns: Sorted list of elements."""
    return sorted(
        vehicles,
        key=lambda x: datetime.strptime(
            x.xpath('./wim/ts/text()')[0],
            TS_FORMAT_STRING
        )
    )


def update_vehicles(vehicles: List[Element], max_time_gap: float) -> Set[Element]:
    """ Updates vehicles in the list and returns a set of duplicates.
    :param vehicles: List of xml objects.
    :param max_time_gap: Maximum allowed time gap between two vehicles.
    :returns Set: A set of duplicate vehicles.
    """
    duplicates = set()  # Set of duplicated vehicles.
    gaps: Dict[str, Tuple[datetime, float]] = {}  # Holds timestamp of last axle and speed of vehicle.
    for vehicle in vehicles:
        lane = vehicle.find('wim/lane').text
        ts = datetime.strptime(vehicle.find('wim/ts').text, TS_FORMAT_STRING)
        v = float(vehicle.find('wim/v').text)
        if lane in gaps:  # If it's first occurrence of the lane, just save the information for next vehicle.
            gap_seconds = calculate_time_gap(gaps[lane][0], ts, gaps[lane][1], v)
            if gap_seconds < max_time_gap:  # Don't save it, if vehicles are more than 20 seconds apart.
                gap = calculate_distance_gap(gaps[lane][0], ts, gaps[lane][1], v)
                if gap < 0:  # Negative gap indicates duplicated vehicle which we want to skip if remove_duplicates is True.
                    duplicates.add(vehicle)
                    continue
                node = etree.Element('integration')
                el = etree.Element('gap')
                el.text = str(round(gap, 4))
                node.append(el)
                vehicle.append(node)
        gaps[lane] = (ts, float(vehicle.find('wim/whlbse').text))  # Save timestamp of vehicle and its length.
    return duplicates


def save_file(path: Path, vehicles: List[Element], duplicates: Set[Element]) -> None:
    with open(path, 'w', encoding='utf-8') as output_file:
        # Write header.
        output_file.write(f'<?xml version="1.0" ?>\n<swd version="1"><site><name>{config.site_name}</name><vehicles>\n')
        # Write vehicles with gaps.
        for vehicle in [v for v in vehicles if v not in duplicates]:
            output_file.write(etree.tostring(vehicle).decode().strip() + '\n')
        # Write footer.
        output_file.write('</vehicles></site></swd>')


class GapModule(Module):
    def __init__(self, args):
        self.remove_duplicates = False
        self.time_gap = 20.0
        self.xml_dir = config.sites_dir / config.site_name / 'ext'
        Module.__init__(self, args, mandatory_keys=tuple(), optional_keys=('remove_duplicates', 'time_gap'))
        self.type = 'gap'

    def add_gap(self, file_path: Path):
        """ Reads XML, sorts vehicle elements by their wim/ts subkey, and writes sorted XML.
        :param file_path: XML file name.
        """
        # Parse the XML with whitespace removal for clean output
        if not file_path.is_file():
            raise StopModule(f'File "{file_path.name}" does not exist.')
        tree = etree.parse(file_path)
        root = tree.getroot()

        # Find all vehicle elements
        vehicles = root.xpath('/swd/site/vehicles/vehicle')

        vehicles_sorted = sort_vehicles(vehicles)

        if len(vehicles) != len(vehicles_sorted):
            raise StopModule('Failed to sort vehicles!')

        # Remove all vehicles from their parents
        for vehicle in vehicles:
            parent = vehicle.getparent()
            if parent is not None:
                parent.remove(vehicle)

        # Update vehicles.
        duplicates = update_vehicles(vehicles_sorted, self.time_gap)
        if len(duplicates) > 0:
            self.logger.error(f'{len(duplicates)} vehicles with negative gap found. Removing {"enabled" if self.remove_duplicates else "disabled"}.')
        for vehicle in duplicates:
            self.logger.info(f'Vehicle {vehicle.find("wim/ts").text} from event {vehicle.find("wim/ets").text} has negative gap.')

        # Overwrite the file.
        self.logger.info(f'Overwriting {file_path.name} ...')
        save_file(file_path, vehicles_sorted, duplicates if self.remove_duplicates else set())
        self.logger.info(f'{file_path.name} written.')

    def run(self) -> None:
        self.alive = True
        self.end = False
        self.logger.debug('Started.')
        try:
            # Convert yesterday's file, if needed.
            try:
                self.logger.info('Checking if yesterday\'s file needs to be updated ...')
                date = datetime.now().date() - timedelta(days=1)
                # If any gap tag already exist, the file has been handled already.
                if len(etree.parse(str(self.xml_dir / f'{date.isoformat()}.xml')).getroot().xpath('/swd/site/vehicles/vehicle/integration/gap')) != 0:
                    self.logger.info('Yesterday\'s file does not need to be updated.')
                    date = datetime.now().date()
            except OSError:
                self.logger.warning(f'Could not find yesterday\'s file.')
                date = datetime.now().date()
            except etree.XMLSyntaxError:
                self.logger.critical(f'Yesterday\'s file is corrupted.')
                date = datetime.now().date()

            try:
                while True:
                    if self.end:
                        raise StopModule('Shutdown flag detected.')
                    if date != datetime.now().date():
                        time.sleep(30)  # There's no rush to convert, so we give some more time to make sure file is no longer in use.
                        file_name = f'{date.isoformat()}.xml'
                        self.logger.debug(f'Updating {file_name} ...')
                        self.add_gap(self.xml_dir / file_name)
                        self.logger.debug(f'{file_name} updated.')
                        date = datetime.now().date()
                    time.sleep(10)
            except StopModule:
                raise
            except:
                self.logger.exception('Unexpected exception occurred:')
                raise StopModule('Fatal error!')
        except StopModule as e:
            self.log_stop_module(e)
        finally:
            self.logger.debug('Stopped.')
