Actual source code: baijsolvtrann.c
1: #include <../src/mat/impls/baij/seq/baij.h>
2: #include <petsc/private/kernels/blockinvert.h>
4: PetscErrorCode MatSolveTranspose_SeqBAIJ_N_inplace(Mat A, Vec bb, Vec xx)
5: {
6: Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data;
7: IS iscol = a->col, isrow = a->row;
8: const PetscInt *r, *c, *rout, *cout, *ai = a->i, *aj = a->j, *vi;
9: PetscInt i, nz, j;
10: const PetscInt n = a->mbs, bs = A->rmap->bs, bs2 = a->bs2;
11: const MatScalar *aa = a->a, *v;
12: PetscScalar *x, *t, *ls;
13: const PetscScalar *b;
15: PetscFunctionBegin;
16: PetscCheck(bs > 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Expected bs %" PetscInt_FMT " > 0", bs);
17: PetscCall(VecGetArrayRead(bb, &b));
18: PetscCall(VecGetArray(xx, &x));
19: t = a->solve_work;
21: PetscCall(ISGetIndices(isrow, &rout));
22: r = rout;
23: PetscCall(ISGetIndices(iscol, &cout));
24: c = cout;
26: /* copy the b into temp work space according to permutation */
27: for (i = 0; i < n; i++) {
28: for (j = 0; j < bs; j++) t[i * bs + j] = b[c[i] * bs + j];
29: }
31: /* forward solve the upper triangular transpose */
32: ls = a->solve_work + A->cmap->n;
33: for (i = 0; i < n; i++) {
34: PetscCall(PetscArraycpy(ls, t + i * bs, bs));
35: PetscKernel_w_gets_transA_times_v(bs, ls, aa + bs2 * a->diag[i], t + i * bs);
36: v = aa + bs2 * (a->diag[i] + 1);
37: vi = aj + a->diag[i] + 1;
38: nz = ai[i + 1] - a->diag[i] - 1;
39: while (nz--) {
40: PetscKernel_v_gets_v_minus_transA_times_w(bs, t + bs * (*vi++), v, t + i * bs);
41: v += bs2;
42: }
43: }
45: /* backward solve the lower triangular transpose */
46: for (i = n - 1; i >= 0; i--) {
47: v = aa + bs2 * ai[i];
48: vi = aj + ai[i];
49: nz = a->diag[i] - ai[i];
50: while (nz--) {
51: PetscKernel_v_gets_v_minus_transA_times_w(bs, t + bs * (*vi++), v, t + i * bs);
52: v += bs2;
53: }
54: }
56: /* copy t into x according to permutation */
57: for (i = 0; i < n; i++) {
58: for (j = 0; j < bs; j++) x[bs * r[i] + j] = t[bs * i + j];
59: }
61: PetscCall(ISRestoreIndices(isrow, &rout));
62: PetscCall(ISRestoreIndices(iscol, &cout));
63: PetscCall(VecRestoreArrayRead(bb, &b));
64: PetscCall(VecRestoreArray(xx, &x));
65: PetscCall(PetscLogFlops(2.0 * (a->bs2) * (a->nz) - A->rmap->bs * A->cmap->n));
66: PetscFunctionReturn(PETSC_SUCCESS);
67: }
69: PetscErrorCode MatSolveTranspose_SeqBAIJ_N(Mat A, Vec bb, Vec xx)
70: {
71: Mat_SeqBAIJ *a = (Mat_SeqBAIJ *)A->data;
72: IS iscol = a->col, isrow = a->row;
73: const PetscInt *r, *c, *rout, *cout;
74: const PetscInt n = a->mbs, *ai = a->i, *aj = a->j, *vi, *diag = a->diag;
75: PetscInt i, j, nz;
76: const PetscInt bs = A->rmap->bs, bs2 = a->bs2;
77: const MatScalar *aa = a->a, *v;
78: PetscScalar *x, *t, *ls;
79: const PetscScalar *b;
81: PetscFunctionBegin;
82: PetscCheck(bs > 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Expected bs %" PetscInt_FMT " > 0", bs);
83: PetscCall(VecGetArrayRead(bb, &b));
84: PetscCall(VecGetArray(xx, &x));
85: t = a->solve_work;
87: PetscCall(ISGetIndices(isrow, &rout));
88: r = rout;
89: PetscCall(ISGetIndices(iscol, &cout));
90: c = cout;
92: /* copy the b into temp work space according to permutation */
93: for (i = 0; i < n; i++) {
94: for (j = 0; j < bs; j++) t[i * bs + j] = b[c[i] * bs + j];
95: }
97: /* forward solve the upper triangular transpose */
98: ls = a->solve_work + A->cmap->n;
99: for (i = 0; i < n; i++) {
100: PetscCall(PetscArraycpy(ls, t + i * bs, bs));
101: PetscKernel_w_gets_transA_times_v(bs, ls, aa + bs2 * diag[i], t + i * bs);
102: v = aa + bs2 * (diag[i] - 1);
103: vi = aj + diag[i] - 1;
104: nz = diag[i] - diag[i + 1] - 1;
105: for (j = 0; j > -nz; j--) {
106: PetscKernel_v_gets_v_minus_transA_times_w(bs, t + bs * (vi[j]), v, t + i * bs);
107: v -= bs2;
108: }
109: }
111: /* backward solve the lower triangular transpose */
112: for (i = n - 1; i >= 0; i--) {
113: v = aa + bs2 * ai[i];
114: vi = aj + ai[i];
115: nz = ai[i + 1] - ai[i];
116: for (j = 0; j < nz; j++) {
117: PetscKernel_v_gets_v_minus_transA_times_w(bs, t + bs * (vi[j]), v, t + i * bs);
118: v += bs2;
119: }
120: }
122: /* copy t into x according to permutation */
123: for (i = 0; i < n; i++) {
124: for (j = 0; j < bs; j++) x[bs * r[i] + j] = t[bs * i + j];
125: }
127: PetscCall(ISRestoreIndices(isrow, &rout));
128: PetscCall(ISRestoreIndices(iscol, &cout));
129: PetscCall(VecRestoreArrayRead(bb, &b));
130: PetscCall(VecRestoreArray(xx, &x));
131: PetscCall(PetscLogFlops(2.0 * (a->bs2) * (a->nz) - A->rmap->bs * A->cmap->n));
132: PetscFunctionReturn(PETSC_SUCCESS);
133: }