#
#               MuPIF: Multi-Physics Integration Framework 
#                   Copyright (C) 2010 Borek Patzak
#
#       Czech Technical University, Faculty of Civil Engineering,
#       Department of Mechanics, 166 29 Prague, Czech Republic
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program; if not, write to the Free Software
#    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
#

import oofemlib
from  mupif import mesh
from  mupif import field
from  mupif import vertex
from  mupif import cell
from  mupif.api import api
from numpy import array, arange, random, zeros

#debug flag
debug = 0

class OofemUnstructuredMesh(mesh.UnstructuredMesh):
    """
    This class implements mupif.UnstructuredMesh interface to oofem.Domain class.
    The oofem::mesh object is an attribute of this class.
    """

    def __init__(self, oofemlibDomain):
        # call __init__ of parent
        super(OofemUnstructuredMesh, self).__init__()
        #mesh.UnstructuredMesh.__init__(self)

        self.__domain = oofemlibDomain
        for i in range (self.__domain.giveNumberOfDofManagers()):
            _dman = self.__domain.giveDofManager(i+1)
            if (_dman.hasCoordinates()):
                self.vertexList.append(vertex.Vertex(i, _dman.giveLabel(), (_dman.giveCoordinate(1), _dman.giveCoordinate(2), _dman.giveCoordinate(3))))
            else:
                self.vertexList.append(vertex.Vertex(i, ()))
                                       
        for i in range (self.__domain.giveNumberOfElements()):
            _elem = self.__domain.giveElement(i+1)
            _gtype=_elem.giveGeometryType()
            if (_gtype == oofemlib.Element_Geometry_Type.EGT_quad_1):
                elem = cell.Quad_2d_lin(self, i, _elem.giveLabel(), 
                                        (_elem.giveDofManagerNumber(1)-1, 
                                         _elem.giveDofManagerNumber(2)-1,
                                         _elem.giveDofManagerNumber(3)-1,
                                         _elem.giveDofManagerNumber(4)-1))
            elif (_gtype == oofemlib.Element_Geometry_Type.EGT_triangle_1):
                elem = cell.Triangle_2d_lin(self, i, _elem.giveLabel(), 
                                            (_elem.giveDofManagerNumber(1)-1, 
                                             _elem.giveDofManagerNumber(2)-1,
                                             _elem.giveDofManagerNumber(3)-1))
            elif (_gtype == oofemlib.Element_Geometry_Type.EGT_tetra_1):
                elem = cell.Tetrahedron_3d_lin(self, i, _elem.giveLabel(), 
                                               (_elem.giveDofManagerNumber(1)-1, 
                                                _elem.giveDofManagerNumber(2)-1,
                                                _elem.giveDofManagerNumber(3)-1,
                                                _elem.giveDofManagerNumber(4)-1))
                
            else:
                raise TypeError("Unsupported EGT Type encountered")

            self.cellList.append(elem)

        

