////////////////////////////////////////////////////////////////////////////////////////
//
//  Copyright 2023 OVITO GmbH, Germany
//
//  This file is part of OVITO (Open Visualization Tool).
//
//  OVITO is free software; you can redistribute it and/or modify it either under the
//  terms of the GNU General Public License version 3 as published by the Free Software
//  Foundation (the "GPL") or, at your option, under the terms of the MIT License.
//  If you do not alter this notice, a recipient may use your version of this
//  file under either the GPL or the MIT License.
//
//  You should have received a copy of the GPL along with this program in a
//  file LICENSE.GPL.txt.  You should have received a copy of the MIT License along
//  with this program in a file LICENSE.MIT.txt
//
//  This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND,
//  either express or implied. See the GPL or the MIT License for the specific language
//  governing rights and limitations.
//
////////////////////////////////////////////////////////////////////////////////////////

#include <ovito/crystalanalysis/CrystalAnalysis.h>
#include <ovito/crystalanalysis/objects/ClusterGraphObject.h>
#include <ovito/core/utilities/concurrent/ParallelFor.h>
#include <ovito/core/dataset/pipeline/ModificationNode.h>
#include <ovito/core/dataset/DataSet.h>
#include "ElasticStrainEngine.h"
#include "ElasticStrainModifier.h"

