Actual source code: fsolvebaij.F

  1: !
  2: !
  3: !    Fortran kernel for sparse triangular solve in the BAIJ matrix format
  4: ! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
  5: ! with MatSolve_SeqBAIJ_4_NaturalOrdering()
  6: !
 7:  #include finclude/petscdef.h
  8: !

 10:       subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
 11:       implicit none
 12:       MatScalar   a(0:*)
 13:       PetscScalar x(0:*),b(0:*)
 14:       PetscInt    n,ai(0:*)
 15:       PetscInt    aj(0:*),adiag(0:*)

 17:       PetscInt    i,j,jstart,jend,idx,ax,jdx
 18:       PetscScalar s1,s2,s3,s4
 19:       PetscScalar x1,x2,x3,x4
 20: !
 21: !     Forward Solve
 22: !

 24:       x(0) = b(0)
 25:       x(1) = b(1)
 26:       x(2) = b(2)
 27:       x(3) = b(3)
 28:       idx  = 0
 29:       do 20 i=1,n-1
 30:          jstart = ai(i)
 31:          jend   = adiag(i) - 1
 32:          ax    = 16*jstart
 33:          idx    = idx + 4
 34:          s1     = b(idx)
 35:          s2     = b(idx+1)
 36:          s3     = b(idx+2)
 37:          s4     = b(idx+3)
 38:          do 30 j=jstart,jend
 39:            jdx   = 4*aj(j)
 40: 
 41:            x1    = x(jdx)
 42:            x2    = x(jdx+1)
 43:            x3    = x(jdx+2)
 44:            x4    = x(jdx+3)
 45:            s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 46:            s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 47:            s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 48:            s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 49:            ax = ax + 16
 50:  30      continue
 51:          x(idx)   = s1
 52:          x(idx+1) = s2
 53:          x(idx+2) = s3
 54:          x(idx+3) = s4
 55:  20   continue
 56: 
 57: !
 58: !     Backward solve the upper triangular
 59: !
 60:       do 40 i=n-1,0,-1
 61:          jstart  = adiag(i) + 1
 62:          jend    = ai(i+1) - 1
 63:          ax     = 16*jstart
 64:          s1      = x(idx)
 65:          s2      = x(idx+1)
 66:          s3      = x(idx+2)
 67:          s4      = x(idx+3)
 68:          do 50 j=jstart,jend
 69:            jdx   = 4*aj(j)
 70:            x1    = x(jdx)
 71:            x2    = x(jdx+1)
 72:            x3    = x(jdx+2)
 73:            x4    = x(jdx+3)
 74:            s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 75:            s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 76:            s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 77:            s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 78:            ax = ax + 16
 79:  50      continue
 80:          ax      = 16*adiag(i)
 81:          x(idx)   = a(ax)*s1  +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
 82:          x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
 83:          x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
 84:          x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
 85:          idx      = idx - 4
 86:  40   continue
 87:       return
 88:       end
 89: 
 90: !
 91: !   version that calls BLAS 2 operation for each row block
 92: !
 93:       subroutine FortranSolveBAIJ4BLAS(n,x,ai,aj,adiag,a,b,w)
 94:       implicit none
 95:       MatScalar   a(0:*),w(0:*)
 96:       PetscScalar x(0:*),b(0:*)
 97:       PetscInt  n,ai(0:*),aj(0:*),adiag(0:*)

 99:       PetscInt i,j,jstart,jend,idx,ax,jdx,kdx
100:       MatScalar   s(0:3)
101: !
102: !     Forward Solve
103: !

105:       x(0) = b(0)
106:       x(1) = b(1)
107:       x(2) = b(2)
108:       x(3) = b(3)
109:       idx  = 0
110:       do 20 i=1,n-1
111: !
112: !        Pack required part of vector into work array
113: !
114:          kdx    = 0
115:          jstart = ai(i)
116:          jend   = adiag(i) - 1
117:          if (jend - jstart .ge. 500) then
118:            write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
119:          endif
120:          do 30 j=jstart,jend
121: 
122:            jdx       = 4*aj(j)
123: 
124:            w(kdx)    = x(jdx)
125:            w(kdx+1)  = x(jdx+1)
126:            w(kdx+2)  = x(jdx+2)
127:            w(kdx+3)  = x(jdx+3)
128:            kdx       = kdx + 4
129:  30      continue

131:          ax      = 16*jstart
132:          idx      = idx + 4
133:          s(0)     = b(idx)
134:          s(1)     = b(idx+1)
135:          s(2)     = b(idx+2)
136:          s(3)     = b(idx+3)
137: !
138: !    s = s - a(ax:)*w
139: !
140:          call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)
141: !         call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)

