Actual source code: submat.c

  1: #define PETSCMAT_DLL

 3:  #include private/matimpl.h

  5: typedef struct {
  6:   IS isrow,iscol;               /* rows and columns in submatrix, only used to check consistency */
  7:   Vec left,right;               /* optional scaling */
  8:   Vec olwork,orwork;            /* work vectors outside the scatters, only touched by PreScale and only created if needed*/
  9:   Vec lwork,rwork;              /* work vectors inside the scatters */
 10:   VecScatter lrestrict,rprolong;
 11:   Mat A;
 12:   PetscScalar scale;
 13: } Mat_SubMatrix;

 17: static PetscErrorCode PreScaleLeft(Mat N,Vec x,Vec *xx)
 18: {
 19:   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;

 23:   if (!Na->left) {
 24:     *xx = x;
 25:   } else {
 26:     if (!Na->olwork) {
 27:       VecDuplicate(Na->left,&Na->olwork);
 28:     }
 29:     VecPointwiseMult(Na->olwork,x,Na->left);
 30:     *xx = Na->olwork;
 31:   }
 32:   return(0);
 33: }

 37: static PetscErrorCode PreScaleRight(Mat N,Vec x,Vec *xx)
 38: {
 39:   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;

 43:   if (!Na->right) {
 44:     *xx = x;
 45:   } else {
 46:     if (!Na->orwork) {
 47:       VecDuplicate(Na->right,&Na->orwork);
 48:     }
 49:     VecPointwiseMult(Na->orwork,x,Na->right);
 50:     *xx = Na->orwork;
 51:   }
 52:   return(0);
 53: }

 57: static PetscErrorCode PostScaleLeft(Mat N,Vec x)
 58: {
 59:   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;

 63:   if (Na->left) {
 64:     VecPointwiseMult(x,x,Na->left);
 65:   }
 66:   return(0);
 67: }

 71: static PetscErrorCode PostScaleRight(Mat N,Vec x)
 72: {
 73:   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;

 77:   if (Na->right) {
 78:     VecPointwiseMult(x,x,Na->right);
 79:   }
 80:   return(0);
 81: }

 85: static PetscErrorCode MatScale_SubMatrix(Mat N,PetscScalar scale)
 86: {
 87:   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;

 90:   Na->scale *= scale;
 91:   return(0);
 92: }

 96: static PetscErrorCode MatDiagonalScale_SubMatrix(Mat N,Vec left,Vec right)
 97: {
 98:   Mat_SubMatrix *Na = (Mat_SubMatrix*)N->data;

102:   if (left) {
103:     if (!Na->left) {
104:       VecDuplicate(left,&Na->left);
105:       VecCopy(left,Na->left);
106:     } else {
107:       VecPointwiseMult(Na->left,left,Na->left);
108:     }
109:   }
110:   if (right) {
111:     if (!Na->right) {
112:       VecDuplicate(right,&Na->right);
113:       VecCopy(right,Na->right);
114:     } else {
115:       VecPointwiseMult(Na->right,right,Na->right);
116:     }
117:   }
118:   return(0);
119: }

123: static PetscErrorCode MatMult_SubMatrix(Mat N,Vec x,Vec y)
124: {
125:   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
126:   Vec             xx=0;
127:   PetscErrorCode  ierr;

130:   PreScaleRight(N,x,&xx);
131:   VecZeroEntries(Na->rwork);
132:   VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
133:   VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
134:   MatMult(Na->A,Na->rwork,Na->lwork);
135:   VecScatterBegin(Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);
136:   VecScatterEnd  (Na->lrestrict,Na->lwork,y,INSERT_VALUES,SCATTER_FORWARD);
137:   PostScaleLeft(N,y);
138:   VecScale(y,Na->scale);
139:   return(0);
140: }

