# encoding: utf-8

import datetime
from hashlib import sha256
from hashlib import sha1
import hmac
import logging
from email.utils import formatdate
from urllib.parse import urlencode
from base64 import encodebytes

SIGV4_TIMESTAMP = '%Y%m%dT%H%M%SZ'
logger = logging.getLogger(__name__)

ISO8601 = '%Y-%m-%dT%H:%M:%SZ'
SIGN_V4_TIMESTAMP = '%Y%m%dT%H%M%SZ'
UNSIGNED_PAYLOAD = 'UNSIGNED-PAYLOAD'
Test_TIME = '20200828T064113Z'
USE_TEST_TIME = False
DEBUG_SIGN = False


def oos_url_encode(input_dict, path=False):
    if input_dict is None or input_dict.__len__() == 0:
        return ''
    url2 = urlencode(input_dict, encoding='utf-8')
    if DEBUG_SIGN:
        print("urlencode:", url2)
    url2 = url2.replace("+", "%20").replace("*", "%2A").replace("%7E", "~")
    if path:
        url2.replace("%2F", "/")
    if DEBUG_SIGN:
        print("replace:", url2)
    return url2


class SignV2Auth:
    def __init__(self, credentials):
        self._credentials = credentials

    def sign_string(self, string_to_sign):
        new_hmac = hmac.new(self._credentials.secret_key.encode('utf-8'),
                            digestmod=sha1)
        new_hmac.update(string_to_sign.encode('utf-8'))
        return encodebytes(new_hmac.digest()).strip().decode('utf-8')

    def canonical_standard_headers(self, headers):
        interesting_headers = ['content-md5', 'content-type', 'date']
        hoi = []
        if 'Date' in headers:
            del headers['Date']
        headers['Date'] = formatdate(usegmt=True)
        for ih in interesting_headers:
            found = False
            for key in headers:
                lk = key.lower()
                if headers[key] is not None and lk == ih:
                    hoi.append(headers[key].strip())
                    found = True
            if not found:
                hoi.append('')
        return '\n'.join(hoi)

    def canonical_custom_headers(self, headers):
        hoi = []
        custom_headers = {}
        for key in headers:
            lk = key.lower()
            if headers[key] is not None:
                if lk.startswith('x-amz-'):
                    custom_headers[lk] = headers[key]
        sorted_header_keys = sorted(custom_headers.keys())
        for key in sorted_header_keys:
            hoi.append("%s:%s" % (key, custom_headers[key]))
        return '\n'.join(hoi)

    def add_auth(self, request):
        logger.debug("Calculating signature using hmacv1 auth.")
        string_to_sign = request.method.upper() + '\n'
        string_to_sign += self.canonical_standard_headers(request.headers) + '\n'
        custom_headers = self.canonical_custom_headers(request.headers)
        if custom_headers:
            string_to_sign += custom_headers + '\n'

        string_to_sign += "/"

        logger.debug('StringToSign:\n%s', string_to_sign)
        # print("StringToSign:\n" + string_to_sign)

        signature = self.sign_string(string_to_sign)

        self._inject_signature(request, signature)

    def _inject_signature(self, request, signature):
        if 'Authorization' in request.headers:
            del request.headers['Authorization']
        request.headers['Authorization'] = (
                "AWS %s:%s" % (self._credentials.access_key, signature))


