Actual source code: mcomposite.c

  1: #include <../src/mat/impls/shell/shell.h>

  3: const char *const MatCompositeMergeTypes[] = {"left", "right", "MatCompositeMergeType", "MAT_COMPOSITE_", NULL};

  5: typedef struct _Mat_CompositeLink *Mat_CompositeLink;
  6: struct _Mat_CompositeLink {
  7:   Mat               mat;
  8:   Vec               work;
  9:   Mat_CompositeLink next, prev;
 10: };

 12: typedef struct {
 13:   MatCompositeType      type;
 14:   Mat_CompositeLink     head, tail;
 15:   Vec                   work;
 16:   PetscInt              nmat;
 17:   PetscBool             merge;
 18:   MatCompositeMergeType mergetype;
 19:   MatStructure          structure;

 21:   PetscScalar *scalings;
 22:   PetscBool    merge_mvctx; /* Whether need to merge mvctx of component matrices */
 23:   Vec         *lvecs;       /* [nmat] Basically, they are Mvctx->lvec of each component matrix */
 24:   PetscScalar *larray;      /* [len] Data arrays of lvecs[] are stored consecutively in larray */
 25:   PetscInt     len;         /* Length of larray[] */
 26:   Vec          gvec;        /* Union of lvecs[] without duplicated entries */
 27:   PetscInt    *location;    /* A map that maps entries in garray[] to larray[] */
 28:   VecScatter   Mvctx;
 29: } Mat_Composite;

 31: static PetscErrorCode MatDestroy_Composite(Mat mat)
 32: {
 33:   Mat_Composite    *shell;
 34:   Mat_CompositeLink next, oldnext;
 35:   PetscInt          i;

 37:   PetscFunctionBegin;
 38:   PetscCall(MatShellGetContext(mat, &shell));
 39:   next = shell->head;
 40:   while (next) {
 41:     PetscCall(MatDestroy(&next->mat));
 42:     if (next->work && (!next->next || next->work != next->next->work)) PetscCall(VecDestroy(&next->work));
 43:     oldnext = next;
 44:     next    = next->next;
 45:     PetscCall(PetscFree(oldnext));
 46:   }
 47:   PetscCall(VecDestroy(&shell->work));

 49:   if (shell->Mvctx) {
 50:     for (i = 0; i < shell->nmat; i++) PetscCall(VecDestroy(&shell->lvecs[i]));
 51:     PetscCall(PetscFree3(shell->location, shell->larray, shell->lvecs));
 52:     PetscCall(PetscFree(shell->larray));
 53:     PetscCall(VecDestroy(&shell->gvec));
 54:     PetscCall(VecScatterDestroy(&shell->Mvctx));
 55:   }

 57:   PetscCall(PetscFree(shell->scalings));
 58:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeAddMat_C", NULL));
 59:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetType_C", NULL));
 60:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetType_C", NULL));
 61:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMergeType_C", NULL));
 62:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetMatStructure_C", NULL));
 63:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMatStructure_C", NULL));
 64:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeMerge_C", NULL));
 65:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetNumberMat_C", NULL));
 66:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeGetMat_C", NULL));
 67:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatCompositeSetScalings_C", NULL));
 68:   PetscCall(PetscFree(shell));
 69:   PetscCall(PetscObjectComposeFunction((PetscObject)mat, "MatShellSetContext_C", NULL)); // needed to avoid a call to MatShellSetContext_Immutable()
 70:   PetscFunctionReturn(PETSC_SUCCESS);
 71: }

 73: static PetscErrorCode MatMult_Composite_Multiplicative(Mat A, Vec x, Vec y)
 74: {
 75:   Mat_Composite    *shell;
 76:   Mat_CompositeLink next;
 77:   Vec               out;

 79:   PetscFunctionBegin;
 80:   PetscCall(MatShellGetContext(A, &shell));
 81:   next = shell->head;
 82:   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
 83:   while (next->next) {
 84:     if (!next->work) { /* should reuse previous work if the same size */
 85:       PetscCall(MatCreateVecs(next->mat, NULL, &next->work));
 86:     }
 87:     out = next->work;
 88:     PetscCall(MatMult(next->mat, x, out));
 89:     x    = out;
 90:     next = next->next;
 91:   }
 92:   PetscCall(MatMult(next->mat, x, y));
 93:   if (shell->scalings) {
 94:     PetscScalar scale = 1.0;
 95:     for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
 96:     PetscCall(VecScale(y, scale));
 97:   }
 98:   PetscFunctionReturn(PETSC_SUCCESS);
 99: }

