ViennaCL - The Vienna Computing Library  1.5.2
viennacl/device_specific/tree_parsing/set_arguments.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_DEVICE_SPECIFIC_TREE_PARSING_SET_ARGUMENTS_HPP
00002 #define VIENNACL_DEVICE_SPECIFIC_TREE_PARSING_SET_ARGUMENTS_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2013, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00021 
00026 #include <set>
00027 
00028 #include "viennacl/forwards.h"
00029 #include "viennacl/scheduler/forwards.h"
00030 #include "viennacl/device_specific/forwards.h"
00031 
00032 #include "viennacl/meta/result_of.hpp"
00033 
00034 #include "viennacl/tools/shared_ptr.hpp"
00035 
00036 
00037 #include "viennacl/device_specific/tree_parsing/traverse.hpp"
00038 #include "viennacl/device_specific/utils.hpp"
00039 #include "viennacl/device_specific/mapped_objects.hpp"
00040 
00041 
00042 namespace viennacl{
00043 
00044   namespace device_specific{
00045 
00046     namespace tree_parsing{
00047 
00048       class set_arguments_functor : public traversal_functor{
00049         public:
00050           typedef void result_type;
00051 
00052           set_arguments_functor(symbolic_binder & binder, unsigned int & current_arg, viennacl::ocl::kernel & kernel) : binder_(binder), current_arg_(current_arg), kernel_(kernel){ }
00053 
00054           template<class ScalarType>
00055           result_type operator()(ScalarType const & scal) const {
00056             typedef typename viennacl::result_of::cl_type<ScalarType>::type cl_scalartype;
00057             kernel_.arg(current_arg_++, cl_scalartype(scal));
00058           }
00059 
00061           template<class ScalarType>
00062           result_type operator()(scalar<ScalarType> const & scal) const {
00063             if(binder_.bind(&viennacl::traits::handle(scal)))
00064               kernel_.arg(current_arg_++, scal.handle().opencl_handle());
00065           }
00066 
00068           template<class ScalarType>
00069           result_type operator()(vector_base<ScalarType> const & vec) const {
00070             if(binder_.bind(&viennacl::traits::handle(vec)))
00071             {
00072               kernel_.arg(current_arg_++, vec.handle().opencl_handle());
00073               kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start(vec)));
00074               kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride(vec)));
00075             }
00076           }
00077 
00079           template<class ScalarType>
00080           result_type operator()(implicit_vector_base<ScalarType> const & vec) const {
00081             typedef typename viennacl::result_of::cl_type<ScalarType>::type cl_scalartype;
00082             kernel_.arg(current_arg_++, cl_scalartype(vec.value()));
00083             if(vec.has_index())
00084               kernel_.arg(current_arg_++, cl_uint(vec.index()));
00085           }
00086 
00088           template<class ScalarType>
00089           result_type operator()(matrix_base<ScalarType> const & mat) const {
00090             if(binder_.bind(&viennacl::traits::handle(mat)))
00091             {
00092               kernel_.arg(current_arg_++, mat.handle().opencl_handle());
00093               if(mat.row_major()){
00094                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::internal_size2(mat)));
00095                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start2(mat)));
00096                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride2(mat)));
00097                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start1(mat)));
00098                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride1(mat)));
00099               }
00100               else{
00101                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::internal_size1(mat)));
00102                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start1(mat)));
00103                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride1(mat)));
00104                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start2(mat)));
00105                 kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride2(mat)));
00106               }
00107             }
00108           }
00109 
00111           template<class ScalarType>
00112           result_type operator()(implicit_matrix_base<ScalarType> const & mat) const {
00113             kernel_.arg(current_arg_++, mat.value());
00114           }
00115 
00117           void operator()(scheduler::statement const & statement, unsigned int root_idx, node_type node_type) const {
00118             scheduler::statement_node const & root_node = statement.array()[root_idx];
00119             if(node_type==LHS_NODE_TYPE && root_node.lhs.type_family != scheduler::COMPOSITE_OPERATION_FAMILY)
00120               utils::call_on_element(root_node.lhs, *this);
00121             else if(node_type==RHS_NODE_TYPE && root_node.rhs.type_family != scheduler::COMPOSITE_OPERATION_FAMILY)
00122               utils::call_on_element(root_node.rhs, *this);
00123           }
00124 
00125         private:
00126           symbolic_binder & binder_;
00127           unsigned int & current_arg_;
00128           viennacl::ocl::kernel & kernel_;
00129       };
00130 
00131     }
00132 
00133   }
00134 
00135 }
00136 #endif