#
#               MuPIF: Multi-Physics Integration Framework 
#                   Copyright (C) 2010 Borek Patzak
#
#       Czech Technical University, Faculty of Civil Engineering,
#       Department of Mechanics, 166 29 Prague, Czech Republic
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program; if not, write to the Free Software
#    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
#

import sys
import mefel 

from  mupif import mesh
from  mupif import field
from  mupif import vertex
from  mupif import cell
from  mupif.api import api
from numpy import array, arange, random, zeros

#debug flag
debug = 0
#print mefel.__dict__

class MEFEL_API(api.API):
    """
    This class implements sifel API.
    """
    def __init__(self, inputfile, kwd):
        self.sd = mefel.stochdriver();
        self.mt =mefel.mt_glob_vec();
        mefel.mefel_init_3(sys.argv[0], inputfile, kwd, self.sd);
        mefel.visco_solver_init(0, self.mt); # arg(1) = load case number
        self.mesh = None

    def __del__(self):
        return
    #mefel.print_close();

    def giveMesh (self, tstep):
        if self.mesh:
            return self.mesh
        else:
            self.mesh = self.__giveMesh(tstep)
            return self.mesh


    def solve(self, tstep):
        mefel.one_step(0, tstep.time, tstep.dt, tstep.number, self.mt); #arg(1) load case number, arg(2) time, arg(3) time increment, arg(4) step number
        
    def giveField(self, fieldID, tstep):
        m = self.giveMesh(tstep)
        fvaluetype = None
        #print "huhu"
        if fieldID == field.FieldID.FID_Displacement:
            vec=mefel.vector(3)
            fvaluetype = field.FieldValueType.vector
            nnodes=mefel.giveNumberOfNodes()
            fieldVals = zeros((nnodes,3))
            for inode in range(nnodes):
                #print "node", inode
                mefel.give_noddispl(inode, vec)
                fieldVals[inode]=array((vec[0], vec[1], vec[2]))
        elif (fieldID == field.FieldID.FID_StrainTensor) or (fieldID == field.FieldID.FID_StressTensor):
            vec=mefel.vector(9) #full tensor expected
            fvaluetype = field.FieldValueType.tensor
            nnodes=mefel.giveNumberOfNodes()
            fieldVals = zeros((nnodes,9))
            for inode in range(nnodes):
                #print "node", inode," ",
                if fieldID == field.FieldID.FID_StrainTensor:
                    mefel.give_nodstrain(inode, vec)
                else:
                    mefel.give_nodstress(inode, vec)
                #print vec
                #vec.p()
                if vec.size() == 9:
                    strain = array([vec[i] for i in range(9)])
                else:
                    raise TypeError("Unknown stress/strain mode")
                fieldVals[inode]=strain
            #print fieldVals
        else:
            raise TypeError("Unknown fieldID")

        return field.Field(m,fieldID,field.FieldType.vertex_based, fvaluetype,fieldVals)
        
    def registerField(self, remoteField, tstep):
        mymesh = self.giveMesh(tstep)
        fieldid = remoteField.giveFieldID()
        if (fieldid == field.FieldID.FID_Temperature):
            nmq = mefel.nonmechquant.temperature
        else:
            raise TypeError("Unsupported field ID encountered")
        
        #print "Register Field:"
        val = mefel.vector(mymesh.giveNumberOfVertices())
        zeroval = mefel.vector(mymesh.giveNumberOfVertices())
        for vertex in mymesh.vertices():
            ans = remoteField.evaluate(vertex.coords) #tuple
            val[vertex.number]=ans[0] # take only first value (scalar field assumed)
        #print "Mapped values"
        #val.p()
        #zeroval.p()
        mefel.initIPValues(val, nmq, 1.0)
        mefel.initIPValues(zeroval, mefel.nonmechquant.initial_temperature, 1.0) #set up initial (reference) temparature


    def __giveMesh (self, tstep):
        m = mesh.UnstructuredMesh()
        # set up vertex list
        vlist = []
        clist = []
        coords = mefel.vector(3)
        nnodes=mefel.giveNumberOfNodes()
        for inode in range(nnodes):
            mefel.give_nodal_coord(inode,coords);
            nc=(coords[0], coords[1], coords[2])
            vlist.append(vertex.Vertex(inode, inode, nc))
            nelem = mefel.giveNumberOfElements()
        for ielem in range(nelem):
            type = mefel.give_elem_type(ielem)
            if type == mefel.elemtype.planeelementlq:
                vertices=mefel.ivector(4)
                mefel.giveElementConnectivity(ielem,vertices)
                vl=(vertices[0], vertices[1], vertices[2], vertices[3])
                clist.append(cell.Quad_2d_lin(m, ielem, ielem, vl))
        if debug:
            print vlist
            print clist
        m.setup (vlist, clist)
        return m


    def giveCriticalTimeStep(self):
        return

    def fieldID2FieldType (self, fieldID):
        if fieldID == field.FieldID.FID_Displacement:
            return oofemlib.FieldType.FT_Displacements
        elif fieldID == field.FieldID.FID_Temperature:
            return oofemlib.FieldType.FT_Temperature
        else:
            raise ValueError
