Actual source code: mpimattransposematmult.c
1: /*
2: Defines matrix-matrix product routines for pairs of MPIAIJ matrices
3: C = A^T * B
4: The routines are slightly modified from MatTransposeMatMultxxx_SeqAIJ_SeqDense().
5: */
6: #include <../src/mat/impls/aij/seq/aij.h>
7: #include <../src/mat/impls/aij/mpi/mpiaij.h>
8: #include <../src/mat/impls/dense/mpi/mpidense.h>
10: static PetscErrorCode MatDestroy_MPIDense_MatTransMatMult(void *data)
11: {
12: Mat_MatTransMatMult *atb = (Mat_MatTransMatMult *)data;
14: PetscFunctionBegin;
15: PetscCall(MatDestroy(&atb->mA));
16: PetscCall(VecDestroy(&atb->bt));
17: PetscCall(VecDestroy(&atb->ct));
18: PetscCall(PetscFree(atb));
19: PetscFunctionReturn(PETSC_SUCCESS);
20: }
22: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat, Mat, Mat);
24: PETSC_INTERN PetscErrorCode MatTransposeMatMultSymbolic_MPIAIJ_MPIDense(Mat A, Mat B, PetscReal fill, Mat C)
25: {
26: Mat_MatTransMatMult *atb;
27: PetscBool cisdense;
29: PetscFunctionBegin;
30: MatCheckProduct(C, 4);
31: PetscCheck(!C->product->data, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Extra product struct not empty");
33: /* create output dense matrix C = A^T*B */
34: PetscCall(MatSetSizes(C, A->cmap->n, B->cmap->n, A->cmap->N, B->cmap->N));
35: PetscCall(PetscObjectTypeCompareAny((PetscObject)C, &cisdense, MATMPIDENSE, MATMPIDENSECUDA, ""));
36: if (!cisdense) PetscCall(MatSetType(C, ((PetscObject)B)->type_name));
37: PetscCall(MatSetUp(C));
39: /* create additional data structure for the product */
40: PetscCall(PetscNew(&atb));
41: if (B->cmap->N) {
42: PetscCall(MatCreateMAIJ(A, B->cmap->N, &atb->mA));
43: if (!atb->mA->assembled) {
44: PetscCall(MatAssemblyBegin(atb->mA, MAT_FINAL_ASSEMBLY));
45: PetscCall(MatAssemblyEnd(atb->mA, MAT_FINAL_ASSEMBLY));
46: }
47: PetscCall(MatCreateVecs(atb->mA, &atb->ct, &atb->bt));
48: }
49: C->product->data = atb;
50: C->product->destroy = MatDestroy_MPIDense_MatTransMatMult;
52: C->ops->transposematmultnumeric = MatTransposeMatMultNumeric_MPIAIJ_MPIDense;
53: PetscFunctionReturn(PETSC_SUCCESS);
54: }
56: static PetscErrorCode MatTransposeMatMultNumeric_MPIAIJ_MPIDense(Mat A, Mat B, Mat C)
57: {
58: const PetscScalar *Barray, *ctarray;
59: PetscScalar *Carray, *btarray;
60: PetscInt i, j, m = A->rmap->n, n = A->cmap->n, ldb, BN = B->cmap->N, ldc;
61: Mat_MatTransMatMult *atb;
62: Vec bt, ct;
64: PetscFunctionBegin;
65: MatCheckProduct(C, 3);
66: atb = (Mat_MatTransMatMult *)C->product->data;
67: PetscCheck(atb, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing product struct");
68: if (!BN) {
69: PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
70: PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
71: PetscFunctionReturn(PETSC_SUCCESS);
72: }
73: bt = atb->bt;
74: ct = atb->ct;
76: /* transpose local array of B, then copy it to vector bt */
77: PetscCall(MatDenseGetArrayRead(B, &Barray));
78: PetscCall(MatDenseGetLDA(B, &ldb));
79: PetscCall(VecGetArray(bt, &btarray));
80: for (j = 0; j < BN; j++)
81: for (i = 0; i < m; i++) btarray[i * BN + j] = Barray[j * ldb + i];
82: PetscCall(VecRestoreArray(bt, &btarray));
83: PetscCall(MatDenseRestoreArrayRead(B, &Barray));
85: /* compute ct = mA^T * cb */
86: PetscCall(MatMultTranspose(atb->mA, bt, ct));
88: /* transpose local array of ct to matrix C */
89: PetscCall(MatDenseGetArray(C, &Carray));
90: PetscCall(MatDenseGetLDA(C, &ldc));
91: PetscCall(VecGetArrayRead(ct, &ctarray));
92: for (j = 0; j < BN; j++)
93: for (i = 0; i < n; i++) Carray[j * ldc + i] = ctarray[i * BN + j];
94: PetscCall(VecRestoreArrayRead(ct, &ctarray));
95: PetscCall(MatDenseRestoreArray(C, &Carray));
96: PetscCall(MatSetOption(C, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
97: PetscCall(MatAssemblyBegin(C, MAT_FINAL_ASSEMBLY));
98: PetscCall(MatAssemblyEnd(C, MAT_FINAL_ASSEMBLY));
99: PetscFunctionReturn(PETSC_SUCCESS);
100: }