import matplotlib.pyplot as plt
import matplotlib.tri as tri
from mpl_toolkits.mplot3d import axes3d, Axes3D
import numpy as np
from itertools import compress 
from mesh_plot import *

#3D Plotting
fig = plt.figure(figsize=(12, 15))

with plt.xkcd():
    plt.suptitle('Geometric element numbering')

    #ax = plt.axes(projection="3d")
    for dimension in range(0,3):
        ax = plt.subplot(3, 2, 2*dimension+1, projection="3d")
        plt.subplots_adjust(top=0.925,bottom=0.05,left=0.05,right=0.95,hspace=0.1,wspace=0.1)
        
        # plot the original mesh
        active_vertices = np.array( [1,1,1,1,1,1,0,0,0,0,0,0,0,0,0],dtype='bool' )
        active_edges = np.ones(np.size(all_edges),dtype='bool' )
        active_edges[7:] = False
        active_faces = np.array( [1,1,0,0,0,0],dtype='bool' )

        plot_vertices(ax,nodes_x, nodes_y, nodes_z, list(compress(all_vertices, active_vertices)),'r','r')
        plot_fem_mesh3D(ax,nodes_x, nodes_y, nodes_z, list(compress(all_edges, active_edges)),'red','-')

        if (dimension==0):
            ax.set_title('Before refinement')
            # ax.annotate('Vertices', xy=(0.5, 0.5), xytext=(0, 0),
            #     xycoords='figure points', rotation = 90, textcoords='offset points',
            #     size='large', ha='right', va='center')
            an1 = ax.annotate("Vertice ids", xy=(0, 0.5), xycoords=ax.transAxes,
                  va="center", ha="center",rotation=90,
                  bbox=dict(boxstyle="round", fc="w"))
            add_face_ids(ax,nodes_x, nodes_y, nodes_z, list(compress(all_vertices, active_vertices)),'',[1,2,3,4,5,6])
        elif (dimension==1):
            an1 = ax.annotate("Edge ids", xy=(0, 0.5), xycoords=ax.transAxes,
                  va="center", ha="center",rotation=90,
                  bbox=dict(boxstyle="round", fc="w"))
            add_face_ids(ax,nodes_x, nodes_y, nodes_z, list(compress(all_edges, active_edges)),'',[1,2,3,4,5,6,7])
        elif (dimension==2):
            an1 = ax.annotate("Face ids", xy=(0, 0.5), xycoords=ax.transAxes,
                  va="center", ha="center",rotation=90,
                  bbox=dict(boxstyle="round", fc="w"))
            add_face_ids(ax,nodes_x, nodes_y, nodes_z, list(compress(all_faces, active_faces)),'',[1,2])

        plot_face_3D(ax,nodes_x, nodes_y, nodes_z, list(compress(all_faces, active_faces)) )

        ax.set_xticks([0,0.5,1,1.5,2.0])
        ax.set_yticks([0,0.5,1])
        ax.set_zticks([0,1,2,3])
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('refinement level')
        ax.set_zlim([0,1])


        # plot the refined mesh
        ax = plt.subplot(3, 2, 2*dimension+2, projection="3d")
        active_vertices = np.array( [1,1,0,1,1,0,0,1,1,0,1,1,0,1,1],dtype='bool' )
        active_edges = np.ones(np.size(all_edges),dtype='bool' )
        active_edges[1] = False
        active_edges[3] = False
        active_edges[6] = False
        active_edges[13] = False
        active_edges[14] = False
        active_faces = np.array( [1,0,1,1,1,1],dtype='bool' )

        plot_vertices(ax,nodes_x, nodes_y, nodes_z, list(compress(all_vertices, active_vertices)),'r','r')
        plot_vertices(ax,nodes_x, nodes_y, nodes_z, list(compress(all_vertices,~active_vertices)),'r','white')

        plot_fem_mesh3D(ax,nodes_x, nodes_y, nodes_z, list(compress(all_edges, active_edges)),'red','-')
        plot_fem_mesh3D(ax,nodes_x, nodes_y, nodes_z, list(compress(all_edges, ~active_edges)),'red','--')
        plot_fem_mesh3D(ax,nodes_x, nodes_y, nodes_z, node_connections,'black','-.')

        if (dimension==0):
            ax.set_title('After refinement')
            add_face_ids(ax,nodes_x, nodes_y, nodes_z, all_vertices,'',[1,2,4,6,7,9,3,11,5,13,15,14,8,12,10])
        elif (dimension==1):
            add_face_ids(ax,nodes_x, nodes_y, nodes_z, all_edges,'',[1,2,5,6,9,10,13,3,4,16,17,7,8,11,12,18,19,14,15])
        elif (dimension==2):
            add_face_ids(ax,nodes_x, nodes_y, nodes_z, all_faces,'',[1,2,3,4,5,6])

        plot_face_3D(ax,nodes_x, nodes_y, nodes_z, list(compress(all_faces, active_faces)) )

        ax.set_xticks([0,0.5,1,1.5,2.0])
        ax.set_yticks([0,0.5,1])
        ax.set_zticks([0,1,2,3])
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('refinement level')
        ax.set_zlim([0,1])

    # # vertices
    # ax.text(nodes_x[0]+0.05,nodes_y[1]+0.05,nodes_z[0],'3 4\n1 2','x',horizontalalignment='left',verticalalignment='bottom')
    # #ax.text(nodes_x[0]+0.1,nodes_y[1]+0.1,nodes_z[0],' 1  2','x',horizontalalignment='left',verticalalignment='top')
    # ax.text(nodes_x[1]-0.1,nodes_y[1]+0.1,nodes_z[1],' 7  8','x',horizontalalignment='right',verticalalignment='bottom')
    # ax.text(nodes_x[1]-0.1,nodes_y[1]+0.1,nodes_z[1],' 5  6','x',horizontalalignment='right',verticalalignment='top')
    # ax.text(nodes_x[3]+0.1,nodes_y[3]-0.1,nodes_z[3],'15 16','x',horizontalalignment='left',verticalalignment='bottom')
    # ax.text(nodes_x[3]+0.1,nodes_y[3]-0.1,nodes_z[3],'13 14','x',horizontalalignment='left',verticalalignment='top')
    # ax.text(nodes_x[4]-0.1,nodes_y[4]-0.1,nodes_z[4],'11 12','x',horizontalalignment='right',verticalalignment='bottom')
    # ax.text(nodes_x[4]-0.1,nodes_y[4]-0.1,nodes_z[4],' 9 10','x',horizontalalignment='right',verticalalignment='top')

    # edges


    # plot the contours
 #   plt.tricontourf(triangulation, nodal_values)

    # show
    # plt.colorbar()
    # plt.axis('equal')


fig.canvas.draw()
plt.show()
