import os
import random

import pytest

from nasbenchapi import NASBench101, NASBench201, NASBench301, Arch101


def test_nb201_query_returns_numeric_on_test_or_val():
    api = NASBench201(os.getenv('NASBENCH201_PATH'))
    samples = api.random_sample(n=5, seed=123)
    assert isinstance(samples, list) and len(samples) > 0

    # NB201 wrapper returns strings per latest implementation
    for arch_str in samples:
        assert isinstance(arch_str, str)
        res_test = api.query(arch_str, dataset='cifar10', split='test', budget=199)
        assert isinstance(res_test, dict)
        # metric may be None if test not available; then val should be tried
        if res_test['metric'] is None:
            res_val = api.query(arch_str, dataset='cifar10', split='val', budget=199)
            assert isinstance(res_val['metric'], (float, type(None)))
        # info should include arch_str
        assert 'arch_str' in res_test['info'] or 'arch_str' in api.query(arch_str)['info']


def test_nb101_roundtrip_and_query_shape():
    api = NASBench101(os.getenv('NASBENCH101_PATH'))
    samples = api.random_sample(n=3, seed=0)
    for arch in samples:
        enc = api.encode(arch)
        arch2 = api.decode(enc)
        assert isinstance(enc, dict)
        assert hasattr(arch2, 'adjacency') and hasattr(arch2, 'operations')
        info, metrics = api.query(arch2, dataset='cifar10', split='val')
        assert isinstance(info, dict)
        assert isinstance(metrics, dict)
        # NB101 budgets are subset of [4, 12, 36, 108]
        assert set(metrics.keys()).issubset({4, 12, 36, 108})
        if metrics:
            # Each budget should provide list of up to three runs
            runs = next(iter(metrics.values()))
            assert isinstance(runs, list) and runs
            sample = runs[0]
            assert 'final_validation_accuracy' in sample
            assert 'halfway_train_accuracy' in sample

        averaged_info, averaged_metrics = api.query(arch2, dataset='cifar10', split='val', average=True)
        assert isinstance(averaged_metrics, dict)
        if averaged_metrics:
            avg_sample = next(iter(averaged_metrics.values()))
            assert isinstance(avg_sample, dict)
            assert 'final_test_accuracy' in avg_sample

        summary = api.query(arch2, dataset='cifar10', split='val', summary=True)
        assert set(summary.keys()) == {'metric', 'metric_name', 'cost', 'std', 'info'}


def test_nb301_query_placeholder_shape():
    api = NASBench301(os.getenv('NASBENCH301_PATH'))
    samples = api.random_sample(n=2, seed=7)
    for arch in samples:
        res = api.query(arch, dataset='cifar10', split='val')
        assert set(res.keys()) == {'metric', 'metric_name', 'cost', 'std', 'info'}
