import base64
import hashlib
import os
import socket
import subprocess
import threading

from lxml import etree

from generic_module import Module


class PictureModule(Module):
    def __init__(self, args):
        expected_params = {"port": 8171, "backlog": 5, "recv_buffer": 2048, "root_path": None, "old": 0, "default_cam_types": [], "default_cam_names": [], "max_client_threads": 50}
        Module.__init__(self, args, expected_params)
        self.parser = etree.XMLParser(remove_blank_text=True)
        self.sock = None
        self.num_of_clients = 0
        self.clients_lock = threading.Lock()

    def set_end(self):
        self.end = True
        # this will initiate shutdown procedure; a bit hacky, but this is how we do things around here
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        try:
            s.connect(("127.0.0.1", self.port))
            s.close()
        except ConnectionRefusedError:
            self.logger.debug('Failed to reset picture, because connection was refused.')
            raise NotImplementedError
        except:
            pass

    def find_path(self, dt, lane, site, photo_root, story, camera_name, camera_type):
        dt_parts = dt.split("-")
        dt_dy = dt_parts[0] + "-" + dt_parts[1] + "-" + dt_parts[2] + "-hh-mm-ss"
        dt_hr = dt_parts[0] + "-" + dt_parts[1] + "-" + dt_parts[2] + "-" + dt_parts[3] + "-mm-ss"
        partial_path = self.sites_path + "/" + site + "/ext/" + photo_root
        # cam_types = list()

        story_dict = dict()
        for i in range(story * (-1), story + 1):
            # not ready for such huge stories
            if i > 99:
                continue
            idd = str(abs(i))
            if i < 10:
                idd = "0" + idd
            if i < 0:
                idd = "n" + idd
            else:
                idd = "p" + idd
            story_dict[dt + "_" + lane + "_" + idd + ".jpg"] = i

        paths_dict = dict()

        cam_types = self.default_cam_types
        if camera_type != "":
            cam_types = [camera_type, ]
        cam_names = self.default_cam_names
        if camera_name != "":
            cam_names = [camera_name, ]

        pictures = list()
        for cam_type in cam_types:
            # envelop this in this outer try/except to eliminate all anomalies
            try:
                less_partial_path = partial_path + "/" + cam_type

                partial_path_dirs = os.listdir(less_partial_path)
                # print cam_type
                for partial_path_dir in partial_path_dirs:
                    if len(cam_names) > 0 and partial_path_dir not in cam_names:
                        continue
                    path = less_partial_path + "/" + partial_path_dir + "/" + dt_dy + "/" + dt_hr
                    pictures = list()
                    try:
                        pictures = os.listdir(path)
                    except:
                        continue
                    for picture in pictures:
                        if picture in story_dict.keys():  # Python2 returns a list here, while Python3 a dict_keys object, but both work in this case.
                            paths_dict[story_dict[picture]] = path + "/" + picture
                            if len(paths_dict) == 2 * story + 1:
                                return paths_dict
            except:
                continue

        return paths_dict

    def md5(self, fname):
        hash_md5 = hashlib.md5()
        with open(fname, "rb") as f:
            for chunk in iter(lambda: f.read(4096), b""):
                hash_md5.update(chunk)
        return hash_md5.hexdigest()

    def md5_from_string(self, binary_string):
        hash_md5 = hashlib.md5()
        i = 0
        while True:
            i += 1
            if i * 4096 > len(binary_string):
                hash_md5.update(binary_string[(i - 1) * 4096:len(binary_string)])
                return hash_md5.hexdigest()
            else:
                hash_md5.update(binary_string[(i - 1) * 4096:i * 4096])

                # remove this ASAP. it's an ugly piece of heritage.

    def old_siwim_picture_getter(self, dt, site, story):
        fname = self.root_path + "/sites/" + site + "/live/camera/" + dt[0:10] + "-hh-mm-ss/" + dt[0:13] + "-mm-ss/" + dt
        fname1 = self.root_path + "/sites/" + site + "/live/camera1/" + dt[0:10] + "-hh-mm-ss/" + dt[0:13] + "-mm-ss/" + dt
        fname2 = self.root_path + "/sites/" + site + "/live/camera2/" + dt[0:10] + "-hh-mm-ss/" + dt[0:13] + "-mm-ss/" + dt
        the_fname = ""
        binary_images = dict()
        if os.path.isfile(fname + ".vehiclephotos"):
            the_fname = fname
        elif os.path.isfile(fname1 + ".vehiclephotos"):
            the_fname = fname1
        elif os.path.isfile(fname2 + ".vehiclephotos"):
            the_fname = fname2
        else:
            return binary_images
        out = subprocess.check_output(["vehiclephotos2jpg.exe", the_fname + ".vehiclephotos"])
        tree = etree.fromstring(out)
        images = tree.findall("image")

        zero_index = next(i for i in range(len(images)) if images[i].attrib["index"] == "0")  # images.find(lambda x: x.attrib["index"] == "0")
        image_id = int(images[zero_index].attrib["index"])
        image_blob = base64.b64decode(images[zero_index].text)
        image_hash = self.md5_from_string(image_blob)
        binary_images[image_id] = ((image_hash, image_blob))

        for i in range(1, story + 1):
            for j in range(2):
                f = 1
                if j % 2 == 0:
                    f = -1
                try:
                    image_id = int(images[zero_index + (i * f)].attrib["index"])
                    image_blob = base64.b64decode(images[zero_index + (i * f)].text)
                    image_hash = self.md5_from_string(image_blob)
                    binary_images[image_id] = ((image_hash, image_blob))
                except:
                    pass

        return binary_images

    def form_message(self, info):
        top_element = etree.Element("i-message", type="pic_response")
        for key, val in info.items():
            tag = etree.Element(key)
            tag.text = val
            top_element.append(tag)
        return etree.tostring(top_element, pretty_print=False).decode()

    def clientthread(self, conn):
        if self.end:
            return
        with self.clients_lock:
            self.num_of_clients += 1
        try:
            line = ""
            timeout = 5
            cnt = 0
            starting_tag_id = -1
            closing_tag_id = -1
            while cnt < timeout and (starting_tag_id < 0 or closing_tag_id < 0):
                line = line + conn.recv(self.recv_buffer).decode()

                starting_tag_id = line.find("<i-message ")
                closing_tag_id = line.find("</i-message>")
                cnt += 1
            if starting_tag_id < 0 or closing_tag_id < 0:
                conn.send(self.form_message({"status": "ERR", "message": "Did not receive a valid picture request."}).encode())
                conn.close()
                with self.clients_lock:
                    self.num_of_clients -= 1
                return
            original_xml = None
            try:
                original_xml = etree.fromstring(line, self.parser)
            except:
                self.logger.info('Line: {0}'.format(line))
                conn.send(self.form_message({"status": "ERR", "message": "Did not receive a valid picture request."}).encode())
                conn.close()
                with self.clients_lock:
                    self.num_of_clients -= 1
                return
            i_message = original_xml
            if i_message.attrib["type"] == "pic_request":
                picture = i_message.find("picture")
                ts = picture.find("ts").text
                lane = picture.find("lane").text
                site_name = picture.find("site_name").text
                # photo_type = picture.find("photo_type")
                story = 0
                self.logger.debug('Message: {}'.format(etree.tostring(picture)))
                try:
                    story = max(0, int(picture.find("max_story").text))
                except (KeyError, AttributeError):
                    self.logger.debug('Failed to get picture for {0} on lane {1}.'.format(ts, lane))
                encoding = "default"
                try:
                    encoding = picture.find("encoding").text
                except:
                    pass
                photo_root = "photo"
                camera_name = ""
                camera_type = ""
                try:
                    pic_type = picture.attrib["type"]
                    if pic_type == "anpr":
                        photo_root = "anpr"
                    elif pic_type == "overview":
                        photo_root = "photo"
                    elif pic_type == "siwim":
                        photo_root = "photo"
                except:
                    pass
                try:
                    camera_name = picture.find("camera_name").text
                except:
                    pass
                try:
                    camera_type = picture.find("camera_type").text
                except:
                    pass
                pic_data = None
                md = None
                picture_data_dict = dict()
                if self.old == 1:
                    try:
                        picture_data_dict = self.old_siwim_picture_getter(ts, site_name, story)
                        pic_data = picture_data_dict[0][1]
                        md = picture_data_dict[0][0]
                        # print md
                    except:
                        self.logger.debug('Failed to get siwim picture.')
                        conn.send(self.form_message({"status": "ERR", "message": "Error during finding/converting .vehiclephotos file."}).encode())
                        conn.close()
                        with self.clients_lock:
                            self.num_of_clients -= 1
                        return
                else:
                    paths_dict = self.find_path(ts, lane, site_name, photo_root, story, camera_name, camera_type)
                    try:
                        for key, val in paths_dict.items():
                            path = paths_dict[key]
                            fd = open(path, "rb")
                            pic_data = fd.read()
                            fd.close()
                            md = str(self.md5(path))
                            picture_data_dict[key] = (md, pic_data)
                    except:
                        self.logger.debug('Failed to fill picture_data_dict.')
                        conn.send(self.form_message({"status": "ERR", "message": "Did not find the .jpeg file."}).encode())
                        conn.close()
                        with self.clients_lock:
                            self.num_of_clients -= 1
                        return
                    try:
                        md = picture_data_dict[0][0]
                        pic_data = picture_data_dict[0][1]
                    except KeyError:
                        self.logger.info('No picture data in picture_data_dict for {0}'.format(ts))
                        with self.clients_lock:
                            self.num_of_clients -= 1
                        return
                if pic_data != None and md != None:
                    dct = {"status": "OK", "size": str(len(pic_data)), "format": "jpg", "encoding": encoding}
                    if encoding == "base64":
                        dct["pictures"] = ""
                        dct["md5s"] = ""
                        msg = etree.fromstring(self.form_message(dct))
                        for key, val in picture_data_dict.items():
                            # print key
                            tag = etree.Element("picture")
                            tag.attrib["index"] = str(key)
                            tag.text = base64.b64encode(val[1])
                            msg.find("pictures").append(tag)
                            tag = etree.Element("md5")
                            tag.attrib["index"] = str(key)
                            tag.text = val[0]
                            msg.find("md5s").append(tag)
                            # dct["pictures"] = "<picture" + "index=" + str(key) + ">" + base64.b64encode(val[1]) + "</picture>"
                        # dct["pictures"] = base64.b64encode(pic_data)
                        conn.send(etree.tostring(msg, pretty_print=False))
                        conn.close()
                        with self.clients_lock:
                            self.num_of_clients -= 1
                        return
                    elif encoding == "default" or encoding == "binary":
                        dct["md5"] = md
                        conn.send(self.form_message(dct).encode())
                        conn.send(pic_data)
                        conn.close()
                        with self.clients_lock:
                            self.num_of_clients -= 1
                        return
                else:
                    conn.send(self.form_message({"status": "ERR", "message": "Requested photo not (yet?) present on site."}).encode())
                    conn.close()
                    with self.clients_lock:
                        self.num_of_clients -= 1
                    return
            else:
                conn.send(self.form_message({"status": "ERR", "message": "Expected xml_request, got something else."}).encode())
                conn.close()
                with self.clients_lock:
                    self.num_of_clients -= 1
                return
        except:
            self.logger.exception('General error:')
            try:
                conn.send(self.form_message({"status": "ERR", "message": "A general error. Connection or something else."}).encode())
                conn.close()
            # catch everything for now
            except:
                self.logger.warning('Could not send error message to client because it reset connection.')
            with self.clients_lock:
                self.num_of_clients -= 1
            return

    # def count_alives(self):
    # cnt = 0
    # for thr in self.threads:
    # if thr.is_alive():
    # cnt+=1
    # return cnt

    def run(self):
        self.alive = True
        self.end = False
        for key, mod in self.downstream_modules_dict.items():
            mod.set_upstream_info(self.name, "photo_server_port", self.port)
        try:
            self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.sock.bind(('', self.port))
            self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            self.sock.listen(self.backlog)
        except:
            self.logger.exception('Fatal error:')
            self.alive = False
            self.zzzzz(2)
            self.logger.debug('Thread closed correctly.')
            return
        try:
            while True:
                conn, addr = self.sock.accept()
                if self.end:
                    while True:
                        with self.clients_lock:
                            self.logger.info(str(self.num_of_clients))
                            if self.num_of_clients == 0:
                                try:
                                    conn.close()
                                except:
                                    pass
                                self.sock.close()
                                self.alive = False
                                self.logger.debug('Thread closed correctly.')
                                return
                        self.zzzzz(0.5)
                if self.num_of_clients < self.max_client_threads:
                    threading.Thread(target=self.clientthread, args=(conn,)).start()
                else:
                    self.end = True
                    self.logger.info('Thread limit reached.')
        except:
            self.logger.exception('Fatal error:')
            self.sock.close()
        self.alive = False
