import json
import os
import re
import socket
import time
import warnings
from abc import ABC, abstractmethod
from datetime import datetime
from logging import getLogger
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

import tomlkit
from cestel_helpers.aliases import PathLike
from cestel_helpers.log import init_logger
from lxml import etree

import config
from consts import GLOBAL_NAME, LOGGER_MAIN, LOG_DIR, LOG_FMT_DUMP, LOG_LEVEL, LOG_LEVEL_CON, LOG_LEVEL_DUMP, LOG_ROTATING, RCV_CYPHER, SAVE_CLEAR_WHITESPACE, SAVE_FOOTER, SAVE_HEADER, SAVE_PATH, \
    SAVE_SEPARATOR, SAVE_SUFFIX
from exceptions import NoData, StopModule


class Module(ABC):  # Module is an abstract class that should never be instantiated directly.
    def __init__(self, module_data: Dict[str, Any], mandatory_keys: Tuple[str, ...] = tuple(), optional_keys: Tuple[str, ...] = tuple()):
        """ Each module is initialized from a dictionary of values read from conf.
        :param module_data: Dictionary containing everything defined in configuration file of the module.
        :param mandatory_keys: A tuple of keys that must exist in the toml file for the module to work correctly.
        :param optional_keys: A tuple of keys that are recognized by the module, but don't need to be defined in the toml file.
        """
        self.offset = tomlkit.inline_table()
        # General information.
        self.end = False
        self.alive = True
        self.restart_on_fail = True
        self.downstream_modules_dict = {}
        self.sites_path = str(config.sites_dir)
        self.data_dirs: Dict[str, Path] = {}
        self.cyphers = {}
        self.type: Optional[str] = None
        self.response_type: Optional[str] = None
        # Logging.
        log_level = module_data[LOG_LEVEL] if LOG_LEVEL in module_data else 20
        log_level_console = module_data[LOG_LEVEL_CON] if LOG_LEVEL_CON in module_data else 40
        log_to_console = bool(log_level_console)
        log_rotating = module_data[LOG_ROTATING] if LOG_ROTATING in module_data else False
        # Connections.
        self.s: Optional[socket.socket] = None
        self.am_i_a_server: bool = False
        # Contains mapping of mandatory keys to saved name.
        self.mandatory_keys = {}
        # various bits of upstream information
        self.upstream_info = {}
        # force this here so logger will initialize
        self.name = module_data[GLOBAL_NAME]
        # each module gets it's own logger
        self.logger = init_logger(self.name, folder=Path(self.sites_path, config.site_name, LOG_DIR), level=log_level, console_level=log_level_console, to_console=log_to_console, add_line_number=True if log_level_console == 10 else False,
                                  no_date=log_rotating)
        self.logger.debug(f'Logging to file with level {log_level} and to console with level {log_level_console}. Using rotating logger: {log_rotating}', stacklevel=2)
        self.logger.debug(f'Module data: {module_data}', stacklevel=2)
        self.logger_dump = init_logger(f'{self.name}_dump', folder=config.log_dir, level=module_data.get(LOG_LEVEL_DUMP, 50), fmt=LOG_FMT_DUMP, no_date=True)
        # Store all keys from configuration file as values of either type int, float, time or string.
        unknown_keys = []
        for arg, val in module_data.items():
            if arg == RCV_CYPHER:
                try:
                    self.init_cyphers(val)
                except StopModule as e:
                    self.logger.critical(e)
                    self.restart_on_fail = False  # This signals the starter to not even run the module.
                continue
            if arg not in mandatory_keys + optional_keys + (GLOBAL_NAME, LOG_LEVEL, LOG_LEVEL_CON, LOG_LEVEL_DUMP, LOG_ROTATING, 'label'):
                unknown_keys.append(arg)
            setattr(self, arg, val)
        if len(unknown_keys) > 0:
            self.logger.warning(f'Unrecognized keys: {",".join(unknown_keys)}. They should be removed.', stacklevel=2)
        # Check if all expected parameters actually exist in dictionary and print an error if any are missing.
        if mandatory_keys:
            missing = [k for k in mandatory_keys if k not in module_data]
            if len(missing) > 0:
                self.logger.critical(f'Missing keys: {",".join(missing)}. Module may not work correctly.', stacklevel=2)

    def init_cyphers(self, cypher_file):
        # Parse cypher located in i_conf directory.
        cypher_path = Path(config.conf_dir, cypher_file)
        try:
            with open(cypher_path, encoding='utf-8') as fd:
                lines = fd.readlines()
        except FileNotFoundError:
            raise StopModule(f'Could not find cypher file at: {cypher_path}.')
        header_parts = lines[0].split(";")
        for i in range(1, len(header_parts)):
            self.cyphers[header_parts[i].strip()] = {}
            for j in range(1, len(lines)):
                line_parts = lines[j].split(";")
                self.cyphers[header_parts[i].strip()][line_parts[0].strip()] = line_parts[i].strip()

    def throttle(self, seconds: float = 0.01):
        time.sleep(seconds)

    def accept_tcp(self, timeout: Optional[int] = 2) -> Tuple[socket.socket, str]:
        """ This method should be called after bind_tcp and socket.listen to establish connection. """
        try:
            soc, addr = self.s.accept()
            soc.settimeout(timeout)
            self.logger.info(f'Accepted connection from {addr[0]}:{self.port}.', stacklevel=2)
            return soc, addr
        except socket.timeout:
            raise TimeoutError('Connection timed out.')

    def connect_tcp(self, timeout: Optional[int] = 2) -> bool:
        self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            self.s.connect((self.host, self.port))
            self.s.settimeout(timeout)
            return True
        except socket.error as e:
            self.logger.warning(f'Connection to {self.host}:{self.port} refused ({e})', stacklevel=2)
            return False
        except:
            self.logger.exception('Failed to connect:')
            return False

    def bind_udp(self) -> bool:
        self.s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        try:
            self.s.bind(("", self.port))
            # self.s.settimeout(2)
            return True
        except:
            self.logger.exception('Failed to bind:')
            return False

    def bind_tcp(self, timeout: Optional[int] = 2) -> bool:
        self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        try:
            self.s.bind(('', self.port))
            self.s.settimeout(timeout)  # Without a timeout, a thread trying to establish connection to a server that doesn't exist will hang.
            self.logger.debug(f'Successfully bound to port {self.port}.', stacklevel=2)
            return True
        except (ConnectionAbortedError, OSError) as e:
            self.logger.error(f'Failed to bind socket to port {self.port}: {e}.', stacklevel=2)
            return False
        except:
            self.logger.exception('Failed to bind:', stacklevel=2)
            return False

    def acquire_data(self, soc: socket.socket, recv_data: bytes, start_bytes: bytes, end_bytes: bytes, parsing_method: str, block_buffer: int = 2048, max_buffer: int = 1048576, save_info: Optional[Dict[str, str]] = None,
                     exit_on_timeout: bool = False, offset_end: int = 0, encoding: str = 'utf-8', response_handler: Optional[Callable[[bytes], bytes]] = None) -> Tuple[Union[Dict[str, Any], etree.Element], bytes]:
        """ Function acquires a data block from the socket and parses it according to parsing_method.
        :param soc: Socket on which we're receiving data.
        :param recv_data: Data that has been received previously.
        :param start_bytes: Text used to identify start of message.
        :param end_bytes: Text used to identify end of message.
        :param parsing_method: Currently supported methods are json and xml.
        :param block_buffer: Block of data in bytes that should be read at a time.
        :param max_buffer: Maximum amount of data in bytes that is allowed in the buffer before it is discarded.
        :param save_info: Dictionary containing saving related parameters. Must contain path, but can also contain header and footer where applicable.
        :param exit_on_timeout: Decides what should be done when timeout is reached. Set to True, if data is acquired on demand and reaching the timeout indicates no response from server.
        :param offset_end: Set if end of message is not the same as end of valid data.
        :param encoding: Encoding that should be used to decode received data.
        :param response_handler: If a response is required, define the function that should be used to generate it. The function takes received message as argument and must return bytes.
        :return: A tuple containing resulting data based on parsing_method and any leftover data.
        """
        if save_info and SAVE_PATH in save_info:
            error_path = Path(config.sites_dir, config.site_name, 'ext', save_info[SAVE_PATH], self.type, self.name)
        else:  # If the module has nothing to save, create a directory based on type and name for errors.
            error_path = Path(config.sites_dir, config.site_name, 'ext', self.type, self.name)
            error_path.mkdir(parents=True, exist_ok=True)
        # First obtain the data block we're working with.
        start_pos = -1
        end_pos = -1
        while (start_pos == -1 or end_pos == -1) and len(recv_data) < max_buffer:  # Keep iterating until both start_pos and end_pos are found or you reach max_buffer.
            if self.end:
                raise StopModule('Shutdown signal received.', conn=soc)
            try:
                new_data = soc.recv(block_buffer)
                if new_data == b'':  # Connection was closed by client.
                    raise NoData('The other end signalled that there is no more data to read.')
                recv_data += new_data
            except ConnectionResetError:
                raise NoData('Connection reset by the other side.')
            except (socket.timeout, TimeoutError):  # FIXME socket.timeout is deprecated in Python 3.10. It becomes an alias for TimeoutError.
                self.logger_dump.debug(f'Data received before timeout: {recv_data}.', stacklevel=2)
                if exit_on_timeout:
                    raise NoData(f'Timeout ({soc.gettimeout()}) has been reached while waiting for message!')
                else:
                    self.logger.debug('Connection timed out.', stacklevel=2)
                    time.sleep(1)
                    continue
            start_pos = recv_data.find(start_bytes)
            end_pos = recv_data.find(end_bytes, start_pos)
            self.logger.debug(f'Start: {start_pos}, End: {end_pos}, Length: {len(recv_data)}, Max: {max_buffer}.', stacklevel=2)
        if start_pos == -1 or end_pos == -1:
            self.logger_dump.error(f'Couldn\'t find block with start {start_bytes} and end {end_bytes} in {recv_data}.', stacklevel=2)
            raise NoData(f'Failed to find data block! Maximum buffer is set to {max_buffer} bytes. If messages for this module can be longer, open an issue and include this message.')

        relevant_data = recv_data[start_pos:end_pos + len(end_bytes) + offset_end]
        # The purpose of this is to catch all errors originating from incorrect encoding. If encoding is utf-8 this does nothing as long as data is valid. If encoding is not utf-8, data is encoded as utf-8 for further processing.
        try:
            relevant_data = relevant_data.decode(encoding=encoding).encode()
        except UnicodeDecodeError:
            if SAVE_PATH in save_info:
                self.generate_data_dir(save_info[SAVE_PATH])
                with open(Path(error_path, f'error_{encoding}.txt'), 'ab') as fd:
                    fd.write(relevant_data + b'\n')
            raise NoData(f'Failed to decode data. Raw data appended to failed.txt in module save directory. Used encoding is {encoding}.')

        # Save received data to disk, if option is set.
        if save_info is not None and self.type != 'siwim' and self.type != 'ffgroup':  # We save SiWIM after processing for consistency. ffgroup is saved after because photo data should be removed first.
            self.write_data(save_info, relevant_data.decode())

        # Now parse said data based on parsing_method.
        if parsing_method == 'json':  # TODO Make this into a function.
            try:
                result = json.loads(relevant_data)
            except json.decoder.JSONDecodeError as e:
                with open(Path(error_path, 'error_utf-8.txt'), 'ab') as fd:
                    fd.write(relevant_data + b'\n')
                raise NoData(f'Parsing failed: {e} in data starting with "{relevant_data[:30]}" and ending with "{relevant_data[-15:]}".')
            # Save data if it's ffgroup. FIXME There should be a way to define where to save module data.
            if save_info and self.type == 'ffgroup':
                try:
                    result.pop('ImageArray')  # Remove the image array key from the dictionary to save disk space.
                except:
                    pass
                self.write_data(save_info, json.dumps(result))
            # Validate keys.
            for key in self.mandatory_keys.keys():
                if key not in result:
                    self.logger.error(f'Key {key} (mandatory) is missing from message!', stacklevel=2)
                    self.logger_dump.error(f'Missing key: {key} in {result}.', stacklevel=2)
        elif parsing_method == 'xml':
            try:
                result = etree.fromstring(relevant_data)
            except etree.XMLSyntaxError as e:
                self.logger_dump.error(f'Couldn\'t parse: {relevant_data}.', stacklevel=2)
                raise NoData(f'Parsing failed: {e}.')
            # Save data if it's SiWIM
            if save_info and self.type == 'siwim':
                self.write_data(save_info, [etree.tostring(vehicle, encoding='utf-8').decode() for vehicle in result.xpath('site/vehicles/vehicle')])
            # Validate keys.
            # TODO
        else:
            raise NotImplementedError(f'Parsing method {parsing_method} is not supported.')

        # Handle sending the response, if the device expects it. Response is binary.
        if response_handler:
            response = response_handler(relevant_data)
            soc.sendall(response)

        return result, recv_data[start_pos + len(relevant_data):]

    def acquire_json_with_size(self, soc: socket.socket, recv_data: bytes, start_bytes: bytes, msg_length_key: bytes, block_buffer: int = 2048, max_buffer: int = 1048576, save_info: Optional[Dict[str, str]] = None, exit_on_timeout: bool = False,
                               encoding='utf-8') -> Tuple[Union[Dict[str, Any], etree.Element], bytes]:
        """ Function acquires a data block from the socket and parses it according to parsing_method.
                :param soc: Socket on which we're receiving data.
                :param recv_data: Data that has been received previously.
                :param start_bytes: Text used to identify start of message.
                :param msg_length_key: Name of the key that contains message length and needs to be found to determine how much data to acquire.
                :param block_buffer: Block of data in bytes that should be read at a time.
                :param max_buffer: Maximum amount of data in bytes that is allowed in the buffer before it is discarded.
                :param save_info: Dictionary containing saving related parameters. Must contain path, but can also contain header and footer where applicable.
                :param exit_on_timeout: Decides what should be done when timeout is reached. Set to True, if data is acquired on demand and reaching the timeout indicates no response from server.
                :param encoding: Encoding that should be used to decode received data.
                :return: A tuple containing resulting data based on parsing_method and any leftover data.
                """
        # First obtain the data block we're working with.
        start_pos = -1
        end_pos = -1
        pattern = b'"' + msg_length_key + b'":"(\d+)"'
        while (start_pos == -1 or end_pos == -1 or end_pos > start_pos + len(recv_data)) and len(recv_data) < max_buffer:  # Keep iterating until both start_pos and end_pos are found or you reach max_buffer.
            if self.end:
                raise StopModule('Shutdown signal received.', conn=soc)
            try:
                recv_data += soc.recv(block_buffer)
            except socket.timeout:
                self.logger_dump.debug(f'Data received before timeout: {recv_data}.', stacklevel=2)
                if exit_on_timeout:
                    raise NoData(f'Timeout has been reached while waiting for response!')
                else:
                    self.logger.debug('Connection timed out.', stacklevel=2)
                    continue
                # raise NoData(f'Timeout was reached before a complete message was received. Make sure your messages start with "{start_bytes.decode()}" and end with "{end_bytes.decode()}".')
            start_pos = recv_data.find(start_bytes)
            match = re.search(pattern, recv_data[start_pos:])
            if match:  # If match is found, set end_pos to the correct value.
                end_pos = start_pos + int(match.group(1)) - 1
        if start_pos == -1 or end_pos == -1:
            self.logger_dump.error(f'Couldn\'t find block with start {start_bytes} and length in {msg_length_key} in {recv_data}.', stacklevel=2)
            raise NoData('Failed to find data block!')

        # Perform a sanity check just in case something is off with the logic.
        self.logger.debug(f'Start: {start_pos}, End: {end_pos}, Length: {len(recv_data)}')
        if start_pos + end_pos > len(recv_data):
            raise NoData('Something is wrong. Incomplete data was marked as complete.')
        # Now parse said data based on parsing_method.
        relevant_data = recv_data[start_pos:end_pos]
        if SAVE_PATH in save_info:
            self.write_data(save_info, relevant_data.decode(encoding=encoding))

        try:
            result = json.loads(relevant_data)
        except json.decoder.JSONDecodeError as e:
            self.logger_dump.error(f'Couldn\'t parse: {relevant_data}.', stacklevel=2)
            raise NoData(f'Parsing failed (data starts with {recv_data[:10]} and ends with {recv_data[-10:]} : {e}.')
        # Validate keys.
        for key in self.mandatory_keys.keys():
            if key not in result:
                self.logger.error(f'Key {key} (mandatory) is missing from message!', stacklevel=2)
                self.logger_dump.error(f'Missing key: {key} in {result}.', stacklevel=2)

        return result, recv_data[start_pos + len(relevant_data):]

    def generate_data_dir(self, folder: PathLike, ts: Optional[str] = None, update: bool = True) -> Path:
        """ Generated directory structure required by module. Due to the fact that site name can change at any time, this should be called whenever we want to write.
        :param folder: Folder in ext to which data should be stored. Also serves as key when identifying the path.
        :param ts: If this is passed, sub-folders with partial timestamps are generated inside the folder.
        :param update: Set to False, if you just want to generate the folder and not update data_dirs.
        :returns Path: Path of the created directory.
        """
        if self.type is None:
            raise NotImplementedError(f'Generating data directories is not enabled for modules of type {type(self).__name__}.')
        full_path = Path(self.sites_path, config.site_name, 'ext', folder, self.type, self.name)
        if ts:
            ts_hour = re.findall(r'\d{4}-\d{2}-\d{2}-\d{2}', ts)[0]  # Matches the part of timestamp string up to the hour.
            ts_day = ts_hour.rsplit('-', 1)[0]
            full_path = Path(full_path, f'{ts_day}-hh-mm-ss', f'{ts_hour}-mm-ss')
        if self.data_dirs.get(folder) != full_path:
            full_path.mkdir(parents=True, exist_ok=True)  # Create the full folder structure.
            if update:
                self.data_dirs[folder] = full_path
                self.logger.debug(f'Data directory for {folder} set to {self.data_dirs[folder]}.', stacklevel=2)
        return full_path

    def write_data(self, save_info: Dict[str, str], data: Union[str, Iterable[str]]) -> None:
        """ Method used for saving data to disk.
        :param save_info: Dictionary containing saving related parameters. Must contain path, but can also contain header and footer where applicable.
        :param data: Line(s) that should be written. If suffix is tsv, all data is written as a single line prepended with timestamp and separated with tabs.
        """
        if SAVE_PATH not in save_info:
            raise NotImplementedError('Data not written because SAVE_PATH is missing from save_info.')
        if isinstance(data, bytes):
            raise NotImplementedError('Saving is not implemented for bytes type objects. Decode to string first.')
        date = datetime.now().strftime('%Y-%m-%d')
        suffix = save_info.get(SAVE_SUFFIX, 'lst')
        if config.root_output == self.name:  # Handle saving to ext folder.
            file = Path(config.sites_dir, config.site_name, 'ext', f'{date}.{suffix}')
        else:  # Otherwise generate the directory and save to default path.
            self.generate_data_dir(save_info[SAVE_PATH])
            file = Path(self.data_dirs[save_info[SAVE_PATH]], f'{date}.{suffix}')
        try:
            with open(file, 'a+' if SAVE_FOOTER in save_info else 'a', encoding='utf-8') as fd:
                separator = save_info.get(SAVE_SEPARATOR, '\n')
                if file.stat().st_size == 0 and SAVE_HEADER in save_info:  # If header is defined, every file should start with it.
                    fd.write(save_info[SAVE_HEADER] + separator)
                if file.stat().st_size != 0 and SAVE_FOOTER in save_info:
                    fd.seek(0, os.SEEK_END)
                    fd.truncate(fd.tell() - len(save_info[SAVE_FOOTER]))  # Remove the footer.
                # There's two ways to format data. If data is an Iterable, it should be prepended with timestamp and joined with tabs. Otherwise, it's written verbatim.
                if suffix == 'tsv':  # These types of files should have their data joined by tabs and prepended with timestamp.
                    fd.write(datetime.now().strftime('%Y-%m-%d-%H-%M-%S') + '\t' + (data if isinstance(data, str) else '\t'.join(data)) + separator)
                elif isinstance(data, str):  # If it's a string, write it directly.
                    if save_info.get(SAVE_CLEAR_WHITESPACE):
                        data = re.sub(r'\s+', '', data)
                    fd.write(data + separator)
                elif isinstance(data, Iterable):  # If it's a list, write each entry separated by separator.
                    if save_info.get(SAVE_CLEAR_WHITESPACE):
                        data = [re.sub(r'\s+', '', d) for d in data]
                    fd.write(separator.join(data))
                else:
                    raise NotImplementedError(f'Writing not implemented for type {type(data)}.')
                if SAVE_FOOTER in save_info:  # If footer exists, append it to the end of file.
                    fd.write(save_info[SAVE_FOOTER])
        except OSError as e:
            self.logger.error(f'Could not write data: {e}', stacklevel=2)

    def set_sites_path(self, sites_path):
        warnings.warn('Function set_sites_path is deprecated. Do not use!', stacklevel=2)
        self.sites_path = sites_path

    def set_end(self):
        self.end = True

    # methods that "return True" are somwhat special and should be reimplemented in certain modules
    def add_downstream_module(self, module):
        self.downstream_modules_dict[module.get_name()] = module

    def set_upstream_info(self, module_name, info_type, info_value):
        with config.lock:
            if isinstance(info_value, Path):  # Convert Path objects to posix path for consistency.
                info_value = info_value.as_posix().replace('D:/siwim_mkiii', 'siwim') + '/'  # Replacing does nothing if the string is not found. Rsync requires ending dash.
            config.status_dict[(module_name, info_type)] = info_value

    def get_name(self):
        return self.name

    # TODO: generalize this into "add_data" or something; this is a method for forwarding data to downstream modules
    def add_vehicle(self, sub_vehicle, module_name):
        self.logger.critical(f'Method add_vehicle not implemented for this module. Make sure that it is not downstream of any other module.', stacklevel=2)
        time.sleep(1)

    def is_alive(self):
        return self.alive

    def zzzzz(self, seconds):
        cnt = 0
        while cnt < seconds * 2:
            if self.end:
                self.alive = False
                return
            try:
                time.sleep(0.5)
            except:
                continue
            cnt += 1
        return

    def form_message(self, info: Dict[str, str]) -> str:
        """ Provides consistent formatting for outgoing messages.
        :param info: A dictionary containing data that should be converted to xml.
        :return unformatted xml.
        """
        if not self.response_type:
            raise NotImplementedError(f'This module does not have a defined message format.')
        top_element = etree.Element('i-message', type=self.response_type)
        for key, val in info.items():
            tag = etree.Element(key)
            tag.text = val
            top_element.append(tag)
        return etree.tostring(top_element, encoding='utf-8', pretty_print=False)

    @abstractmethod
    def run(self):
        pass

    def log_stop_module(self, exception: Exception) -> None:
        """ This is a convenience method that handles logging done when a module stops due to a handled cause.
        :param exception: Exception object that should be logged.
        """
        self.logger.warning(f'Stopped: {exception}', stacklevel=2)
        if self.restart_on_fail:
            getLogger(LOGGER_MAIN).debug(f'Stopped: {exception}', stacklevel=2)
        else:
            getLogger(LOGGER_MAIN).error(f'Could not be started: {exception}', stacklevel=2)
        time.sleep(1)