143:          x(idx)   = s(0)
144:          x(idx+1) = s(1)
145:          x(idx+2) = s(2)
146:          x(idx+3) = s(3)
147:  20   continue
148: 
149: !
150: !     Backward solve the upper triangular
151: !
152:       do 40 i=n-1,0,-1
153:          jstart    = adiag(i) + 1
154:          jend      = ai(i+1) - 1
155:          ax       = 16*jstart
156:          s(0)      = x(idx)
157:          s(1)      = x(idx+1)
158:          s(2)      = x(idx+2)
159:          s(3)      = x(idx+3)
160: !
161: !   Pack each chunk of vector needed
162: !
163:          kdx = 0
164:          if (jend - jstart .ge. 500) then
165:            write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
166:          endif
167:          do 50 j=jstart,jend
168:            jdx      = 4*aj(j)
169:            w(kdx)   = x(jdx)
170:            w(kdx+1) = x(jdx+1)
171:            w(kdx+2) = x(jdx+2)
172:            w(kdx+3) = x(jdx+3)
173:            kdx      = kdx + 4
174:  50      continue
175: !         call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)
176:          call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)

178:          ax      = 16*adiag(i)
179:          x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
180:          x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
181:          x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
182:          x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
183:          idx     = idx - 4
184:  40   continue
185:       return
186:       end
187: 

189: !
190: !   version that does not call BLAS 2 operation for each row block
191: !
192:       subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
193:       implicit none
194:       MatScalar   a(0:*)
195:       PetscScalar x(0:*),b(0:*),w(0:*)
196:       PetscInt n,ai(0:*),aj(0:*),adiag(0:*)
197:       PetscInt ii,jj,i,j

199:       PetscInt jstart,jend,idx,ax,jdx,kdx,nn
200:       PetscScalar s(0:3)
201: !
202: !     Forward Solve
203: !

205:       x(0) = b(0)
206:       x(1) = b(1)
207:       x(2) = b(2)
208:       x(3) = b(3)
209:       idx  = 0
210:       do 20 i=1,n-1
211: !
212: !        Pack required part of vector into work array
213: !
214:          kdx    = 0
215:          jstart = ai(i)
216:          jend   = adiag(i) - 1
217:          if (jend - jstart .ge. 500) then
218:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
219:          endif
220:          do 30 j=jstart,jend
221: 
222:            jdx       = 4*aj(j)
223: 
224:            w(kdx)    = x(jdx)
225:            w(kdx+1)  = x(jdx+1)
226:            w(kdx+2)  = x(jdx+2)
227:            w(kdx+3)  = x(jdx+3)
228:            kdx       = kdx + 4
229:  30      continue

231:          ax       = 16*jstart
232:          idx      = idx + 4
233:          s(0)     = b(idx)
234:          s(1)     = b(idx+1)
235:          s(2)     = b(idx+2)
236:          s(3)     = b(idx+3)
237: !
238: !    s = s - a(ax:)*w
239: !
240:          nn = 4*(jend - jstart + 1) - 1
241:          do 100, ii=0,3
242:            do 110, jj=0,nn
243:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
244:  110       continue
245:  100     continue

247:          x(idx)   = s(0)
248:          x(idx+1) = s(1)
249:          x(idx+2) = s(2)
250:          x(idx+3) = s(3)
251:  20   continue
252: 
253: !
254: !     Backward solve the upper triangular
255: !
256:       do 40 i=n-1,0,-1
257:          jstart    = adiag(i) + 1
258:          jend      = ai(i+1) - 1
259:          ax        = 16*jstart
260:          s(0)      = x(idx)
261:          s(1)      = x(idx+1)
262:          s(2)      = x(idx+2)
263:          s(3)      = x(idx+3)
264: !
265: !   Pack each chunk of vector needed
266: !
267:          kdx = 0
268:          if (jend - jstart .ge. 500) then
269:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
270:          endif
271:          do 50 j=jstart,jend
272:            jdx      = 4*aj(j)
273:            w(kdx)   = x(jdx)
274:            w(kdx+1) = x(jdx+1)
275:            w(kdx+2) = x(jdx+2)
276:            w(kdx+3) = x(jdx+3)
277:            kdx      = kdx + 4
278:  50      continue
279:          nn = 4*(jend - jstart + 1) - 1
280:          do 200, ii=0,3
281:            do 210, jj=0,nn
282:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
283:  210       continue
284:  200     continue

286:          ax      = 16*adiag(i)
287:          x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
288:          x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
289:          x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
290:          x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
291:          idx     = idx - 4
292:  40   continue
293:       return
294:       end
295: