import unittest
import os
from itp_interface.lean.tactic_parser import build_lean4_project, build_tactic_parser_if_needed

def pretty_print(s1, s2, proof_step, done):
    print(f"Current Goal:")
    print('-'*30)
    for goal in s1.training_data_format.start_goals:
        hyps = '\n'.join([hyp for hyp in goal.hypotheses])
        print(hyps)
        print('|- ', end='')
        print(goal.goal)
        print(f'*'*30)
    print(f"="*30)
    print(f"Action: {proof_step}")
    print(f"="*30)
    print(f"Next Goal:")
    print('-'*30)
    if s2 is not None:
        for goal in s2.training_data_format.start_goals:
            hyps = '\n'.join([hyp for hyp in goal.hypotheses])
            print(hyps)
            print('|- ', end='')
            print(goal.goal)
            print(f'*'*30)
    print(f"="*30)
    print(f"DONE: {done}")
    print('-'*30)
    if s2 is None and done:
        print("No more goals. Proof Finished!")

class LeanHelper():
    def build_lean4_project(self, project_folder):
        build_tactic_parser_if_needed()
        # Build the project
        path_to_lake_folder = os.path.join(project_folder, ".lake")
        if not os.path.exists(path_to_lake_folder):
            build_lean4_project(project_folder)


