import datetime
import os
import platform
import re
import socket
import subprocess
import sys
from typing import Dict

import requests
import serial
from cestel_helpers.exceptions import ConfError
from cestel_helpers.siwim import Siwim

from exceptions import StopModule

if platform.system() == 'Windows':
    pass

from datetime import timedelta
from lxml import etree
from abstract.module import Module
from consts import SITE_NAME, SIWIM_E_VERSION, CONF_GLOBAL, GLOBAL_UPDATE_CHANNEL, STAT_CONF, STAT_HOST, STAT_INTERVAL, STAT_NTP, STAT_OVPN, STAT_PROTOCOL, STAT_SAVE, STAT_URL
from cestel_helpers.version import get_version
from cestel_helpers.i_conf_manager import SEC_DEFAULT, read_conf
import config
from pathlib import Path
from psutil import disk_usage

ROUTER, IMEI, FREQ_BAND, IMSI, LAT, LON, RSRQ = 'router', 'gsm_imei', 'gsm_frequency_band', 'gsm_imsi', 'gps_lat', 'gps_lon', 'gsm_signal_quality'


def get_total_mb(partition):
    return int(round(disk_usage(partition).total / 1024 / 1024))


def get_free_mb(partition):
    return int(round(disk_usage(partition).free / 1024 / 1024))