class OOFEM_API(api.API):
    """
    This class implements oofem API.
    """
    def __init__(self, inputfile):
        dr=oofemlib.OOFEMTXTDataReader(inputfile);
        self.problem=oofemlib.InstanciateProblem(dr,oofemlib.problemMode._processor,0);
        self.problem.checkProblemConsistency();

    def giveMesh (self, tstep):
        return OofemUnstructuredMesh(self.problem.giveDomain(1))

    def giveField(self, fieldID, tstep):
        d = self.problem.giveDomain(1)
        mymesh = OofemUnstructuredMesh(d)
        #assume 2d problem
        if (fieldID == field.FieldID.FID_Displacement):
            f = field.Field(mymesh, fieldID, field.FieldType.vertex_based, field.FieldValueType.vector)
            nsd = d.giveNumberOfSpatialDimensions()
            mask=oofemlib.IntArray(nsd)
            if nsd == 1:
                mask[0]=oofemlib.DofIDItem.D_u
            elif nsd == 2:
                mask[0]=oofemlib.DofIDItem.D_u
                mask[1]=oofemlib.DofIDItem.D_v
            elif nsd == 3:
                mask[0]=oofemlib.DofIDItem.D_u
                mask[1]=oofemlib.DofIDItem.D_v
                mask[2]=oofemlib.DofIDItem.D_w
            
            #set field 
            dl=oofemlib.FloatArray(nsd)
            dg=oofemlib.FloatArray(nsd)
            t=oofemlib.FloatMatrix(nsd,nsd)
            oofem_domain = self.problem.giveDomain(1)
            for vertex in mymesh.vertices():
            #print vertex.number
                dman=oofem_domain.giveDofManager(vertex.number+1)
                dman.giveUnknownVector(dl, mask, oofemlib.EquationID.EID_MomentumBalance, oofemlib.ValueModeType.VM_Total, self.problem.giveCurrentStep())
                if (dman.computeL2GTransformation(t, mask)) :
                    dg.beProductOf(t,dl);
                if dg.giveSize() == 1:
                    val=array((dg[0], 0.0 , 0.0))
                elif dg.giveSize() == 2:
                    val=array((dg[0], dg[1], 0.0))
                elif dg.giveSize() == 3:
                    val=array((dg[0], dg[1], dg[2]))
                else:
                    val=array((0, 0, 0.0))
                #print vertex.number, dg.giveSize(), "--->", val
                f.setValue(vertex.number, val)

        elif (fieldID == field.FieldID.FID_Temperature):
            f = field.Field(mymesh, fieldID, field.FieldType.vertex_based, field.FieldValueType.scalar)
            mask=oofemlib.IntArray(1)
            mask[0]=oofemlib.DofIDItem.T_f
            dl=oofemlib.FloatArray(1)
            oofem_domain = self.problem.giveDomain(1)
            for vertex in mymesh.vertices():
            #print vertex.number
                dman=oofem_domain.giveDofManager(vertex.number+1)
                dman.giveUnknownVector(dl, mask, oofemlib.EquationID.EID_ConservationEquation, oofemlib.ValueModeType.VM_Total, self.problem.giveCurrentStep())
                val=array((dl[0],))
                #print vertex.number, dg.giveSize(), "--->", val
                f.setValue(vertex.number, val)
        else:
            raise TypeError("Unsupported field ID encountered")
            
           
        # ok return created field, but how to retun an error code?
        return f 
      
    def updateField(self, field, tstep):
        return

    def registerField(self, remoteField, tstep):
        d = self.problem.giveDomain(1)
        # create empty oofem.dofmanvalfield defined on problem domain
        target = oofemlib.createDofManValueFieldPtr(self.fieldID2FieldType(remoteField.giveFieldID()), d)
        print "Created target field ...", target
        mymesh = OofemUnstructuredMesh(d)
        #project given field into target
        for vertex in mymesh.vertices():
            if debug: print "Projecting vertex ",vertex.number, "coords ",vertex.coords
            ans = remoteField.evaluate(vertex.coords) #tuple
            #transform tuple to flotarray
            val = oofemlib.FloatArray(len(ans))
            for i in range(len(ans)):
                val[i]=ans[i]
            target.setDofManValue(vertex.number+1, val)
        fieldid = remoteField.giveFieldID()
        if (fieldid ==  field.FieldID.FID_Displacement):
            ftype = oofemlib.FieldType.FT_Displacements
        elif (fieldid == field.FieldID.FID_Temperature):
            ftype = oofemlib.FieldType.FT_Temperature
        else:
            raise TypeError("Unsupported field ID encountered")
        print "Registering target field ..."
        #self.problem.giveContext().giveFieldManager().registerField(target,ftype, True)
        oofemlib.FieldManager_registerField(self.problem.giveContext().giveFieldManager(), target, ftype, True)
        return

    def solve(self, tstep):
        #self.problem.solveYourself()
        oofemtstep = self.problem.giveNextStep()
        #overwrite tstep parameters
        oofemtstep.setTimeIncrement(tstep.dt)
        oofemtstep.setTargetTime(tstep.time)
        oofemtstep.setIntrinsicTime(tstep.time)
        self.problem.solveYourselfAt(oofemtstep)
        self.problem.updateYourself(oofemtstep)
        self.problem.terminate(oofemtstep)
        
    def giveCriticalTimeStep(self):
        return

    def fieldID2FieldType (self, fieldID):
        if fieldID == field.FieldID.FID_Displacement:
            return oofemlib.FieldType.FT_Displacements
        elif fieldID == field.FieldID.FID_Temperature:
            return oofemlib.FieldType.FT_Temperature
        else:
            raise ValueError