class Lean4Test(unittest.TestCase):
    def test_simple_lean4(self):
        from itp_interface.rl.proof_state import ProofState
        from itp_interface.rl.proof_action import ProofAction
        from itp_interface.rl.simple_proof_env import ProofEnv
        from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
        from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
        project_folder = "src/data/test/lean4_proj"
        file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
        # Build the project
        # cd src/data/test/lean4_proj && lake build
        helper = LeanHelper()
        helper.build_lean4_project(project_folder)
        language = ProofAction.Language.LEAN4
        theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}'
        # theorem test3 (p q : Prop) (hp : p) (hq : q)
        # : p ∧ q ∧ p :=
        proof_exec_callback = ProofExecutorCallback(
            project_folder=project_folder,
            file_path=file_path,
            language=language,
            always_use_retrieval=False,
            keep_local_context=True
        )
        always_retrieve_thms = False
        retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
        env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
        proof_steps = [
            '-- TODO',
            'apply And.intro',
            'exact hp',
            'apply And.intro',
            '--TODO',
            '-- This is just some extra comment',
            'exact hq',
            'exact hp'
        ]
        with env:
            proof_was_finished = False
            for proof_step in proof_steps:
                state, _, next_state, _, done, info = env.step(ProofAction(
                    ProofAction.ActionType.RUN_TACTIC,
                    language,
                    tactics=[proof_step]))
                if info.error_message is not None:
                    print(f"Error: {info.error_message}")
                # This prints StateChanged, StateUnchanged, Failed, or Done
                print(info.progress)
                print('-'*30)
                if done:
                    print("Proof Finished!!")
                    proof_was_finished = True
                else:
                    s1 : ProofState = state
                    s2 : ProofState = next_state
                    pretty_print(s1, s2, proof_step, done)
            assert proof_was_finished, "Proof was not finished"

    def test_lean4_backtracking(self):
        from itp_interface.rl.proof_action import ProofAction
        from itp_interface.rl.simple_proof_env import ProofEnv
        from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
        from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
        import random
        project_folder = "src/data/test/lean4_proj"
        file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
        # Build the project
        helper = LeanHelper()
        helper.build_lean4_project(project_folder)
        language = ProofAction.Language.LEAN4
        theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}'
        # theorem test3 (p q : Prop) (hp : p) (hq : q)
        # : p ∧ q ∧ p :=
        proof_exec_callback = ProofExecutorCallback(
            project_folder=project_folder,
            file_path=file_path,
            language=language,
            always_use_retrieval=False,
            keep_local_context=True
        )
        always_retrieve_thms = False
        retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
        env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
        proof_steps = [
            'apply And.intro',
            'exact hp',
            'apply And.intro',
            'exact hq',
            'exact hp'
        ]
        with env:
            prev_state = env.state
            proof_was_finished = False
            for idx, proof_step in enumerate(proof_steps):
                if idx > 0 and random.random() <= 0.5:
                    print(f"Backtracking at step {idx + 1} i.e. {proof_step}")
                    state, _, next_state, _, done, info = env.step(
                    ProofAction(
                        ProofAction.ActionType.BACKTRACK,
                        language))
                    assert next_state == prev_state, "Backtracking failed"
                    # Replay the last action
                    last_proof_step = proof_steps[idx-1]
                    state, _, next_state, _, done, info = env.step(
                        ProofAction(
                            ProofAction.ActionType.RUN_TACTIC,
                            language,
                            tactics=[last_proof_step]))
                state, _, next_state, _, done, info = env.step(
                ProofAction(
                    ProofAction.ActionType.RUN_TACTIC,
                    language,
                    tactics=[proof_step]))
                prev_state = state
                if done:
                    print("Proof Finished!!")
                    proof_was_finished = True
            assert proof_was_finished, "Proof was not finished"

    def test_simple_lean_calc(self):
        from itp_interface.rl.proof_state import ProofState
        from itp_interface.rl.proof_action import ProofAction
        from itp_interface.rl.simple_proof_env import ProofEnv
        from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
        from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
        project_folder = "src/data/test/lean4_proj"
        file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
        # Build the project
        # cd src/data/test/lean4_proj && lake build
        helper = LeanHelper()
        helper.build_lean4_project(project_folder)
        language = ProofAction.Language.LEAN4
        theorem_name = "{\"namespace\":\"Lean4Proj1\",\"name\":\"test_calc\"}"
        # theorem test_calc (n: Nat) : n^2 + 2*n + 1 = (n + 1)*(n + 1) := by
        proof_exec_callback = ProofExecutorCallback(
            project_folder=project_folder,
            file_path=file_path,
            language=language,
            always_use_retrieval=False,
            keep_local_context=True
        )
        always_retrieve_thms = False
        retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
        env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
        proof_steps = [
"""calc
  _ = n^2 + n*2 + 1 := by rw [Nat.mul_comm 2 n]
  _ = n^2 + (n + n) + 1 := by rw [Nat.mul_two]
  _ = n^2 + n + n + 1 := by rw [←Nat.add_assoc]
  _ = n*n + n + n + 1 := by rw [Nat.pow_two]
  _ = n*n + n*1 + n + 1 := by rw [Nat.mul_one n]
  _ = n*(n + 1) + n + 1 := by rw [Nat.left_distrib n n 1]
  _ = n*(n + 1) + (n + 1) := by rw [Nat.add_assoc]
  _ = n*(n + 1) + 1*(n + 1) := by rw (config := { occs := .pos [2]}) [←Nat.mul_one (n + 1), Nat.mul_comm]""",
"_ = (n + 1)*(n + 1) := by \n   rw [Nat.right_distrib n 1 (n + 1)]"
]
        with env:
            env.set_max_proof_step_length(10000)
            proof_was_finished = False
            for proof_step in proof_steps:
                state, _, next_state, _, done, info = env.step(ProofAction(
                    ProofAction.ActionType.RUN_TACTIC,
                    language,
                    tactics=[proof_step]))
                if info.error_message is not None:
                    print(f"Error: {info.error_message}")
                # This prints StateChanged, StateUnchanged, Failed, or Done
                print(f"DONE: {done}")
                print(info.progress)
                print('-'*30)
                if done:
                    s1 : ProofState = state
                    pretty_print(s1, None, proof_step, done)
                    proof_was_finished = True
                else:
                    s1 : ProofState = state
                    s2 : ProofState = next_state
                    pretty_print(s1, s2, proof_step, done)
            assert proof_was_finished, "Proof was not finished"

    def test_simple_lean_calc_with_validation(self):
        from itp_interface.rl.proof_state import ProofState
        from itp_interface.rl.proof_action import ProofAction
        from itp_interface.rl.simple_proof_env import ProofEnv
        from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
        from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
        project_folder = "src/data/test/lean4_proj"
        file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
        # Build the project
        # cd src/data/test/lean4_proj && lake build
        helper = LeanHelper()
        helper.build_lean4_project(project_folder)
        language = ProofAction.Language.LEAN4
        theorem_name = "{\"namespace\":\"Lean4Proj1\",\"name\":\"test_calc\"}"
        # theorem test_calc (n: Nat) : n^2 + 2*n + 1 = (n + 1)*(n + 1) := by
        proof_exec_callback = ProofExecutorCallback(
            project_folder=project_folder,
            file_path=file_path,
            language=language,
            always_use_retrieval=False,
            keep_local_context=True
        )
        always_retrieve_thms = False
        retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
        env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
        proof_steps = [
"""calc
_ = n^2 + n*2 + 1 := by rw [Nat.mul_comm 2 n]
_ = n^2 + (n + n) + 1 := by rw [Nat.mul_two]
_ = n^2 + n + n + 1 := by rw [←Nat.add_assoc]
_ = n*n + n + n + 1 := by rw [Nat.pow_two]
_ = n*n + n*1 + n + 1 := by rw [Nat.mul_one n]
_ = n*(n + 1) + n + 1 := by rw [Nat.left_distrib n n 1]
_ = n*(n + 1) + (n + 1) := by rw [Nat.add_assoc]
_ = n*(n + 1) + 1*(n + 1) := by rw (config := { occs := .pos [2]}) [←Nat.mul_one (n + 1), Nat.mul_comm]""",
"_ = (n + 1)*(n + 1) := by \n   rw [Nat.right_distrib n 1 (n + 1)]"
]
        with env:
            env.set_max_proof_step_length(10000)
            proof_was_finished = False
            for proof_step in proof_steps:
                state, _, next_state, _, done, info = env.step(ProofAction(
                    ProofAction.ActionType.RUN_TACTIC,
                    language,
                    tactics=[proof_step]))
                if info.error_message is not None:
                    print(f"Error: {info.error_message}")
                # This prints StateChanged, StateUnchanged, Failed, or Done
                print(f"DONE: {done}")
                print(info.progress)
                print('-'*30)
                if done:
                    s1 : ProofState = state
                    pretty_print(s1, None, proof_step, done)
                    proof_was_finished = True
                else:
                    s1 : ProofState = state
                    s2 : ProofState = next_state
                    pretty_print(s1, s2, proof_step, done)
            assert proof_was_finished, "Proof was not finished"
            # Run the validation
            val_result = env.validate_proof_completion(timeout_in_secs=60, keep_validation_file=False)
            print("Validation Result:")
            print(val_result)
            assert val_result.get('success', False), f"Proof validation failed:\n{val_result.get('error_message', '')}"
            assert val_result.get('compilation_ok', False), f"Proof validation failed:\n{val_result.get('error_message', '')}"

    def test_simple_lean_enforce_done_test(self):
        from itp_interface.rl.proof_state import ProofState
        from itp_interface.rl.proof_action import ProofAction
        from itp_interface.rl.simple_proof_env import ProofEnv
        from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
        from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
        project_folder = "src/data/test/lean4_proj"
        file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
        # Build the project
        # cd src/data/test/lean4_proj && lake build
        helper = LeanHelper()
        helper.build_lean4_project(project_folder)
        language = ProofAction.Language.LEAN4
        theorem_name = "{\"namespace\":\"Lean4Proj1\",\"name\":\"test_calc\"}"
        # theorem test_calc (n: Nat) : n^2 + 2*n + 1 = (n + 1)*(n + 1) := by
        proof_exec_callback = ProofExecutorCallback(
            project_folder=project_folder,
            file_path=file_path,
            language=language,
            always_use_retrieval=False,
            keep_local_context=True,
            enforce_qed=True
        )
        always_retrieve_thms = False
        retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
        env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
        proof_steps = [
"""calc
  _ = n^2 + n*2 + 1 := by rw [Nat.mul_comm 2 n]
  _ = n^2 + (n + n) + 1 := by rw [Nat.mul_two]
  _ = n^2 + n + n + 1 := by rw [←Nat.add_assoc]
  _ = n*n + n + n + 1 := by rw [Nat.pow_two]
  _ = n*n + n*1 + n + 1 := by rw [Nat.mul_one n]
  _ = n*(n + 1) + n + 1 := by rw [Nat.left_distrib n n 1]
  _ = n*(n + 1) + (n + 1) := by rw [Nat.add_assoc]
  _ = n*(n + 1) + 1*(n + 1) := by rw (config := { occs := .pos [2]}) [←Nat.mul_one (n + 1), Nat.mul_comm]""",
"_ = (n + 1)*(n + 1) := by rw [Nat.right_distrib n 1 (n + 1)]",
"done"
]
        with env:
            env.set_max_proof_step_length(10000)
            proof_finished = False
            for proof_step in proof_steps:
                state, _, next_state, _, done, info = env.step(ProofAction(
                    ProofAction.ActionType.RUN_TACTIC,
                    language,
                    tactics=[proof_step]))
                if info.error_message is not None:
                    print(f"Error: {info.error_message}")
                # This prints StateChanged, StateUnchanged, Failed, or Done
                print(f"DONE: {done}")
                print(info.progress)
                print('-'*30)
                if done:
                    assert proof_step == "done", "Proof can only finish with done"
                    s1 : ProofState = state
                    pretty_print(s1, None, proof_step, done)
                    proof_finished = True
                else:
                    s1 : ProofState = state
                    s2 : ProofState = next_state
                    pretty_print(s1, s2, proof_step, done)
            assert proof_finished, "Proof was not finished"

    def test_simple_lean4_done_test(self):
        from itp_interface.rl.proof_state import ProofState
        from itp_interface.rl.proof_action import ProofAction
        from itp_interface.rl.simple_proof_env import ProofEnv
        from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
        from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
        project_folder = "src/data/test/lean4_proj"
        file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
        # Build the project
        # cd src/data/test/lean4_proj && lake build
        helper = LeanHelper()
        helper.build_lean4_project(project_folder)
        language = ProofAction.Language.LEAN4
        theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}'
        # theorem test3 (p q : Prop) (hp : p) (hq : q)
        # : p ∧ q ∧ p :=
        proof_exec_callback = ProofExecutorCallback(
            project_folder=project_folder,
            file_path=file_path,
            language=language,
            always_use_retrieval=False,
            keep_local_context=True,
            enforce_qed=True
        )
        always_retrieve_thms = False
        retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
        env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
        proof_steps = [
            'apply And.intro',
            'exact hp',
            'apply And.intro',
            'exact hq',
            'done'
        ]
        with env:
            for proof_step in proof_steps:
                state, _, next_state, _, done, info = env.step(ProofAction(
                    ProofAction.ActionType.RUN_TACTIC,
                    language,
                    tactics=[proof_step]))
                if info.error_message is not None:
                    print(f"Error: {info.error_message}")
                # This prints StateChanged, StateUnchanged, Failed, or Done
                print(info.progress)
                print('-'*30)
                if done:
                    raise Exception("Proof should not have finished")
                else:
                    s1 : ProofState = state
                    s2 : ProofState = next_state
                    pretty_print(s1, s2, proof_step, done)

    def test_simple_lean4_have_test(self):
        from itp_interface.rl.proof_state import ProofState
        from itp_interface.rl.proof_action import ProofAction
        from itp_interface.rl.simple_proof_env import ProofEnv
        from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
        from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
        project_folder = "src/data/test/lean4_proj"
        file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
        # Build the project
        # cd src/data/test/lean4_proj && lake build
        helper = LeanHelper()
        helper.build_lean4_project(project_folder)
        language = ProofAction.Language.LEAN4
        theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"imo_1959_p1\"}'
        # theorem test3 (p q : Prop) (hp : p) (hq : q)
        # : p ∧ q ∧ p :=
        proof_exec_callback = ProofExecutorCallback(
            project_folder=project_folder,
            file_path=file_path,
            language=language,
            always_use_retrieval=False,
            keep_local_context=True,
            enforce_qed=True
        )
        always_retrieve_thms = False
        retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
        env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
        proof_steps = [
'rw [Nat.gcd_rec]',
'rw [Nat.mod_eq_of_lt (by linarith)]',
'rw [Nat.gcd_rec]',
'rw [Nat.gcd_rec]',
'have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by',
'  have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring',
'  rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]',
'  have h₂ : 7 * n + 1 < 14 * n + 3 := by', 'linarith',
'  rw [Nat.mod_eq_of_lt]',
'  rw [Nat.mod_eq_of_lt]',
'  exact h₂',
'  rw [Nat.mod_eq_of_lt]',
'  exact h₂',
'  exact h₂',
'rw [eq₂]'
        ]
        with env:
            env.set_max_proof_step_length(10000)
            for proof_step in proof_steps:
                state, m_action, next_state, _, done, info = env.step(ProofAction(
                    ProofAction.ActionType.RUN_TACTIC,
                    language,
                    tactics=[proof_step]))
                if info.error_message is not None:
                    print(f"Error: {info.error_message}")
                if proof_step == 'linarith' and m_action is not None and isinstance(m_action, ProofAction) and m_action.kwargs.get('modified', False):
                    print("Modified action detected:")
                    print(m_action)
                    modified_tac = m_action.kwargs['tactics'][0]
                    assert modified_tac.lstrip() == 'linarith'
                    assert len(modified_tac) - len('linarith') == 4
                # This prints StateChanged, StateUnchanged, Failed, or Done
                print(info.progress)
                print('-'*30)
                if done:
                    raise Exception("Proof should not have finished")
                else:
                    s1 : ProofState = state
                    s2 : ProofState = next_state
                    pretty_print(s1, s2, proof_step, done)

    def test_simple_lean4_with_error(self):
        from itp_interface.rl.proof_state import ProofState
        from itp_interface.rl.proof_action import ProofAction
        from itp_interface.rl.simple_proof_env import ProofEnv
        from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
        from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
        project_folder = "src/data/test/lean4_proj"
        file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
        # Build the project
        # cd src/data/test/lean4_proj && lake build
        helper = LeanHelper()
        helper.build_lean4_project(project_folder)
        language = ProofAction.Language.LEAN4
        theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}'
        # theorem test3 (p q : Prop) (hp : p) (hq : q)
        # : p ∧ q ∧ p :=
        proof_exec_callback = ProofExecutorCallback(
            project_folder=project_folder,
            file_path=file_path,
            language=language,
            always_use_retrieval=False,
            keep_local_context=True
        )
        always_retrieve_thms = False
        retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
        env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
        proof_steps = [
            'apply And.intro',
            'exact hpx', # Error here
            'exact hp', # This should automatically work
            'apply And.intro',
            'exact hq',
            'exact hp'
        ]
        proof_finished = False
        with env:
            for i, proof_step in enumerate(proof_steps):
                state, _, next_state, _, done, info = env.step(ProofAction(
                    ProofAction.ActionType.RUN_TACTIC,
                    language,
                    tactics=[proof_step]))
                if info.error_message is not None:
                    print(f"Error: {info.error_message}")
                    print(f"Proof step {i + 1} failed")
                if i == 1:
                    assert info.error_message is not None, "Error was expected at step 2"
                else:
                    assert info.error_message is None, f"Error was not expected at step {i + 1}"
                # This prints StateChanged, StateUnchanged, Failed, or Done
                print(info.progress)
                print('-'*30)
                if done:
                    print("Proof Finished!!")
                    proof_finished = True
                else:
                    s1 : ProofState = state
                    s2 : ProofState = next_state
                    pretty_print(s1, s2, proof_step, done)
            assert proof_finished, "Proof was not finished"

    def test_simple_lean4_multiline_multigoal(self):
        from itp_interface.rl.proof_state import ProofState
        from itp_interface.rl.proof_action import ProofAction
        from itp_interface.rl.simple_proof_env import ProofEnv
        from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
        from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
        project_folder = "src/data/test/lean4_proj"
        file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
        # Build the project
        # cd src/data/test/lean4_proj && lake build
        helper = LeanHelper()
        helper.build_lean4_project(project_folder)
        language = ProofAction.Language.LEAN4
        theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"complicated_have\"}'
        # theorem complicated_have
        #   (a b c d e f : ℕ)
        #   (h1 : a + b = c)
        #   (h2 : d + e = f) :
        #   a + b + d + e = c + f
        #   ∧ a + d + b + e = c + f := by
        proof_exec_callback = ProofExecutorCallback(
            project_folder=project_folder,
            file_path=file_path,
            language=language,
            always_use_retrieval=False,
            keep_local_context=True
        )
        always_retrieve_thms = False
        retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
        env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
        proof_steps = [
            'apply And.intro <;> have h3 : a + b + d + e = c + f := by grind;',
            'exact h3 ; grind'
        ]
        with env:
            proof_was_finished = False
            for proof_step in proof_steps:
                state, action, next_state, _, done, info = env.step(ProofAction(
                    ProofAction.ActionType.RUN_TACTIC,
                    language,
                    tactics=[proof_step]))
                proof_step = action.kwargs.get('tactics', ['INVALID'])[0]
                if info.error_message is not None:
                    print(f"Error: {info.error_message}")
                # This prints StateChanged, StateUnchanged, Failed, or Done
                print(info.progress)
                print('-'*30)
                if done:
                    pretty_print(next_state, None, proof_step, done)
                    proof_was_finished = True
                else:
                    s1 : ProofState = state
                    s2 : ProofState = next_state
                    pretty_print(s1, s2, proof_step, done)
            assert proof_was_finished, "Proof was not finished"


def main():
    unittest.main()


if __name__ == '__main__':
    main()
