#
#               MuPIF: Multi-Physics Integration Framework 
#                   Copyright (C) 2010 Borek Patzak
#
#       Czech Technical University, Faculty of Civil Engineering,
#       Department of Mechanics, 166 29 Prague, Czech Republic
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program; if not, write to the Free Software
#    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
#

import bbox
import cell
from numpy import array, arange, random, zeros
import copy

class FieldType:
    """
    Represent the supported values of FieldType. 
    """
    vertex_based = 1
    cell_based   = 2

class FieldValueType:
    """
    Represent the supported values of Field value types. 
    """
    scalar = 1
    vector = 2
    tensor = 3

class FieldID:
    """ 
    Datermines the physical meaning of field values
    """
    FID_Displacement = 1
    FID_StrainTensor = 2
    FID_StressTensor = 3
    FID_Temperature = 4
    FID_Humidity = 5

class Field:
    """
    Representation of field. Field is generraly a mapping, that assigns 
    a value to points in space.
    The field values stored in self.__values numpy array. The index is either vertex or cell id.
    In case of cell-based fields, the value may be list (or tuple) containing values of individual IPs.
    Then a dictionary with mapping from cell, IP to array index needs to be established.
    """
    
    def __init__(self, mesh, field_id, field_type, value_type, values=None):
        self.mesh = mesh
        self.field_id = field_id
        self.field_type = field_type
        self.value_type = value_type
        self.uri = None   #pyro uri; used in distributed setting
        if values == None:
            if (self.field_type == FieldType.vertex_based):
                ncomponents = mesh.giveNumberOfVertices()
            else:
                ncomponents = mesh.giveNumberOfCells()
            if value_type == FieldValueType.scalar:
                recsize = 1
            elif value_type == FieldValueType.vector:
                recsize = 3
            elif value_type == FieldValueType.tensor:
                recsize = 9
            else:
                raise TypeError("Unknown FieldValueType")
            self.values=zeros((ncomponents, recsize))
        else:
            self.values = values


    def giveValueType(self):
        return self.value_type

    def giveMesh(self):
        return self.mesh

    def giveFieldType(self):
        return self.field_type

    def giveFieldID(self):
        return self.field_id

    def asArray(self):
        return self.values

    def evaluate(self, position, eps=0.001):
        
        cells = self.mesh.giveCellLocalizer().giveItemsInBBox(bbox.BBox([ c-eps for c in position], [c+eps for c in position]))
        if len(cells):
            for icell in cells:
                try:
                    if icell.containsPoint(position):
                        if (self.field_type == FieldType.vertex_based):
                            answer = icell.interpolate(position, [self.values[i] for i in icell.giveVertices()])
                            return answer
                        else:
                            answer = icell.interpolateIP(position, IPValues)
                            return answer

                except ZeroDivisionError:
                    print icell.number, position,
                    cell.debug=1
                    print icell.containsPoint(position), icell.glob2loc(position)

            print "Field evaluate -no source cell found for position ",position
            for icell in cells:
                print icell.number, icell.containsPoint(position), icell.glob2loc(position)

            raise ValueError
                
        else:
            #no source cell found
            print "Field evaluate - no source cell found for position ",position
            raise ValueError
        
    def evaluateAtVertex (self, vertexID):
        if (self.field_type == FieldType.vertex_based):
            return self.giveValue(vertexID)
        

    def giveValue(self, componentID):
        """
        Returns the value associated to given component (vertex or cell IP).
        The component is a tuple: (vertexID,) or (CellID, IPID)
        """
        return self.values[componentID]

    def setValue(self, componentID, value):
        """
        Sets the value associated to given component (vertex or cell IP).
        The componentID is a tuple: (vertexID,) or (CellID, IPID)
        
        ToDo:
        If mesh has mapping attached (it is a mesh view) then we have to remember value locally and record change.
        The source field values are updated after commit() method is invoked.
        """
        #print "Field setting value ",componentID, "to ", value
        self.values[componentID] = value

    def commit(self):
        """
        Commits the recorded changes (via setValue method) to primary field.
        """

    def merge(self, field):
        """
        Merges the receiver with given field together. The both fields should be on different parts of the domain (can also overlap),
        but should refer to same underlying discretization, otherwise unpredictable results can occur.
        """
        # first merge meshes 
        mesh = copy.deepcopy(self.mesh)
        mesh.merge(field.mesh)
        print mesh
        # merge the field values 
        # some type checking first
        if (self.field_type != field.field_type):
            raise TypeError("Field::merge: field_type of receiver and parameter is different")
        if (self.field_type == FieldType.vertex_based):
            values=[0]*mesh.giveNumberOfVertices()
            for v in xrange(self.mesh.giveNumberOfVertices()):
                values[mesh.vertexLabel2Number(self.mesh.giveVertex(v).label)]=self.values[v]
            for v in xrange(field.mesh.giveNumberOfVertices()):
                values[mesh.vertexLabel2Number(field.mesh.giveVertex(v).label)]=field.values[v]
        else:
            values=[0]*mesh.giveNumberOfCells()
            for v in xrange(self.mesh.giveNumberOfCells()):
                values[mesh.cellLabel2Number(self.mesh.giveCell(v).label)]=self.values[v]
            for v in xrange(field.mesh.giveNumberOfCells()):
                values[mesh.cellLabel2Number(field.mesh.giveCell(v).label)]=field.values[v]

        self.mesh=mesh
        self.values=values
            
                


            
        

        
class FieldView(Field):
    """
    Field view represents a subset of master feild, defined by mesh view and its mapping context.
    """
    def __init__(self, meshView, masterField, mirror=False):
        self.mesh = meshView
        self.master = masterField
        self.field_id = masterField.giveFieldID()
        self.field_type = masterField.giveFieldType()
        self.value_type = masterField.giveValueType()
        self.mirrorFlag = mirror
        if self.mirrorFlag:
            if (self.field_type == FieldType.vertex_based):
                size = meshView.giveNumberOfVertices()
            else:
                size = meshView.giveNumberOfCells()
            self.values = [None]*size
        else:
            self.values = None


    def update(self):
        """Updates receiver values (if mirroring)"""
        if self.mirrorFlag:
            if (self.field_type == FieldType.vertex_based):
                size = self.mesh.giveNumberOfVertices()
            else:
                size = self.mesh.giveNumberOfCells()
            for i in range(size):
                id = self.localComponentID2MasterID(i)
                self.values[i] = self.master.giveValue(id)
                

    def evaluate(self, position):
        if self.mirrorFlag:
            return Field.evaluate(self, position)
        else:
            return self.master.evaluate(position)

    def evaluateAtVertex (self, vertexID):
        if self.mirrorFlag:
            return Field.evaluateAtVertex(self, vertexID)
        else:
            #vertexID is the local number
            masterVertexID = self.mesh.giveMapping().giveMasterVertexNumber(vertexID)
            return self.master.evaluateAtVertex(masterVertexID)
    
    def giveValue(self, componentID):
        if self.mirrorFlag:
            return Field.giveValue(self, componentID)
        else:
            id = self.localComponentID2MasterID(componentID)
            return self.giveValue(id)

    def setValue(self, componentID, value):
        if self.mirrorFlag:
            Field.setValue(self, componentID, value)
        else:
            id = self.localComponentID2MasterID(componentID)
            return self.setValue(id, value)
         
    def commit(self):
        if not self.mirrorFlag:
            self.master.commit()

    
    def localComponentID2MasterID (self, componentID):
        if self.field_type ==  FieldType.vertex_based:
            id = self.mesh.giveMapping().giveMasterVertexNumber(componentID)[0]
        else:
            id = self.mesh.giveMapping().giveMasterCellNumber(componentID)[0]
        return id
      