144: static PetscErrorCode MatMultAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
145: {
146:   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
147:   Vec             xx=0;
148:   PetscErrorCode  ierr;

151:   PreScaleRight(N,v1,&xx);
152:   VecZeroEntries(Na->rwork);
153:   VecScatterBegin(Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
154:   VecScatterEnd  (Na->rprolong,xx,Na->rwork,INSERT_VALUES,SCATTER_FORWARD);
155:   MatMult(Na->A,Na->rwork,Na->lwork);
156:   if (v2 == v3) {
157:     if (Na->scale == 1.0 && !Na->left) {
158:       VecScatterBegin(Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);
159:       VecScatterEnd  (Na->lrestrict,Na->lwork,v3,ADD_VALUES,SCATTER_FORWARD);
160:     } else {
161:       if (!Na->olwork) {VecDuplicate(v3,&Na->olwork);}
162:       VecScatterBegin(Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);
163:       VecScatterEnd  (Na->lrestrict,Na->lwork,Na->olwork,INSERT_VALUES,SCATTER_FORWARD);
164:       PostScaleLeft(N,Na->olwork);
165:       VecAXPY(v3,Na->scale,Na->olwork);
166:     }
167:   } else {
168:     VecScatterBegin(Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);
169:     VecScatterEnd  (Na->lrestrict,Na->lwork,v3,INSERT_VALUES,SCATTER_FORWARD);
170:     PostScaleLeft(N,v3);
171:     VecAYPX(v3,Na->scale,v2);
172:   }
173:   return(0);
174: }

178: static PetscErrorCode MatMultTranspose_SubMatrix(Mat N,Vec x,Vec y)
179: {
180:   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
181:   Vec             xx=0;

185:   PreScaleLeft(N,x,&xx);
186:   VecZeroEntries(Na->lwork);
187:   VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
188:   VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
189:   MatMultTranspose(Na->A,Na->lwork,Na->rwork);
190:   VecScatterBegin(Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);
191:   VecScatterEnd  (Na->rprolong,Na->rwork,y,INSERT_VALUES,SCATTER_REVERSE);
192:   PostScaleRight(N,y);
193:   VecScale(y,Na->scale);
194:   return(0);
195: }

199: static PetscErrorCode MatMultTransposeAdd_SubMatrix(Mat N,Vec v1,Vec v2,Vec v3)
200: {
201:   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;
202:   Vec             xx =0;

206:   PreScaleLeft(N,v1,&xx);
207:   VecZeroEntries(Na->lwork);
208:   VecScatterBegin(Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
209:   VecScatterEnd  (Na->lrestrict,xx,Na->lwork,INSERT_VALUES,SCATTER_REVERSE);
210:   MatMultTranspose(Na->A,Na->lwork,Na->rwork);
211:   if (v2 == v3) {
212:     if (Na->scale == 1.0 && !Na->right) {
213:       VecScatterBegin(Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);
214:       VecScatterEnd  (Na->rprolong,Na->rwork,v3,ADD_VALUES,SCATTER_REVERSE);
215:     } else {
216:       if (!Na->orwork) {VecDuplicate(v3,&Na->orwork);}
217:       VecScatterBegin(Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);
218:       VecScatterEnd  (Na->rprolong,Na->rwork,Na->orwork,INSERT_VALUES,SCATTER_REVERSE);
219:       PostScaleRight(N,Na->orwork);
220:       VecAXPY(v3,Na->scale,Na->orwork);
221:     }
222:   } else {
223:     VecScatterBegin(Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);
224:     VecScatterEnd  (Na->rprolong,Na->rwork,v3,INSERT_VALUES,SCATTER_REVERSE);
225:     PostScaleRight(N,v3);
226:     VecAYPX(v3,Na->scale,v2);
227:   }
228:   return(0);
229: }

233: static PetscErrorCode MatDestroy_SubMatrix(Mat N)
234: {
235:   Mat_SubMatrix  *Na = (Mat_SubMatrix*)N->data;

239:   ISDestroy(Na->isrow);
240:   ISDestroy(Na->iscol);
241:   if (Na->left) {VecDestroy(Na->left);}
242:   if (Na->right) {VecDestroy(Na->right);}
243:   if (Na->olwork) {VecDestroy(Na->olwork);}
244:   if (Na->orwork) {VecDestroy(Na->orwork);}
245:   VecDestroy(Na->lwork);
246:   VecDestroy(Na->rwork);
247:   VecScatterDestroy(Na->lrestrict);
248:   VecScatterDestroy(Na->rprolong);
249:   MatDestroy(Na->A);
250:   PetscFree(Na);
251:   return(0);
252: }

256: /*@
257:    MatCreateSubMatrix - Creates a composite matrix that acts as a submatrix

259:    Collective on Mat

261:    Input Parameters:
262: +  A - matrix that we will extract a submatrix of
263: .  isrow - rows to be present in the submatrix
264: -  iscol - columns to be present in the submatrix

266:    Output Parameters:
267: .  newmat - new matrix

269:    Level: developer

271:    Notes:
272:    Most will use MatGetSubMatrix which provides a more efficient representation if it is available.

274: .seealso: MatGetSubMatrix(), MatSubMatrixUpdate()
275: @*/
276: PetscErrorCode  MatCreateSubMatrix(Mat A,IS isrow,IS iscol,Mat *newmat)
277: {
278:   Vec            left,right;
279:   PetscInt       m,n;
280:   Mat            N;
281:   Mat_SubMatrix *Na;

289:   *newmat = 0;

291:   MatCreate(((PetscObject)A)->comm,&N);
292:   ISGetLocalSize(isrow,&m);
293:   ISGetLocalSize(iscol,&n);
294:   MatSetSizes(N,m,n,PETSC_DETERMINE,PETSC_DETERMINE);
295:   PetscObjectChangeTypeName((PetscObject)N,MATSUBMATRIX);

297:   PetscNewLog(N,Mat_SubMatrix,&Na);
298:   N->data   = (void*)Na;
299:   PetscObjectReference((PetscObject)A);
300:   PetscObjectReference((PetscObject)isrow);
301:   PetscObjectReference((PetscObject)iscol);
302:   Na->A     = A;
303:   Na->isrow = isrow;
304:   Na->iscol = iscol;
305:   Na->scale = 1.0;

307:   N->ops->destroy          = MatDestroy_SubMatrix;
308:   N->ops->mult             = MatMult_SubMatrix;
309:   N->ops->multadd          = MatMultAdd_SubMatrix;
310:   N->ops->multtranspose    = MatMultTranspose_SubMatrix;
311:   N->ops->multtransposeadd = MatMultTransposeAdd_SubMatrix;
312:   N->ops->scale            = MatScale_SubMatrix;
313:   N->ops->diagonalscale    = MatDiagonalScale_SubMatrix;

315:   N->assembled = PETSC_TRUE;

317:   PetscLayoutSetBlockSize(N->rmap,A->rmap->bs);
318:   PetscLayoutSetBlockSize(N->cmap,A->cmap->bs);
319:   PetscLayoutSetUp(N->rmap);
320:   PetscLayoutSetUp(N->cmap);

322:   MatGetVecs(A,&Na->rwork,&Na->lwork);
323:   VecCreate(((PetscObject)isrow)->comm,&left);
324:   VecCreate(((PetscObject)iscol)->comm,&right);
325:   VecSetSizes(left,m,PETSC_DETERMINE);
326:   VecSetSizes(right,n,PETSC_DETERMINE);
327:   VecSetUp(left);
328:   VecSetUp(right);
329:   VecScatterCreate(Na->lwork,isrow,left,PETSC_NULL,&Na->lrestrict);
330:   VecScatterCreate(right,PETSC_NULL,Na->rwork,iscol,&Na->rprolong);
331:   VecDestroy(left);
332:   VecDestroy(right);

334:   *newmat = N;
335:   return(0);
336: }


341: /*@
342:    MatSubMatrixUpdate - Updates a submatrix

344:    Collective on Mat

346:    Input Parameters:
347: +  N - submatrix to update
348: .  A - full matrix in the submatrix
349: .  isrow - rows in the update (same as the first time the submatrix was created)
350: -  iscol - columns in the update (same as the first time the submatrix was created)

352:    Level: developer

354:    Notes:
355:    Most will use MatGetSubMatrix which provides a more efficient representation if it is available.

357: .seealso: MatGetSubMatrix(), MatCreateSubMatrix()
358: @*/
359: PetscErrorCode  MatSubMatrixUpdate(Mat N,Mat A,IS isrow,IS iscol)
360: {
361:   PetscErrorCode  ierr;
362:   PetscTruth      flg;
363:   Mat_SubMatrix  *Na;

370:   PetscTypeCompare((PetscObject)N,MATSUBMATRIX,&flg);
371:   if (!flg) SETERRQ(PETSC_ERR_ARG_WRONG,"Matrix has wrong type");

373:   Na = (Mat_SubMatrix*)N->data;
374:   ISEqual(isrow,Na->isrow,&flg);
375:   if (!flg) SETERRQ(PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different row indices");
376:   ISEqual(iscol,Na->iscol,&flg);
377:   if (!flg) SETERRQ(PETSC_ERR_ARG_INCOMP,"Cannot update submatrix with different column indices");

379:   PetscObjectReference((PetscObject)A);
380:   MatDestroy(Na->A);
381:   Na->A = A;

383:   Na->scale = 1.0;
384:   if (Na->left) {VecDestroy(Na->left);}
385:   if (Na->right) {VecDestroy(Na->right);}
386:   return(0);
387: }