import copy
import datetime
import queue
import time
from typing import Dict, List, Optional, Tuple

import cestel_helpers.aliases
from cestel_helpers.i_conf_manager import SEC_DEFAULT, read_conf
from lxml import etree

import config
from consts import AGGR_FUZZINESS, AGGR_SAVE, AGGR_TIMEOUT, CONF_DOWNSTREAM, RCV_LANES, RCV_OFFSET, TS_FORMAT_STRING
from exceptions import NoData, StopModule
from generic_module import Module


def ts_to_datetime(ts: str) -> datetime.datetime:
    return datetime.datetime.strptime(ts, TS_FORMAT_STRING)


def remove_unwanted_tags(module: etree.Element, tags: Tuple[str, ...]) -> None:  # FIXME This is potentially risky, since if a valid key may be removed from an unrelated module.
    """ Function checks xml for keys that should not be saved and removes them.

    :param module: Element that we're editing.
    :param tags: Tuple of keys to check.
    """
    for tag_name in tags:
        tag = module.find(tag_name)
        if tag is not None:
            module.remove(tag)


class AggregationModule(Module):
    def __init__(self, args):
        self.data_ttl: float = 5  # Amount of time unmatched data is kept in memory before it's discarded.
        self.fuzziness: float = 1  # Maximum difference between timestamps after applying offsets.
        self.matching_timeout: float = 0  # Amount of time the module will wait for data.
        self.queue_warning: int = 20  # Number of items in Queue at which a warning should be logged.
        Module.__init__(self, args, mandatory_keys=(AGGR_SAVE, AGGR_TIMEOUT, AGGR_FUZZINESS))

        if self.matching_timeout == 0:
            self.logger.info('Matching timeout was not set. Unmatched data will never be discarded.')

        # Initialize primary input container.
        acq_modules = read_conf(config.conf_dir / 'acquisition.toml')  # Only acquisition modules make sense to be upstream of aggregation.
        for mod, data in acq_modules.items():
            monitored_lanes: Optional[List[int]] = data.get(RCV_LANES)  # Primary input has defined lanes rather that offset.
            if monitored_lanes:
                # These variables are intentionally not given a dummy value above, because the module stops executing if they're not set here.
                self.expected_auxiliary_per_lane: Dict[int: List[int]] = {lane: [] for lane in monitored_lanes}  # This dict contains expected number of vehicles for each monitored lane.
                self.primary_input: Dict[int, queue.Queue[cestel_helpers.aliases.Element]] = {lane: queue.Queue() for lane in monitored_lanes}  # Data from primary input.
                self.primary_input_name: str = mod
                self.auxiliary_inputs: Dict[int, Dict[str, Tuple[queue.Queue[cestel_helpers.aliases.Element], List[float]]]] = {}  # Keys are lane and module.
                break
        else:  # Raise error if primary_input was not found.
            self.logger.critical(f'Key "{RCV_LANES}" missing from primary acquisition module. Add line "{RCV_LANES} = [x, y, z, ...]" where x, y and z represent lanes on which primary input sends data.')
            raise StopModule(f'Key "{RCV_LANES}" missing from primary acquisition module. Add line "{RCV_LANES} = [x, y, z, ...]" where x, y and z represent lanes on which primary input sends data.')

        # Initialize auxiliary input containers.
        for mod, downstream in read_conf(config.conf_dir / CONF_DOWNSTREAM)[SEC_DEFAULT].items():
            if self.name in downstream and mod != self.primary_input_name:  # If this module is downstream of selected module and isn't the primary one.
                try:
                    offsets = acq_modules[mod].get(RCV_OFFSET, {})
                    for lane in [int(i) for i in offsets.keys()]:
                        if lane not in self.auxiliary_inputs:
                            self.auxiliary_inputs[lane] = {}
                        self.auxiliary_inputs[lane][mod] = (queue.Queue(), acq_modules[mod][RCV_OFFSET][str(lane)])  # Create the dictionary which will the queue.Queue and the offsets which should be applied to it.
                        self.expected_auxiliary_per_lane[lane] += [mod]
                except KeyError as e:
                    self.logger.error(f'Upstream module, {mod}, is not a valid acquisition module: Missing key {e}.')
        self.logger.info(f'Primary input: {self.primary_input_name}. Auxiliary lanes: {list(self.auxiliary_inputs.keys())}.')

    def add_vehicle(self, vehicle, recv_module_name) -> None:
        sub_vehicle: cestel_helpers.aliases.Element = vehicle.getchildren()[0]
        try:
            ts = sub_vehicle.find('ts').text
            lane = int(sub_vehicle.find('lane').text)
        except AttributeError:
            self.logger_dump.critical(f'Auxiliary missing ts or lane: {etree.tostring(sub_vehicle)}.')
            raise NotImplementedError(f'Modules without ts or lane can not be aggregated. Problematic module: {vehicle.tag}: {vehicle.attrib}.')
        try:
            if recv_module_name == self.primary_input_name:
                self.primary_input[lane].put(sub_vehicle)
            else:
                self.auxiliary_inputs[lane][recv_module_name][0].put(sub_vehicle)
            self.logger.debug(f'Received {ts} from {recv_module_name}:{lane}.')
        except KeyError as e:
            mod_name = sub_vehicle.attrib.get('source')
            mod_type = sub_vehicle.attrib.get('type')
            self.logger.warning(f'Module {sub_vehicle.tag}-{mod_name} ({mod_type}) sent data for lane {e}, which is not captured.')

    def vehicles_match(self, primary: cestel_helpers.aliases.Element, auxiliary: cestel_helpers.aliases.Element, lane_offsets: List[float], fuzziness: float) -> bool:
        """Checks if the vehicle from auxiliary module can be matched to primary module.

        :param primary: Lxml element containing data for the primary module (usually SiWIM).
        :param auxiliary: Lxml element containing data for the auxiliary module.
        :param lane_offsets: List of offsets for this lane.
        :param fuzziness: Maximum difference between offset timestamp and primary to still be considered a match.
        :return: `True` if vehicles can be matched, `False` otherwise.
        :raises NotImplementedError: Raised in situations where matching logic is not implemented.
        """
        if int(primary.find('lane').text) != int(auxiliary.find('lane').text):
            raise NotImplementedError(f'Can not compare vehicles on different lanes!')
        auxiliary_name = auxiliary.attrib['source']
        ts_primary = primary.find('ts').text
        ts_auxiliary = auxiliary.find('ts').text
        lane = int(primary.find('lane').text)
        v = float(primary.find('v').text)
        mp = int(primary.find('admpsec').text[-1]) - 1  # Format of admpsec is axleXY where X is lane and Y is measuring point.
        try:
            t_offset = -lane_offsets[mp] / v  # Calculate base offset and change its sign so that results are consistent with old implementation.
        except IndexError:
            raise NoData(f'Could not match {ts_primary} and {ts_auxiliary} ({auxiliary_name}) on lane {lane} because mp {mp} is undefined!')

        # This indicates that the measuring point triggers after the vehicle is over it rather than when it enters measuring area. Because of this we need to adjust offset by vehicle length.
        vehicle_end = auxiliary.find('siwim_i_add_to_offset')  # TODO Refactor this to take data directly from the conf or something.
        if vehicle_end is not None:
            if vehicle_end.text == '0':
                t_offset -= (float(primary.find('whlbse').text) + float(primary.find('append').text)) / v
            else:
                raise NotImplementedError(f'Module {auxiliary_name} has a non-zero value for "siwim_i_add_to_offset", which is not supported.')
        self.logger.debug(f'Matching timestamp for {ts_auxiliary}_{lane} ({auxiliary_name}): {(ts_to_datetime(ts_auxiliary) + datetime.timedelta(seconds=t_offset)).strftime(TS_FORMAT_STRING)}.')
        return t_offset - fuzziness <= (ts_to_datetime(ts_primary) - ts_to_datetime(ts_auxiliary)).total_seconds() < t_offset + fuzziness

    def run(self) -> None:
        self.alive = True
        self.end = False
        if self.matching_timeout == 0:  # If matching timeout is 0, we're postprocessing and should wait for all data to arrive before proceeding.
            time.sleep(5)
        self.logger.debug(f'Aggregation started')
        try:
            while True:
                self.throttle()
                if self.end:
                    self.alive = False
                    raise StopModule('Thread closed correctly.')

                for lane, primary_queue in self.primary_input.items():  # Process vehicles on a lane by lane basis.
                    while not primary_queue.empty():  # Get all vehicles from the queue.
                        # Sanity check to quickly see if the queue is too slow for realtime operations.
                        queue_size = primary_queue.qsize()
                        if queue_size > self.queue_warning:
                            self.logger.warning(f'Primary queue has a backlog of {queue_size} items. Report this warning immediately!')

                        # Get the most recent vehicle.
                        primary: cestel_helpers.aliases.Element = primary_queue.get()
                        primary_ts: str = primary.find('ts').text  # This indicates timestamp in siwim format for logging purposes.
                        matched_auxiliary: Dict[str, cestel_helpers.aliases.Element] = {}

                        # If there are no auxiliary modules defined for the lane, just send primary input downstream.
                        if lane not in self.auxiliary_inputs:
                            self.logger.info(f'Lane {lane} has no auxiliary inputs. Sending {primary_ts} downstream.')
                            for key, module in self.downstream_modules_dict.items():
                                vehicle: cestel_helpers.aliases.Element = etree.Element('vehicle')
                                vehicle.append(primary)
                                module.add_vehicle(copy.deepcopy(vehicle), self.name)
                            continue

                        # Keep looping until all auxiliary are matched or timeout is reached. If timeout is set to 0, wait until match is found.
                        while True:
                            self.throttle(seconds=0.1)  # Without this, matching will use all available CPU time.
                            # Iterate over all the queues that acquire events on this lane.
                            for mod, data in self.auxiliary_inputs[lane].items():
                                tmp_items: List[cestel_helpers.aliases.Element] = []  # List holds items that weren't a match.
                                # Keep trying until match is found or timeout is reached.
                                if mod not in matched_auxiliary:
                                    self.logger.debug(f'Looking for {mod} for {primary_ts}_{lane}.')
                                    try:
                                        while not data[0].empty():  # Obtain vehicles until a match is found or newest is too far.
                                            # Sanity check to quickly see if the queue is too slow for realtime operations.
                                            queue_size = primary_queue.qsize()
                                            if queue_size > self.queue_warning:
                                                self.logger.warning(f'Queue for {mod} has a backlog of {queue_size} items. Report this warning immediately!')

                                            # Get first item from queue.
                                            auxiliary: cestel_helpers.aliases.Element = data[0].get()
                                            if self.vehicles_match(primary, auxiliary, data[1], self.fuzziness):
                                                # Remove tags that should not be saved.
                                                remove_unwanted_tags(auxiliary, ('photo', 'siwim_i_add_to_offset'))  # FIXME This is a really bad way to go around removing unwanted tags. They shouldn't even be added.
                                                self.logger.info(f'Successfully matched {primary_ts} and {auxiliary.find("ts").text} ({auxiliary.attrib["source"]}) on lane {lane}.')
                                                matched_auxiliary[mod] = auxiliary
                                                break
                                            else:
                                                tmp_items.append(auxiliary)
                                        else:
                                            self.logger.debug(f'No {mod} for {primary_ts}_{lane}.')
                                    except NoData as e:
                                        self.logger.error(e)
                                        tmp_items.append(auxiliary)  # noqa Technically can be undefined, but this error can only happen in self.vehicles_match.
                                else:
                                    self.logger.debug(f'{mod} for {primary_ts}_{lane} skipped because it\'s already matched.')

                                for item in tmp_items:  # Put items that might still be matched back in the queue. This is inefficient, but complexity of the operation is O(1) and it shouldn't happen enough to matter.
                                    ts_item = item.find('ts').text
                                    if self.matching_timeout == 0 or datetime.datetime.now() - datetime.timedelta(seconds=self.matching_timeout + self.data_ttl) < ts_to_datetime(ts_item):
                                        data[0].put(item)
                                        self.logger.debug(f'Item {ts_item}_{lane} from {mod} added back to queue.')
                            self.logger.debug(f'Matched {len(matched_auxiliary)} out of {len(self.auxiliary_inputs[lane])} so far for {primary_ts}_{lane}.')
                            # We want to go through the loop at least once, even if the conditions are met. This is so that when a delay is introduced by previous data, this does not affect matching for any subsequent data.
                            if len(matched_auxiliary) == len(self.auxiliary_inputs[lane]) or datetime.datetime.now() - datetime.timedelta(seconds=self.matching_timeout) > ts_to_datetime(primary_ts) or self.matching_timeout == 0:
                                break
                        if len(matched_auxiliary) != len(self.auxiliary_inputs[lane]):
                            self.logger.warning(f'Did not find {[aux for aux in self.auxiliary_inputs[lane] if aux not in matched_auxiliary]} for {primary_ts}_{lane}. If data exists, consider increasing matching_timeout.')

                        vehicle: cestel_helpers.aliases.Element = etree.Element('vehicle')
                        vehicle.append(primary)
                        for typ, data in matched_auxiliary.items():
                            vehicle.append(data)

                        for key, mod in self.downstream_modules_dict.items():
                            mod.add_vehicle(copy.deepcopy(vehicle), self.name)
                        self.logger.info(f'Sent {vehicle.xpath("wim/ts")[0].text}_{vehicle.xpath("wim/lane")[0].text} downstream with tags {[el.tag for el in vehicle.getchildren()]}.')
        except StopModule as e:
            self.log_stop_module(e)
        except:
            self.logger.exception('Fatal error:')
            self.alive = False
        self.logger.debug(f'Aggregation stopped.')