101: static PetscErrorCode MatMultTranspose_Composite_Multiplicative(Mat A, Vec x, Vec y)
102: {
103:   Mat_Composite    *shell;
104:   Mat_CompositeLink tail;
105:   Vec               out;

107:   PetscFunctionBegin;
108:   PetscCall(MatShellGetContext(A, &shell));
109:   tail = shell->tail;
110:   PetscCheck(tail, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
111:   while (tail->prev) {
112:     if (!tail->prev->work) { /* should reuse previous work if the same size */
113:       PetscCall(MatCreateVecs(tail->mat, NULL, &tail->prev->work));
114:     }
115:     out = tail->prev->work;
116:     PetscCall(MatMultTranspose(tail->mat, x, out));
117:     x    = out;
118:     tail = tail->prev;
119:   }
120:   PetscCall(MatMultTranspose(tail->mat, x, y));
121:   if (shell->scalings) {
122:     PetscScalar scale = 1.0;
123:     for (PetscInt i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
124:     PetscCall(VecScale(y, scale));
125:   }
126:   PetscFunctionReturn(PETSC_SUCCESS);
127: }

129: static PetscErrorCode MatMult_Composite(Mat mat, Vec x, Vec y)
130: {
131:   Mat_Composite     *shell;
132:   Mat_CompositeLink  cur;
133:   Vec                y2, xin;
134:   Mat                A, B;
135:   PetscInt           i, j, k, n, nuniq, lo, hi, mid, *gindices, *buf, *tmp, tot;
136:   const PetscScalar *vals;
137:   const PetscInt    *garray;
138:   IS                 ix, iy;
139:   PetscBool          match;

141:   PetscFunctionBegin;
142:   PetscCall(MatShellGetContext(mat, &shell));
143:   cur = shell->head;
144:   PetscCheck(cur, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");

146:   /* Try to merge Mvctx when instructed but not yet done. We did not do it in MatAssemblyEnd() since at that time
147:      we did not know whether mat is ADDITIVE or MULTIPLICATIVE. Only now we are assured mat is ADDITIVE and
148:      it is legal to merge Mvctx, because all component matrices have the same size.
149:    */
150:   if (shell->merge_mvctx && !shell->Mvctx) {
151:     /* Currently only implemented for MATMPIAIJ */
152:     for (cur = shell->head; cur; cur = cur->next) {
153:       PetscCall(PetscObjectTypeCompare((PetscObject)cur->mat, MATMPIAIJ, &match));
154:       if (!match) {
155:         shell->merge_mvctx = PETSC_FALSE;
156:         goto skip_merge_mvctx;
157:       }
158:     }

160:     /* Go through matrices first time to count total number of nonzero off-diag columns (may have dups) */
161:     tot = 0;
162:     for (cur = shell->head; cur; cur = cur->next) {
163:       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, NULL));
164:       PetscCall(MatGetLocalSize(B, NULL, &n));
165:       tot += n;
166:     }
167:     PetscCall(PetscMalloc3(tot, &shell->location, tot, &shell->larray, shell->nmat, &shell->lvecs));
168:     shell->len = tot;

170:     /* Go through matrices second time to sort off-diag columns and remove dups */
171:     PetscCall(PetscMalloc1(tot, &gindices)); /* No Malloc2() since we will give one to petsc and free the other */
172:     PetscCall(PetscMalloc1(tot, &buf));
173:     nuniq = 0; /* Number of unique nonzero columns */
174:     for (cur = shell->head; cur; cur = cur->next) {
175:       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray));
176:       PetscCall(MatGetLocalSize(B, NULL, &n));
177:       /* Merge pre-sorted garray[0,n) and gindices[0,nuniq) to buf[] */
178:       i = j = k = 0;
179:       while (i < n && j < nuniq) {
180:         if (garray[i] < gindices[j]) buf[k++] = garray[i++];
181:         else if (garray[i] > gindices[j]) buf[k++] = gindices[j++];
182:         else {
183:           buf[k++] = garray[i++];
184:           j++;
185:         }
186:       }
187:       /* Copy leftover in garray[] or gindices[] */
188:       if (i < n) {
189:         PetscCall(PetscArraycpy(buf + k, garray + i, n - i));
190:         nuniq = k + n - i;
191:       } else if (j < nuniq) {
192:         PetscCall(PetscArraycpy(buf + k, gindices + j, nuniq - j));
193:         nuniq = k + nuniq - j;
194:       } else nuniq = k;
195:       /* Swap gindices and buf to merge garray of the next matrix */
196:       tmp      = gindices;
197:       gindices = buf;
198:       buf      = tmp;
199:     }
200:     PetscCall(PetscFree(buf));

202:     /* Go through matrices third time to build a map from gindices[] to garray[] */
203:     tot = 0;
204:     for (cur = shell->head, j = 0; cur; cur = cur->next, j++) { /* j-th matrix */
205:       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, NULL, &B, &garray));
206:       PetscCall(MatGetLocalSize(B, NULL, &n));
207:       PetscCall(VecCreateSeqWithArray(PETSC_COMM_SELF, 1, n, NULL, &shell->lvecs[j]));
208:       /* This is an optimized PetscFindInt(garray[i],nuniq,gindices,&shell->location[tot+i]), using the fact that garray[] is also sorted */
209:       lo = 0;
210:       for (i = 0; i < n; i++) {
211:         hi = nuniq;
212:         while (hi - lo > 1) {
213:           mid = lo + (hi - lo) / 2;
214:           if (garray[i] < gindices[mid]) hi = mid;
215:           else lo = mid;
216:         }
217:         shell->location[tot + i] = lo; /* gindices[lo] = garray[i] */
218:         lo++;                          /* Since garray[i+1] > garray[i], we can safely advance lo */
219:       }
220:       tot += n;
221:     }

223:     /* Build merged Mvctx */
224:     PetscCall(ISCreateGeneral(PETSC_COMM_SELF, nuniq, gindices, PETSC_OWN_POINTER, &ix));
225:     PetscCall(ISCreateStride(PETSC_COMM_SELF, nuniq, 0, 1, &iy));
226:     PetscCall(VecCreateMPIWithArray(PetscObjectComm((PetscObject)mat), 1, mat->cmap->n, mat->cmap->N, NULL, &xin));
227:     PetscCall(VecCreateSeq(PETSC_COMM_SELF, nuniq, &shell->gvec));
228:     PetscCall(VecScatterCreate(xin, ix, shell->gvec, iy, &shell->Mvctx));
229:     PetscCall(VecDestroy(&xin));
230:     PetscCall(ISDestroy(&ix));
231:     PetscCall(ISDestroy(&iy));
232:   }

