import matplotlib.pyplot as plt
import matplotlib.tri as tri
#from mpl_toolkits.mplot3d import axes3d, Axes3D
import numpy as np
from itertools import compress 
import nemesis
from mesh_plot import *

# # converts quad elements into tri elements
# def quads_to_tris(quads):
#     tris = [[None for j in range(3)] for i in range(2*len(quads))]
#     for i in range(len(quads)):
#         j = 2*i
#         n0 = quads[i][0]
#         n1 = quads[i][1]
#         n2 = quads[i][2]
#         n3 = quads[i][3]
#         tris[j][0] = n0
#         tris[j][1] = n1
#         tris[j][2] = n2
#         tris[j + 1][0] = n2
#         tris[j + 1][1] = n3
#         tris[j + 1][2] = n0
#     return tris

# # plots a finite element mesh
# def plot_fem_mesh(nodes_x, nodes_y, elements):
#     for element in elements:
#         x = [nodes_x[element[i]] for i in range(len(element))]
#         y = [nodes_y[element[i]] for i in range(len(element))]
#         plt.fill(x, y, edgecolor='black', fill=False)

# def plot_fem_mesh3D(nodes_x, nodes_y, elements, linecolor, linestyle):
#     for element in elements:
#         x = [nodes_x[element[i]] for i in range(len(element))]
#         y = [nodes_y[element[i]] for i in range(len(element))]
#         ax.plot(x, y, color=linecolor, linestyle=linestyle)


# def plot_vertex_dofs(nodes_x, nodes_y, nodes_z,elements):
#     for j in range(len(nodes_x)):
#         x = [nodes_x[j] for i in range(len(element))]
#         y = [nodes_y[j] for i in range(len(element))]
#         z = [nodes_z[j] for i in range(len(element))]
#         ax.plot(x, y, z, color=linecolor, linestyle=linestyle)


# def add_edge_ids(nodes_x, nodes_y, nodes_z,elements):
#     for idx,element in enumerate(elements):
#         x = [nodes_x[element[i]] for i in range(len(element))]
#         y = [nodes_y[element[i]] for i in range(len(element))]
#         z = [nodes_z[element[i]] for i in range(len(element))]
#         ax.text(np.mean(x),np.mean(y),np.mean(z),idx)

# def add_face_ids(nodes_x, nodes_y, nodes_z,elements,element_tag,indices):
#     for idx,element in enumerate(elements):
#         x = [nodes_x[element[i]] for i in range(len(element))]
#         y = [nodes_y[element[i]] for i in range(len(element))]
#         z = [nodes_z[element[i]] for i in range(len(element))]
#         tag = element_tag+str(indices[idx])
#         ax.text(np.mean(x),np.mean(y),tag)



# def add_dof_numbers(nodes_x,nodes_y,elements,dof_numbers):
#     for idx,element in enumerate(elements):
#         x    = [nodes_x[element[i]] for i in range(len(element))]
#         y    = [nodes_y[element[i]] for i in range(len(element))]
        
#         [xp,yp] = get_locations(x,y,[5,5])

#         #dofs = [dof_numbers[0] for i in range(len(element))]

#         xp = xp.flatten(); yp = yp.flatten()

#         for point in range(len(xp)):
#             ax.text(xp[point],yp[point],dof_numbers[idx][point],horizontalalignment='center',verticalalignment='center')

# def plot_face(nodes_x,nodes_y,elements):
#     for idx,element in enumerate(elements):
#         x = [nodes_x[element[i]] for i in range(len(element))]
#         y = [nodes_y[element[i]] for i in range(len(element))]
#         ax.fill(x,y,color='red',alpha=0.1)

# # Mesh data
# nodes_x      = [0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 1.0, 1.5, 2.0, 1.0, 1.5, 2.0, 1.0, 1.5, 2.0]
# nodes_y      = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0]
# nodes_z      = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

# all_vertices    = [[0],[1],[2],[3],[4],[5],[6],[7],[8],[9],[10],[11],[12],[13],[14]]
# active_vertices = np.array( [1,1,0,1,1,0,0,1,1,0,1,1,0,1,1],dtype='bool' )

# nodal_values = np.zeros( np.size( nodes_x ) )

