Actual source code: submat.c

petsc-3.11.3 2019-06-26
Report Typos and Errors

  2:  #include <petsc/private/matimpl.h>

  4: typedef struct {
  5:   IS          isrow,iscol;      /* rows and columns in submatrix, only used to check consistency */
  6:   Vec         left,right;       /* optional scaling */
  7:   Vec         olwork,orwork;    /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
  8:   Vec         lwork,rwork;      /* work vectors inside the scatters */
  9:   Vec         dshift;
 10:   VecScatter  lrestrict,rprolong;
 11:   Mat         A;
 12:   PetscScalar vscale, axpy_vscale;
 13:   PetscScalar vshift, axpy_vshift;
 14: } Mat_SubVirtual;

 16: static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
 17: {
 18:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

 22:   if (!Na->left) {
 23:     *xx = x;
 24:   } else {
 25:     if (!Na->olwork) {
 26:       VecDuplicate(Na->left,&Na->olwork);
 27:     }
 28:     VecPointwiseMult(Na->olwork,x,Na->left);
 29:     *xx  = Na->olwork;
 30:   }
 31:   return(0);
 32: }

 34: static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
 35: {
 36:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

 40:   if (!Na->right) {
 41:     *xx = x;
 42:   } else {
 43:     if (!Na->orwork) {
 44:       VecDuplicate(Na->right,&Na->orwork);
 45:     }
 46:     VecPointwiseMult(Na->orwork,x,Na->right);
 47:     *xx  = Na->orwork;
 48:   }
 49:   return(0);
 50: }

 52: static PetscErrorCode PostScaleLeft(Mat N,Vec x)
 53: {
 54:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

 58:   if (Na->left) {
 59:     VecPointwiseMult(x,x,Na->left);
 60:   }
 61:   return(0);
 62: }

 64: static PetscErrorCode PostScaleRight(Mat N,Vec x)
 65: {
 66:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

 70:   if (Na->right) {
 71:     VecPointwiseMult(x,x,Na->right);
 72:   }
 73:   return(0);
 74: }

 76: /*
 77:          Y = vscale*Y + diag(dshift)*X + vshift*X

 79:          On input Y already contains A*x
 80: */
 81: static PetscErrorCode MatSubmatShiftAndScale(Mat A,Vec X,Vec Y)
 82: {
 83:   Mat_SubVirtual *Na = (Mat_SubVirtual*)A->data;

 87:   if (Na->dshift) {          /* get arrays because there is no VecPointwiseMultAdd() */
 88:     PetscInt          i,m;
 89:     const PetscScalar *x,*d;
 90:     PetscScalar       *y;
 91:     VecGetLocalSize(X,&m);
 92:     VecGetArrayRead(Na->dshift,&d);
 93:     VecGetArrayRead(X,&x);
 94:     VecGetArray(Y,&y);
 95:     for (i=0; i<m; i++) y[i] = Na->vscale*y[i] + d[i]*x[i];
 96:     VecRestoreArrayRead(Na->dshift,&d);
 97:     VecRestoreArrayRead(X,&x);
 98:     VecRestoreArray(Y,&y);
 99:   } else {
100:     VecScale(Y,Na->vscale);
101:   }
102:   if (Na->vshift != 0.0) {VecAXPY(Y,Na->vshift,X);} /* if test is for non-square matrices */
103:   return(0);
104: }

106: static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar a)
107: {
108:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

112:   Na->vscale *= a;
113:   Na->vshift *= a;
114:   if (Na->dshift) {
115:     VecScale(Na->dshift,a);
116:   }
117:   Na->axpy_vscale *= a;
118:   return(0);
119: }

121: static PetscErrorCode MatShift_SubMatrix(Mat N,PetscScalar a)
122: {
123:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

127:   if (Na->left || Na->right) {
128:     if (!Na->dshift) {
129:       VecDuplicate(Na->left ? Na->left : Na->right, &Na->dshift);
130:       VecSet(Na->dshift,a);
131:     } else {
132:       if (Na->left)  {VecPointwiseMult(Na->dshift,Na->dshift,Na->left);}
133:       if (Na->right) {VecPointwiseMult(Na->dshift,Na->dshift,Na->right);}
134:       VecShift(Na->dshift,a);
135:     }
136:     if (Na->left)  {VecPointwiseDivide(Na->dshift,Na->dshift,Na->left);}
137:     if (Na->right) {VecPointwiseDivide(Na->dshift,Na->dshift,Na->right);}
138:   } else Na->vshift += a;
139:   return(0);
140: }

