ViennaCL - The Vienna Computing Library  1.5.2
viennacl/device_specific/tree_parsing/evaluate_expression.hpp
Go to the documentation of this file.
00001 #ifndef VIENNACL_DEVICE_SPECIFIC_TREE_PARSING_ELEMENTWISE_EXPRESSION_HPP
00002 #define VIENNACL_DEVICE_SPECIFIC_TREE_PARSING_ELEMENTWISE_EXPRESSION_HPP
00003 
00004 /* =========================================================================
00005    Copyright (c) 2010-2013, Institute for Microelectronics,
00006                             Institute for Analysis and Scientific Computing,
00007                             TU Wien.
00008    Portions of this software are copyright by UChicago Argonne, LLC.
00009 
00010                             -----------------
00011                   ViennaCL - The Vienna Computing Library
00012                             -----------------
00013 
00014    Project Head:    Karl Rupp                   rupp@iue.tuwien.ac.at
00015 
00016    (A list of authors and contributors can be found in the PDF manual)
00017 
00018    License:         MIT (X11), see file LICENSE in the base directory
00019 ============================================================================= */
00020 
00021 
00026 #include <set>
00027 
00028 #include "CL/cl.h"
00029 
00030 #include "viennacl/forwards.h"
00031 
00032 #include "viennacl/scheduler/forwards.h"
00033 
00034 #include "viennacl/device_specific/forwards.h"
00035 #include "viennacl/device_specific/tree_parsing/traverse.hpp"
00036 
00037 namespace viennacl{
00038 
00039   namespace device_specific{
00040 
00041     namespace tree_parsing{
00042 
00044       inline const char * evaluate(scheduler::operation_node_type type){
00045         using namespace scheduler;
00046         // unary expression
00047         switch(type){
00048           //Function
00049           case OPERATION_UNARY_ABS_TYPE : return "abs";
00050           case OPERATION_UNARY_ACOS_TYPE : return "acos";
00051           case OPERATION_UNARY_ASIN_TYPE : return "asin";
00052           case OPERATION_UNARY_ATAN_TYPE : return "atan";
00053           case OPERATION_UNARY_CEIL_TYPE : return "ceil";
00054           case OPERATION_UNARY_COS_TYPE : return "cos";
00055           case OPERATION_UNARY_COSH_TYPE : return "cosh";
00056           case OPERATION_UNARY_EXP_TYPE : return "exp";
00057           case OPERATION_UNARY_FABS_TYPE : return "fabs";
00058           case OPERATION_UNARY_FLOOR_TYPE : return "floor";
00059           case OPERATION_UNARY_LOG_TYPE : return "log";
00060           case OPERATION_UNARY_LOG10_TYPE : return "log10";
00061           case OPERATION_UNARY_SIN_TYPE : return "sin";
00062           case OPERATION_UNARY_SINH_TYPE : return "sinh";
00063           case OPERATION_UNARY_SQRT_TYPE : return "sqrt";
00064           case OPERATION_UNARY_TAN_TYPE : return "tan";
00065           case OPERATION_UNARY_TANH_TYPE : return "tanh";
00066 
00067           case OPERATION_UNARY_CAST_CHAR_TYPE : return "(char)";
00068           case OPERATION_UNARY_CAST_UCHAR_TYPE : return "(uchar)";
00069           case OPERATION_UNARY_CAST_SHORT_TYPE : return "(short)";
00070           case OPERATION_UNARY_CAST_USHORT_TYPE : return "(ushort)";
00071           case OPERATION_UNARY_CAST_INT_TYPE : return "(int)";
00072           case OPERATION_UNARY_CAST_UINT_TYPE : return "(uint)";
00073           case OPERATION_UNARY_CAST_LONG_TYPE : return "(long)";
00074           case OPERATION_UNARY_CAST_ULONG_TYPE : return "(ulong)";
00075           case OPERATION_UNARY_CAST_HALF_TYPE : return "(half)";
00076           case OPERATION_UNARY_CAST_FLOAT_TYPE : return "(float)";
00077           case OPERATION_UNARY_CAST_DOUBLE_TYPE : return "(double)";
00078 
00079           case OPERATION_BINARY_ELEMENT_ARGFMAX_TYPE : return "argfmax";
00080           case OPERATION_BINARY_ELEMENT_ARGMAX_TYPE : return "argmax";
00081           case OPERATION_BINARY_ELEMENT_ARGFMIN_TYPE : return "argfmin";
00082           case OPERATION_BINARY_ELEMENT_ARGMIN_TYPE : return "argmin";
00083           case OPERATION_BINARY_ELEMENT_POW_TYPE : return "pow";
00084 
00085           //Arithmetic
00086           case OPERATION_UNARY_MINUS_TYPE : return "-";
00087           case OPERATION_BINARY_ASSIGN_TYPE : return "=";
00088           case OPERATION_BINARY_INPLACE_ADD_TYPE : return "+=";
00089           case OPERATION_BINARY_INPLACE_SUB_TYPE : return "-=";
00090           case OPERATION_BINARY_ADD_TYPE : return "+";
00091           case OPERATION_BINARY_SUB_TYPE : return "-";
00092           case OPERATION_BINARY_MULT_TYPE : return "*";
00093           case OPERATION_BINARY_ELEMENT_PROD_TYPE : return "*";
00094           case OPERATION_BINARY_DIV_TYPE : return "/";
00095           case OPERATION_BINARY_ELEMENT_DIV_TYPE : return "/";
00096           case OPERATION_BINARY_ACCESS_TYPE : return "[]";
00097 
00098           //Relational
00099           case OPERATION_BINARY_ELEMENT_EQ_TYPE : return "isequal";
00100           case OPERATION_BINARY_ELEMENT_NEQ_TYPE : return "isnotequal";
00101           case OPERATION_BINARY_ELEMENT_GREATER_TYPE : return "isgreater";
00102           case OPERATION_BINARY_ELEMENT_GEQ_TYPE : return "isgreaterequal";
00103           case OPERATION_BINARY_ELEMENT_LESS_TYPE : return "isless";
00104           case OPERATION_BINARY_ELEMENT_LEQ_TYPE : return "islessequal";
00105 
00106           case OPERATION_BINARY_ELEMENT_FMAX_TYPE : return "fmax";
00107           case OPERATION_BINARY_ELEMENT_FMIN_TYPE : return "fmin";
00108           case OPERATION_BINARY_ELEMENT_MAX_TYPE : return "max";
00109           case OPERATION_BINARY_ELEMENT_MIN_TYPE : return "min";
00110           //Unary
00111           case OPERATION_UNARY_TRANS_TYPE : return "trans";
00112 
00113           //Binary
00114           case OPERATION_BINARY_INNER_PROD_TYPE : return "iprod";
00115           case OPERATION_BINARY_MAT_MAT_PROD_TYPE : return "mmprod";
00116           case OPERATION_BINARY_MAT_VEC_PROD_TYPE : return "mvprod";
00117           case OPERATION_BINARY_VECTOR_DIAG_TYPE : return "vdiag";
00118           case OPERATION_BINARY_MATRIX_DIAG_TYPE : return "mdiag";
00119           case OPERATION_BINARY_MATRIX_ROW_TYPE : return "row";
00120           case OPERATION_BINARY_MATRIX_COLUMN_TYPE : return "col";
00121 
00122           default : throw generator_not_supported_exception("Unsupported operator");
00123         }
00124       }
00125 
00126       inline const char * evaluate_str(scheduler::operation_node_type type){
00127         using namespace scheduler;
00128         switch(type){
00129         case OPERATION_UNARY_CAST_CHAR_TYPE : return "char";
00130         case OPERATION_UNARY_CAST_UCHAR_TYPE : return "uchar";
00131         case OPERATION_UNARY_CAST_SHORT_TYPE : return "short";
00132         case OPERATION_UNARY_CAST_USHORT_TYPE : return "ushort";
00133         case OPERATION_UNARY_CAST_INT_TYPE : return "int";
00134         case OPERATION_UNARY_CAST_UINT_TYPE : return "uint";
00135         case OPERATION_UNARY_CAST_LONG_TYPE : return "long";
00136         case OPERATION_UNARY_CAST_ULONG_TYPE : return "ulong";
00137         case OPERATION_UNARY_CAST_HALF_TYPE : return "half";
00138         case OPERATION_UNARY_CAST_FLOAT_TYPE : return "float";
00139         case OPERATION_UNARY_CAST_DOUBLE_TYPE : return "double";
00140 
00141         case OPERATION_UNARY_MINUS_TYPE : return "mi";
00142         case OPERATION_BINARY_ASSIGN_TYPE : return "as";
00143         case OPERATION_BINARY_INPLACE_ADD_TYPE : return "iad";
00144         case OPERATION_BINARY_INPLACE_SUB_TYPE : return "isu";
00145         case OPERATION_BINARY_ADD_TYPE : return "ad";
00146         case OPERATION_BINARY_SUB_TYPE : return "su";
00147         case OPERATION_BINARY_MULT_TYPE : return "mu";
00148         case OPERATION_BINARY_ELEMENT_PROD_TYPE : return "epr";
00149         case OPERATION_BINARY_DIV_TYPE : return "di";
00150         case OPERATION_BINARY_ELEMENT_DIV_TYPE : return "edi";
00151         case OPERATION_BINARY_ACCESS_TYPE : return "ac";
00152           default : return evaluate(type);
00153         }
00154       }
00155 
00156 
00158       class evaluate_expression_traversal: public traversal_functor{
00159         private:
00160           index_tuple index_;
00161           int simd_element_;
00162           std::string & str_;
00163           mapping_type const & mapping_;
00164 
00165         public:
00166           evaluate_expression_traversal(index_tuple const & index, int simd_element, std::string & str, mapping_type const & mapping) : index_(index), simd_element_(simd_element), str_(str), mapping_(mapping){ }
00167 
00168           void call_before_expansion(scheduler::statement const & statement, unsigned int root_idx) const
00169           {
00170               scheduler::statement_node const & root_node = statement.array()[root_idx];
00171               if((root_node.op.type_family==scheduler::OPERATION_UNARY_TYPE_FAMILY || utils::elementwise_function(root_node.op))
00172                   && !utils::node_leaf(root_node.op))
00173                   str_+=evaluate(root_node.op.type);
00174               str_+="(";
00175 
00176           }
00177           void call_after_expansion(scheduler::statement const & /*statement*/, unsigned int /*root_idx*/) const
00178           {
00179             str_+=")";
00180           }
00181 
00182           void operator()(scheduler::statement const & statement, unsigned int root_idx, node_type leaf) const
00183           {
00184             scheduler::statement_node const & root_node = statement.array()[root_idx];
00185             mapping_type::key_type key = std::make_pair(root_idx, leaf);
00186             if(leaf==PARENT_NODE_TYPE)
00187             {
00188               if(utils::node_leaf(root_node.op))
00189                 str_ += mapping_.at(key)->evaluate(index_, simd_element_);
00190               else if(utils::elementwise_operator(root_node.op))
00191                 str_ += evaluate(root_node.op.type);
00192               else if(root_node.op.type_family!=scheduler::OPERATION_UNARY_TYPE_FAMILY && utils::elementwise_function(root_node.op))
00193                 str_ += ",";
00194             }
00195             else
00196             {
00197               if(leaf==LHS_NODE_TYPE)
00198               {
00199                 if(root_node.lhs.type_family!=scheduler::COMPOSITE_OPERATION_FAMILY)
00200                   str_ += mapping_.at(key)->evaluate(index_,simd_element_);
00201               }
00202 
00203               if(leaf==RHS_NODE_TYPE)
00204               {
00205                 if(root_node.rhs.type_family!=scheduler::COMPOSITE_OPERATION_FAMILY)
00206                   str_ += mapping_.at(key)->evaluate(index_,simd_element_);
00207               }
00208             }
00209           }
00210       };
00211 
00212       inline std::string evaluate_expression(scheduler::statement const & statement, unsigned int root_idx, index_tuple const & index,
00213                                              int simd_element, mapping_type const & mapping, node_type leaf)
00214       {
00215         std::string res;
00216         evaluate_expression_traversal traversal_functor(index, simd_element, res, mapping);
00217         scheduler::statement_node const & root_node = statement.array()[root_idx];
00218 
00219         if(leaf==RHS_NODE_TYPE)
00220         {
00221           if(root_node.rhs.type_family==scheduler::COMPOSITE_OPERATION_FAMILY)
00222             traverse(statement, root_node.rhs.node_index, traversal_functor, false);
00223           else
00224             traversal_functor(statement, root_idx, leaf);
00225         }
00226         else if(leaf==LHS_NODE_TYPE)
00227         {
00228           if(root_node.lhs.type_family==scheduler::COMPOSITE_OPERATION_FAMILY)
00229             traverse(statement, root_node.lhs.node_index, traversal_functor, false);
00230           else
00231             traversal_functor(statement, root_idx, leaf);
00232         }
00233         else
00234           traverse(statement, root_idx, traversal_functor, false);
00235 
00236         return res;
00237       }
00238 
00239 
00240 
00241 
00242     }
00243   }
00244 }
00245 #endif