ViennaCL - The Vienna Computing Library
1.5.2
|
00001 #ifndef VIENNACL_DEVICE_SPECIFIC_TREE_PARSING_SET_ARGUMENTS_HPP 00002 #define VIENNACL_DEVICE_SPECIFIC_TREE_PARSING_SET_ARGUMENTS_HPP 00003 00004 /* ========================================================================= 00005 Copyright (c) 2010-2013, Institute for Microelectronics, 00006 Institute for Analysis and Scientific Computing, 00007 TU Wien. 00008 Portions of this software are copyright by UChicago Argonne, LLC. 00009 00010 ----------------- 00011 ViennaCL - The Vienna Computing Library 00012 ----------------- 00013 00014 Project Head: Karl Rupp rupp@iue.tuwien.ac.at 00015 00016 (A list of authors and contributors can be found in the PDF manual) 00017 00018 License: MIT (X11), see file LICENSE in the base directory 00019 ============================================================================= */ 00020 00021 00026 #include <set> 00027 00028 #include "viennacl/forwards.h" 00029 #include "viennacl/scheduler/forwards.h" 00030 #include "viennacl/device_specific/forwards.h" 00031 00032 #include "viennacl/meta/result_of.hpp" 00033 00034 #include "viennacl/tools/shared_ptr.hpp" 00035 00036 00037 #include "viennacl/device_specific/tree_parsing/traverse.hpp" 00038 #include "viennacl/device_specific/utils.hpp" 00039 #include "viennacl/device_specific/mapped_objects.hpp" 00040 00041 00042 namespace viennacl{ 00043 00044 namespace device_specific{ 00045 00046 namespace tree_parsing{ 00047 00048 class set_arguments_functor : public traversal_functor{ 00049 public: 00050 typedef void result_type; 00051 00052 set_arguments_functor(symbolic_binder & binder, unsigned int & current_arg, viennacl::ocl::kernel & kernel) : binder_(binder), current_arg_(current_arg), kernel_(kernel){ } 00053 00054 template<class ScalarType> 00055 result_type operator()(ScalarType const & scal) const { 00056 typedef typename viennacl::result_of::cl_type<ScalarType>::type cl_scalartype; 00057 kernel_.arg(current_arg_++, cl_scalartype(scal)); 00058 } 00059 00061 template<class ScalarType> 00062 result_type operator()(scalar<ScalarType> const & scal) const { 00063 if(binder_.bind(&viennacl::traits::handle(scal))) 00064 kernel_.arg(current_arg_++, scal.handle().opencl_handle()); 00065 } 00066 00068 template<class ScalarType> 00069 result_type operator()(vector_base<ScalarType> const & vec) const { 00070 if(binder_.bind(&viennacl::traits::handle(vec))) 00071 { 00072 kernel_.arg(current_arg_++, vec.handle().opencl_handle()); 00073 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start(vec))); 00074 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride(vec))); 00075 } 00076 } 00077 00079 template<class ScalarType> 00080 result_type operator()(implicit_vector_base<ScalarType> const & vec) const { 00081 typedef typename viennacl::result_of::cl_type<ScalarType>::type cl_scalartype; 00082 kernel_.arg(current_arg_++, cl_scalartype(vec.value())); 00083 if(vec.has_index()) 00084 kernel_.arg(current_arg_++, cl_uint(vec.index())); 00085 } 00086 00088 template<class ScalarType> 00089 result_type operator()(matrix_base<ScalarType> const & mat) const { 00090 if(binder_.bind(&viennacl::traits::handle(mat))) 00091 { 00092 kernel_.arg(current_arg_++, mat.handle().opencl_handle()); 00093 if(mat.row_major()){ 00094 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::internal_size2(mat))); 00095 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start2(mat))); 00096 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride2(mat))); 00097 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start1(mat))); 00098 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride1(mat))); 00099 } 00100 else{ 00101 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::internal_size1(mat))); 00102 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start1(mat))); 00103 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride1(mat))); 00104 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start2(mat))); 00105 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride2(mat))); 00106 } 00107 } 00108 } 00109 00111 template<class ScalarType> 00112 result_type operator()(implicit_matrix_base<ScalarType> const & mat) const { 00113 kernel_.arg(current_arg_++, mat.value()); 00114 } 00115 00117 void operator()(scheduler::statement const & statement, unsigned int root_idx, node_type node_type) const { 00118 scheduler::statement_node const & root_node = statement.array()[root_idx]; 00119 if(node_type==LHS_NODE_TYPE && root_node.lhs.type_family != scheduler::COMPOSITE_OPERATION_FAMILY) 00120 utils::call_on_element(root_node.lhs, *this); 00121 else if(node_type==RHS_NODE_TYPE && root_node.rhs.type_family != scheduler::COMPOSITE_OPERATION_FAMILY) 00122 utils::call_on_element(root_node.rhs, *this); 00123 } 00124 00125 private: 00126 symbolic_binder & binder_; 00127 unsigned int & current_arg_; 00128 viennacl::ocl::kernel & kernel_; 00129 }; 00130 00131 } 00132 00133 } 00134 00135 } 00136 #endif