class SignV4Auth:
    """
    Sign a request with Signature V4.
    """
    REQUIRES_REGION = True

    def __init__(self, credentials, service_name, region_name, request, sign_payload=False):
        datetime_now = datetime.datetime.utcnow()
        self._timestamp = datetime_now.strftime(SIGV4_TIMESTAMP)
        if USE_TEST_TIME:
            self._timestamp = Test_TIME
        self.credentials = credentials
        self._region_name = region_name
        self._service_name = service_name
        self.request = request
        self._sign_payload = sign_payload

    def _sign(self, key, msg, hex=False):
        if hex:
            sig = hmac.new(key, msg.encode('utf-8'), sha256).hexdigest()
        else:
            sig = hmac.new(key, msg.encode('utf-8'), sha256).digest()
        return sig

    def headers_to_sign(self, request):
        header_map = {'content-type': request.headers['Content-Type'], 'host': request.headers['Host'],
                      'user-agent': request.headers['User-Agent']}
        # header_map['user-agent'] = 'oos-sdk-java/6.5.0 Windows_10/10.0 Java_HotSpot(TM)_64-Bit_Server_VM/25.131-b11'
        for name, value in request.headers.items():
            lname = name.lower()
            if lname.startswith('x-amz-'):
                header_map[lname] = value

        return header_map

    def canonical_headers(self, headers_to_sign):
        """
        Return the headers that need to be included in the StringToSign
        in their canonical form by converting all header keys to lower
        case, sorting them in alphabetical order and then joining
        them into a string, separated by newlines.
        """
        headers = []
        sorted_header_names = sorted(set(headers_to_sign))
        for key in sorted_header_names:
            value = headers_to_sign.get(key)
            headers.append('%s:%s' % (key, value))
        return '\n'.join(headers)

    def canonical_query_string_org(self, request):
        # request.querys
        queries = []
        # request.params['hu'] ='hu long'
        # request.params['me++4']='#$hulong$$'
        sorted_query_names = sorted(set(request.params))
        for key in sorted_query_names:
            value = request.params.get(key)
            queries.append('%s=%s' % (key, value))
        return '&'.join(queries)

    def canonical_query_string(self, request):
        query = {}
        sorted_query_names = sorted(set(request.params))
        for key in sorted_query_names:
            value = request.params.get(key)
            query[key] = value
        return oos_url_encode(query)

    def signed_headers(self, headers_to_sign):
        l = ['%s' % n.lower().strip() for n in set(headers_to_sign)]
        l = sorted(l)
        return ';'.join(l)

    def calc_body_checksum(self, request):
        if request.data is not None:
            checksum = sha256(str(request.data, encoding='utf-8').encode('utf-8')).hexdigest()
        else:
            checksum = sha256("".encode('utf-8')).hexdigest()
        if DEBUG_SIGN:
            print("checksum:", checksum)
        return checksum

    def canonical_request(self, request):
        cr = [request.method.upper()]
        uri = "/"
        query_string = self.canonical_query_string(request)
        if not self._sign_payload:
            body_checksum = UNSIGNED_PAYLOAD
        else:
            body_checksum = self.calc_body_checksum(request)
        request.headers['x-amz-content-sha256'] = body_checksum
        cr.append(uri)
        cr.append(query_string)

        # sign header must include content-type and host and x-amz
        headers_to_sign = self.headers_to_sign(request)
        cr.append(self.canonical_headers(headers_to_sign) + '\n')
        cr.append(self.signed_headers(headers_to_sign))
        cr.append(body_checksum)

        return '\n'.join(cr)

    def scope(self, request):
        scope = [self.credentials.access_key, self._timestamp[0:8], self._region_name, self._service_name,
                 'aws4_request']
        return '/'.join(scope)

    def credential_scope(self, request):
        scope = [self._timestamp[0:8], self._region_name, self._service_name, 'aws4_request']
        return '/'.join(scope)

    def string_to_sign(self, request, canonical_request):
        """
        Return the canonical StringToSign as well as a dict
        containing the original version of all headers that
        were included in the StringToSign.
        """
        sts = ['AWS4-HMAC-SHA256', self._timestamp, self.credential_scope(request),
               sha256(canonical_request.encode('utf-8')).hexdigest()]
        return '\n'.join(sts)

    def signature(self, string_to_sign, request):
        key = self.credentials.secret_key
        k_date = self._sign(('AWS4' + key).encode('utf-8'),
                            self._timestamp[0:8])
        k_region = self._sign(k_date, self._region_name)
        k_service = self._sign(k_region, self._service_name)
        k_signing = self._sign(k_service, 'aws4_request')
        return self._sign(k_signing, string_to_sign, hex=True)

    def add_auth(self, request):
        datetime_now = datetime.datetime.utcnow()
        # request.headers['x-amz-content-sha256'] = UNSIGNED_PAYLOAD
        request.headers['x-amz-date'] = self._timestamp
        if USE_TEST_TIME:
            request.headers['x-amz-date'] = Test_TIME
        canonical_request = self.canonical_request(request)
        if DEBUG_SIGN:
            print("Calculating signature using v4 auth.")
            print('CanonicalRequest:\n', canonical_request)

        string_to_sign = self.string_to_sign(request, canonical_request)
        if DEBUG_SIGN:
            print('StringToSign:\n', string_to_sign)
        signature = self.signature(string_to_sign, request)
        if DEBUG_SIGN:
            print('Signature:\n', signature)

        l = ['AWS4-HMAC-SHA256 Credential=%s' % self.scope(request)]
        headers_to_sign = self.headers_to_sign(request)
        l.append('SignedHeaders=%s' % self.signed_headers(headers_to_sign))
        l.append('Signature=%s' % signature)
        request.headers['Authorization'] = ', '.join(l)
        if DEBUG_SIGN:
            print(request.headers['Authorization'])
        return request
