Actual source code: fsolvebaij.F

petsc-3.3-p2 2012-07-13
  1: !
  2: !
  3: !    Fortran kernel for sparse triangular solve in the BAIJ matrix format
  4: ! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
  5: ! with MatSolve_SeqBAIJ_4_NaturalOrdering()
  6: !
  7: #include <finclude/petscsysdef.h>
  8: !

 10:       subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
 11:       implicit none
 12:       MatScalar   a(0:*)
 13:       PetscScalar x(0:*),b(0:*)
 14:       PetscInt    n,ai(0:*),aj(0:*)
 15:       PetscInt    adiag(0:*)

 17:       PetscInt    i,j,jstart,jend,idx,ax,jdx
 18:       PetscScalar s1,s2,s3,s4
 19:       PetscScalar x1,x2,x3,x4
 20: !
 21: !     Forward Solve
 22: !
 23:       PETSC_AssertAlignx(16,a(1))
 24:       PETSC_AssertAlignx(16,x(1))
 25:       PETSC_AssertAlignx(16,b(1))
 26:       PETSC_AssertAlignx(16,ai(1))
 27:       PETSC_AssertAlignx(16,aj(1))
 28:       PETSC_AssertAlignx(16,adiag(1))
 29: 
 30:          x(0) = b(0)
 31:          x(1) = b(1)
 32:          x(2) = b(2)
 33:          x(3) = b(3)
 34:          idx  = 0
 35:          do 20 i=1,n-1
 36:             jstart = ai(i)
 37:             jend   = adiag(i) - 1
 38:             ax    = 16*jstart
 39:             idx    = idx + 4
 40:             s1     = b(idx)
 41:             s2     = b(idx+1)
 42:             s3     = b(idx+2)
 43:             s4     = b(idx+3)
 44:             do 30 j=jstart,jend
 45:               jdx   = 4*aj(j)
 46: 
 47:               x1    = x(jdx)
 48:               x2    = x(jdx+1)
 49:               x3    = x(jdx+2)
 50:               x4    = x(jdx+3)
 51:               s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 52:               s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 53:               s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 54:               s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 55:               ax = ax + 16
 56:  30         continue
 57:             x(idx)   = s1
 58:             x(idx+1) = s2
 59:             x(idx+2) = s3
 60:             x(idx+3) = s4
 61:  20      continue

 63: 
 64: !
 65: !     Backward solve the upper triangular
 66: !
 67:          do 40 i=n-1,0,-1
 68:             jstart  = adiag(i) + 1
 69:             jend    = ai(i+1) - 1
 70:             ax     = 16*jstart
 71:             s1      = x(idx)
 72:             s2      = x(idx+1)
 73:             s3      = x(idx+2)
 74:             s4      = x(idx+3)
 75:             do 50 j=jstart,jend
 76:               jdx   = 4*aj(j)
 77:               x1    = x(jdx)
 78:               x2    = x(jdx+1)
 79:               x3    = x(jdx+2)
 80:               x4    = x(jdx+3)
 81:               s1 = s1-(a(ax)*x1  +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
 82:               s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
 83:               s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
 84:               s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
 85:               ax = ax + 16
 86:  50         continue
 87:             ax      = 16*adiag(i)
 88:             x(idx)   = a(ax)*s1  +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
 89:             x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
 90:             x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
 91:             x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
 92:             idx      = idx - 4
 93:  40      continue
 94:       return
 95:       end
 96: 
 97: !
 98: !   version that calls BLAS 2 operation for each row block
 99: !
100:       subroutine FortranSolveBAIJ4BLAS(n,x,ai,aj,adiag,a,b,w)
101:       implicit none
102:       MatScalar   a(0:*),w(0:*)
103:       PetscScalar x(0:*),b(0:*)
104:       PetscInt n,ai(0:*),aj(0:*),adiag(0:*)

106:       PetscInt i,j,jstart,jend,idx,ax,jdx,kdx
107:       MatScalar s(0:3)
108:       integer   align7
109: !
110: !     Forward Solve
111: !


114:       PETSC_AssertAlignx(16,a(1))
115:       PETSC_AssertAlignx(16,w(1))
116:       PETSC_AssertAlignx(16,x(1))
117:       PETSC_AssertAlignx(16,b(1))
118:       PETSC_AssertAlignx(16,ai(1))
119:       PETSC_AssertAlignx(16,aj(1))
120:       PETSC_AssertAlignx(16,adiag(1))

122:       x(0) = b(0)
123:       x(1) = b(1)
124:       x(2) = b(2)
125:       x(3) = b(3)
126:       idx  = 0
127:       do 20 i=1,n-1
128: !
129: !        Pack required part of vector into work array
130: !
131:          kdx    = 0
132:          jstart = ai(i)
133:          jend   = adiag(i) - 1
134:          if (jend - jstart .ge. 500) then
135:            write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
136:          endif
137:          do 30 j=jstart,jend
138: 
139:            jdx       = 4*aj(j)
140: 
141:            w(kdx)    = x(jdx)
142:            w(kdx+1)  = x(jdx+1)
143:            w(kdx+2)  = x(jdx+2)
144:            w(kdx+3)  = x(jdx+3)
145:            kdx       = kdx + 4
146:  30      continue

148:          ax      = 16*jstart
149:          idx      = idx + 4
150:          s(0)     = b(idx)
151:          s(1)     = b(idx+1)
152:          s(2)     = b(idx+2)
153:          s(3)     = b(idx+3)
154: !
155: !    s = s - a(ax:)*w
156: !
157:          call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)
158: !         call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)