234: skip_merge_mvctx:
235:   PetscCall(VecSet(y, 0));
236:   if (!((Mat_Shell *)mat->data)->left_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)mat->data)->left_work)));
237:   y2 = ((Mat_Shell *)mat->data)->left_work;

239:   if (shell->Mvctx) { /* Have a merged Mvctx */
240:     /* Suppose we want to compute y = sMx, where s is the scaling factor and A, B are matrix M's diagonal/off-diagonal part. We could do
241:        in y = s(Ax1 + Bx2) or y = sAx1 + sBx2. The former incurs less FLOPS than the latter, but the latter provides an opportunity to
242:        overlap communication/computation since we can do sAx1 while communicating x2. Here, we use the former approach.
243:      */
244:     PetscCall(VecScatterBegin(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD));
245:     PetscCall(VecScatterEnd(shell->Mvctx, x, shell->gvec, INSERT_VALUES, SCATTER_FORWARD));

247:     PetscCall(VecGetArrayRead(shell->gvec, &vals));
248:     for (i = 0; i < shell->len; i++) shell->larray[i] = vals[shell->location[i]];
249:     PetscCall(VecRestoreArrayRead(shell->gvec, &vals));

251:     for (cur = shell->head, tot = i = 0; cur; cur = cur->next, i++) { /* i-th matrix */
252:       PetscCall(MatMPIAIJGetSeqAIJ(cur->mat, &A, &B, NULL));
253:       PetscUseTypeMethod(A, mult, x, y2);
254:       PetscCall(MatGetLocalSize(B, NULL, &n));
255:       PetscCall(VecPlaceArray(shell->lvecs[i], &shell->larray[tot]));
256:       PetscUseTypeMethod(B, multadd, shell->lvecs[i], y2, y2);
257:       PetscCall(VecResetArray(shell->lvecs[i]));
258:       PetscCall(VecAXPY(y, (shell->scalings ? shell->scalings[i] : 1.0), y2));
259:       tot += n;
260:     }
261:   } else {
262:     if (shell->scalings) {
263:       for (cur = shell->head, i = 0; cur; cur = cur->next, i++) {
264:         PetscCall(MatMult(cur->mat, x, y2));
265:         PetscCall(VecAXPY(y, shell->scalings[i], y2));
266:       }
267:     } else {
268:       for (cur = shell->head; cur; cur = cur->next) PetscCall(MatMultAdd(cur->mat, x, y, y));
269:     }
270:   }
271:   PetscFunctionReturn(PETSC_SUCCESS);
272: }

274: static PetscErrorCode MatMultTranspose_Composite(Mat A, Vec x, Vec y)
275: {
276:   Mat_Composite    *shell;
277:   Mat_CompositeLink next;
278:   Vec               y2 = NULL;
279:   PetscInt          i;

281:   PetscFunctionBegin;
282:   PetscCall(MatShellGetContext(A, &shell));
283:   next = shell->head;
284:   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");

286:   PetscCall(MatMultTranspose(next->mat, x, y));
287:   if (shell->scalings) {
288:     PetscCall(VecScale(y, shell->scalings[0]));
289:     if (!((Mat_Shell *)A->data)->right_work) PetscCall(VecDuplicate(y, &(((Mat_Shell *)A->data)->right_work)));
290:     y2 = ((Mat_Shell *)A->data)->right_work;
291:   }
292:   i = 1;
293:   while ((next = next->next)) {
294:     if (!shell->scalings) PetscCall(MatMultTransposeAdd(next->mat, x, y, y));
295:     else {
296:       PetscCall(MatMultTranspose(next->mat, x, y2));
297:       PetscCall(VecAXPY(y, shell->scalings[i++], y2));
298:     }
299:   }
300:   PetscFunctionReturn(PETSC_SUCCESS);
301: }

303: static PetscErrorCode MatGetDiagonal_Composite(Mat A, Vec v)
304: {
305:   Mat_Composite    *shell;
306:   Mat_CompositeLink next;
307:   PetscInt          i;

309:   PetscFunctionBegin;
310:   PetscCall(MatShellGetContext(A, &shell));
311:   next = shell->head;
312:   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
313:   PetscCall(MatGetDiagonal(next->mat, v));
314:   if (shell->scalings) PetscCall(VecScale(v, shell->scalings[0]));

316:   if (next->next && !shell->work) PetscCall(VecDuplicate(v, &shell->work));
317:   i = 1;
318:   while ((next = next->next)) {
319:     PetscCall(MatGetDiagonal(next->mat, shell->work));
320:     PetscCall(VecAXPY(v, (shell->scalings ? shell->scalings[i++] : 1.0), shell->work));
321:   }
322:   PetscFunctionReturn(PETSC_SUCCESS);
323: }

325: static PetscErrorCode MatAssemblyEnd_Composite(Mat Y, MatAssemblyType t)
326: {
327:   Mat_Composite *shell;

329:   PetscFunctionBegin;
330:   PetscCall(MatShellGetContext(Y, &shell));
331:   if (shell->merge) PetscCall(MatCompositeMerge(Y));
332:   else PetscCall(MatAssemblyEnd_Shell(Y, t));
333:   PetscFunctionReturn(PETSC_SUCCESS);
334: }

