from .Framework import SocketService
import ssl


class SMTPSocket(SocketService):
    """
    I potentially want to update this later to check for the status code of each command.
    For example, the initial connection returns 220 on each line, EHLO does 250, and QUIT does 221
    """
    # def __init__(self, port=25):
    def __init__(self, port=None, tls=False, starttls=False):
        SocketService.__init__(self)
        if port:
            self.port = port
        elif tls and not starttls:
            self.port = 465
        else:
            self.port = 25

        # self.request = 'EHLO localhost\r\nQUIT\r\n'.encode()
        self.request = 'EHLO localhost\r\n'.encode()
        self.tls = tls and not starttls
        self.starttls = starttls

    def _get_data(self):
        # For our Zabbix Triggers, 0 is good, 1 is bad.
        data = self._get_raw_data()
        if data and 'Hello' in data and 'closing connection' in data:
            return 1
        return 0

    def check(self):
        data = self.get_data()
        return bool(data)

    def _get_raw_data(self, raw=False):
        """
        Get raw data with low-level "socket" module.
        :param raw: set `True` to return bytes
        :type raw: bool
        :return: decoded data (str) or raw data (bytes)
        :rtype: str/bytes
        """
        if self._sock is None:
            self._connect()
            if self._sock is None:
                return None

        # Get initial connection message
        data = self._receive(raw)

        # Send request and get post request data
        self.request = 'EHLO localhost\r\n'.encode()
        if self._send():
            data += self._receive(raw)

        if self.starttls:
            self.request = 'STARTTLS\r\n'.encode()
            # Send request and get post request data
            if self._send():
                data += self._receive(raw)

            print(data)
            self._sock = ssl.wrap_socket(self._sock,
                                         keyfile=self.key,
                                         certfile=self.cert,
                                         server_side=False,
                                         cert_reqs=ssl.CERT_NONE,
                                         ssl_version=ssl.PROTOCOL_SSLv23)

            # Send request and get post request data
            self.request = 'EHLO localhost\r\n'.encode()
            if self._send():
                data += self._receive(raw)

        # Send request and get post request data
        self.request = 'QUIT\r\n'.encode()
        if self._send():
            data += self._receive(raw)


        return data
