1 #ifndef VIENNACL_LINALG_OPENCL_KERNELS_SVD_HPP
2 #define VIENNACL_LINALG_OPENCL_KERNELS_SVD_HPP
20 template <
typename StringType>
23 source.append(
"__kernel void bidiag_pack(__global "); source.append(numeric_string); source.append(
"* A, \n");
24 source.append(
" __global "); source.append(numeric_string); source.append(
"* D, \n");
25 source.append(
" __global "); source.append(numeric_string); source.append(
"* S, \n");
26 source.append(
" uint size1, \n");
27 source.append(
" uint size2, \n");
28 source.append(
" uint stride \n");
29 source.append(
") { \n");
30 source.append(
" uint size = min(size1, size2); \n");
32 source.append(
" if(get_global_id(0) == 0) \n");
33 source.append(
" S[0] = 0; \n");
36 source.append(
" for(uint i = get_global_id(0); i < size ; i += get_global_size(0)) { \n");
37 source.append(
" D[i] = A[i*stride + i]; \n");
38 source.append(
" S[i + 1] = (i + 1 < size2) ? A[i*stride + (i + 1)] : 0; \n");
42 source.append(
" for(uint i = get_global_id(0); i < size ; i += get_global_size(0)) { \n");
43 source.append(
" D[i] = A[i*stride + i]; \n");
44 source.append(
" S[i + 1] = (i + 1 < size2) ? A[i + (i + 1) * stride] : 0; \n");
46 source.append(
" } \n");
47 source.append(
"} \n");
50 template<
typename StringT>
54 source.append(
"void col_reduce_lcl_array(__local "); source.append(numeric_string); source.append(
"* sums, uint lcl_id, uint lcl_sz) { \n");
55 source.append(
" uint step = lcl_sz >> 1; \n");
57 source.append(
" while (step > 0) { \n");
58 source.append(
" if (lcl_id < step) { \n");
59 source.append(
" sums[lcl_id] += sums[lcl_id + step]; \n");
60 source.append(
" } \n");
61 source.append(
" step >>= 1; \n");
62 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
63 source.append(
" } \n");
64 source.append(
"} \n");
67 template <
typename StringType>
71 source.append(
"__kernel void copy_col(__global "); source.append(numeric_string); source.append(
"* A, \n");
72 source.append(
" __global "); source.append(numeric_string); source.append(
"* V, \n");
73 source.append(
" uint row_start, \n");
74 source.append(
" uint col_start, \n");
75 source.append(
" uint size, \n");
76 source.append(
" uint stride \n");
77 source.append(
" ) { \n");
78 source.append(
" uint glb_id = get_global_id(0); \n");
79 source.append(
" uint glb_sz = get_global_size(0); \n");
82 source.append(
" for(uint i = row_start + glb_id; i < size; i += glb_sz) { \n");
83 source.append(
" V[i - row_start] = A[i * stride + col_start]; \n");
84 source.append(
" } \n");
88 source.append(
" for(uint i = row_start + glb_id; i < size; i += glb_sz) { \n");
89 source.append(
" V[i - row_start] = A[i + col_start * stride]; \n");
90 source.append(
" } \n");
93 source.append(
"} \n");
96 template <
typename StringType>
100 source.append(
"__kernel void copy_row(__global "); source.append(numeric_string); source.append(
"* A, \n");
101 source.append(
" __global "); source.append(numeric_string); source.append(
"* V, \n");
102 source.append(
" uint row_start, \n");
103 source.append(
" uint col_start, \n");
104 source.append(
" uint size, \n");
105 source.append(
" uint stride \n");
106 source.append(
" ) { \n");
107 source.append(
" uint glb_id = get_global_id(0); \n");
108 source.append(
" uint glb_sz = get_global_size(0); \n");
111 source.append(
" for(uint i = col_start + glb_id; i < size; i += glb_sz) { \n");
112 source.append(
" V[i - col_start] = A[row_start * stride + i]; \n");
113 source.append(
" } \n");
117 source.append(
" for(uint i = col_start + glb_id; i < size; i += glb_sz) { \n");
118 source.append(
" V[i - col_start] = A[row_start + i * stride]; \n");
119 source.append(
" } \n");
122 source.append(
"} \n");
125 template<
typename StringT>
128 source.append(
"__kernel void final_iter_update(__global "); source.append(numeric_string); source.append(
"* A, \n");
129 source.append(
" uint stride, \n");
130 source.append(
" uint n, \n");
131 source.append(
" uint last_n, \n");
132 source.append(
" "); source.append(numeric_string); source.append(
" q, \n");
133 source.append(
" "); source.append(numeric_string); source.append(
" p \n");
134 source.append(
" ) \n");
135 source.append(
"{ \n");
136 source.append(
" uint glb_id = get_global_id(0); \n");
137 source.append(
" uint glb_sz = get_global_size(0); \n");
139 source.append(
" for (uint px = glb_id; px < last_n; px += glb_sz) \n");
140 source.append(
" { \n");
141 source.append(
" "); source.append(numeric_string); source.append(
" v_in = A[n * stride + px]; \n");
142 source.append(
" "); source.append(numeric_string); source.append(
" z = A[(n - 1) * stride + px]; \n");
143 source.append(
" A[(n - 1) * stride + px] = q * z + p * v_in; \n");
144 source.append(
" A[n * stride + px] = q * v_in - p * z; \n");
145 source.append(
" } \n");
146 source.append(
"} \n");
149 template <
typename StringType>
152 source.append(
"__kernel void givens_next(__global "); source.append(numeric_string); source.append(
"* matr, \n");
153 source.append(
" __global "); source.append(numeric_string); source.append(
"* cs, \n");
154 source.append(
" __global "); source.append(numeric_string); source.append(
"* ss, \n");
155 source.append(
" uint size, \n");
156 source.append(
" uint stride, \n");
157 source.append(
" uint start_i, \n");
158 source.append(
" uint end_i \n");
159 source.append(
" ) \n");
160 source.append(
"{ \n");
161 source.append(
" uint glb_id = get_global_id(0); \n");
162 source.append(
" uint glb_sz = get_global_size(0); \n");
164 source.append(
" uint lcl_id = get_local_id(0); \n");
165 source.append(
" uint lcl_sz = get_local_size(0); \n");
167 source.append(
" uint j = glb_id; \n");
169 source.append(
" __local "); source.append(numeric_string); source.append(
" cs_lcl[256]; \n");
170 source.append(
" __local "); source.append(numeric_string); source.append(
" ss_lcl[256]; \n");
174 source.append(
" "); source.append(numeric_string); source.append(
" x = (j < size) ? matr[(end_i + 1) + j * stride] : 0; \n");
176 source.append(
" uint elems_num = end_i - start_i + 1; \n");
177 source.append(
" uint block_num = (elems_num + lcl_sz - 1) / lcl_sz; \n");
179 source.append(
" for(uint block_id = 0; block_id < block_num; block_id++) \n");
180 source.append(
" { \n");
181 source.append(
" uint to = min(elems_num - block_id * lcl_sz, lcl_sz); \n");
183 source.append(
" if(lcl_id < to) \n");
184 source.append(
" { \n");
185 source.append(
" cs_lcl[lcl_id] = cs[end_i - (lcl_id + block_id * lcl_sz)]; \n");
186 source.append(
" ss_lcl[lcl_id] = ss[end_i - (lcl_id + block_id * lcl_sz)]; \n");
187 source.append(
" } \n");
189 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
191 source.append(
" if(j < size) \n");
192 source.append(
" { \n");
193 source.append(
" for(uint ind = 0; ind < to; ind++) \n");
194 source.append(
" { \n");
195 source.append(
" uint i = end_i - (ind + block_id * lcl_sz); \n");
197 source.append(
" "); source.append(numeric_string); source.append(
" z = matr[i + j * stride]; \n");
199 source.append(
" "); source.append(numeric_string); source.append(
" cs_val = cs_lcl[ind]; \n");
200 source.append(
" "); source.append(numeric_string); source.append(
" ss_val = ss_lcl[ind]; \n");
202 source.append(
" matr[(i + 1) + j * stride] = x * cs_val + z * ss_val; \n");
203 source.append(
" x = -x * ss_val + z * cs_val; \n");
204 source.append(
" } \n");
205 source.append(
" } \n");
206 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
207 source.append(
" } \n");
208 source.append(
" if(j < size) \n");
209 source.append(
" matr[(start_i) + j * stride] = x; \n");
214 source.append(
" "); source.append(numeric_string); source.append(
" x = (j < size) ? matr[(end_i + 1) * stride + j] : 0; \n");
216 source.append(
" uint elems_num = end_i - start_i + 1; \n");
217 source.append(
" uint block_num = (elems_num + lcl_sz - 1) / lcl_sz; \n");
219 source.append(
" for(uint block_id = 0; block_id < block_num; block_id++) \n");
220 source.append(
" { \n");
221 source.append(
" uint to = min(elems_num - block_id * lcl_sz, lcl_sz); \n");
223 source.append(
" if(lcl_id < to) \n");
224 source.append(
" { \n");
225 source.append(
" cs_lcl[lcl_id] = cs[end_i - (lcl_id + block_id * lcl_sz)]; \n");
226 source.append(
" ss_lcl[lcl_id] = ss[end_i - (lcl_id + block_id * lcl_sz)]; \n");
227 source.append(
" } \n");
229 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
231 source.append(
" if(j < size) \n");
232 source.append(
" { \n");
233 source.append(
" for(uint ind = 0; ind < to; ind++) \n");
234 source.append(
" { \n");
235 source.append(
" uint i = end_i - (ind + block_id * lcl_sz); \n");
237 source.append(
" "); source.append(numeric_string); source.append(
" z = matr[i * stride + j]; \n");
239 source.append(
" "); source.append(numeric_string); source.append(
" cs_val = cs_lcl[ind]; \n");
240 source.append(
" "); source.append(numeric_string); source.append(
" ss_val = ss_lcl[ind]; \n");
242 source.append(
" matr[(i + 1) * stride + j] = x * cs_val + z * ss_val; \n");
243 source.append(
" x = -x * ss_val + z * cs_val; \n");
244 source.append(
" } \n");
245 source.append(
" } \n");
246 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
247 source.append(
" } \n");
248 source.append(
" if(j < size) \n");
249 source.append(
" matr[(start_i) * stride + j] = x; \n");
251 source.append(
"} \n");
254 template<
typename StringT>
257 source.append(
"__kernel void givens_prev(__global "); source.append(numeric_string); source.append(
"* matr, \n");
258 source.append(
" __global "); source.append(numeric_string); source.append(
"* cs, \n");
259 source.append(
" __global "); source.append(numeric_string); source.append(
"* ss, \n");
260 source.append(
" uint size, \n");
261 source.append(
" uint stride, \n");
262 source.append(
" uint start_i, \n");
263 source.append(
" uint end_i \n");
264 source.append(
" ) \n");
265 source.append(
"{ \n");
266 source.append(
" uint glb_id = get_global_id(0); \n");
267 source.append(
" uint glb_sz = get_global_size(0); \n");
269 source.append(
" uint lcl_id = get_local_id(0); \n");
270 source.append(
" uint lcl_sz = get_local_size(0); \n");
272 source.append(
" uint j = glb_id; \n");
274 source.append(
" __local "); source.append(numeric_string); source.append(
" cs_lcl[256]; \n");
275 source.append(
" __local "); source.append(numeric_string); source.append(
" ss_lcl[256]; \n");
277 source.append(
" "); source.append(numeric_string); source.append(
" x = (j < size) ? matr[(start_i - 1) * stride + j] : 0; \n");
279 source.append(
" uint elems_num = end_i - start_i; \n");
280 source.append(
" uint block_num = (elems_num + lcl_sz - 1) / lcl_sz; \n");
282 source.append(
" for (uint block_id = 0; block_id < block_num; block_id++) \n");
283 source.append(
" { \n");
284 source.append(
" uint to = min(elems_num - block_id * lcl_sz, lcl_sz); \n");
286 source.append(
" if (lcl_id < to) \n");
287 source.append(
" { \n");
288 source.append(
" cs_lcl[lcl_id] = cs[lcl_id + start_i + block_id * lcl_sz]; \n");
289 source.append(
" ss_lcl[lcl_id] = ss[lcl_id + start_i + block_id * lcl_sz]; \n");
290 source.append(
" } \n");
292 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
294 source.append(
" if (j < size) \n");
295 source.append(
" { \n");
296 source.append(
" for (uint ind = 0; ind < to; ind++) \n");
297 source.append(
" { \n");
298 source.append(
" uint i = ind + start_i + block_id * lcl_sz; \n");
300 source.append(
" "); source.append(numeric_string); source.append(
" z = matr[i * stride + j]; \n");
302 source.append(
" "); source.append(numeric_string); source.append(
" cs_val = cs_lcl[ind];//cs[i]; \n");
303 source.append(
" "); source.append(numeric_string); source.append(
" ss_val = ss_lcl[ind];//ss[i]; \n");
305 source.append(
" matr[(i - 1) * stride + j] = x * cs_val + z * ss_val; \n");
306 source.append(
" x = -x * ss_val + z * cs_val; \n");
307 source.append(
" } \n");
308 source.append(
" } \n");
309 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
310 source.append(
" } \n");
311 source.append(
" if (j < size) \n");
312 source.append(
" matr[(end_i - 1) * stride + j] = x; \n");
313 source.append(
"} \n");
316 template <
typename StringType>
319 source.append(
"__kernel void house_update_A_left( \n");
320 source.append(
" __global "); source.append(numeric_string); source.append(
"* A, \n");
321 source.append(
" __constant "); source.append(numeric_string); source.append(
"* V, \n");
322 source.append(
" uint row_start, \n");
323 source.append(
" uint col_start, \n");
324 source.append(
" uint size1, \n");
325 source.append(
" uint size2, \n");
326 source.append(
" uint stride, \n");
327 source.append(
" __local "); source.append(numeric_string); source.append(
"* sums \n");
328 source.append(
" ) { \n");
329 source.append(
" uint glb_id = get_global_id(0); \n");
330 source.append(
" uint glb_sz = get_global_size(0); \n");
332 source.append(
" uint grp_id = get_group_id(0); \n");
333 source.append(
" uint grp_nm = get_num_groups(0); \n");
335 source.append(
" uint lcl_id = get_local_id(0); \n");
336 source.append(
" uint lcl_sz = get_local_size(0); \n");
338 source.append(
" "); source.append(numeric_string); source.append(
" ss = 0; \n");
343 source.append(
" for(uint i = glb_id + col_start; i < size2; i += glb_sz) { \n");
344 source.append(
" ss = 0; \n");
345 source.append(
" for(uint j = row_start; j < size1; j++) ss = ss + (V[j] * A[j * stride + i]); \n");
347 source.append(
" for(uint j = row_start; j < size1; j++) \n");
348 source.append(
" A[j * stride + i] = A[j * stride + i] - (2 * V[j] * ss); \n");
349 source.append(
" } \n");
353 source.append(
" for(uint i = glb_id + col_start; i < size2; i += glb_sz) { \n");
354 source.append(
" ss = 0; \n");
355 source.append(
" for(uint j = row_start; j < size1; j++) ss = ss + (V[j] * A[j + i * stride]); \n");
357 source.append(
" for(uint j = row_start; j < size1; j++) \n");
358 source.append(
" A[j + i * stride] = A[j + i * stride] - (2 * V[j] * ss); \n");
359 source.append(
" } \n");
361 source.append(
"} \n");
364 template <
typename StringType>
368 source.append(
"__kernel void house_update_A_right( \n");
369 source.append(
" __global "); source.append(numeric_string); source.append(
"* A, \n");
370 source.append(
" __global "); source.append(numeric_string); source.append(
"* V, \n");
371 source.append(
" uint row_start, \n");
372 source.append(
" uint col_start, \n");
373 source.append(
" uint size1, \n");
374 source.append(
" uint size2, \n");
375 source.append(
" uint stride, \n");
376 source.append(
" __local "); source.append(numeric_string); source.append(
"* sums \n");
377 source.append(
" ) { \n");
379 source.append(
" uint glb_id = get_global_id(0); \n");
381 source.append(
" uint grp_id = get_group_id(0); \n");
382 source.append(
" uint grp_nm = get_num_groups(0); \n");
384 source.append(
" uint lcl_id = get_local_id(0); \n");
385 source.append(
" uint lcl_sz = get_local_size(0); \n");
387 source.append(
" "); source.append(numeric_string); source.append(
" ss = 0; \n");
392 source.append(
" for(uint i = grp_id + row_start; i < size1; i += grp_nm) { \n");
393 source.append(
" ss = 0; \n");
395 source.append(
" for(uint j = lcl_id; j < size2; j += lcl_sz) ss = ss + (V[j] * A[i * stride + j]); \n");
396 source.append(
" sums[lcl_id] = ss; \n");
398 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
399 source.append(
" col_reduce_lcl_array(sums, lcl_id, lcl_sz); \n");
400 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
402 source.append(
" "); source.append(numeric_string); source.append(
" sum_Av = sums[0]; \n");
404 source.append(
" for(uint j = lcl_id; j < size2; j += lcl_sz) \n");
405 source.append(
" A[i * stride + j] = A[i * stride + j] - (2 * V[j] * sum_Av); \n");
406 source.append(
" } \n");
410 source.append(
" for(uint i = grp_id + row_start; i < size1; i += grp_nm) { \n");
411 source.append(
" ss = 0; \n");
413 source.append(
" for(uint j = lcl_id; j < size2; j += lcl_sz) ss = ss + (V[j] * A[i + j * stride]); \n");
414 source.append(
" sums[lcl_id] = ss; \n");
416 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
417 source.append(
" col_reduce_lcl_array(sums, lcl_id, lcl_sz); \n");
418 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
420 source.append(
" "); source.append(numeric_string); source.append(
" sum_Av = sums[0]; \n");
422 source.append(
" for(uint j = lcl_id; j < size2; j += lcl_sz) \n");
423 source.append(
" A[i + j * stride] = A[i + j * stride] - (2 * V[j] * sum_Av); \n");
424 source.append(
" } \n");
427 source.append(
"} \n");
431 template <
typename StringType>
434 source.append(
"__kernel void house_update_QL(\n");
435 source.append(
" __global "); source.append(numeric_string); source.append(
"* QL, \n");
436 source.append(
" __constant "); source.append(numeric_string); source.append(
"* V, \n");
437 source.append(
" uint size1, \n");
438 source.append(
" uint strideQ, \n");
439 source.append(
" __local "); source.append(numeric_string); source.append(
"* sums \n");
440 source.append(
" ) { \n");
441 source.append(
" uint glb_id = get_global_id(0); \n");
442 source.append(
" uint glb_sz = get_global_size(0); \n");
444 source.append(
" uint grp_id = get_group_id(0); \n");
445 source.append(
" uint grp_nm = get_num_groups(0); \n");
447 source.append(
" uint lcl_id = get_local_id(0); \n");
448 source.append(
" uint lcl_sz = get_local_size(0); \n");
450 source.append(
" "); source.append(numeric_string); source.append(
" ss = 0; \n");
454 source.append(
" for(uint i = grp_id; i < size1; i += grp_nm) { \n");
455 source.append(
" ss = 0; \n");
456 source.append(
" for(uint j = lcl_id; j < size1; j += lcl_sz) ss = ss + (V[j] * QL[i * strideQ + j]); \n");
457 source.append(
" sums[lcl_id] = ss; \n");
459 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
460 source.append(
" col_reduce_lcl_array(sums, lcl_id, lcl_sz); \n");
461 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
463 source.append(
" "); source.append(numeric_string); source.append(
" sum_Qv = sums[0]; \n");
465 source.append(
" for(uint j = lcl_id; j < size1; j += lcl_sz) \n");
466 source.append(
" QL[i * strideQ + j] = QL[i * strideQ + j] - (2 * V[j] * sum_Qv); \n");
467 source.append(
" } \n");
471 source.append(
" for(uint i = grp_id; i < size1; i += grp_nm) { \n");
472 source.append(
" ss = 0; \n");
473 source.append(
" for(uint j = lcl_id; j < size1; j += lcl_sz) ss = ss + (V[j] * QL[i + j * strideQ]); \n");
474 source.append(
" sums[lcl_id] = ss; \n");
476 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
477 source.append(
" col_reduce_lcl_array(sums, lcl_id, lcl_sz); \n");
478 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
480 source.append(
" "); source.append(numeric_string); source.append(
" sum_Qv = sums[0]; \n");
482 source.append(
" for(uint j = lcl_id; j < size1; j += lcl_sz) \n");
483 source.append(
" QL[i + j * strideQ] = QL[i + j * strideQ] - (2 * V[j] * sum_Qv); \n");
484 source.append(
" } \n");
486 source.append(
"} \n");
490 template<
typename StringT>
493 source.append(
"__kernel void house_update_QR( \n");
494 source.append(
" __global "); source.append(numeric_string); source.append(
"* QR, \n");
495 source.append(
" __global "); source.append(numeric_string); source.append(
"* V, \n");
496 source.append(
" uint size1, \n");
497 source.append(
" uint size2, \n");
498 source.append(
" uint strideQ, \n");
499 source.append(
" __local "); source.append(numeric_string); source.append(
"* sums \n");
500 source.append(
" ) { \n");
502 source.append(
" uint glb_id = get_global_id(0); \n");
504 source.append(
" uint grp_id = get_group_id(0); \n");
505 source.append(
" uint grp_nm = get_num_groups(0); \n");
507 source.append(
" uint lcl_id = get_local_id(0); \n");
508 source.append(
" uint lcl_sz = get_local_size(0); \n");
510 source.append(
" "); source.append(numeric_string); source.append(
" ss = 0; \n");
515 source.append(
" for (uint i = grp_id; i < size2; i += grp_nm) { \n");
516 source.append(
" ss = 0; \n");
517 source.append(
" for (uint j = lcl_id; j < size2; j += lcl_sz) ss = ss + (V[j] * QR[i * strideQ + j]); \n");
518 source.append(
" sums[lcl_id] = ss; \n");
520 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
521 source.append(
" col_reduce_lcl_array(sums, lcl_id, lcl_sz); \n");
522 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
524 source.append(
" "); source.append(numeric_string); source.append(
" sum_Qv = sums[0]; \n");
525 source.append(
" for (uint j = lcl_id; j < size2; j += lcl_sz) \n");
526 source.append(
" QR[i * strideQ + j] = QR[i * strideQ + j] - (2 * V[j] * sum_Qv); \n");
527 source.append(
" } \n");
528 source.append(
"} \n");
531 template<
typename StringT>
534 source.append(
"__kernel void inverse_signs(__global "); source.append(numeric_string); source.append(
"* v, \n");
535 source.append(
" __global "); source.append(numeric_string); source.append(
"* signs, \n");
536 source.append(
" uint size, \n");
537 source.append(
" uint stride \n");
538 source.append(
" ) \n");
539 source.append(
"{ \n");
540 source.append(
" uint glb_id_x = get_global_id(0); \n");
541 source.append(
" uint glb_id_y = get_global_id(1); \n");
543 source.append(
" if ((glb_id_x < size) && (glb_id_y < size)) \n");
544 source.append(
" v[glb_id_x * stride + glb_id_y] *= signs[glb_id_x]; \n");
545 source.append(
"} \n");
549 template<
typename StringT>
553 source.append(
"__kernel void transpose_inplace(__global "); source.append(numeric_string); source.append(
"* input, \n");
554 source.append(
" unsigned int row_num, \n");
555 source.append(
" unsigned int col_num) { \n");
556 source.append(
" unsigned int size = row_num * col_num; \n");
557 source.append(
" for (unsigned int i = get_global_id(0); i < size; i+= get_global_size(0)) { \n");
558 source.append(
" unsigned int row = i / col_num; \n");
559 source.append(
" unsigned int col = i - row*col_num; \n");
561 source.append(
" unsigned int new_pos = col * row_num + row; \n");
566 source.append(
" if (i < new_pos) { \n");
567 source.append(
" "); source.append(numeric_string); source.append(
" val = input[i]; \n");
568 source.append(
" input[i] = input[new_pos]; \n");
569 source.append(
" input[new_pos] = val; \n");
570 source.append(
" } \n");
571 source.append(
" } \n");
572 source.append(
"} \n");
576 template<
typename StringT>
579 source.append(
"__kernel void update_qr_column(__global "); source.append(numeric_string); source.append(
"* A, \n");
580 source.append(
" uint stride, \n");
581 source.append(
" __global "); source.append(numeric_string); source.append(
"* buf, \n");
582 source.append(
" int m, \n");
583 source.append(
" int n, \n");
584 source.append(
" int last_n) \n");
585 source.append(
"{ \n");
586 source.append(
" uint glb_id = get_global_id(0); \n");
587 source.append(
" uint glb_sz = get_global_size(0); \n");
589 source.append(
" for (int i = glb_id; i < last_n; i += glb_sz) \n");
590 source.append(
" { \n");
591 source.append(
" "); source.append(numeric_string); source.append(
" a_ik = A[m * stride + i], a_ik_1, a_ik_2; \n");
593 source.append(
" a_ik_1 = A[(m + 1) * stride + i]; \n");
595 source.append(
" for (int k = m; k < n; k++) \n");
596 source.append(
" { \n");
597 source.append(
" bool notlast = (k != n - 1); \n");
599 source.append(
" "); source.append(numeric_string); source.append(
" p = buf[5 * k] * a_ik + buf[5 * k + 1] * a_ik_1; \n");
601 source.append(
" if (notlast) \n");
602 source.append(
" { \n");
603 source.append(
" a_ik_2 = A[(k + 2) * stride + i]; \n");
604 source.append(
" p = p + buf[5 * k + 2] * a_ik_2; \n");
605 source.append(
" a_ik_2 = a_ik_2 - p * buf[5 * k + 4]; \n");
606 source.append(
" } \n");
608 source.append(
" A[k * stride + i] = a_ik - p; \n");
609 source.append(
" a_ik_1 = a_ik_1 - p * buf[5 * k + 3]; \n");
611 source.append(
" a_ik = a_ik_1; \n");
612 source.append(
" a_ik_1 = a_ik_2; \n");
613 source.append(
" } \n");
615 source.append(
" A[n * stride + i] = a_ik; \n");
616 source.append(
" } \n");
618 source.append(
"} \n");
623 template <
typename StringType>
626 source.append(
"#define SECTION_SIZE 256\n");
627 source.append(
"__kernel void inclusive_scan_1(__global "); source.append(numeric_string); source.append(
"* X, \n");
628 source.append(
" uint startX, \n");
629 source.append(
" uint incX, \n");
630 source.append(
" uint InputSize, \n");
632 source.append(
" __global "); source.append(numeric_string); source.append(
"* Y, \n");
633 source.append(
" uint startY, \n");
634 source.append(
" uint incY, \n");
636 source.append(
" __global "); source.append(numeric_string); source.append(
"* S, \n");
637 source.append(
" uint startS, \n");
638 source.append(
" uint incS) \n");
640 source.append(
"{ \n");
641 source.append(
" uint glb_id = get_global_id(0); \n");
643 source.append(
" uint grp_id = get_group_id(0); \n");
644 source.append(
" uint grp_nm = get_num_groups(0); \n");
646 source.append(
" uint lcl_id = get_local_id(0); \n");
647 source.append(
" uint lcl_sz = get_local_size(0); \n");
648 source.append(
" __local "); source.append(numeric_string); source.append(
" XY[SECTION_SIZE]; \n");
650 source.append(
" if(glb_id < InputSize) \n");
651 source.append(
" XY[lcl_id] = X[glb_id * incX + startX]; \n");
652 source.append(
" \n");
654 source.append(
" for(uint stride = 1; stride < lcl_sz; stride *= 2) \n");
655 source.append(
" { \n");
656 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
657 source.append(
" int index = (lcl_id + 1) * 2 * stride - 1; \n");
658 source.append(
" if(index < lcl_sz) \n");
659 source.append(
" XY[index] += XY[index - stride]; \n");
660 source.append(
" } \n");
662 source.append(
" for(int stride = SECTION_SIZE / 4; stride > 0; stride /= 2) \n");
663 source.append(
" { \n");
664 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
665 source.append(
" int index = (lcl_id + 1) * 2 * stride - 1; \n");
666 source.append(
" if(index + stride < lcl_sz) \n");
667 source.append(
" XY[index + stride] += XY[index]; \n");
668 source.append(
" } \n");
670 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
671 source.append(
" Y[glb_id * incY + startY] = XY[lcl_id]; \n");
672 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
673 source.append(
" if(lcl_id == 0) \n");
674 source.append(
" { \n");
675 source.append(
" S[grp_id * incS + startS] = XY[SECTION_SIZE - 1]; \n");
676 source.append(
" } \n");
677 source.append(
"} \n");
680 template <
typename StringType>
683 source.append(
"__kernel void exclusive_scan_1(__global "); source.append(numeric_string); source.append(
"* X, \n");
684 source.append(
" uint startX, \n");
685 source.append(
" uint incX, \n");
686 source.append(
" uint InputSize, \n");
688 source.append(
" __global "); source.append(numeric_string); source.append(
"* Y, \n");
689 source.append(
" uint startY, \n");
690 source.append(
" uint incY, \n");
692 source.append(
" __global "); source.append(numeric_string); source.append(
"* S, \n");
693 source.append(
" uint startS, \n");
694 source.append(
" uint incS) \n");
696 source.append(
"{ \n");
697 source.append(
" uint glb_id = get_global_id(0); \n");
699 source.append(
" uint grp_id = get_group_id(0); \n");
700 source.append(
" uint grp_nm = get_num_groups(0); \n");
702 source.append(
" uint lcl_id = get_local_id(0); \n");
703 source.append(
" uint lcl_sz = get_local_size(0); \n");
704 source.append(
" __local "); source.append(numeric_string); source.append(
" XY[SECTION_SIZE]; \n");
706 source.append(
" if(glb_id < InputSize + 1 && glb_id != 0) \n");
707 source.append(
" XY[lcl_id] = X[(glb_id - 1) * incX + startX]; \n");
708 source.append(
" if(glb_id == 0) \n");
709 source.append(
" XY[0] = 0; \n");
710 source.append(
" \n");
712 source.append(
" for(uint stride = 1; stride < lcl_sz; stride *= 2) \n");
713 source.append(
" { \n");
714 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
715 source.append(
" int index = (lcl_id + 1) * 2 * stride - 1; \n");
716 source.append(
" if(index < lcl_sz) \n");
717 source.append(
" XY[index] += XY[index - stride]; \n");
718 source.append(
" } \n");
720 source.append(
" for(int stride = SECTION_SIZE / 4; stride > 0; stride /= 2) \n");
721 source.append(
" { \n");
722 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
723 source.append(
" int index = (lcl_id + 1) * 2 * stride - 1; \n");
724 source.append(
" if(index + stride < lcl_sz) \n");
725 source.append(
" XY[index + stride] += XY[index]; \n");
726 source.append(
" } \n");
727 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
729 source.append(
" Y[glb_id * incY + startY] = XY[lcl_id]; \n");
730 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
731 source.append(
" if(lcl_id == 0) \n");
732 source.append(
" { \n");
733 source.append(
" S[grp_id * incS + startS] = XY[SECTION_SIZE - 1]; \n");
734 source.append(
" } \n");
735 source.append(
"} \n");
738 template <
typename StringType>
740 { source.append(
"__kernel void scan_kernel_2(__global "); source.append(numeric_string); source.append(
"* S_ref, \n");
741 source.append(
" uint startS_ref, \n");
742 source.append(
" uint incS_ref, \n");
744 source.append(
" __global "); source.append(numeric_string); source.append(
"* S, \n");
745 source.append(
" uint startS, \n");
746 source.append(
" uint incS, \n");
747 source.append(
" uint InputSize) \n");
749 source.append(
" { \n");
750 source.append(
" uint glb_id = get_global_id(0); \n");
752 source.append(
" uint grp_id = get_group_id(0); \n");
753 source.append(
" uint grp_nm = get_num_groups(0); \n");
755 source.append(
" uint lcl_id = get_local_id(0); \n");
756 source.append(
" uint lcl_sz = get_local_size(0); \n");
757 source.append(
" __local "); source.append(numeric_string); source.append(
" XY[SECTION_SIZE]; \n");
759 source.append(
" if(glb_id < InputSize) \n");
760 source.append(
" XY[lcl_id] = S[glb_id * incS + startS]; \n");
762 source.append(
" for(uint stride = 1; stride < lcl_sz; stride *= 2) \n");
763 source.append(
" { \n");
764 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
765 source.append(
" int index = (lcl_id + 1) * 2 * stride - 1; \n");
766 source.append(
" if(index < lcl_sz) \n");
767 source.append(
" XY[index] += XY[index - stride]; \n");
768 source.append(
" } \n");
770 source.append(
" for(int stride = SECTION_SIZE / 4; stride > 0; stride /= 2) \n");
771 source.append(
" { \n");
772 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
773 source.append(
" int index = (lcl_id + 1) * 2 * stride - 1; \n");
774 source.append(
" if(index + stride < lcl_sz) \n");
775 source.append(
" XY[index + stride] += XY[index]; \n");
776 source.append(
" } \n");
778 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
779 source.append(
" if(glb_id < InputSize) \n");
780 source.append(
" { \n");
781 source.append(
" S[glb_id * incS + startS] = XY[lcl_id]; \n");
782 source.append(
" S_ref[glb_id * incS_ref + startS_ref] = XY[lcl_id]; \n");
783 source.append(
" } \n");
784 source.append(
" } \n");
788 template <
typename StringType>
790 { source.append(
"__kernel void scan_kernel_3(__global "); source.append(numeric_string); source.append(
"* S_ref, \n");
791 source.append(
" uint startS_ref, \n");
792 source.append(
" uint incS_ref, \n");
794 source.append(
" __global "); source.append(numeric_string); source.append(
"* S, \n");
795 source.append(
" uint startS, \n");
796 source.append(
" uint incS) \n");
798 source.append(
" { \n");
799 source.append(
" uint glb_id = get_global_id(0); \n");
801 source.append(
" uint grp_id = get_group_id(0); \n");
802 source.append(
" uint grp_nm = get_num_groups(0); \n");
804 source.append(
" uint lcl_id = get_local_id(0); \n");
805 source.append(
" uint lcl_sz = get_local_size(0); \n");
808 source.append(
" for(int j = 1; j <= grp_id; j++) \n");
809 source.append(
" S[glb_id * incS + startS] += S_ref[(j * lcl_sz - 1) * incS_ref + startS_ref]; \n");
810 source.append(
" } \n");
813 template <
typename StringType>
815 { source.append(
"__kernel void scan_kernel_4(__global "); source.append(numeric_string); source.append(
"* S, \n");
816 source.append(
" uint startS, \n");
817 source.append(
" uint incS, \n");
819 source.append(
" __global "); source.append(numeric_string); source.append(
"* Y, \n");
820 source.append(
" uint startY, \n");
821 source.append(
" uint incY, \n");
822 source.append(
" uint OutputSize) \n");
824 source.append(
" { \n");
825 source.append(
" barrier(CLK_LOCAL_MEM_FENCE); \n");
826 source.append(
" uint glb_id = get_global_id(0); \n");
828 source.append(
" uint grp_id = get_group_id(0); \n");
829 source.append(
" uint grp_nm = get_num_groups(0); \n");
831 source.append(
" uint lcl_id = get_local_id(0); \n");
832 source.append(
" uint lcl_sz = get_local_size(0); \n");
835 source.append(
" uint var = (grp_id + 1) * lcl_sz + lcl_id; \n");
836 source.append(
" if(var < OutputSize) \n");
837 source.append(
" Y[var * incY + startY] += S[grp_id * incS + startS]; \n");
838 source.append(
" } \n");
846 template<
typename NumericT,
typename MatrixLayout = row_major>
857 static std::map<cl_context, bool> init_done;
865 source.reserve(1024);
867 viennacl::ocl::append_double_precision_pragma<NumericT>(ctx, source);
870 if (numeric_string ==
"float" || numeric_string ==
"double")
897 #ifdef VIENNACL_BUILD_INFO
898 std::cout <<
"Creating program " << prog_name << std::endl;
900 ctx.add_program(source, prog_name);
901 init_done[ctx.handle().get()] =
true;
void generate_svd_inclusive_scan_kernel_1(StringType &source, std::string const &numeric_string)
void generate_svd_copy_row(StringType &source, std::string const &numeric_string, bool is_row_major)
Helper class for checking whether a matrix has a row-major layout.
void generate_svd_scan_kernel_3(StringType &source, std::string const &numeric_string)
void generate_svd_bidiag_pack(StringType &source, std::string const &numeric_string, bool is_row_major)
void generate_svd_final_iter_update(StringT &source, std::string const &numeric_string)
Manages an OpenCL context and provides the respective convenience functions for creating buffers...
void generate_svd_update_qr_column(StringT &source, std::string const &numeric_string)
Provides OpenCL-related utilities.
static std::string program_name()
void generate_svd_inverse_signs(StringT &source, std::string const &numeric_string)
static void init(viennacl::ocl::context &ctx)
const viennacl::ocl::handle< cl_context > & handle() const
Returns the context handle.
void generate_svd_house_update_QR(StringT &source, std::string const &numeric_string)
void generate_svd_transpose_inplace(StringT &source, std::string const &numeric_string)
Main namespace in ViennaCL. Holds all the basic types such as vector, matrix, etc. and defines operations upon them.
void generate_svd_givens_prev(StringT &source, std::string const &numeric_string)
static void apply(viennacl::ocl::context const &)
const OCL_TYPE & get() const
Main kernel class for generating OpenCL kernels for singular value decomposition of dense matrices...
void generate_svd_scan_kernel_2(StringType &source, std::string const &numeric_string)
void generate_svd_house_update_A_right(StringType &source, std::string const &numeric_string, bool is_row_major)
void generate_svd_exclusive_scan_kernel_1(StringType &source, std::string const &numeric_string)
Representation of an OpenCL kernel in ViennaCL.
void generate_svd_col_reduce_lcl_array(StringT &source, std::string const &numeric_string)
void generate_svd_copy_col(StringType &source, std::string const &numeric_string, bool is_row_major)
void generate_svd_house_update_QL(StringType &source, std::string const &numeric_string, bool is_row_major)
void generate_svd_givens_next(StringType &source, std::string const &numeric_string, bool is_row_major)
Helper class for converting a type to its string representation.
void generate_svd_house_update_A_left(StringType &source, std::string const &numeric_string, bool is_row_major)
void generate_svd_scan_kernel_4(StringType &source, std::string const &numeric_string)