336: static PetscErrorCode MatSetFromOptions_Composite(Mat A, PetscOptionItems *PetscOptionsObject)
337: {
338:   Mat_Composite *a;

340:   PetscFunctionBegin;
341:   PetscCall(MatShellGetContext(A, &a));
342:   PetscOptionsHeadBegin(PetscOptionsObject, "MATCOMPOSITE options");
343:   PetscCall(PetscOptionsBool("-mat_composite_merge", "Merge at MatAssemblyEnd", "MatCompositeMerge", a->merge, &a->merge, NULL));
344:   PetscCall(PetscOptionsEnum("-mat_composite_merge_type", "Set composite merge direction", "MatCompositeSetMergeType", MatCompositeMergeTypes, (PetscEnum)a->mergetype, (PetscEnum *)&a->mergetype, NULL));
345:   PetscCall(PetscOptionsBool("-mat_composite_merge_mvctx", "Merge MatMult() vecscat contexts", "MatCreateComposite", a->merge_mvctx, &a->merge_mvctx, NULL));
346:   PetscOptionsHeadEnd();
347:   PetscFunctionReturn(PETSC_SUCCESS);
348: }

350: /*@
351:   MatCreateComposite - Creates a matrix as the sum or product of one or more matrices

353:   Collective

355:   Input Parameters:
356: + comm - MPI communicator
357: . nmat - number of matrices to put in
358: - mats - the matrices

360:   Output Parameter:
361: . mat - the matrix

363:   Options Database Keys:
364: + -mat_composite_merge       - merge in `MatAssemblyEnd()`
365: . -mat_composite_merge_mvctx - merge Mvctx of component matrices to optimize communication in `MatMult()` for ADDITIVE matrices
366: - -mat_composite_merge_type  - set merge direction

368:   Level: advanced

370:   Note:
371:   Alternative construction
372: .vb
373:        MatCreate(comm,&mat);
374:        MatSetSizes(mat,m,n,M,N);
375:        MatSetType(mat,MATCOMPOSITE);
376:        MatCompositeAddMat(mat,mats[0]);
377:        ....
378:        MatCompositeAddMat(mat,mats[nmat-1]);
379:        MatAssemblyBegin(mat,MAT_FINAL_ASSEMBLY);
380:        MatAssemblyEnd(mat,MAT_FINAL_ASSEMBLY);
381: .ve

383:   For the multiplicative form the product is mat[nmat-1]*mat[nmat-2]*....*mat[0]

385: .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCompositeGetMat()`, `MatCompositeMerge()`, `MatCompositeSetType()`,
386:           `MATCOMPOSITE`, `MatCompositeType`
387: @*/
388: PetscErrorCode MatCreateComposite(MPI_Comm comm, PetscInt nmat, const Mat *mats, Mat *mat)
389: {
390:   PetscInt m, n, M, N, i;

392:   PetscFunctionBegin;
393:   PetscCheck(nmat >= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Must pass in at least one matrix");
394:   PetscAssertPointer(mat, 4);

396:   PetscCall(MatGetLocalSize(mats[0], PETSC_IGNORE, &n));
397:   PetscCall(MatGetLocalSize(mats[nmat - 1], &m, PETSC_IGNORE));
398:   PetscCall(MatGetSize(mats[0], PETSC_IGNORE, &N));
399:   PetscCall(MatGetSize(mats[nmat - 1], &M, PETSC_IGNORE));
400:   PetscCall(MatCreate(comm, mat));
401:   PetscCall(MatSetSizes(*mat, m, n, M, N));
402:   PetscCall(MatSetType(*mat, MATCOMPOSITE));
403:   for (i = 0; i < nmat; i++) PetscCall(MatCompositeAddMat(*mat, mats[i]));
404:   PetscCall(MatAssemblyBegin(*mat, MAT_FINAL_ASSEMBLY));
405:   PetscCall(MatAssemblyEnd(*mat, MAT_FINAL_ASSEMBLY));
406:   PetscFunctionReturn(PETSC_SUCCESS);
407: }

409: static PetscErrorCode MatCompositeAddMat_Composite(Mat mat, Mat smat)
410: {
411:   Mat_Composite    *shell;
412:   Mat_CompositeLink ilink, next;
413:   VecType           vtype_mat, vtype_smat;
414:   PetscBool         match;

416:   PetscFunctionBegin;
417:   PetscCall(MatShellGetContext(mat, &shell));
418:   next = shell->head;
419:   PetscCall(PetscNew(&ilink));
420:   ilink->next = NULL;
421:   PetscCall(PetscObjectReference((PetscObject)smat));
422:   ilink->mat = smat;

424:   if (!next) shell->head = ilink;
425:   else {
426:     while (next->next) next = next->next;
427:     next->next  = ilink;
428:     ilink->prev = next;
429:   }
430:   shell->tail = ilink;
431:   shell->nmat += 1;

433:   /* If all of the partial matrices have the same default vector type, then the composite matrix should also have this default type.
434:      Otherwise, the default type should be "standard". */
435:   PetscCall(MatGetVecType(smat, &vtype_smat));
436:   if (shell->nmat == 1) PetscCall(MatSetVecType(mat, vtype_smat));
437:   else {
438:     PetscCall(MatGetVecType(mat, &vtype_mat));
439:     PetscCall(PetscStrcmp(vtype_smat, vtype_mat, &match));
440:     if (!match) PetscCall(MatSetVecType(mat, VECSTANDARD));
441:   }

443:   /* Retain the old scalings (if any) and expand it with a 1.0 for the newly added matrix */
444:   if (shell->scalings) {
445:     PetscCall(PetscRealloc(sizeof(PetscScalar) * shell->nmat, &shell->scalings));
446:     shell->scalings[shell->nmat - 1] = 1.0;
447:   }
448:   PetscFunctionReturn(PETSC_SUCCESS);
449: }

451: /*@
452:   MatCompositeAddMat - Add another matrix to a composite matrix.

454:   Collective

456:   Input Parameters:
457: + mat  - the composite matrix
458: - smat - the partial matrix

460:   Level: advanced

462: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE`
463: @*/
464: PetscErrorCode MatCompositeAddMat(Mat mat, Mat smat)
465: {
466:   PetscFunctionBegin;
469:   PetscUseMethod(mat, "MatCompositeAddMat_C", (Mat, Mat), (mat, smat));
470:   PetscFunctionReturn(PETSC_SUCCESS);
471: }

473: static PetscErrorCode MatCompositeSetType_Composite(Mat mat, MatCompositeType type)
474: {
475:   Mat_Composite *b;

477:   PetscFunctionBegin;
478:   PetscCall(MatShellGetContext(mat, &b));
479:   b->type = type;
480:   if (type == MAT_COMPOSITE_MULTIPLICATIVE) {
481:     PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, NULL));
482:     PetscCall(MatShellSetOperation(mat, MATOP_MULT, (void (*)(void))MatMult_Composite_Multiplicative));
483:     PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite_Multiplicative));
484:     b->merge_mvctx = PETSC_FALSE;
485:   } else {
486:     PetscCall(MatShellSetOperation(mat, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Composite));
487:     PetscCall(MatShellSetOperation(mat, MATOP_MULT, (void (*)(void))MatMult_Composite));
488:     PetscCall(MatShellSetOperation(mat, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite));
489:   }
490:   PetscFunctionReturn(PETSC_SUCCESS);
491: }

493: /*@
494:   MatCompositeSetType - Indicates if the matrix is defined as the sum of a set of matrices or the product.

496:   Logically Collective

498:   Input Parameters:
499: + mat  - the composite matrix
500: - type - the `MatCompositeType` to use for the matrix

502:   Level: advanced

504: .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeGetType()`, `MATCOMPOSITE`,
505:           `MatCompositeType`
506: @*/
507: PetscErrorCode MatCompositeSetType(Mat mat, MatCompositeType type)
508: {
509:   PetscFunctionBegin;
512:   PetscUseMethod(mat, "MatCompositeSetType_C", (Mat, MatCompositeType), (mat, type));
513:   PetscFunctionReturn(PETSC_SUCCESS);
514: }

516: static PetscErrorCode MatCompositeGetType_Composite(Mat mat, MatCompositeType *type)
517: {
518:   Mat_Composite *shell;

520:   PetscFunctionBegin;
521:   PetscCall(MatShellGetContext(mat, &shell));
522:   *type = shell->type;
523:   PetscFunctionReturn(PETSC_SUCCESS);
524: }

526: /*@
527:   MatCompositeGetType - Returns type of composite.

529:   Not Collective

531:   Input Parameter:
532: . mat - the composite matrix

534:   Output Parameter:
535: . type - type of composite

537:   Level: advanced

539: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetType()`, `MATCOMPOSITE`, `MatCompositeType`
540: @*/
541: PetscErrorCode MatCompositeGetType(Mat mat, MatCompositeType *type)
542: {
543:   PetscFunctionBegin;
545:   PetscAssertPointer(type, 2);
546:   PetscUseMethod(mat, "MatCompositeGetType_C", (Mat, MatCompositeType *), (mat, type));
547:   PetscFunctionReturn(PETSC_SUCCESS);
548: }

550: static PetscErrorCode MatCompositeSetMatStructure_Composite(Mat mat, MatStructure str)
551: {
552:   Mat_Composite *shell;

554:   PetscFunctionBegin;
555:   PetscCall(MatShellGetContext(mat, &shell));
556:   shell->structure = str;
557:   PetscFunctionReturn(PETSC_SUCCESS);
558: }

560: /*@
561:   MatCompositeSetMatStructure - Indicates structure of matrices in the composite matrix.

563:   Not Collective

565:   Input Parameters:
566: + mat - the composite matrix
567: - str - either `SAME_NONZERO_PATTERN`, `DIFFERENT_NONZERO_PATTERN` (default) or `SUBSET_NONZERO_PATTERN`

569:   Level: advanced

571:   Note:
572:   Information about the matrices structure is used in `MatCompositeMerge()` for additive composite matrix.

574: .seealso: [](ch_matrices), `Mat`, `MatAXPY()`, `MatCreateComposite()`, `MatCompositeMerge()` `MatCompositeGetMatStructure()`, `MATCOMPOSITE`
575: @*/
576: PetscErrorCode MatCompositeSetMatStructure(Mat mat, MatStructure str)
577: {
578:   PetscFunctionBegin;
580:   PetscUseMethod(mat, "MatCompositeSetMatStructure_C", (Mat, MatStructure), (mat, str));
581:   PetscFunctionReturn(PETSC_SUCCESS);
582: }

584: static PetscErrorCode MatCompositeGetMatStructure_Composite(Mat mat, MatStructure *str)
585: {
586:   Mat_Composite *shell;

588:   PetscFunctionBegin;
589:   PetscCall(MatShellGetContext(mat, &shell));
590:   *str = shell->structure;
591:   PetscFunctionReturn(PETSC_SUCCESS);
592: }

594: /*@
595:   MatCompositeGetMatStructure - Returns the structure of matrices in the composite matrix.

597:   Not Collective

599:   Input Parameter:
600: . mat - the composite matrix

602:   Output Parameter:
603: . str - structure of the matrices

605:   Level: advanced

607: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MATCOMPOSITE`
608: @*/
609: PetscErrorCode MatCompositeGetMatStructure(Mat mat, MatStructure *str)
610: {
611:   PetscFunctionBegin;
613:   PetscAssertPointer(str, 2);
614:   PetscUseMethod(mat, "MatCompositeGetMatStructure_C", (Mat, MatStructure *), (mat, str));
615:   PetscFunctionReturn(PETSC_SUCCESS);
616: }

618: static PetscErrorCode MatCompositeSetMergeType_Composite(Mat mat, MatCompositeMergeType type)
619: {
620:   Mat_Composite *shell;

622:   PetscFunctionBegin;
623:   PetscCall(MatShellGetContext(mat, &shell));
624:   shell->mergetype = type;
625:   PetscFunctionReturn(PETSC_SUCCESS);
626: }

628: /*@
629:   MatCompositeSetMergeType - Sets order of `MatCompositeMerge()`.

631:   Logically Collective

633:   Input Parameters:
634: + mat  - the composite matrix
635: - type - `MAT_COMPOSITE_MERGE RIGHT` (default) to start merge from right with the first added matrix (mat[0]),
636:           `MAT_COMPOSITE_MERGE_LEFT` to start merge from left with the last added matrix (mat[nmat-1])

638:   Level: advanced

640:   Note:
641:   The resulting matrix is the same regardless of the `MatCompositeMergeType`. Only the order of operation is changed.
642:   If set to `MAT_COMPOSITE_MERGE_RIGHT` the order of the merge is mat[nmat-1]*(mat[nmat-2]*(...*(mat[1]*mat[0])))
643:   otherwise the order is (((mat[nmat-1]*mat[nmat-2])*mat[nmat-3])*...)*mat[0].

645: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeMerge()`, `MATCOMPOSITE`
646: @*/
647: PetscErrorCode MatCompositeSetMergeType(Mat mat, MatCompositeMergeType type)
648: {
649:   PetscFunctionBegin;
652:   PetscUseMethod(mat, "MatCompositeSetMergeType_C", (Mat, MatCompositeMergeType), (mat, type));
653:   PetscFunctionReturn(PETSC_SUCCESS);
654: }

656: static PetscErrorCode MatCompositeMerge_Composite(Mat mat)
657: {
658:   Mat_Composite    *shell;
659:   Mat_CompositeLink next, prev;
660:   Mat               tmat, newmat;
661:   Vec               left, right, dshift;
662:   PetscScalar       scale, shift;
663:   PetscInt          i;

665:   PetscFunctionBegin;
666:   PetscCall(MatShellGetContext(mat, &shell));
667:   next = shell->head;
668:   prev = shell->tail;
669:   PetscCheck(next, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "Must provide at least one matrix with MatCompositeAddMat()");
670:   PetscCheck(!((Mat_Shell *)mat->data)->zrows && !((Mat_Shell *)mat->data)->zcols, PetscObjectComm((PetscObject)mat), PETSC_ERR_SUP, "Cannot call MatCompositeMerge() if MatZeroRows() or MatZeroRowsColumns() has been called on the input Mat"); // TODO FIXME: lift this limitation by calling MatZeroRows()/MatZeroRowsColumns() after the merge
671:   PetscCheck(!((Mat_Shell *)mat->data)->axpy, PetscObjectComm((PetscObject)mat), PETSC_ERR_SUP, "Cannot call MatCompositeMerge() if MatAXPY() has been called on the input Mat"); // TODO FIXME: lift this limitation by calling MatAXPY() after the merge
672:   scale = ((Mat_Shell *)mat->data)->vscale;
673:   shift = ((Mat_Shell *)mat->data)->vshift;
674:   if (shell->type == MAT_COMPOSITE_ADDITIVE) {
675:     if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) {
676:       i = 0;
677:       PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat));
678:       if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i++]));
679:       while ((next = next->next)) PetscCall(MatAXPY(tmat, (shell->scalings ? shell->scalings[i++] : 1.0), next->mat, shell->structure));
680:     } else {
681:       i = shell->nmat - 1;
682:       PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat));
683:       if (shell->scalings) PetscCall(MatScale(tmat, shell->scalings[i--]));
684:       while ((prev = prev->prev)) PetscCall(MatAXPY(tmat, (shell->scalings ? shell->scalings[i--] : 1.0), prev->mat, shell->structure));
685:     }
686:   } else {
687:     if (shell->mergetype == MAT_COMPOSITE_MERGE_RIGHT) {
688:       PetscCall(MatDuplicate(next->mat, MAT_COPY_VALUES, &tmat));
689:       while ((next = next->next)) {
690:         PetscCall(MatMatMult(next->mat, tmat, MAT_INITIAL_MATRIX, PETSC_DECIDE, &newmat));
691:         PetscCall(MatDestroy(&tmat));
692:         tmat = newmat;
693:       }
694:     } else {
695:       PetscCall(MatDuplicate(prev->mat, MAT_COPY_VALUES, &tmat));
696:       while ((prev = prev->prev)) {
697:         PetscCall(MatMatMult(tmat, prev->mat, MAT_INITIAL_MATRIX, PETSC_DECIDE, &newmat));
698:         PetscCall(MatDestroy(&tmat));
699:         tmat = newmat;
700:       }
701:     }
702:     if (shell->scalings) {
703:       for (i = 0; i < shell->nmat; i++) scale *= shell->scalings[i];
704:     }
705:   }

707:   if ((left = ((Mat_Shell *)mat->data)->left)) PetscCall(PetscObjectReference((PetscObject)left));
708:   if ((right = ((Mat_Shell *)mat->data)->right)) PetscCall(PetscObjectReference((PetscObject)right));
709:   if ((dshift = ((Mat_Shell *)mat->data)->dshift)) PetscCall(PetscObjectReference((PetscObject)dshift));

711:   PetscCall(MatHeaderReplace(mat, &tmat));

713:   PetscCall(MatDiagonalScale(mat, left, right));
714:   PetscCall(MatScale(mat, scale));
715:   PetscCall(MatShift(mat, shift));
716:   PetscCall(VecDestroy(&left));
717:   PetscCall(VecDestroy(&right));
718:   if (dshift) {
719:     PetscCall(MatDiagonalSet(mat, dshift, ADD_VALUES));
720:     PetscCall(VecDestroy(&dshift));
721:   }
722:   PetscFunctionReturn(PETSC_SUCCESS);
723: }

725: /*@
726:   MatCompositeMerge - Given a composite matrix, replaces it with a "regular" matrix
727:   by summing or computing the product of all the matrices inside the composite matrix.

729:   Collective

731:   Input Parameter:
732: . mat - the composite matrix

734:   Options Database Keys:
735: + -mat_composite_merge      - merge in `MatAssemblyEnd()`
736: - -mat_composite_merge_type - set merge direction

738:   Level: advanced

740:   Note:
741:   The `MatType` of the resulting matrix will be the same as the `MatType` of the FIRST matrix in the composite matrix.

743: .seealso: [](ch_matrices), `Mat`, `MatDestroy()`, `MatMult()`, `MatCompositeAddMat()`, `MatCreateComposite()`, `MatCompositeSetMatStructure()`, `MatCompositeSetMergeType()`, `MATCOMPOSITE`
744: @*/
745: PetscErrorCode MatCompositeMerge(Mat mat)
746: {
747:   PetscFunctionBegin;
749:   PetscUseMethod(mat, "MatCompositeMerge_C", (Mat), (mat));
750:   PetscFunctionReturn(PETSC_SUCCESS);
751: }

753: static PetscErrorCode MatCompositeGetNumberMat_Composite(Mat mat, PetscInt *nmat)
754: {
755:   Mat_Composite *shell;

757:   PetscFunctionBegin;
758:   PetscCall(MatShellGetContext(mat, &shell));
759:   *nmat = shell->nmat;
760:   PetscFunctionReturn(PETSC_SUCCESS);
761: }

763: /*@
764:   MatCompositeGetNumberMat - Returns the number of matrices in the composite matrix.

766:   Not Collective

768:   Input Parameter:
769: . mat - the composite matrix

771:   Output Parameter:
772: . nmat - number of matrices in the composite matrix

774:   Level: advanced

776: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetMat()`, `MATCOMPOSITE`
777: @*/
778: PetscErrorCode MatCompositeGetNumberMat(Mat mat, PetscInt *nmat)
779: {
780:   PetscFunctionBegin;
782:   PetscAssertPointer(nmat, 2);
783:   PetscUseMethod(mat, "MatCompositeGetNumberMat_C", (Mat, PetscInt *), (mat, nmat));
784:   PetscFunctionReturn(PETSC_SUCCESS);
785: }

787: static PetscErrorCode MatCompositeGetMat_Composite(Mat mat, PetscInt i, Mat *Ai)
788: {
789:   Mat_Composite    *shell;
790:   Mat_CompositeLink ilink;
791:   PetscInt          k;

793:   PetscFunctionBegin;
794:   PetscCall(MatShellGetContext(mat, &shell));
795:   PetscCheck(i < shell->nmat, PetscObjectComm((PetscObject)mat), PETSC_ERR_ARG_OUTOFRANGE, "index out of range: %" PetscInt_FMT " >= %" PetscInt_FMT, i, shell->nmat);
796:   ilink = shell->head;
797:   for (k = 0; k < i; k++) ilink = ilink->next;
798:   *Ai = ilink->mat;
799:   PetscFunctionReturn(PETSC_SUCCESS);
800: }

802: /*@
803:   MatCompositeGetMat - Returns the ith matrix from the composite matrix.

805:   Logically Collective

807:   Input Parameters:
808: + mat - the composite matrix
809: - i   - the number of requested matrix

811:   Output Parameter:
812: . Ai - ith matrix in composite

814:   Level: advanced

816: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeGetNumberMat()`, `MatCompositeAddMat()`, `MATCOMPOSITE`
817: @*/
818: PetscErrorCode MatCompositeGetMat(Mat mat, PetscInt i, Mat *Ai)
819: {
820:   PetscFunctionBegin;
823:   PetscAssertPointer(Ai, 3);
824:   PetscUseMethod(mat, "MatCompositeGetMat_C", (Mat, PetscInt, Mat *), (mat, i, Ai));
825:   PetscFunctionReturn(PETSC_SUCCESS);
826: }