142: static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
143: {
144:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

148:   if (left) {
149:     if (!Na->left) {
150:       VecDuplicate(left,&Na->left);
151:       VecCopy(left,Na->left);
152:     } else {
153:       VecPointwiseMult(Na->left,left,Na->left);
154:     }
155:   }
156:   if (right) {
157:     if (!Na->right) {
158:       VecDuplicate(right,&Na->right);
159:       VecCopy(right,Na->right);
160:     } else {
161:       VecPointwiseMult(Na->right,right,Na->right);
162:     }
163:   }
164:   return(0);
165: }

167: static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
168: {
169:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
170:   Vec            xx  = 0;

174:   PreScaleRight(N,x,&xx);
175:   VecZeroEntries(Na->rwork);
176:   VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
177:   VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
178:   MatMult(Na->A,Na->rwork,Na->lwork);
179:   VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);
180:   VecScatterEnd  (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);
181:   MatSubmatShiftAndScale(N,xx,y);
182:   PostScaleLeft(N,y);
183:   return(0);
184: }

186: static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
187: {
188:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
189:   Vec            xx  = 0;

193:   PreScaleRight(N,v1,&xx);
194:   VecZeroEntries(Na->rwork);
195:   VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
196:   VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
197:   MatMult(Na->A,Na->rwork,Na->lwork);
198:   if (v2 == v3) {
199:     if (!Na->olwork) {VecDuplicate(v3,&Na->olwork);}
200:     VecScatterBegin(Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);
201:     VecScatterEnd  (Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);
202:     MatSubmatShiftAndScale(N,xx,Na->olwork);
203:     PostScaleLeft(N,Na->olwork);
204:     VecAXPY(v3,1.0,Na->olwork);
205:   } else {
206:     VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);
207:     VecScatterEnd  (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);
208:     MatSubmatShiftAndScale(N,xx,v3);
209:     PostScaleLeft(N,v3);
210:     VecAXPY(v3,1.0,v2);
211:   }
212:   return(0);
213: }

215: static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
216: {
217:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
218:   Vec            xx  = 0;

222:   PreScaleLeft(N,x,&xx);
223:   VecZeroEntries(Na->lwork);
224:   VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
225:   VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
226:   MatMultTranspose(Na->A,Na->lwork,Na->rwork);
227:   VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);
228:   VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);
229:   MatSubmatShiftAndScale(N,xx,y);
230:   PostScaleRight(N,y);
231:   return(0);
232: }

234: static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
235: {
236:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;
237:   Vec            xx  = 0;

241:   PreScaleLeft(N,v1,&xx);
242:   VecZeroEntries(Na->lwork);
243:   VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
244:   VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
245:   MatMultTranspose(Na->A,Na->lwork,Na->rwork);
246:   if (v2 == v3) {
247:     if (!Na->orwork) {VecDuplicate(v3,&Na->orwork);}
248:     VecScatterBegin(Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);
249:     VecScatterEnd  (Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);
250:     MatSubmatShiftAndScale(N,xx,Na->orwork);
251:     PostScaleRight(N,Na->orwork);
252:     VecAXPY(v3,1.0,Na->orwork);
253:   } else {
254:     VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);
255:     VecScatterEnd  (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);
256:     MatSubmatShiftAndScale(N,xx,v3);
257:     PostScaleRight(N,v3);
258:     VecAXPY(v3,1.0,v2);
259:   }
260:   return(0);
261: }

263: static PetscErrorCode MatDestroy_SubMatrix(Mat N)
264: {
265:   Mat_SubVirtual *Na = (Mat_SubVirtual*)N->data;

269:   ISDestroy(&Na->isrow);
270:   ISDestroy(&Na->iscol);
271:   VecDestroy(&Na->left);
272:   VecDestroy(&Na->right);
273:   VecDestroy(&Na->olwork);
274:   VecDestroy(&Na->orwork);
275:   VecDestroy(&Na->lwork);
276:   VecDestroy(&Na->rwork);
277:   VecDestroy(&Na->dshift);
278:   VecScatterDestroy(&Na->lrestrict);
279:   VecScatterDestroy(&Na->rprolong);
280:   MatDestroy(&Na->A);
281:   PetscFree(N->data);
282:   return(0);
283: }

