"""Testing module morph."""

import numpy as np

from treem import SWC, DGram, Morph, Node, get_segdata


def test_node_str():
    """Tests for Node string representation."""
    node = Node(value=1)
    assert str(node) == '1'


def test_node_stem_leaf():
    """Tests for Node stem and leaf attributes."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 1, 2]]))
    stem = [node for node in morph.root.walk() if not node.is_root()][0]
    leaf = [node for node in morph.root.walk() if not node.is_root()][-1]
    assert stem.is_stem()
    assert leaf.is_leaf()


def test_node_order():
    """Tests for Node order attribute."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 1, 2],
                                 [4, 3, 1, 2, 0, 1, 2]]))
    orders = [0, 1, 2, 2]
    assert [node.order() for node in morph.root.walk()] == orders


def test_node_point():
    """Tests for Node data point."""
    node = Node(value=np.array([1, 1, 0, 0, 0, 1, -1]))
    assert all(node.point() == np.array([0, 0, 0, 1]))


def test_node_dist():
    """Tests for Node distance to origin."""
    node = Node(value=np.array([1, 1, 1, 0, 0, 1, -1]))
    assert np.isclose(node.dist(), 1.0)


def test_morph_node():
    """Tests for Morph node lookup."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1]]))
    node = morph.node(1)
    assert node.ident() == 1


def test_morph_insert():
    """Tests for Morph node insertion."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 1, 0, 0, 0, 1, 1]]))
    new_node = Node(value=np.array([0, 1, 0, 0, 0, 1, 0]))
    morph.insert(new_node, morph.node(2))
    assert morph.node(2).ident() == 2
    assert morph.node(3).ident() == 3
    assert morph.root.size() == 3


def test_morph_delete():
    """Tests for Morph node deletion."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 1, 2]]))
    morph.delete(morph.node(2))
    assert morph.data.tolist() == [[1, 1, 0, 0, 0, 1, -1], [2, 3, 2, 0, 0, 1, 1]]


def test_morph_prune():
    """Tests for Morph branch pruning."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 1, 2]]))
    morph.prune(morph.node(2))
    assert morph.data.tolist() == [[1, 1, 0, 0, 0, 1, -1]]

def test_move_node():
    """Tests moving a node in morphology."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1], [2, 3, 1, 0, 0, 1, 1]]))
    node = list(morph.root.leaves())[0]
    morph.move([1, 1, 1], node)
    assert morph.data.tolist() == [[1, 1, 0, 0, 0, 1, -1], [2, 3, 2, 1, 1, 1, 1]]


def test_node_area():
    """Tests node area."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1], [2, 3, 1, 0, 0, 1, 1]]))
    node = list(morph.root.leaves())[0]
    area = np.pi * node.diam() * node.length()
    assert node.area() == area


def test_node_volume():
    """Tests node volume."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1], [2, 3, 1, 0, 0, 1, 1]]))
    node = list(morph.root.leaves())[0]
    volume = np.pi * node.radius()**2 * node.length()
    assert node.volume() == volume


def test_sec_radii():
    """Tests section radii."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 2, 2]]))
    stem = list(morph.stems())[0]
    sec = list(stem.sections())[0]
    radii = [[1], [2]]
    assert morph.radii(sec).tolist() == radii


def test_sec_points():
    """Tests section points data."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 2, 2]]))
    stem = list(morph.stems())[0]
    sec = list(stem.sections())[0]
    points = [[1, 0, 0, 1], [2, 0, 0, 2]]
    assert morph.points(sec).tolist() == points


def test_sec_area():
    """Tests section area."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 1, 2]]))
    node = list(morph.root.leaves())[0]
    sec = list(node.sections())[0]
    assert np.isclose(morph.area(sec), node.area())


def test_sec_volume():
    """Tests section volume."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 1, 2]]))
    node = list(morph.root.leaves())[0]
    sec = list(node.sections())[0]
    assert np.isclose(morph.volume(sec), node.volume())


def test_save(tmp_path):
    """Tests saving morphology to SWC file."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 1, 2]]))
    morph.save(tmp_path / 'test_treem.json')


def test_segdata():
    """Tests for extended segment morphometric data."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 1, 2]]))
    data = [[1, 1, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 1, 2],
            [2, 3, 1, 0, 0, 1, 1, 1, 1, 1, 0.5, 1, 1, 1, 1, 1],
            [3, 3, 2, 0, 0, 1, 2, 1, 2, 2, 1, 2, 0, 1, 1, 0]]
    assert np.allclose(get_segdata(morph), data)


def test_segdata_branching():
    """Tests for segment data in branching morphology."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 1, 2],
                                 [4, 3, 1, 2, 0, 1, 2]]))
    data = [[1, 1, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 2, 4],
            [2, 3, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 3],
            [3, 3, 2, 0, 0, 1, 2, 1, 2, 1, 1, 2, 0, 2, 1, 0],
            [4, 3, 1, 2, 0, 1, 2, 2, 3, 2, 1, 2.23606798, 0, 2, 1, 0]]
    assert np.allclose(get_segdata(morph), data)


def test_dgram_init():
    """Tests for dendrogram initialization."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 1, 2],
                                 [4, 3, 1, 2, 0, 1, 2]]))
    dgram = DGram()
    assert dgram.data is None
    dgram = DGram(morph.copy())
    data =  [[1, 1, 0, 2, 0, 1, -1],
             [2, 3, 0, 2, 0, 1, 1],
             [3, 3, 0, 1, 0, 1, 2],
             [4, 3, 0, 3, 0, 1, 2]]
    assert np.allclose(dgram.data.tolist(), data)
    dgram = DGram(data=morph.data)
    assert np.allclose(dgram.data.tolist(), data)


def test_dgram():
    """Tests for dendrogram representation of morphology."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 1, 2],
                                 [4, 3, 1, 2, 0, 1, 2]]))
    dgram = DGram(morph)
    data =  [[1, 1, 0, 2, 0, 1, -1],
             [2, 3, 0, 2, 0, 1, 1],
             [3, 3, 0, 1, 0, 1, 2],
             [4, 3, 0, 3, 0, 1, 2]]
    assert np.allclose(dgram.data.tolist(), data)


def test_dgram_prune():
    """Tests for dendrogram of selected neurite types."""
    morph = Morph(data=np.array([[1, 1, 0, 0, 0, 1, -1],
                                 [2, 3, 1, 0, 0, 1, 1],
                                 [3, 3, 2, 0, 0, 1, 2],
                                 [4, 3, 1, 2, 0, 1, 2],
                                 [5, 2, 0, 1, 0, 1, 1]]))
    dgram = DGram(morph, types=[SWC.SOMA, SWC.DEND])
    data =  [[1, 1, 0, 1, 0, 1, -1],
             [2, 3, 0, 1, 0, 1, 1],
             [3, 3, 0, 1, 0, 1, 2],
             [4, 3, 0, 2, 0, 1, 2]]
    assert np.allclose(dgram.data.tolist(), data)
