import base64
import hashlib
import re
import socket
import threading
import time
from pathlib import Path

from lxml import etree

import config
from abstract.module import Module
from consts import PIC_BACKLOG, PIC_CAM_NAMES, PIC_CAM_TYPES, PIC_PORT, PIC_THREADS
from exceptions import NoData, NoPhoto, StopModule


def md5(fname) -> str:
    hash_md5 = hashlib.md5()
    with open(fname, 'rb') as f:
        for chunk in iter(lambda: f.read(4096), bytes()):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()


class PhotoSiwim(Module):
    def __init__(self, args):
        # Define defaults that should be overwritten by the conf.
        self.backlog = 5
        self.max_clients = 50
        # Initialize the module, passing it a list of keys that must exist for it to work properly and a list of optional keys that are recognized by it.
        Module.__init__(self, args, mandatory_keys=(PIC_PORT, PIC_CAM_TYPES, PIC_CAM_NAMES), optional_keys=(PIC_BACKLOG, PIC_THREADS))
        self.type = 'siwim'
        self.response_type = 'pic_response'
        self.parser = etree.XMLParser(remove_blank_text=True)
        self.s = None
        self.num_of_clients = 0
        self.clients_lock = threading.Lock()

    def set_end(self):
        self.end = True
        # this will initiate shutdown procedure; a bit hacky, but this is how we do things around here
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            s.connect(("127.0.0.1", self.port))
            s.close()
        except ConnectionRefusedError:  # This will be raised when module is shutdown by Ctrl+C.
            self.logger.debug('Failed to reset picture, because connection was refused.')
        except:
            pass

    def find_path(self, dt, lane, site, photo_root, story, camera_name, camera_type):
        ts_hour = re.findall(r'\d{4}-\d{2}-\d{2}-\d{2}', dt)[0]  # Matches the part of timestamp string up to the hour.
        ts_day = ts_hour.rsplit('-', 1)[0]  # Strips 'hh' part.
        dt_dy = f'{ts_day}-hh-mm-ss'
        dt_hr = f'{ts_hour}-mm-ss'
        if config.site_name != site:
            self.logger.warning(f'Requested photo from site {site}, but active site is {config.site_name}.')
        partial_path = str(Path(config.sites_dir, site, 'ext', photo_root))

        story_dict = {}
        for i in range(story * (-1), story + 1):
            # not ready for such huge stories
            if i > 99:
                continue
            idd = str(abs(i))
            if i < 10:
                idd = "0" + idd
            if i < 0:
                idd = "n" + idd
            else:
                idd = "p" + idd
            story_dict[f'{dt}_{lane}_{idd}.jpg'] = i

        paths_dict = {}

        # If camera type/name have been defined in picture request, use those.
        cam_types = self.default_cam_types
        if camera_type:
            cam_types = [camera_type, ]
        cam_names = self.default_cam_names
        if camera_name:
            cam_names = [camera_name, ]
        self.logger.debug(f'Camera types: {cam_types}. Camera names: {cam_names}.')

        for cam_type in cam_types:
            try:
                less_partial_path = Path(partial_path, cam_type)
                for camera_dir in less_partial_path.iterdir():
                    if len(cam_names) > 0 and camera_dir.name not in cam_names:
                        continue
                    path = Path(camera_dir, dt_dy, dt_hr)
                    self.logger.debug(f'Looking for {story_dict} at {path}.')
                    for picture in path.iterdir():
                        if picture.name in story_dict.keys():
                            paths_dict[story_dict[picture.name]] = picture
                            if len(paths_dict) == 2 * story + 1:
                                return paths_dict
                    else:
                        self.logger.debug(f'Only {len(paths_dict)} out of {len(story_dict)} photos found for {camera_dir}.')
            except FileNotFoundError as e:
                self.logger.warning(f'Folder missing from camera structure. This often occurs when a camera module exists for a non-existent/non-functional camera: {e}.')
                continue
        return paths_dict

    def clientthread(self, conn):
        if self.end:
            return
        with self.clients_lock:
            self.num_of_clients += 1
        try:
            msg_start = b'<i-message'
            msg_end = b'</i-message>'
            while True:
                self.throttle()
                conn.settimeout(2)  # Messages that we receive are relatively short, so we set a timeout.
                received_xml, buffer = self.acquire_data(conn, b'', msg_start, msg_end, 'xml', exit_on_timeout=True)

                conn.setblocking(True)  # When sending the photos back we don't want the connection to time out.
                i_message = received_xml
                if 'type' in i_message.attrib and i_message.attrib['type'] == 'pic_request':
                    picture = i_message.find("picture")
                    ts = picture.find("ts").text
                    lane = picture.find("lane").text
                    site_name = config.site_name
                    self.logger.debug(f'Message: {etree.tostring(picture)}')
                    try:
                        story = max(0, int(picture.find('max_story').text))
                    except (ValueError, AttributeError):
                        story = 0
                    try:
                        encoding = picture.find('encoding').text
                    except:
                        encoding = 'default'
                    pic_type = picture.attrib.get('type')
                    if pic_type == 'anpr' or pic_type == 'lpr':  # Legacy
                        photo_root = 'lpr'
                    elif pic_type == 'adr':
                        photo_root = 'adr'
                    else:
                        photo_root = 'photo'
                    try:
                        camera_name = picture.find("camera_name").text
                    except:
                        camera_name = None
                    try:
                        camera_type = picture.find("camera_type").text
                    except:
                        camera_type = None
                    picture_data_dict = {}
                    paths_dict = self.find_path(ts, lane, site_name, photo_root, story, camera_name, camera_type)
                    try:
                        for key, val in paths_dict.items():
                            path = paths_dict[key]
                            with open(path, 'rb') as fd:
                                pic_data = fd.read()
                            md = md5(path)
                            picture_data_dict[key] = (md, pic_data)
                    except:
                        self.logger.debug('Failed to fill picture_data_dict.')
                        raise NoPhoto('Did not find the .jpeg file.')
                    try:
                        md = picture_data_dict[0][0]
                        pic_data = picture_data_dict[0][1]
                    except KeyError:
                        self.logger.info(f'No picture data in picture_data_dict for {ts}')
                        raise NoPhoto(f'No picture data.')
                    if pic_data is not None and md is not None:
                        dct = {"status": "OK", "size": str(len(pic_data)), "format": "jpg", "encoding": encoding}
                        if encoding == 'base64':
                            dct['pictures'] = ''
                            dct['md5s'] = ''
                            msg = etree.fromstring(self.form_message(dct))
                            for key, val in picture_data_dict.items():
                                # print key
                                tag = etree.Element("picture")
                                tag.attrib["index"] = str(key)
                                tag.text = base64.b64encode(val[1])
                                msg.find("pictures").append(tag)
                                tag = etree.Element("md5")
                                tag.attrib["index"] = str(key)
                                tag.text = val[0]
                                msg.find("md5s").append(tag)
                            conn.send(etree.tostring(msg, pretty_print=False))
                            return
                        elif encoding == "default" or encoding == "binary":
                            dct["md5"] = md
                            conn.send(self.form_message(dct))
                            conn.send(pic_data)
                            return
                    else:
                        raise NoPhoto('Requested photo not (yet?) present on site.')
                else:
                    raise NoPhoto('Unrecognised request.')

                if buffer == b'':  # Generally the whole loop should only happen once, but if more than one request was received, we want to process them all.
                    break
                self.logger.info('More than one photo request received at once. Processing next one ...')
        except (NoData, NoPhoto) as e:
            conn.send(self.form_message({'status': 'ERR', 'message': str(e)}))
        except:
            self.logger.exception('General error:')
            try:
                conn.send(self.form_message({'status': 'ERR', 'message': 'A general error. Connection or something else.'}))
            except:
                self.logger.warning('Could not send error message to client because it reset connection.')
        finally:
            with self.clients_lock:
                self.num_of_clients -= 1
            conn.close()

    def run(self):
        self.alive = True
        self.end = False
        self.set_upstream_info(self.name, 'photo_server_port', self.port)
        try:
            if not self.bind_tcp(timeout=None):
                raise StopModule(f'Binding to port {self.port} failed!')
            self.s.listen(self.backlog)

            while True:
                conn, _ = self.accept_tcp()
                if self.end:
                    while True:
                        with self.clients_lock:
                            self.logger.info(f'Number of clients: {self.num_of_clients}.')
                            if self.num_of_clients == 0:
                                self.alive = False
                                raise StopModule('No more clients connected.', conn=conn)
                        time.sleep(0.5)
                if self.num_of_clients < self.max_clients:
                    threading.Thread(target=self.clientthread, args=(conn,)).start()
                else:
                    raise StopModule('Thread limit reached.', conn=conn)
        except StopModule as e:
            self.log_stop_module(e)
        except:
            self.logger.exception('Fatal error:')
        finally:
            self.s.close()
        self.alive = False
