ViennaCL - The Vienna Computing Library
1.5.2
|
00001 #ifndef VIENNACL_GENERATOR_ENQUEUE_TREE_HPP 00002 #define VIENNACL_GENERATOR_ENQUEUE_TREE_HPP 00003 00004 /* ========================================================================= 00005 Copyright (c) 2010-2014, Institute for Microelectronics, 00006 Institute for Analysis and Scientific Computing, 00007 TU Wien. 00008 Portions of this software are copyright by UChicago Argonne, LLC. 00009 00010 ----------------- 00011 ViennaCL - The Vienna Computing Library 00012 ----------------- 00013 00014 Project Head: Karl Rupp rupp@iue.tuwien.ac.at 00015 00016 (A list of authors and contributors can be found in the PDF manual) 00017 00018 License: MIT (X11), see file LICENSE in the base directory 00019 ============================================================================= */ 00020 00021 00026 #include <set> 00027 00028 #include "viennacl/matrix.hpp" 00029 #include "viennacl/vector.hpp" 00030 00031 #include "viennacl/forwards.h" 00032 #include "viennacl/scheduler/forwards.h" 00033 #include "viennacl/generator/forwards.h" 00034 00035 #include "viennacl/meta/result_of.hpp" 00036 00037 #include "viennacl/tools/shared_ptr.hpp" 00038 00039 #include "viennacl/ocl/kernel.hpp" 00040 00041 #include "viennacl/generator/helpers.hpp" 00042 #include "viennacl/generator/utils.hpp" 00043 #include "viennacl/generator/mapped_objects.hpp" 00044 00045 00046 namespace viennacl{ 00047 00048 namespace generator{ 00049 00050 namespace detail{ 00051 00053 class set_arguments_functor : public traversal_functor{ 00054 public: 00055 typedef void result_type; 00056 00057 set_arguments_functor(std::set<void *> & memory, unsigned int & current_arg, viennacl::ocl::kernel & kernel) : memory_(memory), current_arg_(current_arg), kernel_(kernel){ } 00058 00059 template<class ScalarType> 00060 result_type operator()(ScalarType const & scal) const { 00061 typedef typename viennacl::result_of::cl_type<ScalarType>::type cl_scalartype; 00062 kernel_.arg(current_arg_++, cl_scalartype(scal)); 00063 } 00064 00066 template<class ScalarType> 00067 result_type operator()(scalar<ScalarType> const & scal) const { 00068 if(memory_.insert((void*)&scal).second) 00069 kernel_.arg(current_arg_++, scal.handle().opencl_handle()); 00070 } 00071 00073 template<class ScalarType> 00074 result_type operator()(vector_base<ScalarType> const & vec) const { 00075 if(memory_.insert((void*)&vec).second){ 00076 kernel_.arg(current_arg_++, vec.handle().opencl_handle()); 00077 if(viennacl::traits::start(vec)>0) 00078 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start(vec))); 00079 if(vec.stride()>1) 00080 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride(vec))); 00081 } 00082 } 00083 00085 template<class ScalarType> 00086 result_type operator()(implicit_vector_base<ScalarType> const & vec) const { 00087 typedef typename viennacl::result_of::cl_type<ScalarType>::type cl_scalartype; 00088 if(memory_.insert((void*)&vec).second){ 00089 if(vec.is_value_static()==false) 00090 kernel_.arg(current_arg_++, cl_scalartype(vec.value())); 00091 if(vec.has_index()) 00092 kernel_.arg(current_arg_++, cl_uint(vec.index())); 00093 } 00094 } 00095 00097 template<class ScalarType, class Layout> 00098 result_type operator()(matrix_base<ScalarType, Layout> const & mat) const { 00099 //typedef typename matrix_base<ScalarType, Layout>::size_type size_type; 00100 if(memory_.insert((void*)&mat).second){ 00101 kernel_.arg(current_arg_++, mat.handle().opencl_handle()); 00102 if(viennacl::traits::start1(mat)>0) 00103 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start1(mat))); 00104 if(viennacl::traits::stride1(mat)>1) 00105 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride1(mat))); 00106 if(viennacl::traits::start2(mat)>0) 00107 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start2(mat))); 00108 if(viennacl::traits::stride2(mat)>1) 00109 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride2(mat))); 00110 } 00111 } 00112 00114 template<class ScalarType> 00115 result_type operator()(implicit_matrix_base<ScalarType> const & mat) const { 00116 if(mat.is_value_static()==false) 00117 kernel_.arg(current_arg_++, mat.value()); 00118 } 00119 00121 void operator()(scheduler::statement const * /*statement*/, scheduler::statement_node const * root_node, detail::node_type node_type) const { 00122 if(node_type==LHS_NODE_TYPE && root_node->lhs.type_family != scheduler::COMPOSITE_OPERATION_FAMILY) 00123 utils::call_on_element(root_node->lhs, *this); 00124 else if(node_type==RHS_NODE_TYPE && root_node->rhs.type_family != scheduler::COMPOSITE_OPERATION_FAMILY) 00125 utils::call_on_element(root_node->rhs, *this); 00126 } 00127 00128 private: 00129 std::set<void *> & memory_; 00130 unsigned int & current_arg_; 00131 viennacl::ocl::kernel & kernel_; 00132 }; 00133 00134 } 00135 00136 } 00137 00138 } 00139 #endif