# -*- coding: utf-8 -*-
"""
Created on Wed Jan 23 12:23:53 2019

@author: Domen
"""
import base64
import datetime
import os
import platform
import socket
import threading
import time
from datetime import date
from pathlib import Path
from shutil import copy2
from zipfile import ZipFile

import cestel_helpers.i_conf_manager as i_conf
import tomlkit.exceptions
from cestel_helpers.exceptions import ConfError
from cestel_helpers.i_conf_manager import get_id
from cestel_helpers.log import init_logger
from cestel_helpers.version import get_version
from lxml import etree

import config
from abstract.module import Module
from consts import COMM_PORT, LOG_NO_DATE
from exceptions import RestartI, StopModule


class CommunicationModule(Module):
    def __init__(self, args):
        self.buffer = 2048
        self.backlog = 5
        args[LOG_NO_DATE] = True
        Module.__init__(self, args, mandatory_keys=(COMM_PORT,))
        self.response_type = 'comm_response'
        self.parser = etree.XMLParser(remove_blank_text=True)
        self.num_of_clients = 0
        self.clients_lock = threading.Lock()
        self.reset_i = False

        # This logger is meant exclusively for logging interactions with system (such as restarts of SiWIM-I, uploading of new configurations and rebooting).
        self.acc_logger = init_logger(f'access', folder=config.log_dir, level=10, console_level=20, to_console=True, no_date=True)

    def are_we_resetting_1(self):
        if self.reset_i:
            raise RestartI

    def set_end(self):
        self.end = True
        self.s.close()

    def clientthread(self, conn, addr):
        if self.end or self.reset_i:
            return
        with self.clients_lock:
            self.num_of_clients += 1
        try:
            line = ""
            # timeout = 10
            starting_tag_id = -1
            closing_tag_id = -1
            while (starting_tag_id < 0 or closing_tag_id < 0):
                line = line + conn.recv(self.buffer).decode()
                starting_tag_id = line.find("<i-message ")
                closing_tag_id = line.find("</i-message>")
                if starting_tag_id >= 0 and closing_tag_id >= 0:
                    break
            if starting_tag_id < 0 or closing_tag_id < 0:
                conn.send(self.form_message({'status': 'ERR'}))
                conn.close()
                return
            try:
                original_xml = etree.fromstring(line, self.parser)
            except:
                self.logger.info(f'Line: {line}.')
                conn.send(self.form_message({'status': 'ERR'}))
                conn.close()
                return
            if original_xml.attrib['type'] == 'i_alive':
                self.acc_logger.debug(f'{addr} requested alive check.')
                conn.send(self.form_message({'status': 'OK', 'version': get_version(), 'site': config.site_name, 'device_id': get_id()}))
            elif original_xml.attrib['type'] == 'conf_request':
                if config.conf_dir.exists():
                    self.acc_logger.info(f'{addr} requested configuration.')
                    msg = self.form_message({'status': 'OK'})
                    msg_node = etree.fromstring(msg)
                    ver_xml = etree.Element('version')
                    ver_xml.text = get_version()
                    msg_node.append(ver_xml)
                    confs_xml = etree.Element('conf')
                    for conf_name, toml_data in i_conf.read_confs(config.conf_dir, downstream=True, general=True).items():
                        el = etree.Element(conf_name)
                        el.text = i_conf.serialize_conf(toml_data, sort=True)
                        confs_xml.append(el)
                    msg_node.append(confs_xml)
                    module_str = etree.tostring(msg_node, encoding='utf-8')
                    self.logger_dump.debug(f'Sent: {module_str}')
                    conn.send(module_str)
                    self.acc_logger.debug('Configuration sent.')
                else:
                    self.acc_logger.error(f'Configuration not found at {config.conf_dir}.')
                    conn.send(self.form_message({'status': 'ERR'}))
            elif original_xml.attrib['type'] == 'conf':
                curr_ts = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')
                self.acc_logger.info(f'{addr} sent new configuration.')
                self.logger_dump.debug(f'Configuration {curr_ts}: {etree.tostring(original_xml)}')
                # Extract data.
                delete_tag = original_xml.find('delete')
                confs_xml = original_xml.find('confs')
                if confs_xml is None and delete_tag is None:
                    self.acc_logger.warning(f'{addr} uploaded empty configuration.')
                    return
                # Backup the old conf files.
                with ZipFile(Path(config.conf_backup_dir, f'siwim_i_{curr_ts}.zip'), 'w') as backup:
                    for conf in config.conf_dir.glob('*.toml'):
                        backup.write(conf, conf.name)
                self.acc_logger.info(f'Previous configuration saved with ID {curr_ts}.')
                # Create a copy of those files for convenience.
                copy2(Path(config.conf_backup_dir, f'siwim_i_{curr_ts}.zip'), Path(config.conf_backup_dir, 'siwim_i_latest.zip'))
                # Delete removed modules.
                if delete_tag is not None:
                    self.acc_logger.debug(f'Delete tag: {etree.tostring(delete_tag)}.')
                    for mod in delete_tag.getchildren():
                        try:
                            conf_path = Path(config.conf_dir, f'{mod.tag}.toml')
                            conf = i_conf.read_conf(conf_path)
                            conf.pop(mod.text)
                            i_conf.write_conf(conf, conf_path)
                            self.acc_logger.debug(f'Removed {mod.tag}:{mod.text}.')
                        except tomlkit.exceptions.NonExistentKey:
                            self.acc_logger.warning(f'Failed to remove {mod.tag}:{mod.text} because it does not exist.')
                        except:
                            self.acc_logger.exception(f'Failed to remove {mod.tag}:{mod.text}.')
                # Save the new configuration files.
                if confs_xml is not None:
                    self.logger_dump.info(f'Received confs:\n{original_xml.find("confs").text}')
                    confs = {}  # This is a Dict of TOMLDocuments.
                    for module in original_xml.find('confs').getchildren():
                        received_conf = i_conf.deserialize_conf(module.text)
                        conf_path = Path(config.conf_dir, f'{module.tag}.toml')
                        try:  # Preserve defaults for undefined keys.
                            if module.tag != 'downstream':  # Downstream confs should not be merged
                                conf = i_conf.read_conf(conf_path)
                                conf.update(received_conf)
                            else:
                                conf = received_conf
                        except ConfError as e:  # If the conf doesn't exist on the system, we don't know what the defaults for any undefined keys are.
                            self.acc_logger.error(f'Something went wrong updating {conf_path.name}: {e}. Old file discarded.')
                            conf = received_conf
                        confs[module.tag] = conf
                    try:
                        i_conf.write_confs(confs, path=config.conf_dir, sort=True)
                    except ConfError as e:
                        self.acc_logger.error(f'An error occurred while saving configuration. You can find previous configuration in {config.conf_backup_dir}.')
                        self.logger.error(f'One or more confs could not be written: {e}')
                    self.acc_logger.info(f'{addr} has updated confs: {",".join(confs.keys())} ({curr_ts}).')
                self.reset_i = True
                self.set_end()
            elif original_xml.attrib['type'] == 'i_reset':
                self.reset_i = True
                self.acc_logger.info(f'{addr} requested reset.')
                self.set_end()
            elif original_xml.attrib['type'] == 'download_log':
                if original_xml.find("module_name") is None:
                    self.acc_logger.error(f'{addr} request missing module_name tag.')
                    return
                self.acc_logger.info(f'{addr} requested log for {original_xml.find("module_name").text}.')
                tdyy = date.today().strftime('%Y-%m-%d')
                tdyy_parts = tdyy.split('-')
                for i in range(1, len(tdyy_parts)):
                    if len(tdyy_parts[i]) < 2:
                        tdyy_parts[i] = '0' + tdyy_parts[i]
                tdyy = '-'.join(tdyy_parts)
                try:
                    log_path = Path(config.log_dir, f'{tdyy}_{original_xml.find("module_name").text}.log')  # Set the path to daily log path.
                    if not log_path.exists():  # If it doesn't exist, check if the general log exists.
                        log_path = Path(config.log_dir, f'{original_xml.find("module_name").text}.log')
                    if not log_path.exists():
                        raise FileNotFoundError(f'Log file {log_path} missing.')
                    with open(log_path, encoding='utf-8') as fd:
                        log_txt = fd.read()
                    encoded = base64.b64encode(log_txt.encode())
                    self.acc_logger.debug(f'Log sent.')
                    conn.send(self.form_message({'status': 'OK', 'content': encoded}))
                except FileNotFoundError as e:
                    self.acc_logger.warning(e)
                    conn.send(self.form_message({'status': 'ERR', 'message': str(e)}))
                except Exception as e:
                    self.acc_logger.error(f'Unhandled error: {e}.')
                    conn.send(self.form_message({'status': 'ERR', 'message': 'Unidentified error.'}))
            elif original_xml.attrib['type'] == 'download_swmstat':
                try:
                    self.acc_logger.info(f'{addr} requested swmstat.')
                    with open(Path(config.log_dir, 'swmstat.xml'), encoding='utf-8') as fd:
                        log_txt = fd.read()
                    encoded = base64.b64encode(log_txt.encode())
                    conn.send(self.form_message({'status': 'OK', 'content': encoded}))
                except Exception as e:
                    self.acc_logger.exception(f'Failed to send swmstat: {e}.')
                    conn.send(self.form_message({'status': 'ERR'}))
            elif original_xml.attrib['type'] == 'reboot':
                self.acc_logger.info(f'{addr} requested reboot.')
                if platform.system() == 'Windows':
                    conn.send(self.form_message({'status': 'OK'}))
                    os.system('shutdown -t 0 -r -f')  # Restart immediately and force all applications to close. TODO Change to subprocess for Python3
                elif platform.system() == 'Linux':
                    conn.send(self.form_message({'status': 'OK'}))
                    os.system('reboot now')
                else:
                    conn.send(self.form_message({'status': 'ERR', 'message': f'Command not supported for {platform.system()}.'}))
            conn.close()
        except ConnectionResetError as e:
            self.logger.debug(f'Connection closed: {e}')
        except:
            self.logger.exception('Fatal error:')
        finally:  # Executes regardless of how the block exits.
            with self.clients_lock:
                self.num_of_clients -= 1

    def run(self):
        self.alive = True
        self.end = False
        try:
            if not self.bind_tcp(timeout=None):
                raise StopModule(f'Binding to port {self.port} failed!')
            self.s.listen(self.backlog)
            while not self.end:
                try:
                    conn, addr = self.accept_tcp(timeout=None)
                    threading.Thread(target=self.clientthread, args=(conn, addr[0])).start()
                except (socket.error, OSError):
                    break
            if self.end:
                while True:
                    with self.clients_lock:
                        if self.num_of_clients == 0:
                            self.alive = False
                            raise StopModule('Thread closed correctly.')
                    time.sleep(0.5)
        except StopModule as e:
            self.log_stop_module(e)
        except:
            self.logger.exception('Fatal error:')
        finally:
            self.s.close()
        self.alive = False