class StatusModule(Module):
    def __init__(self, args):
        self.interval = 600  # FIXME Remove this once conversion from SiWIM-I v4 to v5 is done. Key was renamed and isn't important enough to handle cases where it might be missing.q
        # Initialize the module, passing it a list of keys that must exist for it to work properly and a list of optional keys that are recognized by it.
        Module.__init__(self, args, mandatory_keys=(STAT_CONF, STAT_HOST, STAT_INTERVAL, STAT_NTP, STAT_OVPN, STAT_PROTOCOL, STAT_SAVE, STAT_URL))
        with config.lock:
            config.status_dict['python_version'] = platform.python_version()
            config.status_dict['siwim_version'] = '4' if platform.system() == 'Windows' else '5'  # MK5 should be running on Linux.
            config.status_dict['siwim_i_version'] = get_version()
        if platform.system() == 'Windows' and platform.release() == '10':  # This only applies to MK4 systems.
            try:
                self.mk4_com_serial = serial.Serial('COM1')
                self.mk4_com_serial.dtr = False
                self.mk4_com_serial.rts = False
            except serial.serialutil.SerialException as e:
                self.logger.debug(e)

    def VPNstatus(self):
        if platform.system() == "Windows":
            ping = subprocess.Popen(["ping", "-n", "1", "-w", "2", self.vpn_host], stdout=subprocess.PIPE).communicate()[0]
        else:
            ping = subprocess.Popen(["ping", "-c", "1", "-w", "2", self.vpn_host], stdout=subprocess.PIPE).communicate()[0]
        if ('unreachable' not in str(ping)) and ('timed' not in str(ping)) and ('failure' not in str(ping)):
            return "Connected"
        return "Not connected"

    def w32ovpn_version(self):
        sp = subprocess.Popen([self.ovpn_path, "--version"], stdout=subprocess.PIPE)
        output = sp.communicate()[0]
        return output.split(b" ")[1].decode()  # Output of the application is binary and thus has to be split by "binary spaces". We could also decode before splitting, but that's probably redundant.

    def getTeltonika_fw07(self, ip: str, data: Dict[str, str]) -> Dict[str, str]:
        login = f'{ip}/api/login'
        url = f'{ip}/ubus'

        # A session is required to obtain data, since data is served based on "ubus_rpc_session".
        with requests.Session() as session:
            # Log in to the router and obtain session ID.
            credentials = {
                'username': 'admin',
                'password': 'Siwim323'
            }
            res = session.post(login, json=credentials, timeout=2)
            if res.ok:
                token = res.json()
                try:
                    session_id = token['ubus_rpc_session']
                except KeyError:  # RUT9_R_00.07.06.1 or higher.
                    session_id = token['data']['token']
            else:
                self.logger.warning(f'{res.status_code}: Failed to get session id from router: {res.text}.')
                return data

            # Obtain device name.
            payload = {
                'jsonrpc': '2.0',
                'method': 'call',
                'params': [session_id, 'uci', 'get', {'config': 'system'}]
            }
            res = session.post(url, json=payload, timeout=2)
            if res.ok:
                result = res.json().get('result')
                data[ROUTER] = result[1]['values']['system']['devicename']
            else:
                self.logger.warning(f'{res.status_code}: Failed to obtain router name: {res.text}.')

            # Obtain mobile data.
            header = {
                'Authorization': f'Bearer {session_id}'
            }

            def request_mobile(url: str, old: bool = True) -> bool:
                """ This function exists in order to easily handle different urls where mobile data is obtained between two different versions of FW. """
                response = requests.get(url, headers=header, timeout=2)
                if response.ok:
                    result = response.json()
                    if result.get('success'):
                        data[IMSI] = result['data']['imsi']
                        data[IMEI] = result['data']['imei']
                        data[RSRQ] = result['data']['rsrq']
                        data[FREQ_BAND] = result['data']['band']
                        return True
                    else:
                        self.logger.warning(f'{response.status_code}: Mobile data not obtained (FW {"<" if old else ">="} RUT9_R_00.07.06).: {response.text}.')
                if not old:  # If we made it this far for the newer FW version, data is missing.
                    self.logger.warning(f'{response.status_code}: Could not obtain mobile data: {response.text}.')
                return False

            # This will try to obtain mobile data using the old url first and try again with the new one if that doesn't work. Warning is printed if the new one also doesn't work.
            if not request_mobile(f'{ip}/api/mobile/modems/status_full/1-1.4'):
                request_mobile(f'{ip}/api/modems/status/1-1.4', old=False)  # RUT9_R_00.07.06.1 or higher.

            # Obtain GPS coordinates.
            payload = {
                'jsonrpc': '2.0',
                'method': 'call',
                'params': [session_id, 'gpsd', 'position', {}]
            }
            r = session.post(url, json=payload, timeout=2)
            if r.ok:
                result = r.json().get('result')
                if result:
                    data[LAT] = result[1]['latitude']
                    data[LON] = result[1]['longitude']
                else:
                    self.logger.warning(f'No GPS info in response: {r.json()}.')
            else:
                self.logger.warning(f'{res.status_code}: Failed to obtain GPS info for {data[ROUTER]}: {res.text}.')
            return data

    def getRouterInfo(self) -> Dict[str, str]:
        # Define "constants" for convenience
        ports = ('80', '9191')  # Ports that we want to check
        router_id = {'Sierra': 'Sierra', 'Teltonika': '/cgi-bin/luci', 'Teltonika_fw07': 'Teltonika'}  # Teltonika has a redirect, which holds the actual data
        data = {}
        for port in ports:
            ip = f'http://192.168.4.100:{port}'  # IP as a variable so that code can be tested anywhere without modifications.
            try:
                req = requests.get(ip, timeout=10)
            except requests.exceptions.ConnectionError:  # If connection can't be established, continue with next port
                self.logger.debug(f'Could not establish connection to router at {ip}.')
                continue
            if not req.ok:
                self.logger.info(f'Router not found at {ip}: Error code {req.status_code}. Trying next IP in {", ".join(ports)}')
                continue

            router = ''
            for type in router_id.values():
                if req.text.find(type) != -1:  # If the router's name exists in the entry page, we assume it's the router of that company
                    router = type
                    break
            else:
                self.logger.info('Legacy router identification failed!')
                req = requests.get(f'{ip}/api/unauthorized/status')
                if req.ok:
                    id = req.json().get('data')
                    self.logger.info(f'Detected {id.get("device_name")} using API {id.get("api_version")}!')
                    router = 'Teltonika'
                else:
                    self.logger.warning('Router identification failed!')
                    self.logger_dump.info(req.text)
                    return data

            self.logger.debug(f'Detected router {router}.')

            if router == router_id['Sierra']:
                try:
                    reg_aleos_version = r'(?:\d+\.?)+'  # Backlashes can't be nested inside of a fstring, so we have to define regex separately to use this format.
                    data[ROUTER] = f'{router}: {re.findall(f"ALEOS Version {reg_aleos_version}", req.text)[0]}'  # Exception should happen here if it does.
                    # Obtain the data we're interested in
                    with requests.Session() as s:
                        login_data = '<request xmlns="urn:acemanager"><connect><login>user</login><password><![CDATA[12345]]></password></connect></request>'
                        req = s.post(f'{ip}/xml/Connect.xml', data=login_data, timeout=10)
                        if not req.ok:
                            self.logger.warning(f'Failed to log in to Sierra: {req}')
                            return
                        # 7=Device model, 10=IMEI, 671=active frequency band, 785=IMSI, 902=latitude, 903=longitude, 10209=signal quality
                        req = s.post(f'{ip}/cgi-bin/Embedded_Ace_Get_Task.cgi', data='7,10,671,785,902,903,10209', timeout=10)
                        if not req.ok:
                            self.logger.warning(f'Failed to obtain information from Sierra: {req}')
                            return
                    for val, key in zip(req.text.split('!')[:-1], (ROUTER, IMEI, FREQ_BAND, IMSI, LAT, LON, RSRQ)):  # Since it ends with a !, last element is empty
                        data[key] = val.split('=')[1]
                    try:
                        self.logger.info(f'lat: {data[LAT]}, lon: {data[LON]}')
                        data[LAT] = f'{data[LAT][:-5]}.{data[LAT][-5:]}'
                        data[LON] = f'{data[LON][:-5]}.{data[LON][-5:]}'
                    except:
                        self.logger.warning(f'Could not convert latitude ({data.get(LAT)}) and/or longitude ({data.get(LON)}) to decimal.')
                except IndexError:  # The old version of Sierra does not use ALEOS and thus regex match fails
                    self.logger.info('Failed to obtain router information for a Sierra router. Falling back to "old" way.')
                    data[ROUTER] = f'{router} {data[ROUTER] if ROUTER in data else str()}'  # Add the router name before the type.
                    data[LAT], data[LON] = self.gpsLoc()
                    data[IMSI] = self.SierraImsi('12345')

            elif router == router_id['Teltonika']:
                req = requests.get(f'{ip}{router_id["Teltonika"]}', timeout=10)  # This request will return status code 403, but still enough information to identify the router.
                if req.text.find('Teltonika') != -1:  # We double check if it really is Teltonika because their entry page has no unique identification available; we assume it is from redirect path
                    router = 'Teltonika'
                    payload = {
                        'username': 'admin',
                        'password': 'Siwim323'
                    }

                    data[ROUTER] = router
                    data_dict = requests.post(f'{ip}/cgi-bin/luci/admin/status/netinfo/mobile?status=1', data=payload, timeout=10).json()
                    data[IMEI] = data_dict['gimei']
                    data[FREQ_BAND] = data_dict['bands']
                    data[IMSI] = data_dict['gimsi']
                    data[RSRQ] = data_dict['grsrq']
                    data_dict = requests.post(f'{ip}/cgi-bin/luci/admin/services/gps/get_cord', data=payload, timeout=10).json()
                    data[LAT] = data_dict['latitude']
                    data[LON] = data_dict['longitude']
                else:
                    self.logger.warning('Unrecognised router (assumed Teltonika), IMSI and GPS information not obtained.')
            elif router == router_id['Teltonika_fw07']:
                data = self.getTeltonika_fw07(ip, data)
            else:
                self.logger.warning('Unrecognised router (could not find a match), IMSI and GPS information not obtained.')
            break  # We've gotten our router information, so break the loop.
        self.logger.debug(f'Router data: {data}')
        return data

    def gpsLoc(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.settimeout(5.0)
        data = 'No GPS data'
        try:
            s.connect(('192.168.4.100', 9494))
            data = s.recv(1024).decode()
            s.close()
        except:
            self.logger.debug('Obtaining legacy SierraGpsLoc failed', exc_info=True)

        p = re.compile('GPGGA,([.\d]*),([.\d]*),([NS]),([.\d]*),([EW])')
        a = p.search(data)

        if a == None:
            return (0, 0)
        else:
            latd = float(a.group(2)[:2])
            lond = float(a.group(4)[:3])
            latm = float(a.group(2)[2:])
            lonm = float(a.group(4)[3:])

            lat = latd + (latm / 60)
            lon = lond + (lonm / 60)

            if a.group(3) != 'N':
                lat = -lat
            if a.group(5) != 'E':
                lon = -lon
            return (lat, lon)

    def _recvfind(self, sock, n, to_sec=5):
        d = ''
        absto = datetime.datetime.now() + timedelta(seconds=to_sec)
        while (-1 == d.find(n)) and (datetime.datetime.now() < absto):
            d += sock.recv(256).decode(errors='replace')
        return re.sub(r'[^a-zA-Z0-9 \n]', '', d)

    def SierraImsi(self, pwd):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.settimeout(5.0)
        imsi = 'No IMSI'
        try:
            s.connect(('192.168.4.100', 2332))
            data = self._recvfind(s, ':')
            if data.find('login') >= 0:
                # print 'GX'
                s.sendall('user\r\n'.encode())
                data = self._recvfind(s, ':')
                # print 'MP'
            s.sendall((pwd + '\r\n').encode())
            data = self._recvfind(s, 'OK')
            s.sendall('at+cimi?\r\n'.encode())
            data = self._recvfind(s, 'OK')
            data = data.split()
            if len(data) > 1: imsi = data[-2]
            s.close()
        except:
            self.logger.debug('Obtaining legacy SierraImsi failed', exc_info=True)
        return imsi

    def ntpsvr(self, loc):
        fn = loc
        if not os.path.exists(fn):
            return 'missing'
        f = open(fn, encoding='utf-8')
        try:
            lines = f.readlines()
        finally:
            f.close()
        for l in lines:
            if l.strip().startswith('server'):
                return l.strip().split()[1]
        for l in lines:
            if l.strip().startswith('pool'):
                return l.strip().split()[1]
        return 'unknown'

    def swmstat(self):
        siwim = Siwim(siwim_root=config.sites_dir.parent)
        with config.lock:
            try:
                config.status_dict[GLOBAL_UPDATE_CHANNEL] = read_conf(Path(config.conf_dir, CONF_GLOBAL))[SEC_DEFAULT].get(GLOBAL_UPDATE_CHANNEL)
            except ConfError as e:
                self.logger.critical(f'{e} It is recommended you restore latest.zip from sites/<site>/usr folder.')
                return
            try:
                config.status_dict['usb_info'] = subprocess.check_output([sys.executable, 'get_usb_info.py'], stderr=open(os.devnull, 'w', encoding='utf-8')).decode()
            except:
                self.logger.warning('USB info not saved.')
            try:
                config.status_dict['timestamp'] = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
            except:
                self.logger.warning('Timestamp not saved.')
            try:
                config.status_dict['computer_name'] = subprocess.check_output(['hostname']).strip().decode().upper()
            except:
                self.logger.warning('Computer name not saved.')
            # To reduce repeated code, map values of partitions depending on platform and save them to partitions dict.
            if platform.system() == 'Windows':
                partitions = {'disk_c': 'c:', 'disk_d': 'd:', 'disk_e': 'e:'}
            else:
                partitions = {'disk_c': '/', 'disk_d': str(config.sites_dir.parent), 'disk_e': '/'}
            for key, p in partitions.items():
                try:
                    config.status_dict[f'{key}_free'] = get_free_mb(p)
                except:
                    self.logger.warning(f'Device {p} not found.')
            try:
                config.status_dict[SITE_NAME] = config.site_name
            except:
                self.logger.warning('Site name not saved.')
            try:
                config.status_dict['vpn_status'] = self.VPNstatus()
                if platform.system() == 'Windows' and platform.release() == '10':
                    try:
                        self.mk4_com_serial.rts = True if config.status_dict['vpn_status'] == 'Connected' else False
                    except AttributeError:  # If self.mk4_com_serial does not exist, it's already being set in another Status module, so we don't care.
                        pass
            except:
                self.logger.warning('VPN status not saved.')
            try:
                config.status_dict[SIWIM_E_VERSION] = siwim.version_engine
            except:
                self.logger.warning('E version not saved.')
            for key, p in partitions.items():
                try:
                    config.status_dict[f'{key}_total'] = get_total_mb(p)
                except:
                    self.logger.warning(f'Total MB not saved for {p}.')
            try:
                for key, val in self.getRouterInfo().items():
                    config.status_dict[key] = val
            except Exception as e:
                self.logger.warning(f'Unhandled error when obtaining router info: {e}.')
            try:
                if self.ntp_conf_path is None:
                    if platform.system() == 'Windows':
                        config.status_dict['ntp_server'] = self.ntpsvr("c:/windows/system32/drivers/etc/ntp.conf")
                    else:
                        config.status_dict['ntp_server'] = self.ntpsvr("/etc/ntp.conf")
                else:
                    config.status_dict['ntp_server'] = self.ntpsvr(self.ntp_conf_path)
            except:
                self.logger.warning('NTP server not saved')
            try:
                config.status_dict['openvpn_version'] = self.w32ovpn_version()
            except:
                self.logger.warning('OpenVPN not saved.')
            if platform.system() == 'Windows' and platform.release() == '10':
                # Check if task siwim_mcp.exe is running.
                call = 'TASKLIST', '/FI', 'imagename eq siwim_mcp.exe'
                output = subprocess.check_output(call).decode()
                last_line = output.strip().split('\r\n')[-1]
                try:
                    self.mk4_com_serial.dtr = last_line.lower().startswith('siwim_mcp.exe')
                except AttributeError:  # If the port does not exist, simply skip it. This occurs if there's more than one status module.
                    pass
                except:
                    self.logger.warning('Failed to get info on SiWIM-E.')
                    self.mk4_com_serial.dtr = False
            # Add all keys that have been sent upstream.
            for key, val in self.upstream_info.items():
                try:
                    config.status_dict[key] = val
                except:
                    self.logger.warning(str(key) + ' not saved.')

    def run(self):
        self.alive = True
        self.end = False
        try:
            while True:
                if self.end:
                    self.alive = False
                    self.logger.debug('Thread closed correctly.')
                    return
                try:
                    self.swmstat()
                    root = etree.Element("swmstat")

                    voltage_present = False
                    for key, val in config.status_dict.copy().items():  # A copy is made, so that any changes during processing don't interrupt.
                        if isinstance(val, dict):
                            nest = root.find(key[1] + "s")
                            if nest == None:
                                nest = etree.Element(key[1] + "s")
                                root.append(nest)
                            node = etree.Element(key[1])
                            nest.append(node)
                            for ky in val.keys():
                                try:
                                    node.attrib[ky] = val[ky]
                                except TypeError:
                                    self.logger.error(f'Failed to process {ky}:{val[ky]}')
                        else:
                            # If the item is a tuple, we're dealing with data from a module where module name should be omitted.
                            if isinstance(key, tuple):
                                nn = key[1]
                            else:
                                nn = key
                            node = etree.Element(nn)
                            node.text = str(val)
                            root.append(node)
                            if nn == "siwim_system_voltage":
                                voltage_present = True
                    # for non-standard systems that don't yield voltage info, despite it being mandatory for a valid register event(alive)
                    # if this piece of info ever stops being mandatory (i see no reason why it is), remove this hack
                    if not voltage_present:
                        node = etree.Element("siwim_system_voltage")
                        node.text = "-1"
                        root.append(node)
                    if self.save_curr_to_disk == 1:
                        self.logger.info('Writing status to disk ...')
                        with open(Path(config.log_dir, 'swmstat.xml'), 'w', encoding='utf-8') as fd:
                            fd.write(etree.tostring(root, encoding='utf-8').decode())
                    with config.lock:
                        computer_name = config.status_dict['computer_name'].lower()
                    resp = requests.post(f'{self.protocol}://{self.host}{self.url}', {'sys': computer_name, 'msg': etree.tostring(root, encoding='utf-8').decode()}, timeout=10, verify=False)
                    if not resp.ok:
                        self.logger.warning(f'Response from {self.host}: {resp.text}')
                except (socket.error, requests.ConnectTimeout) as e:
                    self.logger.warning(f'Connection attempt failed ({e}).')
                except:
                    self.logger.exception('General error:')
                self.zzzzz(self.interval)
        except StopModule as e:
            self.log_stop_module(e)
        except:
            self.logger.exception('Fatal error:')
            self.alive = False
            if platform.system() == 'Windows' and platform.release() == '10':
                try:
                    self.mk4_com_serial.close()
                except AttributeError:
                    pass