828: static PetscErrorCode MatCompositeSetScalings_Composite(Mat mat, const PetscScalar *scalings)
829: {
830:   Mat_Composite *shell;
831:   PetscInt       nmat;

833:   PetscFunctionBegin;
834:   PetscCall(MatShellGetContext(mat, &shell));
835:   PetscCall(MatCompositeGetNumberMat(mat, &nmat));
836:   if (!shell->scalings) PetscCall(PetscMalloc1(nmat, &shell->scalings));
837:   PetscCall(PetscArraycpy(shell->scalings, scalings, nmat));
838:   PetscFunctionReturn(PETSC_SUCCESS);
839: }

841: /*@
842:   MatCompositeSetScalings - Sets separate scaling factors for component matrices.

844:   Logically Collective

846:   Input Parameters:
847: + mat      - the composite matrix
848: - scalings - array of scaling factors with scalings[i] being factor of i-th matrix, for i in [0, nmat)

850:   Level: advanced

852: .seealso: [](ch_matrices), `Mat`, `MatScale()`, `MatDiagonalScale()`, `MATCOMPOSITE`
853: @*/
854: PetscErrorCode MatCompositeSetScalings(Mat mat, const PetscScalar *scalings)
855: {
856:   PetscFunctionBegin;
858:   PetscAssertPointer(scalings, 2);
860:   PetscUseMethod(mat, "MatCompositeSetScalings_C", (Mat, const PetscScalar *), (mat, scalings));
861:   PetscFunctionReturn(PETSC_SUCCESS);
862: }

864: /*MC
865:    MATCOMPOSITE - A matrix defined by the sum (or product) of one or more matrices.
866:     The matrices need to have a correct size and parallel layout for the sum or product to be valid.

868:   Level: advanced

870:    Note:
871:    To use the product of the matrices call `MatCompositeSetType`(mat,`MAT_COMPOSITE_MULTIPLICATIVE`);

873:   Developer Notes:
874:   This is implemented on top of `MATSHELL` to get support for scaling and shifting without requiring duplicate code

876:   Users can not call `MatShellSetOperation()` operations on this class, there is some error checking for that incorrect usage

878: .seealso: [](ch_matrices), `Mat`, `MatCreateComposite()`, `MatCompositeSetScalings()`, `MatCompositeAddMat()`, `MatSetType()`, `MatCompositeSetType()`, `MatCompositeGetType()`,
879:           `MatCompositeSetMatStructure()`, `MatCompositeGetMatStructure()`, `MatCompositeMerge()`, `MatCompositeSetMergeType()`, `MatCompositeGetNumberMat()`, `MatCompositeGetMat()`
880: M*/

882: PETSC_EXTERN PetscErrorCode MatCreate_Composite(Mat A)
883: {
884:   Mat_Composite *b;

886:   PetscFunctionBegin;
887:   PetscCall(PetscNew(&b));

889:   b->type        = MAT_COMPOSITE_ADDITIVE;
890:   b->nmat        = 0;
891:   b->merge       = PETSC_FALSE;
892:   b->mergetype   = MAT_COMPOSITE_MERGE_RIGHT;
893:   b->structure   = DIFFERENT_NONZERO_PATTERN;
894:   b->merge_mvctx = PETSC_TRUE;

896:   PetscCall(MatSetType(A, MATSHELL));
897:   PetscCall(MatShellSetContext(A, b));
898:   PetscCall(MatShellSetOperation(A, MATOP_DESTROY, (void (*)(void))MatDestroy_Composite));
899:   PetscCall(MatShellSetOperation(A, MATOP_MULT, (void (*)(void))MatMult_Composite));
900:   PetscCall(MatShellSetOperation(A, MATOP_MULT_TRANSPOSE, (void (*)(void))MatMultTranspose_Composite));
901:   PetscCall(MatShellSetOperation(A, MATOP_GET_DIAGONAL, (void (*)(void))MatGetDiagonal_Composite));
902:   PetscCall(MatShellSetOperation(A, MATOP_ASSEMBLY_END, (void (*)(void))MatAssemblyEnd_Composite));
903:   PetscCall(MatShellSetOperation(A, MATOP_SET_FROM_OPTIONS, (void (*)(void))MatSetFromOptions_Composite));
904:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeAddMat_C", MatCompositeAddMat_Composite));
905:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetType_C", MatCompositeSetType_Composite));
906:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetType_C", MatCompositeGetType_Composite));
907:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMergeType_C", MatCompositeSetMergeType_Composite));
908:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetMatStructure_C", MatCompositeSetMatStructure_Composite));
909:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMatStructure_C", MatCompositeGetMatStructure_Composite));
910:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeMerge_C", MatCompositeMerge_Composite));
911:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetNumberMat_C", MatCompositeGetNumberMat_Composite));
912:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeGetMat_C", MatCompositeGetMat_Composite));
913:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatCompositeSetScalings_C", MatCompositeSetScalings_Composite));
914:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContext_C", MatShellSetContext_Immutable));
915:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetContextDestroy_C", MatShellSetContextDestroy_Immutable));
916:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatShellSetManageScalingShifts_C", MatShellSetManageScalingShifts_Immutable));
917:   PetscCall(PetscObjectChangeTypeName((PetscObject)A, MATCOMPOSITE));
918:   PetscFunctionReturn(PETSC_SUCCESS);
919: }