namespace Ovito {

/******************************************************************************
* Constructor.
******************************************************************************/
ElasticStrainEngine::ElasticStrainEngine(
        const ModifierEvaluationRequest& request,
        ParticleOrderingFingerprint fingerprint,
        ConstPropertyPtr positions, const SimulationCell* simCell,
        int inputCrystalStructure, std::vector<Matrix3> preferredCrystalOrientations,
        bool calculateDeformationGradients, bool calculateStrainTensors,
        FloatType latticeConstant, FloatType caRatio, bool pushStrainTensorsForward) :
    StructureIdentificationModifier::StructureIdentificationEngine(request, std::move(fingerprint), positions, simCell, {}),
    _structureAnalysis(std::make_unique<StructureAnalysis>(positions, simCell, (StructureAnalysis::LatticeStructureType)inputCrystalStructure, selection(), structures(), std::move(preferredCrystalOrientations))),
    _inputCrystalStructure(inputCrystalStructure),
    _latticeConstant(latticeConstant),
    _pushStrainTensorsForward(pushStrainTensorsForward),
    _volumetricStrains(Particles::OOClass().createUserProperty(DataBuffer::Uninitialized, positions->size(), DataBuffer::FloatDefault, 1, QStringLiteral("Volumetric Strain"))),
    _strainTensors(calculateStrainTensors ? Particles::OOClass().createStandardProperty(DataBuffer::Uninitialized, positions->size(), Particles::ElasticStrainTensorProperty) : nullptr),
    _deformationGradients(calculateDeformationGradients ? Particles::OOClass().createStandardProperty(DataBuffer::Uninitialized, positions->size(), Particles::ElasticDeformationGradientProperty) : nullptr)
{
    setAtomClusters(_structureAnalysis->atomClusters());
    if(inputCrystalStructure == StructureAnalysis::LATTICE_FCC || inputCrystalStructure == StructureAnalysis::LATTICE_BCC || inputCrystalStructure == StructureAnalysis::LATTICE_CUBIC_DIAMOND) {
        // Cubic crystal structures always have a c/a ratio of one.
        _axialScaling = 1;
    }
    else {
        // Convert to internal units.
        _latticeConstant *= sqrt(2.0);
        _axialScaling = caRatio / sqrt(8.0/3.0);
    }
}

/******************************************************************************
* Performs the actual analysis. This method is executed in a worker thread.
******************************************************************************/
void ElasticStrainEngine::perform()
{
    setProgressText(ElasticStrainModifier::tr("Calculating elastic strain tensors"));

    beginProgressSubStepsWithWeights({ 35, 6, 1, 1, 20 });
    if(!_structureAnalysis->identifyStructures())
        return;

    nextProgressSubStep();
    if(!_structureAnalysis->buildClusters())
        return;

    nextProgressSubStep();
    if(!_structureAnalysis->connectClusters())
        return;

    nextProgressSubStep();
    if(!_structureAnalysis->formSuperClusters())
        return;

    nextProgressSubStep();

    BufferReadAccess<Point3> positionsArray(positions());
    BufferWriteAccess<Matrix3, access_mode::discard_write> deformationGradientsArray(deformationGradients());
    BufferWriteAccess<SymmetricTensor2, access_mode::discard_write> strainTensorsArray(strainTensors());
    BufferWriteAccess<FloatType, access_mode::discard_write> volumetricStrainsArray(volumetricStrains());

    parallelForWithProgress(positions()->size(), [&](size_t particleIndex) {

        Cluster* localCluster = _structureAnalysis->atomCluster(particleIndex);
        if(localCluster->id != 0) {

            // The shape of the ideal unit cell.
            Matrix3 idealUnitCellTM(_latticeConstant, 0, 0,
                                    0, _latticeConstant, 0,
                                    0, 0, _latticeConstant * _axialScaling);

            // If the cluster is a defect (stacking fault), find the parent crystal cluster.
            Cluster* parentCluster = nullptr;
            if(localCluster->parentTransition != nullptr) {
                parentCluster = localCluster->parentTransition->cluster2;
                idealUnitCellTM = idealUnitCellTM * localCluster->parentTransition->tm;
            }
            else if(localCluster->structure == _inputCrystalStructure) {
                parentCluster = localCluster;
            }

            if(parentCluster != nullptr) {
                OVITO_ASSERT(parentCluster->structure == _inputCrystalStructure);

                // For calculating the cluster orientation.
                Matrix_3<double> orientationV = Matrix_3<double>::Zero();
                Matrix_3<double> orientationW = Matrix_3<double>::Zero();

                int numneigh = _structureAnalysis->numberOfNeighbors(particleIndex);
                for(int n = 0; n < numneigh; n++) {
                    int neighborAtomIndex = _structureAnalysis->getNeighbor(particleIndex, n);
                    // Add vector pair to matrices for computing the elastic deformation gradient.
                    Vector3 latticeVector = idealUnitCellTM * _structureAnalysis->neighborLatticeVector(particleIndex, n);
                    Vector3 spatialVector = positionsArray[neighborAtomIndex] - positionsArray[particleIndex];
                    if(cell()) spatialVector = cell()->wrapVector(spatialVector);
                    for(size_t i = 0; i < 3; i++) {
                        for(size_t j = 0; j < 3; j++) {
                            orientationV(i,j) += (double)(latticeVector[j] * latticeVector[i]);
                            orientationW(i,j) += (double)(latticeVector[j] * spatialVector[i]);
                        }
                    }
                }

                // Calculate deformation gradient tensor.
                Matrix_3<double> elasticF = orientationW * orientationV.inverse();
                if(deformationGradientsArray)
                    deformationGradientsArray[particleIndex] = elasticF.toDataType<FloatType>();

                // Calculate strain tensor.
                SymmetricTensor2T<double> elasticStrain;
                if(!_pushStrainTensorsForward) {
                    // Compute Green strain tensor in material frame.
                    elasticStrain = (Product_AtA(elasticF) - SymmetricTensor2T<double>::Identity()) * 0.5;
                }
                else {
                    // Compute Euler strain tensor in spatial frame.
                    Matrix_3<double> inverseF;
                    if(!elasticF.inverse(inverseF))
                        throw Exception(ElasticStrainModifier::tr("Cannot compute strain tensor in spatial reference frame, because the elastic deformation gradient at atom index %1 is singular.").arg(particleIndex+1));
                    elasticStrain = (SymmetricTensor2T<double>::Identity() - Product_AtA(inverseF)) * 0.5;
                }

                // Store strain tensor in output property.
                if(strainTensorsArray)
                    strainTensorsArray[particleIndex] = (SymmetricTensor2)elasticStrain;

                // Calculate volumetric strain component.
                double volumetricStrain = (elasticStrain(0,0) + elasticStrain(1,1) + elasticStrain(2,2)) / 3.0;
                OVITO_ASSERT(std::isfinite(volumetricStrain));
                volumetricStrainsArray[particleIndex] = static_cast<FloatType>(volumetricStrain);

                return;
            }
        }

        // Mark atom as invalid.
        volumetricStrainsArray[particleIndex] = 0;
        if(strainTensorsArray)
            strainTensorsArray[particleIndex] = SymmetricTensor2::Zero();
        if(deformationGradientsArray)
            deformationGradientsArray[particleIndex] = Matrix3::Zero();
    });

    endProgressSubSteps();

    // Release data that is no longer needed.
    releaseWorkingData();
    _structureAnalysis.reset();
}

/******************************************************************************
* Injects the computed results of the engine into the data pipeline.
******************************************************************************/
void ElasticStrainEngine::applyResults(const ModifierEvaluationRequest& request, PipelineFlowState& state)
{
    ElasticStrainModifier* modifier = static_object_cast<ElasticStrainModifier>(request.modifier());

    StructureIdentificationEngine::applyResults(request, state);

    // Output cluster graph.
    ClusterGraphObject* clusterGraphObj = state.createObject<ClusterGraphObject>(request.modificationNode());
    clusterGraphObj->setStorage(clusterGraph());

    // Output particle properties.
    Particles* particles = state.expectMutableObject<Particles>();
    particles->createProperty(atomClusters());
    if(modifier->calculateStrainTensors() && strainTensors())
        particles->createProperty(strainTensors());

    if(modifier->calculateDeformationGradients() && deformationGradients())
        particles->createProperty(deformationGradients());

    if(volumetricStrains())
        particles->createProperty(volumetricStrains());
}

}   // End of namespace
