import datetime
import os
import socket
import threading
import time
from collections import deque

from lxml import etree

from generic_module import Module


class OutputModule(Module):
    def __init__(self, args):
        expected_params = {"port": 8170, "backlog": 5, "root_path": None, "max_client_threads": 50, "filter_rules": [("anpr", "not_sibling_of", "wim"), ("lwh", "not_sibling_of", "wim")], "heartbeat_interval": 0, "save_data": 1}
        Module.__init__(self, args, expected_params)
        self.parser = etree.XMLParser(remove_blank_text=True)
        # client-thread indice, vehicle deque pairs
        self.vehicles_dict = dict()
        self.clients_lock = threading.Lock()
        self.sock = None

    def form_message(self, info, typ):
        top_element = etree.Element("i-message", type=typ)
        for key, val in info.items():
            tag = etree.Element(key)
            tag.text = val
            top_element.append(tag)
        return etree.tostring(top_element, pretty_print=False)

    def set_end(self):
        self.end = True
        # this will initiate shutdown procedure; a bit hacky, but this is how we do things around here
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.connect(("127.0.0.1", self.port))
        try:
            s.close()
        except:
            pass

    def find_lowest_unused_id(self):
        n = 0
        with self.clients_lock:
            while True:
                if n not in self.vehicles_dict:
                    return n
                n += 1

    def save_to_file(self, str_vehicle):
        # these two ports are used by kalmia who coincidentally has the most restrictions regarding xml
        # so saving only "the best" should be the way for now --> for at least as long as we all play by the rules
        if self.port == 8170 or self.port == 50001:
            out_folder = "/ext/"
        else:
            out_folder = "/ext/" + self.name + "/"
        date = datetime.datetime.now().strftime('%Y-%m-%d')
        if not os.path.exists(self.sites_path + "/" + self.sname + out_folder):
            os.makedirs(self.sites_path + "/" + self.sname + out_folder)
        fd = None
        if os.path.exists(self.sites_path + "/" + self.sname + out_folder + date + ".xml"):
            fd = open(self.sites_path + "/" + self.sname + out_folder + date + ".xml", "a+")
        else:
            first_line = '<?xml version="1.0" ?>\n<swd version=' + '"' + "1" + '"><site><name>' + self.sname + '</name><vehicles>'
            fd = open(self.sites_path + "/" + self.sname + out_folder + date + ".xml", "a+")
            fd.write(first_line)
        if fd == None:
            return
        fd.seek(0, os.SEEK_END)
        pos = fd.tell() - 1
        # find the last vehicle closing tag
        try:
            while pos > 0 and fd.read(10) != "</vehicle>":
                pos -= 1
                fd.seek(pos, os.SEEK_SET)
                # delete everything from the end of last vehicle closing tag onwards
        except IOError as e:
            self.logger.critical('Failsafe meant for removing incorrectly formatted data failed: {}. Daily file is likely corrupted.'.format(str(e)))
        if pos > 0:
            fd.seek(pos + 10, os.SEEK_SET)
            fd.truncate()
        fd.write("\n" + str_vehicle)
        last_line = "</vehicles></site></swd>"
        fd.write("\n" + last_line)
        fd.close()

    def add_vehicle(self, vehicle, module_name="irrelevant"):
        # iterate through filtering rules
        for filter_rule in self.filter_rules:
            levels = filter_rule[0].split("/")
            lmnts = vehicle.findall(levels[0])
            # simple ugly way since we're only anticipating two levels
            if len(levels) == 1:
                for lmnt in lmnts:
                    if len(filter_rule) == 3 and filter_rule[1] == "is_sibling_of":
                        if lmnt.getparent().find(filter_rule[2]) != None:
                            lmnt.getparent().remove(lmnt)
                    elif len(filter_rule) == 3 and filter_rule[1] == "not_sibling_of":
                        if lmnt.getparent().find(filter_rule[2]) == None:
                            lmnt.getparent().remove(lmnt)
                    elif len(filter_rule) == 1:
                        lmnt.getparent().remove(lmnt)
            elif len(levels) == 2:
                for lmnt in lmnts:
                    chldlmnts = lmnt.findall(levels[1])
                    for chldlmnt in chldlmnts:
                        if len(filter_rule) == 3 and filter_rule[1] == "is_sibling_of":
                            if chldlmnt.getparent().find(filter_rule[2]) != None:
                                chldlmnt.getparent().remove(chldlmnt)
                        elif len(filter_rule) == 3 and filter_rule[1] == "not_sibling_of":
                            if chldlmnt.getparent().find(filter_rule[2]) == None:
                                chldlmnt.getparent().remove(chldlmnt)
                        elif len(filter_rule) == 1:
                            chldlmnt.getparent().remove(chldlmnt)
        if len(vehicle.getchildren()) == 0:
            return
        if self.save_data == 1:
            self.save_to_file(etree.tostring(vehicle).decode())
        with self.clients_lock:
            for key in self.vehicles_dict:
                self.vehicles_dict[key].append(vehicle)

    def is_doing_something(self):
        return len(self.vehicles_dict) > 0

    def clientthread_xml(self, conn, idx):
        try:
            last_send_at = datetime.datetime.now()
            while True:
                if self.end:
                    with self.clients_lock:
                        del self.vehicles_dict[idx]
                    self.logger.debug('Ending client thread: {}'.format(idx))
                    try:
                        conn.shutdown(socket.SHUT_WR)
                        conn.close()
                    except:  # ConnectionResetError
                        pass
                    self.logger.info('Client thread {} succesfully ended.'.format(idx))
                    return
                if len(self.vehicles_dict[idx]) > 0:
                    try:
                        with self.clients_lock:
                            vehicle = self.vehicles_dict[idx].popleft()
                        complete_xml = '<swd version="1"><site><name>' + self.sname + '</name><vehicles>' + etree.tostring(vehicle).decode() + '</vehicles></site></swd>'
                        to_send = etree.fromstring(complete_xml, self.parser)
                    except:
                        self.logger.exception('Failed form a sendable string from xml:')
                        """
                        try:
                          self.clients_lock.release()
                        except Exception:
                          pass
                        """
                        time.sleep(0.5)
                        continue
                    # print id
                    self.logger.debug('Sending data on thread {}'.format(idx))
                    conn.send(etree.tostring(to_send, pretty_print=False))
                    last_send_at = datetime.datetime.now()
                if self.heartbeat_interval != 0 and (datetime.datetime.now() - last_send_at).total_seconds() > self.heartbeat_interval:
                    conn.send(self.form_message({"ts": datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")[:-3]}, "output_heartbeat"))
                    last_send_at = datetime.datetime.now()
                time.sleep(0.5)
        # except BrokenPipeError as e:  # Python38 only
        #     self.logger.warning('Client thread with id {} closed due to: {}'.format(idx, e))
        except:
            self.logger.exception('Clientthread detrimental:')
            self.logger.warning('Closing client thread: ' + str(idx))
            """
            try:
              self.clients_lock.release()
            except Exception:
              pass
            """
            # print "Ending thread: " + str(len(self.vehicles_dict)) + "\n"
            with self.clients_lock:
                del self.vehicles_dict[idx]
            conn.close()
            return

    def run(self):
        self.alive = True
        self.end = False
        if self.port == 8170 or self.port == 50001:
            out_folder = "/ext/"
        else:
            out_folder = "/ext/" + self.name + "/"
        for key, mod in self.downstream_modules_dict.items():
            mod.set_upstream_info(self.get_name(), 'data_server_port', self.port)
            mod.set_upstream_info(self.get_name(), 'rsync_event_path', self.sites_path + "/" + self.sname + out_folder)
        try:
            self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.sock.bind(('', self.port))
            self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            self.sock.listen(self.backlog)
        except:
            self.logger.exception('Fatal error:')
            self.alive = False
            self.logger.debug('Thread closed correctly.')
            return
        try:
            while True:
                conn, addr = self.sock.accept()
                # when there are no child threads, end the "main" thread
                if self.end:
                    while True:
                        with self.clients_lock:
                            if len(self.vehicles_dict) == 0:
                                try:
                                    conn.shutdown(socket.SHUT_WR)
                                    conn.close()
                                except:
                                    self.logger.debug('Connection close failed:', exc_info=True)
                                self.sock.close()
                                self.alive = False
                                self.logger.debug('Thread closed correctly.')
                                return
                        self.zzzzz(0.5)
                idx = self.find_lowest_unused_id()
                if idx < self.max_client_threads:
                    with self.clients_lock:
                        self.vehicles_dict[idx] = deque()
                    self.logger.info('Starting client thread to {} with ID {}'.format(addr, idx))
                    threading.Thread(target=self.clientthread_xml, args=(conn, idx)).start()
                else:
                    self.end = True
                    self.logger.info('Thread limit reached.')
        except:
            self.logger.exception('Fatal error:')
            self.sock.close()
        self.alive = False
