# Implementation of data acquisition from comark.
import copy
import datetime
import socket
import time
from typing import Dict, Optional

from lxml import etree

from abstract.module import Module
from consts import MOD_TYPE, RCV_HOST, RCV_OFFSET, RCV_PORT, RCV_SAVE, SAVE_PATH
from exceptions import NoData, StopModule


def hash_djb2(s: bytes):
    s = s.decode()
    hsh = 5381
    for x in s:
        hsh = ((hsh << 5) + hsh) + ord(x)
    msh = ("%x" % (hsh & 0xFFFFFFFF)).upper()
    while len(msh) < 8:
        msh = "0" + msh
    return ("<HSH|" + msh + ">").encode()


class AcquisitionComark(Module):
    def __init__(self, args) -> None:
        # Initialize the module, passing it a list of keys that must exist for it to work properly.
        Module.__init__(self, args, mandatory_keys=(MOD_TYPE, RCV_HOST, RCV_OFFSET, RCV_PORT, RCV_SAVE))
        self.type = 'comark'
        # Define expected keys and what they should map to.
        self.mandatory_keys: Dict[str, Optional[str]] = {
            'transit_end': None
        }
        self.known_attributes: Dict[str, Optional[str]] = {
            'id': 'id',
            'lane': 'lane',
            'lane_id': None,
            'time_iso': 'ts',
            'time_iso_ms': 'ts',
            'speed': 'v',
            'height': 'height',
            'width': 'width',
            'length': 'length',
            'refl_idx': None,
            'refl_pos': None,
            'gap': 'gap',
            'headway': None,
            'occupancy': None,
            'class_id': 'cls',
            'position': None,
            'direction': None
        }

    def parse_data(self, message: etree.Element) -> float:
        """ Due to the sending protocol only one message can be received at a time.
        :param message: Received message.
        :return: Time of message.
        """
        for key in self.mandatory_keys:
            if message.find(key) is None:
                if key == 'transit_end':  # It's normal for this key to be missing, since messages have different tags.
                    self.logger.debug('Key transit_end missing in message.')
                else:
                    self.logger.warning(f'Mandatory key {key} not found in message!')
                return time.time()

        vehicle = etree.Element('vehicle')
        node = etree.Element('lwh')
        node.set('source', self.name)
        node.set('type', 'comark')

        # Transform data into SiWIM format, making changes where necessary and keeping the rest as is.
        for key, val in message.find('transit_end').attrib.items():
            self.logger.debug(f'Handling {key}: {val}.')
            if key not in self.known_attributes.keys():
                self.logger.warning(f'Key {key} not handled')
                continue
            elif self.known_attributes[key] is None:  # Skip keys that map to None.
                continue
            if key == 'time_iso_ms':
                val = datetime.datetime.strptime(val, '%Y-%m-%dT%H:%M:%S.%f').strftime('%Y-%m-%d-%H-%M-%S-%f')[:-3]
            elif key == 'time_iso':
                val = datetime.datetime.strptime(val, '%Y-%m-%dT%H:%M:%S').strftime('%Y-%m-%d-%H-%M-%S-%f')[:-3]
            elif key == 'speed':
                val = str(float(val) / 3.6)
            elif key == 'class_id':
                for cypher_key in self.cyphers.keys():
                    key_el = etree.Element(cypher_key)
                    key_el.text = self.cyphers[cypher_key][val]
                    node.append(key_el)
            el = etree.Element(self.known_attributes[key])
            el.text = val
            node.append(el)
        vehicle.append(node)
        el = etree.Element('siwim_i_add_to_offset')
        el.text = '0'
        node.append(el)

        for key, mod in self.downstream_modules_dict.items():
            mod.add_vehicle(copy.deepcopy(vehicle), self.name)
        self.logger.info(f'Sent {node.find("ts").text}_{node.find("lane").text} downstream.')
        return time.time()

    def run(self) -> None:
        self.alive = True
        self.end = False

        try:
            self.am_i_a_server = True
            if not self.bind_tcp():
                raise StopModule('Connection failed.')
            self.s.listen(3)
            conn = None
            while True:
                self.throttle()
                if self.end:  # Checks for end flag while connection is being established.
                    raise StopModule('Shutdown flag detected.', conn=conn)
                try:
                    conn, addr = self.accept_tcp()
                except TimeoutError as e:
                    self.logger.debug(e)
                    continue
                last_heartbeat = time.time()
                buffer = b''
                while True:
                    if self.end:
                        raise StopModule('Shutdown flag detected.', conn=conn)
                    try:
                        received_xml, buffer = self.acquire_data(conn, buffer, b'<sensor ', b'</sensor>', 'xml', save_info={SAVE_PATH: 'lwh'}, exit_on_timeout=True, response_handler=hash_djb2)
                    except NoData as e:
                        if time.time() - last_heartbeat > 50:
                            raise StopModule('No data received.', conn=conn)
                        continue
                    except socket.error as e:
                        raise StopModule(f'SocketError: {e}')
                    last_heartbeat = self.parse_data(received_xml)
        except StopModule as e:
            self.log_stop_module(e)
