Actual source code: baijsolvtrann.c

  1: #include <../src/mat/impls/baij/seq/baij.h>
  2: #include <petsc/private/kernels/blockinvert.h>

  4: PetscErrorCode MatSolveTranspose_SeqBAIJ_N_inplace(Mat A, Vec bb, Vec xx)
  5: {
  6:   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
  7:   IS                 iscol = a->col, isrow = a->row;
  8:   const PetscInt    *r, *c, *rout, *cout, *ai = a->i, *aj = a->j, *vi;
  9:   PetscInt           i, nz, j;
 10:   const PetscInt     n = a->mbs, bs = A->rmap->bs, bs2 = a->bs2;
 11:   const MatScalar   *aa = a->a, *v;
 12:   PetscScalar       *x, *t, *ls;
 13:   const PetscScalar *b;

 15:   PetscFunctionBegin;
 16:   PetscCheck(bs > 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Expected bs %" PetscInt_FMT " > 0", bs);
 17:   PetscCall(VecGetArrayRead(bb, &b));
 18:   PetscCall(VecGetArray(xx, &x));
 19:   t = a->solve_work;

 21:   PetscCall(ISGetIndices(isrow, &rout));
 22:   r = rout;
 23:   PetscCall(ISGetIndices(iscol, &cout));
 24:   c = cout;

 26:   /* copy the b into temp work space according to permutation */
 27:   for (i = 0; i < n; i++) {
 28:     for (j = 0; j < bs; j++) t[i * bs + j] = b[c[i] * bs + j];
 29:   }

 31:   /* forward solve the upper triangular transpose */
 32:   ls = a->solve_work + A->cmap->n;
 33:   for (i = 0; i < n; i++) {
 34:     PetscCall(PetscArraycpy(ls, t + i * bs, bs));
 35:     PetscKernel_w_gets_transA_times_v(bs, ls, aa + bs2 * a->diag[i], t + i * bs);
 36:     v  = aa + bs2 * (a->diag[i] + 1);
 37:     vi = aj + a->diag[i] + 1;
 38:     nz = ai[i + 1] - a->diag[i] - 1;
 39:     while (nz--) {
 40:       PetscKernel_v_gets_v_minus_transA_times_w(bs, t + bs * (*vi++), v, t + i * bs);
 41:       v += bs2;
 42:     }
 43:   }

 45:   /* backward solve the lower triangular transpose */
 46:   for (i = n - 1; i >= 0; i--) {
 47:     v  = aa + bs2 * ai[i];
 48:     vi = aj + ai[i];
 49:     nz = a->diag[i] - ai[i];
 50:     while (nz--) {
 51:       PetscKernel_v_gets_v_minus_transA_times_w(bs, t + bs * (*vi++), v, t + i * bs);
 52:       v += bs2;
 53:     }
 54:   }

 56:   /* copy t into x according to permutation */
 57:   for (i = 0; i < n; i++) {
 58:     for (j = 0; j < bs; j++) x[bs * r[i] + j] = t[bs * i + j];
 59:   }

 61:   PetscCall(ISRestoreIndices(isrow, &rout));
 62:   PetscCall(ISRestoreIndices(iscol, &cout));
 63:   PetscCall(VecRestoreArrayRead(bb, &b));
 64:   PetscCall(VecRestoreArray(xx, &x));
 65:   PetscCall(PetscLogFlops(2.0 * (a->bs2) * (a->nz) - A->rmap->bs * A->cmap->n));
 66:   PetscFunctionReturn(PETSC_SUCCESS);
 67: }

 69: PetscErrorCode MatSolveTranspose_SeqBAIJ_N(Mat A, Vec bb, Vec xx)
 70: {
 71:   Mat_SeqBAIJ       *a     = (Mat_SeqBAIJ *)A->data;
 72:   IS                 iscol = a->col, isrow = a->row;
 73:   const PetscInt    *r, *c, *rout, *cout;
 74:   const PetscInt     n = a->mbs, *ai = a->i, *aj = a->j, *vi, *diag = a->diag;
 75:   PetscInt           i, j, nz;
 76:   const PetscInt     bs = A->rmap->bs, bs2 = a->bs2;
 77:   const MatScalar   *aa = a->a, *v;
 78:   PetscScalar       *x, *t, *ls;
 79:   const PetscScalar *b;

 81:   PetscFunctionBegin;
 82:   PetscCheck(bs > 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Expected bs %" PetscInt_FMT " > 0", bs);
 83:   PetscCall(VecGetArrayRead(bb, &b));
 84:   PetscCall(VecGetArray(xx, &x));
 85:   t = a->solve_work;

 87:   PetscCall(ISGetIndices(isrow, &rout));
 88:   r = rout;
 89:   PetscCall(ISGetIndices(iscol, &cout));
 90:   c = cout;

 92:   /* copy the b into temp work space according to permutation */
 93:   for (i = 0; i < n; i++) {
 94:     for (j = 0; j < bs; j++) t[i * bs + j] = b[c[i] * bs + j];
 95:   }

 97:   /* forward solve the upper triangular transpose */
 98:   ls = a->solve_work + A->cmap->n;
 99:   for (i = 0; i < n; i++) {
100:     PetscCall(PetscArraycpy(ls, t + i * bs, bs));
101:     PetscKernel_w_gets_transA_times_v(bs, ls, aa + bs2 * diag[i], t + i * bs);
102:     v  = aa + bs2 * (diag[i] - 1);
103:     vi = aj + diag[i] - 1;
104:     nz = diag[i] - diag[i + 1] - 1;
105:     for (j = 0; j > -nz; j--) {
106:       PetscKernel_v_gets_v_minus_transA_times_w(bs, t + bs * (vi[j]), v, t + i * bs);
107:       v -= bs2;
108:     }
109:   }

111:   /* backward solve the lower triangular transpose */
112:   for (i = n - 1; i >= 0; i--) {
113:     v  = aa + bs2 * ai[i];
114:     vi = aj + ai[i];
115:     nz = ai[i + 1] - ai[i];
116:     for (j = 0; j < nz; j++) {
117:       PetscKernel_v_gets_v_minus_transA_times_w(bs, t + bs * (vi[j]), v, t + i * bs);
118:       v += bs2;
119:     }
120:   }

122:   /* copy t into x according to permutation */
123:   for (i = 0; i < n; i++) {
124:     for (j = 0; j < bs; j++) x[bs * r[i] + j] = t[bs * i + j];
125:   }

127:   PetscCall(ISRestoreIndices(isrow, &rout));
128:   PetscCall(ISRestoreIndices(iscol, &cout));
129:   PetscCall(VecRestoreArrayRead(bb, &b));
130:   PetscCall(VecRestoreArray(xx, &x));
131:   PetscCall(PetscLogFlops(2.0 * (a->bs2) * (a->nz) - A->rmap->bs * A->cmap->n));
132:   PetscFunctionReturn(PETSC_SUCCESS);
133: }