285: /*@
286:    MatCreateSubMatrixVirtual - Creates a virtual matrix that acts as a submatrix

288:    Collective on Mat

290:    Input Parameters:
291: +  A - matrix that we will extract a submatrix of
292: .  isrow - rows to be present in the submatrix
293: -  iscol - columns to be present in the submatrix

295:    Output Parameters:
296: .  newmat - new matrix

298:    Level: developer

300:    Notes:
301:    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.

303: .seealso: MatCreateSubMatrix(), MatSubMatrixVirtualUpdate()
304: @*/
305: PetscErrorCode MatCreateSubMatrixVirtual(Mat A,IS isrow,IS iscol,Mat *newmat)
306: {
307:   Vec            left,right;
308:   PetscInt       m,n;
309:   Mat            N;
310:   Mat_SubVirtual *Na;

318:   *newmat = 0;

320:   MatCreate(PetscObjectComm((PetscObject)A),&N);
321:   ISGetLocalSize(isrow,&m);
322:   ISGetLocalSize(iscol,&n);
323:   MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);
324:   PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);

326:   PetscNewLog(N,&Na);
327:   N->data   = (void*)Na;
328:   PetscObjectReference((PetscObject)A);
329:   PetscObjectReference((PetscObject)isrow);
330:   PetscObjectReference((PetscObject)iscol);
331:   Na->A     = A;
332:   Na->isrow = isrow;
333:   Na->iscol = iscol;
334:   Na->vscale = 1.0;
335:   Na->vshift = 0.0;

337:   N->ops->destroy          = MatDestroy_SubMatrix;
338:   N->ops->mult             = MatMult_SubMatrix;
339:   N->ops->multadd          = MatMultAdd_SubMatrix;
340:   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
341:   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
342:   N->ops->scale            = MatScale_SubMatrix;
343:   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;
344:   N->ops->shift            = MatShift_SubMatrix;

346:   MatSetBlockSizesFromMats(N,A,A);
347:   PetscLayoutSetUp(N->rmap);
348:   PetscLayoutSetUp(N->cmap);

350:   MatCreateVecs(A,&Na->rwork,&Na->lwork);
351:   VecCreate(PetscObjectComm((PetscObject)isrow),&left);
352:   VecCreate(PetscObjectComm((PetscObject)iscol),&right);
353:   VecSetSizes(left,m,PETSC_DETERMINE);
354:   VecSetSizes(right,n,PETSC_DETERMINE);
355:   VecSetUp(left);
356:   VecSetUp(right);
357:   VecScatterCreate(Na->lwork,isrow,left,NULL,&Na->lrestrict);
358:   VecScatterCreate(right,NULL,Na->rwork,iscol,&Na->rprolong);
359:   VecDestroy(&left);
360:   VecDestroy(&right);

362:   N->assembled = PETSC_TRUE;

364:   MatSetUp(N);

366:   *newmat      = N;
367:   return(0);
368: }


371: /*@
372:    MatSubMatrixVirtualUpdate - Updates a submatrix

374:    Collective on Mat

376:    Input Parameters:
377: +  N - submatrix to update
378: .  A - full matrix in the submatrix
379: .  isrow - rows in the update (same as the first time the submatrix was created)
380: -  iscol - columns in the update (same as the first time the submatrix was created)

382:    Level: developer

384:    Notes:
385:    Most will use MatCreateSubMatrix which provides a more efficient representation if it is available.

387: .seealso: MatCreateSubMatrixVirtual()
388: @*/
389: PetscErrorCode  MatSubMatrixVirtualUpdate(Mat N,Mat A,IS isrow,IS iscol)
390: {
392:   PetscBool      flg;
393:   Mat_SubVirtual *Na;

400:   PetscObjectTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);
401:   if (!flg) SETERRQ(PetscObjectComm((PetscObject)A),PETSC_ERR_ARG_WRONG,"Matrix has wrong type");

403:   Na   = (Mat_SubVirtual*)N->data;
404:   ISEqual(isrow,Na->isrow,&flg);
405:   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
406:   ISEqual(iscol,Na->iscol,&flg);
407:   if (!flg) SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");

409:   PetscObjectReference((PetscObject)A);
410:   MatDestroy(&Na->A);
411:   Na->A = A;

413:   Na->vshift = 0.0;
414:   Na->vscale = 1.0;
415:   VecDestroy(&Na->left);
416:   VecDestroy(&Na->right);
417:   return(0);
418: }