Actual source code: baijsolvtran3.c

  1: #include <../src/mat/impls/baij/seq/baij.h>
  2: #include <petsc/private/kernels/blockinvert.h>

  4: PetscErrorCode MatSolveTranspose_SeqBAIJ_3_inplace(Mat A, Vec bb, Vec xx)
  5: {
  6:   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
  7:   IS                 iscol = a->col, isrow = a->row;
  8:   const PetscInt    *r, *c, *rout, *cout;
  9:   const PetscInt    *diag = a->diag, n = a->mbs, *vi, *ai = a->i, *aj = a->j;
 10:   PetscInt           i, nz, idx, idt, ii, ic, ir, oidx;
 11:   const MatScalar   *aa = a->a, *v;
 12:   PetscScalar        s1, s2, s3, x1, x2, x3, *x, *t;
 13:   const PetscScalar *b;

 15:   PetscFunctionBegin;
 16:   PetscCall(VecGetArrayRead(bb, &b));
 17:   PetscCall(VecGetArray(xx, &x));
 18:   t = a->solve_work;

 20:   PetscCall(ISGetIndices(isrow, &rout));
 21:   r = rout;
 22:   PetscCall(ISGetIndices(iscol, &cout));
 23:   c = cout;

 25:   /* copy the b into temp work space according to permutation */
 26:   ii = 0;
 27:   for (i = 0; i < n; i++) {
 28:     ic        = 3 * c[i];
 29:     t[ii]     = b[ic];
 30:     t[ii + 1] = b[ic + 1];
 31:     t[ii + 2] = b[ic + 2];
 32:     ii += 3;
 33:   }

 35:   /* forward solve the U^T */
 36:   idx = 0;
 37:   for (i = 0; i < n; i++) {
 38:     v = aa + 9 * diag[i];
 39:     /* multiply by the inverse of the block diagonal */
 40:     x1 = t[idx];
 41:     x2 = t[1 + idx];
 42:     x3 = t[2 + idx];
 43:     s1 = v[0] * x1 + v[1] * x2 + v[2] * x3;
 44:     s2 = v[3] * x1 + v[4] * x2 + v[5] * x3;
 45:     s3 = v[6] * x1 + v[7] * x2 + v[8] * x3;
 46:     v += 9;

 48:     vi = aj + diag[i] + 1;
 49:     nz = ai[i + 1] - diag[i] - 1;
 50:     while (nz--) {
 51:       oidx = 3 * (*vi++);
 52:       t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3;
 53:       t[oidx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3;
 54:       t[oidx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3;
 55:       v += 9;
 56:     }
 57:     t[idx]     = s1;
 58:     t[1 + idx] = s2;
 59:     t[2 + idx] = s3;
 60:     idx += 3;
 61:   }
 62:   /* backward solve the L^T */
 63:   for (i = n - 1; i >= 0; i--) {
 64:     v   = aa + 9 * diag[i] - 9;
 65:     vi  = aj + diag[i] - 1;
 66:     nz  = diag[i] - ai[i];
 67:     idt = 3 * i;
 68:     s1  = t[idt];
 69:     s2  = t[1 + idt];
 70:     s3  = t[2 + idt];
 71:     while (nz--) {
 72:       idx = 3 * (*vi--);
 73:       t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3;
 74:       t[idx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3;
 75:       t[idx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3;
 76:       v -= 9;
 77:     }
 78:   }

 80:   /* copy t into x according to permutation */
 81:   ii = 0;
 82:   for (i = 0; i < n; i++) {
 83:     ir        = 3 * r[i];
 84:     x[ir]     = t[ii];
 85:     x[ir + 1] = t[ii + 1];
 86:     x[ir + 2] = t[ii + 2];
 87:     ii += 3;
 88:   }

 90:   PetscCall(ISRestoreIndices(isrow, &rout));
 91:   PetscCall(ISRestoreIndices(iscol, &cout));
 92:   PetscCall(VecRestoreArrayRead(bb, &b));
 93:   PetscCall(VecRestoreArray(xx, &x));
 94:   PetscCall(PetscLogFlops(2.0 * 9 * (a->nz) - 3.0 * A->cmap->n));
 95:   PetscFunctionReturn(PETSC_SUCCESS);
 96: }

 98: PetscErrorCode MatSolveTranspose_SeqBAIJ_3(Mat A, Vec bb, Vec xx)
 99: {
100:   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
101:   IS                 iscol = a->col, isrow = a->row;
102:   const PetscInt     n = a->mbs, *vi, *ai = a->i, *aj = a->j, *diag = a->diag;
103:   const PetscInt    *r, *c, *rout, *cout;
104:   PetscInt           nz, idx, idt, j, i, oidx, ii, ic, ir;
105:   const PetscInt     bs = A->rmap->bs, bs2 = a->bs2;
106:   const MatScalar   *aa = a->a, *v;
107:   PetscScalar        s1, s2, s3, x1, x2, x3, *x, *t;
108:   const PetscScalar *b;

110:   PetscFunctionBegin;
111:   PetscCall(VecGetArrayRead(bb, &b));
112:   PetscCall(VecGetArray(xx, &x));
113:   t = a->solve_work;

115:   PetscCall(ISGetIndices(isrow, &rout));
116:   r = rout;
117:   PetscCall(ISGetIndices(iscol, &cout));
118:   c = cout;

120:   /* copy b into temp work space according to permutation */
121:   for (i = 0; i < n; i++) {
122:     ii        = bs * i;
123:     ic        = bs * c[i];
124:     t[ii]     = b[ic];
125:     t[ii + 1] = b[ic + 1];
126:     t[ii + 2] = b[ic + 2];
127:   }

129:   /* forward solve the U^T */
130:   idx = 0;
131:   for (i = 0; i < n; i++) {
132:     v = aa + bs2 * diag[i];
133:     /* multiply by the inverse of the block diagonal */
134:     x1 = t[idx];
135:     x2 = t[1 + idx];
136:     x3 = t[2 + idx];
137:     s1 = v[0] * x1 + v[1] * x2 + v[2] * x3;
138:     s2 = v[3] * x1 + v[4] * x2 + v[5] * x3;
139:     s3 = v[6] * x1 + v[7] * x2 + v[8] * x3;
140:     v -= bs2;

142:     vi = aj + diag[i] - 1;
143:     nz = diag[i] - diag[i + 1] - 1;
144:     for (j = 0; j > -nz; j--) {
145:       oidx = bs * vi[j];
146:       t[oidx] -= v[0] * s1 + v[1] * s2 + v[2] * s3;
147:       t[oidx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3;
148:       t[oidx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3;
149:       v -= bs2;
150:     }
151:     t[idx]     = s1;
152:     t[1 + idx] = s2;
153:     t[2 + idx] = s3;
154:     idx += bs;
155:   }
156:   /* backward solve the L^T */
157:   for (i = n - 1; i >= 0; i--) {
158:     v   = aa + bs2 * ai[i];
159:     vi  = aj + ai[i];
160:     nz  = ai[i + 1] - ai[i];
161:     idt = bs * i;
162:     s1  = t[idt];
163:     s2  = t[1 + idt];
164:     s3  = t[2 + idt];
165:     for (j = 0; j < nz; j++) {
166:       idx = bs * vi[j];
167:       t[idx] -= v[0] * s1 + v[1] * s2 + v[2] * s3;
168:       t[idx + 1] -= v[3] * s1 + v[4] * s2 + v[5] * s3;
169:       t[idx + 2] -= v[6] * s1 + v[7] * s2 + v[8] * s3;
170:       v += bs2;
171:     }
172:   }

174:   /* copy t into x according to permutation */
175:   for (i = 0; i < n; i++) {
176:     ii        = bs * i;
177:     ir        = bs * r[i];
178:     x[ir]     = t[ii];
179:     x[ir + 1] = t[ii + 1];
180:     x[ir + 2] = t[ii + 2];
181:   }

183:   PetscCall(ISRestoreIndices(isrow, &rout));
184:   PetscCall(ISRestoreIndices(iscol, &cout));
185:   PetscCall(VecRestoreArrayRead(bb, &b));
186:   PetscCall(VecRestoreArray(xx, &x));
187:   PetscCall(PetscLogFlops(2.0 * bs2 * (a->nz) - bs * A->cmap->n));
188:   PetscFunctionReturn(PETSC_SUCCESS);
189: }