ViennaCL - The Vienna Computing Library  1.5.2
viennacl/generator/utils.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_GENERATOR_UTILS_HPP
00002 #define VIENNACL_GENERATOR_UTILS_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2014, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00021 
00026 #include <sstream>
00027 
00028 #include "viennacl/ocl/forwards.h"
00029 
00030 #include "viennacl/traits/size.hpp"
00031 
00032 #include "viennacl/scheduler/forwards.h"
00033 
00034 namespace viennacl{
00035 
00036   namespace generator{
00037 
00038     namespace utils{
00039 
00040     template<class Fun>
00041     static typename Fun::result_type call_on_host_scalar(scheduler::lhs_rhs_element element, Fun const & fun){
00042         assert(element.type_family == scheduler::SCALAR_TYPE_FAMILY && bool("Must be called on a host scalar"));
00043         switch(element.numeric_type){
00044         case scheduler::FLOAT_TYPE :
00045             return fun(element.host_float);
00046         case scheduler::DOUBLE_TYPE :
00047             return fun(element.host_double);
00048         default :
00049             throw "not implemented";
00050         }
00051     }
00052 
00053     template<class Fun>
00054     static typename Fun::result_type call_on_scalar(scheduler::lhs_rhs_element element, Fun const & fun){
00055         assert(element.type_family == scheduler::SCALAR_TYPE_FAMILY && bool("Must be called on a scalar"));
00056         switch(element.numeric_type){
00057         case scheduler::FLOAT_TYPE :
00058             return fun(*element.scalar_float);
00059         case scheduler::DOUBLE_TYPE :
00060             return fun(*element.scalar_double);
00061         default :
00062             throw "not implemented";
00063         }
00064     }
00065 
00066     template<class Fun>
00067     static typename Fun::result_type call_on_vector(scheduler::lhs_rhs_element element, Fun const & fun){
00068         assert(element.type_family == scheduler::VECTOR_TYPE_FAMILY && bool("Must be called on a vector"));
00069         switch(element.numeric_type){
00070         case scheduler::FLOAT_TYPE :
00071             return fun(*element.vector_float);
00072         case scheduler::DOUBLE_TYPE :
00073             return fun(*element.vector_double);
00074         default :
00075             throw "not implemented";
00076         }
00077     }
00078 
00079     template<class Fun>
00080     static typename Fun::result_type call_on_implicit_vector(scheduler::lhs_rhs_element element, Fun const & fun){
00081         assert(element.type_family == scheduler::VECTOR_TYPE_FAMILY   && bool("Must be called on a implicit_vector"));
00082         assert(element.subtype     == scheduler::IMPLICIT_VECTOR_TYPE && bool("Must be called on a implicit_vector"));
00083         switch(element.numeric_type){
00084         case scheduler::FLOAT_TYPE :
00085             return fun(*element.implicit_vector_float);
00086         case scheduler::DOUBLE_TYPE :
00087             return fun(*element.implicit_vector_double);
00088         default :
00089             throw "not implemented";
00090         }
00091     }
00092 
00093     template<class Fun>
00094     static typename Fun::result_type call_on_matrix(scheduler::lhs_rhs_element element, Fun const & fun){
00095         assert(element.type_family == scheduler::MATRIX_TYPE_FAMILY && bool("Must be called on a matrix"));
00096         if (element.subtype == scheduler::DENSE_ROW_MATRIX_TYPE)
00097         {
00098             switch(element.numeric_type){
00099             case scheduler::FLOAT_TYPE :
00100                 return fun(*element.matrix_row_float);
00101             case scheduler::DOUBLE_TYPE :
00102                 return fun(*element.matrix_row_double);
00103             default :
00104                 throw "not implemented";
00105             }
00106         }
00107         else
00108         {
00109             switch(element.numeric_type){
00110             case scheduler::FLOAT_TYPE :
00111                 return fun(*element.matrix_col_float);
00112             case scheduler::DOUBLE_TYPE :
00113                 return fun(*element.matrix_col_double);
00114             default :
00115                 throw "not implemented";
00116             }
00117         }
00118     }
00119 
00120 
00121     template<class Fun>
00122     static typename Fun::result_type call_on_implicit_matrix(scheduler::lhs_rhs_element element, Fun const & fun){
00123         assert(element.type_family == scheduler::MATRIX_TYPE_FAMILY   && bool("Must be called on a matrix_vector"));
00124         assert(element.subtype     == scheduler::IMPLICIT_MATRIX_TYPE && bool("Must be called on a matrix_vector"));
00125         switch(element.numeric_type){
00126         case scheduler::FLOAT_TYPE :
00127             return fun(*element.implicit_matrix_float);
00128         case scheduler::DOUBLE_TYPE :
00129             return fun(*element.implicit_matrix_double);
00130         default :
00131             throw "not implemented";
00132         }
00133     }
00134 
00135       template<class Fun>
00136       static typename Fun::result_type call_on_element(scheduler::lhs_rhs_element const & element, Fun const & fun){
00137         switch(element.type_family){
00138           case scheduler::SCALAR_TYPE_FAMILY:
00139             if (element.subtype == scheduler::HOST_SCALAR_TYPE)
00140               return call_on_host_scalar(element, fun);
00141             else
00142               return call_on_scalar(element, fun);
00143           case scheduler::VECTOR_TYPE_FAMILY :
00144             if (element.subtype == scheduler::IMPLICIT_VECTOR_TYPE)
00145               return call_on_implicit_vector(element, fun);
00146             else
00147               return call_on_vector(element, fun);
00148           case scheduler::MATRIX_TYPE_FAMILY:
00149             if (element.subtype == scheduler::IMPLICIT_MATRIX_TYPE)
00150               return call_on_implicit_matrix(element, fun);
00151             else
00152               return call_on_matrix(element,fun);
00153           default:
00154             throw "not implemented";
00155         }
00156       }
00157 
00159       struct scalartype_size_fun{
00160           typedef vcl_size_t result_type;
00161           result_type operator()(float const &) const { return sizeof(float); }
00162           result_type operator()(double const &) const { return sizeof(double); }
00163           template<class T> result_type operator()(T const &) const { return sizeof(typename viennacl::result_of::cpu_value_type<T>::type); }
00164       };
00165 
00167       struct internal_size_fun{
00168           typedef vcl_size_t result_type;
00169           template<class T>
00170           result_type operator()(T const &t) const { return viennacl::traits::internal_size(t); }
00171       };
00172 
00174       struct handle_fun{
00175           typedef cl_mem result_type;
00176           template<class T>
00177           result_type operator()(T const &t) const { return t.handle().opencl_handle(); }
00178       };
00179 
00181       struct internal_size1_fun{
00182           typedef vcl_size_t result_type;
00183           template<class T>
00184           result_type operator()(T const &t) const { return viennacl::traits::internal_size1(t); }
00185       };
00186 
00188       struct internal_size2_fun{
00189           typedef vcl_size_t result_type;
00190           template<class T>
00191           result_type operator()(T const &t) const { return viennacl::traits::internal_size2(t); }
00192       };
00193 
00195       template<class T, class U>
00196       struct is_same_type { enum { value = 0 }; };
00197 
00199       template<class T>
00200       struct is_same_type<T,T> { enum { value = 1 }; };
00203       template <class T>
00204       inline std::string to_string ( T const t )
00205       {
00206         std::stringstream ss;
00207         ss << t;
00208         return ss.str();
00209       }
00210 
00212       template<class T>
00213       struct type_to_string;
00214 
00215 
00217       template<> struct type_to_string<float> { static const char * value() { return "float"; } };
00218       template<> struct type_to_string<double> { static const char * value() { return "double"; } };
00222       template<class T>
00223       struct first_letter_of_type;
00224 
00226       template<> struct first_letter_of_type<float> { static char value() { return 'f'; } };
00227       template<> struct first_letter_of_type<double> { static char value() { return 'd'; } };
00228       template<> struct first_letter_of_type<viennacl::row_major> { static char value() { return 'r'; } };
00229       template<> struct first_letter_of_type<viennacl::column_major> { static char value() { return 'c'; } };
00233       class kernel_generation_stream : public std::ostream{
00234         private:
00235 
00236           class kgenstream : public std::stringbuf{
00237             public:
00238               kgenstream(std::ostringstream& oss,unsigned int const & tab_count) : oss_(oss), tab_count_(tab_count){ }
00239               int sync() {
00240                 for(unsigned int i=0 ; i<tab_count_;++i)
00241                   oss_ << "    ";
00242                 oss_ << str();
00243                 str("");
00244                 return !oss_;
00245               }
00246               ~kgenstream() {  pubsync(); }
00247             private:
00248               std::ostream& oss_;
00249               unsigned int const & tab_count_;
00250           };
00251 
00252         public:
00253           kernel_generation_stream() : std::ostream(new kgenstream(oss,tab_count_)), tab_count_(0){ }
00254 
00255           std::string str(){ return oss.str(); }
00256 
00257           void inc_tab(){ ++tab_count_; }
00258 
00259           void dec_tab(){ --tab_count_; }
00260 
00261           ~kernel_generation_stream(){ delete rdbuf(); }
00262 
00263         private:
00264           unsigned int tab_count_;
00265           std::ostringstream oss;
00266       };
00267 
00268 
00269     }
00270 
00271   }
00272 
00273 }
00274 #endif