160:          x(idx)   = s(0)
161:          x(idx+1) = s(1)
162:          x(idx+2) = s(2)
163:          x(idx+3) = s(3)
164:  20   continue
165: 
166: !
167: !     Backward solve the upper triangular
168: !
169:       do 40 i=n-1,0,-1
170:          jstart    = adiag(i) + 1
171:          jend      = ai(i+1) - 1
172:          ax       = 16*jstart
173:          s(0)      = x(idx)
174:          s(1)      = x(idx+1)
175:          s(2)      = x(idx+2)
176:          s(3)      = x(idx+3)
177: !
178: !   Pack each chunk of vector needed
179: !
180:          kdx = 0
181:          if (jend - jstart .ge. 500) then
182:            write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
183:          endif
184:          do 50 j=jstart,jend
185:            jdx      = 4*aj(j)
186:            w(kdx)   = x(jdx)
187:            w(kdx+1) = x(jdx+1)
188:            w(kdx+2) = x(jdx+2)
189:            w(kdx+3) = x(jdx+3)
190:            kdx      = kdx + 4
191:  50      continue
192: !         call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)
193:          call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)

195:          ax      = 16*adiag(i)
196:          x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
197:          x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
198:          x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
199:          x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
200:          idx     = idx - 4
201:  40   continue

203:       return
204:       end
205: 

207: !
208: !   version that does not call BLAS 2 operation for each row block
209: !
210:       subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
211:       implicit none
212:       MatScalar   a(0:*)
213:       PetscScalar x(0:*),b(0:*),w(0:*)
214:       PetscInt  n,ai(0:*),aj(0:*),adiag(0:*)
215:       PetscInt  ii,jj,i,j

217:       PetscInt  jstart,jend,idx,ax,jdx,kdx,nn
218:       PetscScalar s(0:3)

220: !
221: !     Forward Solve
222: !

224:       PETSC_AssertAlignx(16,a(1))
225:       PETSC_AssertAlignx(16,w(1))
226:       PETSC_AssertAlignx(16,x(1))
227:       PETSC_AssertAlignx(16,b(1))
228:       PETSC_AssertAlignx(16,ai(1))
229:       PETSC_AssertAlignx(16,aj(1))
230:       PETSC_AssertAlignx(16,adiag(1))

232:       x(0) = b(0)
233:       x(1) = b(1)
234:       x(2) = b(2)
235:       x(3) = b(3)
236:       idx  = 0
237:       do 20 i=1,n-1
238: !
239: !        Pack required part of vector into work array
240: !
241:          kdx    = 0
242:          jstart = ai(i)
243:          jend   = adiag(i) - 1
244:          if (jend - jstart .ge. 500) then
245:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
246:          endif
247:          do 30 j=jstart,jend
248: 
249:            jdx       = 4*aj(j)
250: 
251:            w(kdx)    = x(jdx)
252:            w(kdx+1)  = x(jdx+1)
253:            w(kdx+2)  = x(jdx+2)
254:            w(kdx+3)  = x(jdx+3)
255:            kdx       = kdx + 4
256:  30      continue

258:          ax       = 16*jstart
259:          idx      = idx + 4
260:          s(0)     = b(idx)
261:          s(1)     = b(idx+1)
262:          s(2)     = b(idx+2)
263:          s(3)     = b(idx+3)
264: !
265: !    s = s - a(ax:)*w
266: !
267:          nn = 4*(jend - jstart + 1) - 1
268:          do 100, ii=0,3
269:            do 110, jj=0,nn
270:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
271:  110       continue
272:  100     continue

274:          x(idx)   = s(0)
275:          x(idx+1) = s(1)
276:          x(idx+2) = s(2)
277:          x(idx+3) = s(3)
278:  20   continue
279: 
280: !
281: !     Backward solve the upper triangular
282: !
283:       do 40 i=n-1,0,-1
284:          jstart    = adiag(i) + 1
285:          jend      = ai(i+1) - 1
286:          ax        = 16*jstart
287:          s(0)      = x(idx)
288:          s(1)      = x(idx+1)
289:          s(2)      = x(idx+2)
290:          s(3)      = x(idx+3)
291: !
292: !   Pack each chunk of vector needed
293: !
294:          kdx = 0
295:          if (jend - jstart .ge. 500) then
296:            write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
297:          endif
298:          do 50 j=jstart,jend
299:            jdx      = 4*aj(j)
300:            w(kdx)   = x(jdx)
301:            w(kdx+1) = x(jdx+1)
302:            w(kdx+2) = x(jdx+2)
303:            w(kdx+3) = x(jdx+3)
304:            kdx      = kdx + 4
305:  50      continue
306:          nn = 4*(jend - jstart + 1) - 1
307:          do 200, ii=0,3
308:            do 210, jj=0,nn
309:              s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
310:  210       continue
311:  200     continue

313:          ax      = 16*adiag(i)
314:          x(idx)  = a(ax)*s(0)  +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
315:          x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
316:          x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
317:          x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
318:          idx     = idx - 4
319:  40   continue

321:       return
322:       end
323: