# Don't trust me with cryptography.

"""
Implementation of https://gist.github.com/RubenSomsen/be7a4760dd4596d06963d67baf140406

Bob (Mint):
A = a*G
return A

Alice (Client):
Y = hash_to_curve(secret_message)
r = random blinding factor
B'= Y + r*G
return B'

Bob:
C' = a*B'
  (= a*Y + a*r*G)
return C'

Alice:
C = C' - r*A
 (= C' - a*r*G)
 (= a*Y)
return C, secret_message

Bob:
Y = hash_to_curve(secret_message)
C == a*Y
If true, C must have originated from Bob


# DLEQ Proof

(These steps occur once Bob returns C')

Bob:
r = random nonce
R1 = r*G
R2 = r*B'
e = hash(R1,R2,A,C')
s = r + e*a
return e, s

Alice:
R1 = s*G - e*A
R2 = s*B' - e*C'
e == hash(R1,R2,A,C')

If true, a in A = a*G must be equal to a in C' = a*B'
"""

import hashlib
from typing import Optional, Tuple

from secp256k1 import PrivateKey, PublicKey

DOMAIN_SEPARATOR = b"Secp256k1_HashToCurve_Cashu_"


def hash_to_curve(message: bytes) -> PublicKey:
    """Generates a secp256k1 point from a message.

    The point is generated by hashing the message with a domain separator and then
    iteratively trying to compute a point from the hash. An increasing uint32 counter
    (byte order little endian) is appended to the hash until a point is found that lies on the curve.

    The chance of finding a valid point is 50% for every iteration. The maximum number of iterations
    is 2**16. If no valid point is found after 2**16 iterations, a ValueError is raised (this should
    never happen in practice).

    The domain separator is b"Secp256k1_HashToCurve_Cashu_" or
    bytes.fromhex("536563703235366b315f48617368546f43757276655f43617368755f").
    """
    msg_to_hash = hashlib.sha256(DOMAIN_SEPARATOR + message).digest()
    counter = 0
    while counter < 2**16:
        _hash = hashlib.sha256(msg_to_hash + counter.to_bytes(4, "little")).digest()
        try:
            # will error if point does not lie on curve
            return PublicKey(b"\x02" + _hash, raw=True)
        except Exception:
            counter += 1
    # it should never reach this point
    raise ValueError("No valid point found")


def step1_alice(
    secret_msg: str, blinding_factor: Optional[PrivateKey] = None
) -> tuple[PublicKey, PrivateKey]:
    Y: PublicKey = hash_to_curve(secret_msg.encode("utf-8"))
    r = blinding_factor or PrivateKey()
    B_: PublicKey = Y + r.pubkey  # type: ignore
    return B_, r


def step2_bob(B_: PublicKey, a: PrivateKey) -> Tuple[PublicKey, PrivateKey, PrivateKey]:
    C_: PublicKey = B_.mult(a)  # type: ignore
    # produce dleq proof
    e, s = step2_bob_dleq(B_, a)
    return C_, e, s


def step3_alice(C_: PublicKey, r: PrivateKey, A: PublicKey) -> PublicKey:
    C: PublicKey = C_ - A.mult(r)  # type: ignore
    return C


def verify(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
    Y: PublicKey = hash_to_curve(secret_msg.encode("utf-8"))
    valid = C == Y.mult(a)  # type: ignore
    # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
    if not valid:
        valid = verify_deprecated(a, C, secret_msg)
    # END: BACKWARDS COMPATIBILITY < 0.15.1
    return valid


def hash_e(*publickeys: PublicKey) -> bytes:
    e_ = ""
    for p in publickeys:
        _p = p.serialize(compressed=False).hex()
        e_ += str(_p)
    e = hashlib.sha256(e_.encode("utf-8")).digest()
    return e


def step2_bob_dleq(
    B_: PublicKey, a: PrivateKey, p_bytes: bytes = b""
) -> Tuple[PrivateKey, PrivateKey]:
    if p_bytes:
        # deterministic p for testing
        p = PrivateKey(privkey=p_bytes, raw=True)
    else:
        # normally, we generate a random p
        p = PrivateKey()

    R1 = p.pubkey  # R1 = pG
    assert R1
    R2: PublicKey = B_.mult(p)  # R2 = pB_ # type: ignore
    C_: PublicKey = B_.mult(a)  # C_ = aB_ # type: ignore
    A = a.pubkey
    assert A
    e = hash_e(R1, R2, A, C_)  # e = hash(R1, R2, A, C_)
    s = p.tweak_add(a.tweak_mul(e))  # s = p + ek
    spk = PrivateKey(s, raw=True)
    epk = PrivateKey(e, raw=True)
    return epk, spk


def alice_verify_dleq(
    B_: PublicKey, C_: PublicKey, e: PrivateKey, s: PrivateKey, A: PublicKey
) -> bool:
    R1 = s.pubkey - A.mult(e)  # type: ignore
    R2 = B_.mult(s) - C_.mult(e)  # type: ignore
    e_bytes = e.private_key
    return e_bytes == hash_e(R1, R2, A, C_)


def carol_verify_dleq(
    secret_msg: str,
    r: PrivateKey,
    C: PublicKey,
    e: PrivateKey,
    s: PrivateKey,
    A: PublicKey,
) -> bool:
    Y: PublicKey = hash_to_curve(secret_msg.encode("utf-8"))
    C_: PublicKey = C + A.mult(r)  # type: ignore
    B_: PublicKey = Y + r.pubkey  # type: ignore
    valid = alice_verify_dleq(B_, C_, e, s, A)
    # BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
    if not valid:
        return carol_verify_dleq_deprecated(secret_msg, r, C, e, s, A)
    # END: BACKWARDS COMPATIBILITY < 0.15.1
    return valid


# -------- Deprecated hash_to_curve before 0.15.0 --------


def hash_to_curve_deprecated(message: bytes) -> PublicKey:
    """Generates a point from the message hash and checks if the point lies on the curve.
    If it does not, iteratively tries to compute a new point from the hash."""
    point = None
    msg_to_hash = message
    while point is None:
        _hash = hashlib.sha256(msg_to_hash).digest()
        try:
            # will error if point does not lie on curve
            point = PublicKey(b"\x02" + _hash, raw=True)
        except Exception:
            msg_to_hash = _hash
    return point


def step1_alice_deprecated(
    secret_msg: str, blinding_factor: Optional[PrivateKey] = None
) -> tuple[PublicKey, PrivateKey]:
    Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8"))
    r = blinding_factor or PrivateKey()
    B_: PublicKey = Y + r.pubkey  # type: ignore
    return B_, r


def verify_deprecated(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
    Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8"))
    valid = C == Y.mult(a)  # type: ignore
    return valid


def carol_verify_dleq_deprecated(
    secret_msg: str,
    r: PrivateKey,
    C: PublicKey,
    e: PrivateKey,
    s: PrivateKey,
    A: PublicKey,
) -> bool:
    Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8"))
    C_: PublicKey = C + A.mult(r)  # type: ignore
    B_: PublicKey = Y + r.pubkey  # type: ignore
    valid = alice_verify_dleq(B_, C_, e, s, A)
    return valid


# Below is a test of a simple positive and negative case

# # Alice's keys
# a = PrivateKey()
# A = a.pubkey
# secret_msg = "test"
# B_, r = step1_alice(secret_msg)
# C_ = step2_bob(B_, a)
# C = step3_alice(C_, r, A)
# print("C:{}, secret_msg:{}".format(C, secret_msg))
# assert verify(a, C, secret_msg)
# assert verify(a, C + C, secret_msg) == False  # adding C twice shouldn't pass
# assert verify(a, A, secret_msg) == False  # A shouldn't pass

# # Test operations
# b = PrivateKey()
# B = b.pubkey
# assert -A -A + A == -A  # neg
# assert B.mult(a) == A.mult(b)  # a*B = A*b
