Actual source code: mpimattransposematmult.c

  1: /*
  2:   Defines matrix-matrix product routines for pairs of MPIAIJ matrices
  3:           C = A^T * B
  4:   The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
  5: */
  6: #include <../src/mat/impls/aij/seq/aij.h>
  7: #include <../src/mat/impls/aij/mpi/mpiaij.h>
  8: #include <../src/mat/impls/dense/mpi/mpidense.h>

 10: static PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data)
 11: {
 12:   Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data;

 14:   PetscFunctionBegin;
 15:   PetscCall(MatDestroy(&atb->mA));
 16:   PetscCall(VecDestroy(&atb->bt));
 17:   PetscCall(VecDestroy(&atb->ct));
 18:   PetscCall(PetscFree(atb));
 19:   PetscFunctionReturn(PETSC_SUCCESS);
 20: }

 22: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat);

 24: PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C)
 25: {
 26:   Mat_MatTransMatMult *atb;
 27:   PetscBool            cisdense;

 29:   PetscFunctionBegin;
 30:   MatCheckProduct(C, 4);
 31:   PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty");

 33:   /* create output dense matrix C = A^T*B */
 34:   PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N));
 35:   PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, ""));
 36:   if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name));
 37:   PetscCall(MatSetUp(C));

 39:   /* create additional data structure for the product */
 40:   PetscCall(PetscNew(&atb));
 41:   if (B->cmap->N) {
 42:     PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA));
 43:     if (!atb->mA->assembled) {
 44:       PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY));
 45:       PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY));
 46:     }
 47:     PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt));
 48:   }
 49:   C->product->data    = atb;
 50:   C->product->destroy = MatDestroy_MPIDense_MatTransMatMult;

 52:   C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
 53:   PetscFunctionReturn(PETSC_SUCCESS);
 54: }

 56: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C)
 57: {
 58:   const PetscScalar   *Barray, *ctarray;
 59:   PetscScalar         *Carray, *btarray;
 60:   PetscInt             i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc;
 61:   Mat_MatTransMatMult *atb;
 62:   Vec                  bt, ct;

 64:   PetscFunctionBegin;
 65:   MatCheckProduct(C, 3);
 66:   atb = (Mat_MatTransMatMult *)C->product->data;
 67:   PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct");
 68:   if (!BN) {
 69:     PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
 70:     PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
 71:     PetscFunctionReturn(PETSC_SUCCESS);
 72:   }
 73:   bt = atb->bt;
 74:   ct = atb->ct;

 76:   /* transpose local array of B, then copy it to vector bt */
 77:   PetscCall(MatDenseGetArrayRead(B, &Barray));
 78:   PetscCall(MatDenseGetLDA(B, &ldb));
 79:   PetscCall(VecGetArray(bt, &btarray));
 80:   for (j = 0; j < BN; j++)
 81:     for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i];
 82:   PetscCall(VecRestoreArray(bt, &btarray));
 83:   PetscCall(MatDenseRestoreArrayRead(B, &Barray));

 85:   /* compute ct = mA^T * cb */
 86:   PetscCall(MatMultTranspose(atb->mA, bt, ct));

 88:   /* transpose local array of ct to matrix C */
 89:   PetscCall(MatDenseGetArray(C, &Carray));
 90:   PetscCall(MatDenseGetLDA(C, &ldc));
 91:   PetscCall(VecGetArrayRead(ct, &ctarray));
 92:   for (j = 0; j < BN; j++)
 93:     for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j];
 94:   PetscCall(VecRestoreArrayRead(ct, &ctarray));
 95:   PetscCall(MatDenseRestoreArray(C, &Carray));
 96:   PetscCall(MatSetOption(C, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
 97:   PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
 98:   PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
 99:   PetscFunctionReturn(PETSC_SUCCESS);
100: }