import datetime
import socket
import threading
import time
from collections import deque
from pathlib import Path

from cestel_helpers.i_conf_manager import SEC_DEFAULT, read_conf
from lxml import etree

import config
from abstract.module import Module
from consts import CONF_GLOBAL, OUT_BACKLOG, OUT_HEARTBEAT, OUT_MAX_THREADS, OUT_PORT, OUT_SAVE, ROOT_OUTPUT, SAVE_FOOTER, SAVE_HEADER, SAVE_PATH, SAVE_SUFFIX, TS_FORMAT_STRING
from exceptions import StopModule


class OutputModule(Module):
    def __init__(self, args):
        self.max_clients = 50
        self.backlog = 5
        # Initialize the module, passing it a list of keys that must exist for it to work properly and a list of optional keys that are recognized by it.
        Module.__init__(self, args, mandatory_keys=(OUT_HEARTBEAT, OUT_PORT, OUT_SAVE), optional_keys=(OUT_BACKLOG, OUT_MAX_THREADS))
        self.type = 'siwim_server'
        self.response_type = 'output_heartbeat'
        self.parser = etree.XMLParser(remove_blank_text=True)
        # client-thread indice, vehicle deque pairs
        self.vehicles_dict = {}
        self.clients_lock = threading.Lock()
        self.writing_lock = threading.Lock()
        self.s = None

    def set_end(self):
        self.end = True
        self.s.close()  # This will close the socket, causing accept() call to fail with OSError.

    def find_lowest_unused_id(self):
        n = 0
        with self.clients_lock:
            while True:
                if n not in self.vehicles_dict:
                    return n
                n += 1

    def add_vehicle(self, vehicle, module_name="irrelevant"):
        # iterate through filtering rules
        if len(vehicle.getchildren()) == 0:
            return
        if self.save_data:
            save_info = {
                SAVE_PATH: 'output',
                SAVE_HEADER: f'<?xml version="1.0" ?>\n<swd version="1"><site><name>{config.site_name}</name><vehicles>',
                SAVE_FOOTER: '</vehicles></site></swd>',
                SAVE_SUFFIX: 'xml'
            }
            with self.writing_lock:
                self.write_data(save_info, etree.tostring(vehicle, encoding='utf-8').decode())
        with self.clients_lock:
            for key in self.vehicles_dict:
                self.vehicles_dict[key].append(vehicle)

    def clientthread_xml(self, conn, idx):
        try:
            last_send_at = datetime.datetime.now()
            while True:
                self.throttle()
                if self.end:
                    with self.clients_lock:
                        del self.vehicles_dict[idx]
                    self.logger.debug(f'Ending client {idx}.')
                    try:
                        conn.shutdown(socket.SHUT_WR)
                    except OSError as e:
                        self.logger.info(f'Client {idx}: {e}')
                    conn.close()
                    self.logger.info(f'Client {idx} ended.')
                    return
                if len(self.vehicles_dict[idx]) > 0:
                    try:
                        with self.clients_lock:
                            vehicle = self.vehicles_dict[idx].popleft()
                        complete_xml = '<swd version="1"><site><name>' + config.site_name + '</name><vehicles>' + etree.tostring(vehicle, encoding='utf-8').decode() + '</vehicles></site></swd>'
                        to_send = etree.fromstring(complete_xml, self.parser)
                    except:
                        self.logger.exception('Failed to form a string from xml:')
                        time.sleep(0.5)
                        continue
                    data = vehicle.find('wim/ts').text if vehicle.find('wim/ts') is not None else 'unknown'
                    self.logger.debug(f'Sending {data} to {idx}')
                    conn.send(etree.tostring(to_send, encoding='utf-8', pretty_print=False))
                    last_send_at = datetime.datetime.now()
                if self.heartbeat_interval != 0 and (datetime.datetime.now() - last_send_at).total_seconds() > self.heartbeat_interval:
                    conn.send(self.form_message({'ts': datetime.datetime.now().strftime(TS_FORMAT_STRING)[:-3]}))
                    last_send_at = datetime.datetime.now()
                time.sleep(0.5)
        except (BrokenPipeError, ConnectionAbortedError, ConnectionResetError) as e:
            self.logger.info(f'Client {idx} closed: {e}')
        except:
            self.logger.exception(f'Client {idx} terminated:')
        finally:
            try:
                with self.clients_lock:
                    del self.vehicles_dict[idx]
                    self.logger.debug(f'Client {idx} removed from active thread list.')
                conn.close()
            except KeyError:
                self.logger.warning(f'Client {idx} does not exist in {self.vehicles_dict}.')
            except Exception as e:
                self.logger.warning(f'Client {idx}; Cleanup failed: {e}')  # TODO Determine if this exception can happen and if it needs to be handled.

    def run(self):
        self.alive = True
        self.end = False
        if read_conf(config.conf_dir / CONF_GLOBAL).get(SEC_DEFAULT, {}).get(ROOT_OUTPUT) == self.name:  # Check which output module is defined to save to ext, if any.
            out_folder = 'ext'
        else:
            out_folder = Path('ext', 'output', self.type, self.name)
        self.set_upstream_info(self.name, "data_server_port", self.port)
        self.set_upstream_info(self.name, 'rsync_event_path', Path(config.sites_dir, config.site_name, out_folder))
        try:
            try:
                if not self.bind_tcp(timeout=None):
                    raise StopModule(f'Binding to port {self.port} failed!', conn=self.s)
                self.s.listen(self.backlog)
                self.logger.debug('Server started.')
            except OSError as e:
                raise StopModule(msg=f'Failed to bind socket: {e}')
            except:
                self.logger.exception('Fatal error:')
                self.alive = False
                self.logger.debug('Thread closed correctly.')
                raise StopModule
            while True:
                try:
                    self.logger.debug('Waiting for connections ...')
                    conn, addr = self.accept_tcp(timeout=None)
                    self.logger.debug(f'Connection from {addr} accepted.')
                except OSError as e:
                    raise StopModule(msg=f'Shutdown flag detected: {e}')
                # when there are no child threads, end the "main" thread
                if self.end:
                    raise StopModule(conn=conn, msg='Shutdown flag detected.')
                idx = self.find_lowest_unused_id()
                self.logger.debug(f'Lowest unused thread ID: {idx}. Limit is {self.max_clients}')
                if idx < self.max_clients:
                    with self.clients_lock:
                        self.vehicles_dict[idx] = deque()
                    self.logger.info(f'Starting client thread to {addr} with ID {idx}')
                    threading.Thread(target=self.clientthread_xml, args=(conn, idx)).start()
                else:
                    raise StopModule(msg='Thread limit reached.')
        except StopModule as e:
            self.log_stop_module(e)
        except:
            self.logger.exception('Fatal error:')
        finally:
            self.s.close()
        self.alive = False