# edges_L0     = [[0,1],[1,2],[3,4],[4,5],[0,3],[1,4],[2,5]]
# edges_L1     = [[6,7],[7,8],[9,10],[10,11],[12,13],[13,14],[6,9],[9,12],[7,10],[10,13],[8,11],[11,14]]
# all_edges    = edges_L0 + edges_L1
# active_edges = np.ones(np.size(all_edges),dtype='bool' )
# active_edges[1] = False
# active_edges[3] = False
# active_edges[6] = False
# active_edges[13] = False
# active_edges[14] = False

# # faces_L0  = [[0, 1, 4, 3], [1, 2, 5, 4]]
# faces_L1  = [[6, 7,10, 9], [7,8,11,10], [9,10,13,12], [10,11,14,13]]
# all_faces = faces_L0 + faces_L1

# elements_all_tris = quads_to_tris(all_faces)

# # create an unstructured triangular grid instance
# triangulation = tri.Triangulation(nodes_x, nodes_y, elements_all_tris)

#3D Plotting
fig = plt.figure(figsize=(12, 12))

with plt.xkcd():
    plt.suptitle('Degree of freedom numbering (C0 basis)')

    ax = plt.subplot(2, 2, 1)
    ax.axis('equal')
    plt.subplots_adjust(top=0.90,bottom=0.075,left=0.075,right=0.925,hspace=0.15,wspace=0.15)
    
    # plot the original mesh
    active_vertices = np.array( [1,1,1,1,1,1,0,0,0,0,0,0,0,0,0],dtype='bool' )
    active_edges = np.ones(np.size(all_edges),dtype='bool' )
    active_edges[7:] = False
    active_faces = np.array( [1,1,0,0,0,0],dtype='bool' )

    plot_face_2D(ax,nodes_x, nodes_y,list(compress(all_faces, active_faces)) )
    plot_fem_mesh_2D(ax,nodes_x, nodes_y, list(compress(all_edges, active_edges)),'red','-')
    plot_vertices_2D(ax,nodes_x, nodes_y, list(compress(all_vertices, active_vertices)),'r','r')

    ax.set_title('Before refinement')

    dof_numbers = [[1,5,6,7,2,14,17,18,19,8,15,20,21,22,9,16,23,24,25,10,4,11,12,13,3],
                   [2,28,29,30,26,8,37,38,39,31,9,40,41,42,32,10,43,44,45,33,3,34,35,36,27]]
    add_dof_numbers(ax,nodes_x,nodes_y,list(compress(all_faces, active_faces)),dof_numbers)

    ax.set_xticks([0,0.5,1,1.5,2.0])
    ax.set_yticks([0,0.5,1])
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_ylim([-0.15,1.15])

    # plot the refined mesh (L0)
    ax = plt.subplot(2, 2, 2)
    ax.axis('equal')
    ax.set_title('After refinement')

    active_vertices   = np.array( [1,1,0,1,1,0,0,0,0,0,0,0,0,0,0],dtype='bool' )
    inactive_vertices = np.array( [0,0,1,0,0,1,0,0,0,0,0,0,0,0,0],dtype='bool' )

    full_edges = np.ones(np.shape(all_edges)[0] ,dtype='bool' )
    full_edges[1] = False
    full_edges[3] = False
    full_edges[6:] = False
    dashed_edges = np.zeros(np.shape(all_edges)[0] ,dtype='bool' )
    dashed_edges[0] = False
    dashed_edges[1] = True
    dashed_edges[3] = True
    dashed_edges[6] = True
    dashed_edges[7:] = False
    active_faces = np.array( [1,0,0,0,0,0],dtype='bool' )

    plot_face_2D(ax,nodes_x, nodes_y, list(compress(all_faces, active_faces)) )
    plot_fem_mesh_2D(ax,nodes_x, nodes_y, list(compress(all_edges, full_edges)),'red','-')
    plot_fem_mesh_2D(ax,nodes_x, nodes_y, list(compress(all_edges, dashed_edges)),'red','--')
    plot_vertices_2D(ax,nodes_x, nodes_y, list(compress(all_vertices, active_vertices)),'r','r')
    plot_vertices_2D(ax,nodes_x, nodes_y, list(compress(all_vertices, inactive_vertices)),'r','white')

    active_faces = np.array( [1,1,0,0,0,0],dtype='bool' )
    dof_numbers = [[1,5,6,7,2,14,17,18,19,8,15,20,21,22,9,16,23,24,25,10,4,11,12,13,3],
                   [2,28,29,30,26,8,37,38,39,31,9,40,41,42,32,10,43,44,45,33,3,34,35,36,27]]
    add_dof_numbers(ax,nodes_x,nodes_y,list(compress(all_faces, active_faces)),dof_numbers)
    
    #add_face_ids(nodes_x, nodes_y, nodes_z, all_vertices,'',[1,2,4,6,7,9,3,11,5,13,15,14,8,12,10])

    ax.set_xticks([0,0.5,1,1.5,2.0])
    ax.set_yticks([0,0.5,1])
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_ylim([-0.15,1.15])

    an1 = ax.annotate("Level 0", xy=(1.1, 0.5), xycoords=ax.transAxes,
                  va="center", ha="center",rotation=90,
                  bbox=dict(boxstyle="round", fc="w"))

    anA = ax.annotate("A", xy=(1, -0.1),va="center", ha="center")
    anB = ax.annotate("B", xy=(2, -0.1),va="center", ha="center")
    anC = ax.annotate("C", xy=(2,  1.1),va="center", ha="center")
    anD = ax.annotate("D", xy=(1,  1.1),va="center", ha="center")

    # plot the refined mesh (L1)
    ax = plt.subplot(2, 2, 4)
    ax.axis('equal')

    active_vertices   = np.array( [0,0,0,0,0,0,0,1,1,0,1,1,0,1,1],dtype='bool' )
    inactive_vertices = np.array( [0,0,0,0,0,0,1,0,0,1,0,0,1,0,0],dtype='bool' )

    full_edges = np.zeros(np.shape(all_edges)[0] ,dtype='bool' )
    full_edges[7:13] = True
    full_edges[15:] = True
    dashed_edges = np.zeros(np.shape(all_edges)[0] ,dtype='bool' )
    dashed_edges[13:15] = True
    active_faces = np.array( [0,0,1,1,1,1],dtype='bool' )

    plot_face_2D(ax,nodes_x, nodes_y, list(compress(all_faces, active_faces)) )
    plot_fem_mesh_2D(ax,nodes_x, nodes_y, list(compress(all_edges, full_edges)),'red','-')
    plot_fem_mesh_2D(ax,nodes_x, nodes_y, list(compress(all_edges, dashed_edges)),'red','--')
    plot_vertices_2D(ax,nodes_x, nodes_y, list(compress(all_vertices, active_vertices)),'r','r')
    plot_vertices_2D(ax,nodes_x, nodes_y, list(compress(all_vertices, inactive_vertices)),'r','white')

    anA = ax.annotate("A", xy=(1, -0.05),va="center", ha="center")
    anB = ax.annotate("B", xy=(2, -0.05),va="center", ha="center")
    anC = ax.annotate("C", xy=(2, 1.05),va="center", ha="center")
    anD = ax.annotate("D", xy=(1, 1.05),va="center", ha="center")

    dof_numbers = [#[1,5,6,7,2,14,17,18,19,8,15,20,21,22,9,16,23,24,25,10,4,11,12,13,3],
                   #[2,28,29,30,26,8,37,38,39,31,9,40,41,42,32,10,43,44,45,33,3,34,35,36,27],
                   [2,48,49,50,46,8,57,58,59,51,9,60,61,62,52,10,63,64,65,53,3,54,55,56,47],
                   [46,67,68,69,26,51,76,77,78,70,52,79,80,81,71,53,82,83,84,72,47,73,74,75,66],
                   [2,54,55,56,47,8,92,93,94,86,9,95,96,97,87,10,98,99,100,88,3,89,90,91,85],
                   [47,73,74,75,66,86,107,108,109,101,87,110,111,112,102,88,113,114,115,103,85,104,105,106,27]]
    add_dof_numbers(ax,nodes_x,nodes_y,list(compress(all_faces, active_faces)),dof_numbers)
    
    #add_face_ids(nodes_x, nodes_y, nodes_z, all_vertices,'',[1,2,4,6,7,9,3,11,5,13,15,14,8,12,10])

    ax.set_xticks([0,0.5,1,1.5,2.0])
    ax.set_yticks([0,0.5,1])
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_xlim([0.8,2.2])
    ax.set_ylim([-0.2,1.2])

    an1 = ax.annotate("Level 1", xy=(1.1, 0.5), xycoords=ax.transAxes,
                  va="center", ha="center",rotation=90,
                  bbox=dict(boxstyle="round", fc="w"))


fig.canvas.draw()
plt.show()