[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/numpy_array_traits.hxx VIGRA

00001 /************************************************************************/
00002 /*                                                                      */
00003 /*       Copyright 2009 by Ullrich Koethe and Hans Meine                */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 #ifndef VIGRA_NUMPY_ARRAY_TRAITS_HXX
00037 #define VIGRA_NUMPY_ARRAY_TRAITS_HXX
00038 
00039 #include "numerictraits.hxx"
00040 #include "multi_array.hxx"
00041 #include "numpy_array_taggedshape.hxx"
00042 
00043 namespace vigra {
00044 
00045 /********************************************************/
00046 /*                                                      */
00047 /*              Singleband and Multiband                */
00048 /*                                                      */
00049 /********************************************************/
00050 
00051 template <class T>
00052 struct Singleband  // the resulting NumpyArray has no explicit channel axis 
00053                    // (i.e. the number of channels is implicitly one)
00054 {
00055     typedef T value_type;
00056 };
00057 
00058 template <class T>
00059 struct Multiband  // the last axis is explicitly designated as channel axis
00060 {
00061     typedef T value_type;
00062 };
00063 
00064 template<class T>
00065 struct NumericTraits<Singleband<T> >
00066 : public NumericTraits<T>
00067 {};
00068 
00069 template<class T>
00070 struct NumericTraits<Multiband<T> >
00071 {
00072     typedef Multiband<T> Type;
00073 /*
00074     typedef int Promote;
00075     typedef unsigned int UnsignedPromote;
00076     typedef double RealPromote;
00077     typedef std::complex<RealPromote> ComplexPromote;
00078 */
00079     typedef Type ValueType;
00080 
00081     typedef typename NumericTraits<T>::isIntegral isIntegral;
00082     typedef VigraFalseType isScalar;
00083     typedef typename NumericTraits<T>::isSigned isSigned;
00084     typedef typename NumericTraits<T>::isSigned isOrdered;
00085     typedef typename NumericTraits<T>::isSigned isComplex;
00086 /*
00087     static signed char zero() { return 0; }
00088     static signed char one() { return 1; }
00089     static signed char nonZero() { return 1; }
00090     static signed char min() { return SCHAR_MIN; }
00091     static signed char max() { return SCHAR_MAX; }
00092 
00093 #ifdef NO_INLINE_STATIC_CONST_DEFINITION
00094     enum { minConst = SCHAR_MIN, maxConst = SCHAR_MIN };
00095 #else
00096     static const signed char minConst = SCHAR_MIN;
00097     static const signed char maxConst = SCHAR_MIN;
00098 #endif
00099 
00100     static Promote toPromote(signed char v) { return v; }
00101     static RealPromote toRealPromote(signed char v) { return v; }
00102     static signed char fromPromote(Promote v) {
00103         return ((v < SCHAR_MIN) ? SCHAR_MIN : (v > SCHAR_MAX) ? SCHAR_MAX : v);
00104     }
00105     static signed char fromRealPromote(RealPromote v) {
00106         return ((v < 0.0)
00107                    ? ((v < (RealPromote)SCHAR_MIN)
00108                        ? SCHAR_MIN
00109                        : static_cast<signed char>(v - 0.5))
00110                    : (v > (RealPromote)SCHAR_MAX)
00111                        ? SCHAR_MAX
00112                        : static_cast<signed char>(v + 0.5));
00113     }
00114 */
00115 };
00116 
00117 /********************************************************/
00118 /*                                                      */
00119 /*               NumpyArrayValuetypeTraits              */
00120 /*                                                      */
00121 /********************************************************/
00122 
00123 template<class ValueType>
00124 struct ERROR_NumpyArrayValuetypeTraits_not_specialized_for_ { };
00125 
00126 template<class ValueType>
00127 struct NumpyArrayValuetypeTraits
00128 {
00129     static bool isValuetypeCompatible(PyArrayObject const * obj)
00130     {
00131         return ERROR_NumpyArrayValuetypeTraits_not_specialized_for_<ValueType>();
00132     }
00133 
00134     static ERROR_NumpyArrayValuetypeTraits_not_specialized_for_<ValueType> typeCode;
00135 
00136     static std::string typeName()
00137     {
00138         return std::string("ERROR: NumpyArrayValuetypeTraits not specialized for this case");
00139     }
00140 
00141     static std::string typeNameImpex()
00142     {
00143         return std::string("ERROR: NumpyArrayValuetypeTraits not specialized for this case");
00144     }
00145 
00146     static PyObject * typeObject()
00147     {
00148         return (PyObject *)0;
00149     }
00150 };
00151 
00152 template<class ValueType>
00153 ERROR_NumpyArrayValuetypeTraits_not_specialized_for_<ValueType> NumpyArrayValuetypeTraits<ValueType>::typeCode;
00154 
00155 #define VIGRA_NUMPY_VALUETYPE_TRAITS(type, typeID, numpyTypeName, impexTypeName) \
00156 template <> \
00157 struct NumpyArrayValuetypeTraits<type > \
00158 { \
00159     static bool isValuetypeCompatible(PyArrayObject const * obj) /* obj must not be NULL */ \
00160     { \
00161         return PyArray_EquivTypenums(typeID, PyArray_DESCR((PyObject *)obj)->type_num) && \
00162                PyArray_ITEMSIZE((PyObject *)obj) == sizeof(type); \
00163     } \
00164     \
00165     static NPY_TYPES const typeCode = typeID; \
00166     \
00167     static std::string typeName() \
00168     { \
00169         return #numpyTypeName; \
00170     } \
00171     \
00172     static std::string typeNameImpex() \
00173     { \
00174         return impexTypeName; \
00175     } \
00176     \
00177     static PyObject * typeObject() \
00178     { \
00179         return PyArray_TypeObjectFromType(typeID); \
00180     } \
00181 };
00182 
00183 VIGRA_NUMPY_VALUETYPE_TRAITS(bool,           NPY_BOOL, bool, "UINT8")
00184 VIGRA_NUMPY_VALUETYPE_TRAITS(signed char,    NPY_INT8, int8, "INT16")
00185 VIGRA_NUMPY_VALUETYPE_TRAITS(unsigned char,  NPY_UINT8, uint8, "UINT8")
00186 VIGRA_NUMPY_VALUETYPE_TRAITS(short,          NPY_INT16, int16, "INT16")
00187 VIGRA_NUMPY_VALUETYPE_TRAITS(unsigned short, NPY_UINT16, uint16, "UINT16")
00188 
00189 #if VIGRA_BITSOF_LONG == 32
00190 VIGRA_NUMPY_VALUETYPE_TRAITS(long,           NPY_INT32, int32, "INT32")
00191 VIGRA_NUMPY_VALUETYPE_TRAITS(unsigned long,  NPY_UINT32, uint32, "UINT32")
00192 #elif VIGRA_BITSOF_LONG == 64
00193 VIGRA_NUMPY_VALUETYPE_TRAITS(long,           NPY_INT64, int64, "DOUBLE")
00194 VIGRA_NUMPY_VALUETYPE_TRAITS(unsigned long,  NPY_UINT64, uint64, "DOUBLE")
00195 #endif
00196 
00197 #if VIGRA_BITSOF_INT == 32
00198 VIGRA_NUMPY_VALUETYPE_TRAITS(int,            NPY_INT32, int32, "INT32")
00199 VIGRA_NUMPY_VALUETYPE_TRAITS(unsigned int,   NPY_UINT32, uint32, "UINT32")
00200 #elif VIGRA_BITSOF_INT == 64
00201 VIGRA_NUMPY_VALUETYPE_TRAITS(int,            NPY_INT64, int64, "DOUBLE")
00202 VIGRA_NUMPY_VALUETYPE_TRAITS(unsigned int,   NPY_UINT64, uint64, "DOUBLE")
00203 #endif
00204 
00205 #ifdef PY_LONG_LONG
00206 # if VIGRA_BITSOF_LONG_LONG == 32
00207 VIGRA_NUMPY_VALUETYPE_TRAITS(long long,            NPY_INT32, int32, "INT32")
00208 VIGRA_NUMPY_VALUETYPE_TRAITS(unsigned long long,   NPY_UINT32, uint32, "UINT32")
00209 # elif VIGRA_BITSOF_LONG_LONG == 64
00210 VIGRA_NUMPY_VALUETYPE_TRAITS(long long,          NPY_INT64, int64, "DOUBLE")
00211 VIGRA_NUMPY_VALUETYPE_TRAITS(unsigned long long, NPY_UINT64, uint64, "DOUBLE")
00212 # endif
00213 #endif
00214 
00215 VIGRA_NUMPY_VALUETYPE_TRAITS(npy_float32, NPY_FLOAT32, float32, "FLOAT")
00216 VIGRA_NUMPY_VALUETYPE_TRAITS(npy_float64, NPY_FLOAT64, float64, "DOUBLE")
00217 #if NPY_SIZEOF_LONGDOUBLE != NPY_SIZEOF_DOUBLE
00218 VIGRA_NUMPY_VALUETYPE_TRAITS(npy_longdouble, NPY_LONGDOUBLE, longdouble, "")
00219 #endif
00220 VIGRA_NUMPY_VALUETYPE_TRAITS(npy_cfloat, NPY_CFLOAT, complex64, "")
00221 VIGRA_NUMPY_VALUETYPE_TRAITS(std::complex<npy_float>, NPY_CFLOAT, complex64, "")
00222 VIGRA_NUMPY_VALUETYPE_TRAITS(npy_cdouble, NPY_CDOUBLE, complex128, "")
00223 VIGRA_NUMPY_VALUETYPE_TRAITS(std::complex<npy_double>, NPY_CDOUBLE, complex128, "")
00224 VIGRA_NUMPY_VALUETYPE_TRAITS(npy_clongdouble, NPY_CLONGDOUBLE, clongdouble, "")
00225 #if NPY_SIZEOF_LONGDOUBLE != NPY_SIZEOF_DOUBLE
00226 VIGRA_NUMPY_VALUETYPE_TRAITS(std::complex<npy_longdouble>, NPY_CLONGDOUBLE, clongdouble, "")
00227 #endif
00228 
00229 #undef VIGRA_NUMPY_VALUETYPE_TRAITS
00230 
00231 /********************************************************/
00232 /*                                                      */
00233 /*                  NumpyArrayTraits                    */
00234 /*                                                      */
00235 /********************************************************/
00236 
00237 template<unsigned int N, class T, class Stride>
00238 struct NumpyArrayTraits;
00239 
00240 /********************************************************/
00241 
00242 template<unsigned int N, class T>
00243 struct NumpyArrayTraits<N, T, StridedArrayTag>
00244 {
00245     typedef T dtype;
00246     typedef T value_type;
00247     typedef NumpyArrayValuetypeTraits<T> ValuetypeTraits;
00248     static NPY_TYPES const typeCode = ValuetypeTraits::typeCode;
00249 
00250     static bool isArray(PyObject * obj)
00251     {
00252         return obj && PyArray_Check(obj);
00253     }
00254 
00255     static bool isValuetypeCompatible(PyArrayObject * obj)  /* obj must not be NULL */
00256     {
00257         return ValuetypeTraits::isValuetypeCompatible(obj);
00258     }
00259 
00260     static bool isShapeCompatible(PyArrayObject * array) /* array must not be NULL */
00261     {
00262         PyObject * obj = (PyObject *)array;
00263         int ndim = PyArray_NDIM(obj);
00264 
00265         return ndim == N;
00266     }
00267 
00268     // The '*Compatible' functions are called whenever a NumpyArray is to be constructed
00269     // from a Python numpy.ndarray to check whether types and memory layout are
00270     // compatible. During overload resolution, boost::python iterates through the list
00271     // of overloads and invokes the first function where all arguments pass this check.
00272     static bool isPropertyCompatible(PyArrayObject * obj) /* obj must not be NULL */
00273     {
00274         return isShapeCompatible(obj) && isValuetypeCompatible(obj);
00275     }
00276     
00277     // Construct a tagged shape from a 'shape - axistags' pair (called in 
00278     // NumpyArray::taggedShape()).
00279     template <class U>
00280     static TaggedShape taggedShape(TinyVector<U, N> const & shape, PyAxisTags axistags)
00281     {
00282         return TaggedShape(shape, axistags);
00283     }
00284 
00285     // Construct a tagged shape from a 'shape - order' pair by creating
00286     // the appropriate axistags object for that order and NumpyArray type.
00287     // (called in NumpyArray constructors via NumpyArray::init())
00288     template <class U>
00289     static TaggedShape taggedShape(TinyVector<U, N> const & shape,
00290                                    std::string const & /* order */ = "")
00291     {
00292         // We ignore the 'order' parameter, because we don't know the axis meaning
00293         // in a plain array (use Singleband, Multiband, TinyVector etc. instead).
00294         // Since we also have no useful axistags in this case, we enforce
00295         // the result array to be a plain numpy.ndarray by passing empty axistags.
00296         return TaggedShape(shape, PyAxisTags());
00297     }
00298 
00299     // Adjust a TaggedShape that was created by another array to the properties of
00300     // the present NumpyArray type (called in NumpyArray::reshapeIfEmpty()).
00301     static void finalizeTaggedShape(TaggedShape & tagged_shape)
00302     {
00303         vigra_precondition(tagged_shape.size() == N,
00304                   "reshapeIfEmpty(): tagged_shape has wrong size.");
00305     }
00306     
00307     // This function is used to synchronize the axis re-ordering of 'data'
00308     // with that of 'array'. For example, when we want to apply Gaussian smoothing
00309     // with a different scale for each axis, 'data' would contains those scales,
00310     // and permuteLikewise() would make sure that the scales are applied to the right
00311     // axes, regardless of axis re-ordering.
00312     template <class ARRAY>
00313     static void permuteLikewise(python_ptr array, ARRAY const & data, ARRAY & res)
00314     {
00315         vigra_precondition((int)data.size() == N,
00316             "NumpyArray::permuteLikewise(): size mismatch.");
00317         
00318         ArrayVector<npy_intp> permute;
00319         detail::getAxisPermutationImpl(permute, array, "permutationToNormalOrder", 
00320                                        AxisInfo::AllAxes, true);
00321 
00322         if(permute.size() != 0)
00323         {
00324             applyPermutation(permute.begin(), permute.end(), data.begin(), res.begin());
00325         }
00326     }
00327     
00328     // This function is called in NumpyArray::setupArrayView() to determine the
00329     // desired axis re-ordering.
00330     template <class U>
00331     static void permutationToSetupOrder(python_ptr array, ArrayVector<U> & permute)
00332     {
00333         detail::getAxisPermutationImpl(permute, array, "permutationToNormalOrder", 
00334                                        AxisInfo::AllAxes, true);
00335 
00336         if(permute.size() == 0)
00337         {
00338             permute.resize(N);
00339             linearSequence(permute.begin(), permute.end());
00340         }
00341     }
00342 
00343     // This function is called in NumpyArray::makeUnsafeReference() to create
00344     // a numpy.ndarray view for a block of memory managed by C++.
00345     // The term 'unsafe' should remind you that memory management cannot be done
00346     // automatically, bu must be done explicitly by the programmer.
00347     template <class U>
00348     static python_ptr unsafeConstructorFromData(TinyVector<U, N> const & shape,
00349                                                 T *data, TinyVector<U, N> const & stride)
00350     {
00351         TinyVector<npy_intp, N> npyStride(stride * sizeof(T));
00352         return constructNumpyArrayFromData(shape, npyStride.begin(), 
00353                                                     ValuetypeTraits::typeCode, data);
00354     }
00355 };
00356 
00357 /********************************************************/
00358 
00359 template<unsigned int N, class T>
00360 struct NumpyArrayTraits<N, T, UnstridedArrayTag>
00361 : public NumpyArrayTraits<N, T, StridedArrayTag>
00362 {
00363     typedef NumpyArrayTraits<N, T, StridedArrayTag> BaseType;
00364     typedef typename BaseType::ValuetypeTraits ValuetypeTraits;
00365 
00366     static bool isShapeCompatible(PyArrayObject * array) /* obj must not be NULL */
00367     {
00368         PyObject * obj = (PyObject *)array;
00369         int ndim = PyArray_NDIM(obj);
00370         long channelIndex = pythonGetAttr(obj, "channelIndex", ndim);
00371         long majorIndex = pythonGetAttr(obj, "innerNonchannelIndex", ndim);
00372         npy_intp * strides = PyArray_STRIDES(obj);
00373         
00374         if(channelIndex < ndim)
00375         {
00376             // When we have a channel axis, it will become the innermost dimension
00377             return (ndim == N && strides[channelIndex] == sizeof(T));
00378         }
00379         else if(majorIndex < ndim)
00380         {
00381             // When we have axistags, but no channel axis, the major spatial
00382             // axis will be the innermost dimension
00383             return (ndim == N && strides[majorIndex] == sizeof(T));
00384         }
00385         else 
00386         {
00387             // When we have no axistags, the first axis will be the innermost dimension
00388             return (ndim == N && strides[0] == sizeof(T));
00389         }
00390     }
00391 
00392     static bool isPropertyCompatible(PyArrayObject * obj) /* obj must not be NULL */
00393     {
00394         return isShapeCompatible(obj) && BaseType::isValuetypeCompatible(obj);
00395     }
00396 };
00397 
00398 /********************************************************/
00399 
00400 template<unsigned int N, class T>
00401 struct NumpyArrayTraits<N, Singleband<T>, StridedArrayTag>
00402 : public NumpyArrayTraits<N, T, StridedArrayTag>
00403 {
00404     typedef NumpyArrayTraits<N, T, StridedArrayTag> BaseType;
00405     typedef typename BaseType::ValuetypeTraits ValuetypeTraits;
00406 
00407     static bool isShapeCompatible(PyArrayObject * array) /* array must not be NULL */
00408     {
00409         PyObject * obj = (PyObject *)array;
00410         int ndim = PyArray_NDIM(obj);
00411         long channelIndex = pythonGetAttr(obj, "channelIndex", ndim);
00412         
00413         // If we have no channel axis (because either we don't have axistags, 
00414         // or the tags do not contain a channel axis), ndim must match.
00415         if(channelIndex == ndim)
00416             return ndim == N;
00417             
00418         // Otherwise, the channel axis must be a singleton axis that we can drop.
00419         return ndim == N+1 && PyArray_DIM(obj, channelIndex) == 1;
00420     }
00421 
00422     static bool isPropertyCompatible(PyArrayObject * obj) /* obj must not be NULL */
00423     {
00424         return isShapeCompatible(obj) && BaseType::isValuetypeCompatible(obj);
00425     }
00426 
00427     template <class U>
00428     static TaggedShape taggedShape(TinyVector<U, N> const & shape, PyAxisTags axistags)
00429     {
00430         return TaggedShape(shape, axistags).setChannelCount(1);
00431     }
00432 
00433     template <class U>
00434     static TaggedShape taggedShape(TinyVector<U, N> const & shape, std::string const & order = "")
00435     {
00436         return TaggedShape(shape, 
00437                   PyAxisTags(detail::defaultAxistags(shape.size()+1, order))).setChannelCount(1);
00438     }
00439 
00440     static void finalizeTaggedShape(TaggedShape & tagged_shape)
00441     {
00442         if(tagged_shape.axistags.hasChannelAxis())
00443         {
00444             tagged_shape.setChannelCount(1);
00445             vigra_precondition(tagged_shape.size() == N+1,
00446                      "reshapeIfEmpty(): tagged_shape has wrong size.");
00447         }
00448         else
00449         {
00450             tagged_shape.setChannelCount(0);
00451             vigra_precondition(tagged_shape.size() == N,
00452                      "reshapeIfEmpty(): tagged_shape has wrong size.");
00453         }
00454     }
00455     
00456     template <class ARRAY>
00457     static void permuteLikewise(python_ptr array, ARRAY const & data, ARRAY & res)
00458     {
00459         vigra_precondition((int)data.size() == N,
00460             "NumpyArray::permuteLikewise(): size mismatch.");
00461         
00462         ArrayVector<npy_intp> permute;
00463         detail::getAxisPermutationImpl(permute, array, "permutationToNormalOrder", 
00464                                        AxisInfo::NonChannel, true);
00465 
00466         if(permute.size() == 0)
00467         {
00468             permute.resize(N);
00469             linearSequence(permute.begin(), permute.end());
00470         }
00471         
00472         applyPermutation(permute.begin(), permute.end(), data.begin(), res.begin());
00473     }
00474     
00475     template <class U>
00476     static void permutationToSetupOrder(python_ptr array, ArrayVector<U> & permute)
00477     {
00478         detail::getAxisPermutationImpl(permute, array, "permutationToNormalOrder", 
00479                                        AxisInfo::AllAxes, true);
00480         if(permute.size() == 0)
00481         {
00482             permute.resize(N);
00483             linearSequence(permute.begin(), permute.end());
00484         }
00485         else if(permute.size() == N+1)
00486         {
00487             permute.erase(permute.begin());
00488         }
00489     }
00490 };
00491 
00492 /********************************************************/
00493 
00494 template<unsigned int N, class T>
00495 struct NumpyArrayTraits<N, Singleband<T>, UnstridedArrayTag>
00496 : public NumpyArrayTraits<N, Singleband<T>, StridedArrayTag>
00497 {
00498     typedef NumpyArrayTraits<N, T, UnstridedArrayTag> UnstridedTraits;
00499     typedef NumpyArrayTraits<N, Singleband<T>, StridedArrayTag> BaseType;
00500     typedef typename BaseType::ValuetypeTraits ValuetypeTraits;
00501 
00502     static bool isShapeCompatible(PyArrayObject * array) /* obj must not be NULL */
00503     {
00504         PyObject * obj = (PyObject *)array;
00505         int ndim = PyArray_NDIM(obj);
00506         long channelIndex = pythonGetAttr(obj, "channelIndex", ndim);
00507         long majorIndex = pythonGetAttr(obj, "innerNonchannelIndex", ndim);
00508         npy_intp * strides = PyArray_STRIDES(obj);
00509         
00510         // If we have no axistags, ndim must match, and axis 0 must be unstrided.
00511         if(majorIndex == ndim) 
00512             return N == ndim && strides[0] == sizeof(T);
00513             
00514         // If we have axistags, but no channel axis, ndim must match, 
00515         // and the major non-channel axis must be unstrided.
00516         if(channelIndex == ndim) 
00517             return N == ndim && strides[majorIndex] == sizeof(T);
00518             
00519         // Otherwise, the channel axis must be a singleton axis that we can drop,
00520         // and the major non-channel axis must be unstrided.
00521         return ndim == N+1 && PyArray_DIM(obj, channelIndex) == 1 && 
00522                 strides[majorIndex] == sizeof(T);
00523     }
00524 
00525     static bool isPropertyCompatible(PyArrayObject * obj) /* obj must not be NULL */
00526     {
00527         return isShapeCompatible(obj) && BaseType::isValuetypeCompatible(obj);
00528     }
00529 };
00530 
00531 /********************************************************/
00532 
00533 template<unsigned int N, class T>
00534 struct NumpyArrayTraits<N, Multiband<T>, StridedArrayTag>
00535 : public NumpyArrayTraits<N, T, StridedArrayTag>
00536 {
00537     typedef NumpyArrayTraits<N, T, StridedArrayTag> BaseType;
00538     typedef typename BaseType::ValuetypeTraits ValuetypeTraits;
00539 
00540     static bool isShapeCompatible(PyArrayObject * array) /* array must not be NULL */
00541     {
00542         PyObject * obj = (PyObject*)array;
00543         int ndim = PyArray_NDIM(obj);
00544         long channelIndex = pythonGetAttr(obj, "channelIndex", ndim);
00545         long majorIndex = pythonGetAttr(obj, "innerNonchannelIndex", ndim);
00546         
00547         if(channelIndex < ndim)
00548         {
00549             // When we have a channel axis, ndim must match.
00550             return ndim == N;
00551         }
00552         else if(majorIndex < ndim)
00553         {
00554             // When we have axistags, but no channel axis, we must add a singleton axis.
00555             return ndim == N-1;
00556         }
00557         else
00558         {
00559             // When we have no axistags, we may add a singleton dimension.
00560             return ndim == N || ndim == N-1;
00561         }
00562     }
00563 
00564     static bool isPropertyCompatible(PyArrayObject * obj) /* obj must not be NULL */
00565     {
00566         return isShapeCompatible(obj) && ValuetypeTraits::isValuetypeCompatible(obj);
00567     }
00568 
00569     template <class U>
00570     static TaggedShape taggedShape(TinyVector<U, N> const & shape, PyAxisTags axistags)
00571     {
00572         return TaggedShape(shape, axistags).setChannelIndexLast();
00573     }
00574 
00575     template <class U>
00576     static TaggedShape taggedShape(TinyVector<U, N> const & shape, std::string const & order = "")
00577     {
00578         return TaggedShape(shape, 
00579                     PyAxisTags(detail::defaultAxistags(shape.size(), order))).setChannelIndexLast();
00580     }
00581 
00582     static void finalizeTaggedShape(TaggedShape & tagged_shape)
00583     {
00584         // When there is only one channel, and the axistags don't enforce an
00585         // explicit channel axis, we return an array without explicit channel axis.
00586         if(tagged_shape.channelCount() == 1 && !tagged_shape.axistags.hasChannelAxis())
00587         {
00588             tagged_shape.setChannelCount(0);
00589             vigra_precondition(tagged_shape.size() == N-1,
00590                   "reshapeIfEmpty(): tagged_shape has wrong size.");
00591         }
00592         else
00593         {
00594             vigra_precondition(tagged_shape.size() == N,
00595                   "reshapeIfEmpty(): tagged_shape has wrong size.");
00596         }
00597     }
00598 
00599     template <class ARRAY>
00600     static void permuteLikewise(python_ptr array, ARRAY const & data, ARRAY & res)
00601     {
00602         ArrayVector<npy_intp> permute;
00603         
00604         if((int)data.size() == N)
00605         {
00606             vigra_precondition(PyArray_NDIM((PyArrayObject*)array.get()) == N,
00607                 "NumpyArray::permuteLikewise(): input array has no channel axis.");
00608 
00609             detail::getAxisPermutationImpl(permute, array, "permutationToNormalOrder", 
00610                                            AxisInfo::AllAxes, true);
00611 
00612             if(permute.size() == 0)
00613             {
00614                 permute.resize(N);
00615                 linearSequence(permute.begin(), permute.end());
00616             }
00617             else
00618             {
00619                 // rotate channel axis to last position
00620                 int channelIndex = permute[0];
00621                 for(int k=1; k<N; ++k)
00622                     permute[k-1] = permute[k];
00623                 permute[N-1] = channelIndex;
00624             }
00625         }
00626         else
00627         {
00628             vigra_precondition((int)data.size() == N-1,
00629                 "NumpyArray::permuteLikewise(): size mismatch.");
00630 
00631             detail::getAxisPermutationImpl(permute, array, "permutationToNormalOrder", 
00632                                            AxisInfo::NonChannel, true);
00633 
00634             if(permute.size() == 0)
00635             {
00636                 permute.resize(N-1);
00637                 linearSequence(permute.begin(), permute.end());
00638             }
00639         }
00640         
00641         applyPermutation(permute.begin(), permute.end(), data.begin(), res.begin());
00642     }
00643     
00644     template <class U>
00645     static void permutationToSetupOrder(python_ptr array, ArrayVector<U> & permute)
00646     {
00647         detail::getAxisPermutationImpl(permute, array, "permutationToNormalOrder", 
00648                                        AxisInfo::AllAxes, true);
00649 
00650         if(permute.size() == 0)
00651         {
00652             permute.resize(PyArray_NDIM((PyArrayObject*)array.get()));
00653             linearSequence(permute.begin(), permute.end());
00654         }
00655         else if(permute.size() == N)
00656         {
00657             // if we have a channel axis, rotate it to last position
00658             int channelIndex = permute[0];
00659             for(int k=1; k<N; ++k)
00660                 permute[k-1] = permute[k];
00661             permute[N-1] = channelIndex;
00662         }
00663     }
00664 };
00665 
00666 /********************************************************/
00667 
00668 template<unsigned int N, class T>
00669 struct NumpyArrayTraits<N, Multiband<T>, UnstridedArrayTag>
00670 : public NumpyArrayTraits<N, Multiband<T>, StridedArrayTag>
00671 {
00672     typedef NumpyArrayTraits<N, Multiband<T>, StridedArrayTag> BaseType;
00673     typedef typename BaseType::ValuetypeTraits ValuetypeTraits;
00674 
00675     static bool isShapeCompatible(PyArrayObject * array) /* obj must not be NULL */
00676     {
00677         PyObject * obj = (PyObject *)array;
00678         int ndim = PyArray_NDIM(obj);
00679         long channelIndex = pythonGetAttr(obj, "channelIndex", ndim);
00680         long majorIndex = pythonGetAttr(obj, "innerNonchannelIndex", ndim);
00681         npy_intp * strides = PyArray_STRIDES(obj);
00682 
00683         if(channelIndex < ndim)
00684         {
00685             // When we have a channel axis, ndim must match, and the major non-channel
00686             // axis must be unstrided.
00687             return ndim == N && strides[majorIndex] == sizeof(T);
00688         }
00689         else if(majorIndex < ndim)
00690         {
00691             // When we have axistags, but no channel axis, we will add a
00692             // singleton channel axis, and the major non-channel axis must be unstrided.
00693             return ndim == N-1 && strides[majorIndex] == sizeof(T);
00694         }
00695         else
00696         {
00697             // When we have no axistags, axis 0 must be unstrided, but we
00698             // may add a singleton dimension at the end.
00699             return (ndim == N || ndim == N-1) && strides[0] == sizeof(T);
00700         }
00701     }
00702 
00703     static bool isPropertyCompatible(PyArrayObject * obj) /* obj must not be NULL */
00704     {
00705         return isShapeCompatible(obj) && BaseType::isValuetypeCompatible(obj);
00706     }
00707 };
00708 
00709 /********************************************************/
00710 
00711 template<unsigned int N, int M, class T>
00712 struct NumpyArrayTraits<N, TinyVector<T, M>, StridedArrayTag>
00713 {
00714     typedef T dtype;
00715     typedef TinyVector<T, M> value_type;
00716     typedef NumpyArrayValuetypeTraits<T> ValuetypeTraits;
00717     static NPY_TYPES const typeCode = ValuetypeTraits::typeCode;
00718 
00719     static bool isArray(PyObject * obj)
00720     {
00721         return obj && PyArray_Check(obj);
00722     }
00723 
00724     static bool isValuetypeCompatible(PyArrayObject * obj)  /* obj must not be NULL */
00725     {
00726         return ValuetypeTraits::isValuetypeCompatible(obj);
00727     }
00728 
00729     static bool isShapeCompatible(PyArrayObject * array) /* array must not be NULL */
00730     {
00731         PyObject * obj = (PyObject *)array;
00732         
00733          // We need an extra channel axis.
00734          if(PyArray_NDIM(obj) != N+1)
00735             return false;
00736             
00737         // When there are no axistags, we assume that the last axis represents the channels.
00738         long channelIndex = pythonGetAttr(obj, "channelIndex", N);
00739         npy_intp * strides = PyArray_STRIDES(obj);
00740         
00741         return PyArray_DIM(obj, channelIndex) == M && strides[channelIndex] == sizeof(T);
00742     }
00743 
00744     static bool isPropertyCompatible(PyArrayObject * obj) /* obj must not be NULL */
00745     {
00746         return isShapeCompatible(obj) && ValuetypeTraits::isValuetypeCompatible(obj);
00747     }
00748 
00749     template <class U>
00750     static TaggedShape taggedShape(TinyVector<U, N> const & shape, PyAxisTags axistags)
00751     {
00752         return TaggedShape(shape, axistags).setChannelCount(M);
00753     }
00754 
00755     template <class U>
00756     static TaggedShape taggedShape(TinyVector<U, N> const & shape, std::string const & order = "")
00757     {
00758         return TaggedShape(shape, 
00759                      PyAxisTags(detail::defaultAxistags(shape.size()+1, order))).setChannelCount(M);
00760     }
00761 
00762     static void finalizeTaggedShape(TaggedShape & tagged_shape)
00763     {
00764         tagged_shape.setChannelCount(M);
00765         vigra_precondition(tagged_shape.size() == N+1,
00766               "reshapeIfEmpty(): tagged_shape has wrong size.");
00767     }
00768 
00769     template <class ARRAY>
00770     static void permuteLikewise(python_ptr array, ARRAY const & data, ARRAY & res)
00771     {
00772         vigra_precondition((int)data.size() == N,
00773             "NumpyArray::permuteLikewise(): size mismatch.");
00774         
00775         ArrayVector<npy_intp> permute;
00776         detail::getAxisPermutationImpl(permute, array, "permutationToNormalOrder", 
00777                                        AxisInfo::NonChannel, true);
00778 
00779         if(permute.size() == 0)
00780         {
00781             permute.resize(N);
00782             linearSequence(permute.begin(), permute.end());
00783         }
00784         
00785         applyPermutation(permute.begin(), permute.end(), data.begin(), res.begin());
00786     }
00787     
00788     template <class U>
00789     static void permutationToSetupOrder(python_ptr array, ArrayVector<U> & permute)
00790     {
00791         detail::getAxisPermutationImpl(permute, array, "permutationToNormalOrder", 
00792                                        AxisInfo::AllAxes, true);
00793         if(permute.size() == 0)
00794         {
00795             permute.resize(N);
00796             linearSequence(permute.begin(), permute.end());
00797         }
00798         else if(permute.size() == N+1)
00799         {
00800             permute.erase(permute.begin());
00801         }
00802     }
00803     
00804     template <class U>
00805     static python_ptr unsafeConstructorFromData(TinyVector<U, N> const & shape,
00806                                                 value_type *data, TinyVector<U, N> const & stride)
00807     {
00808         TinyVector<npy_intp, N+1> npyShape;
00809         std::copy(shape.begin(), shape.end(), npyShape.begin());
00810         npyShape[N] = M;
00811 
00812         TinyVector<npy_intp, N+1> npyStride;
00813         std::transform(
00814             stride.begin(), stride.end(), npyStride.begin(),
00815             std::bind2nd(std::multiplies<npy_intp>(), sizeof(value_type)));
00816         npyStride[N] = sizeof(T);
00817 
00818         return constructNumpyArrayFromData(npyShape, npyStride.begin(), 
00819                                                     ValuetypeTraits::typeCode, data);
00820     }
00821 };
00822 
00823 /********************************************************/
00824 
00825 template<unsigned int N, int M, class T>
00826 struct NumpyArrayTraits<N, TinyVector<T, M>, UnstridedArrayTag>
00827 : public NumpyArrayTraits<N, TinyVector<T, M>, StridedArrayTag>
00828 {
00829     typedef NumpyArrayTraits<N, TinyVector<T, M>, StridedArrayTag> BaseType;
00830     typedef typename BaseType::value_type value_type;
00831     typedef typename BaseType::ValuetypeTraits ValuetypeTraits;
00832 
00833     static bool isShapeCompatible(PyArrayObject * array) /* obj must not be NULL */
00834     {
00835         PyObject * obj = (PyObject *)array;
00836         int ndim = PyArray_NDIM(obj);
00837         
00838          // We need an extra channel axis. 
00839         if(ndim != N+1)
00840             return false;
00841             
00842         long channelIndex = pythonGetAttr(obj, "channelIndex", ndim);
00843         long majorIndex = pythonGetAttr(obj, "innerNonchannelIndex", ndim);
00844         npy_intp * strides = PyArray_STRIDES(obj);
00845         
00846         if(majorIndex < ndim)
00847         {
00848             // We have axistags, but no channel axis => cannot be a TinyVector image
00849             if(channelIndex == ndim)
00850                 return false;
00851                 
00852             // We have an explicit channel axis => shapes and strides must match
00853             return PyArray_DIM(obj, channelIndex) == M && 
00854                    strides[channelIndex] == sizeof(T) &&
00855                    strides[majorIndex] == sizeof(TinyVector<T, M>);
00856             
00857             
00858         }
00859         else
00860         {
00861             // we have no axistags => we assume that the channel axis is last
00862             return PyArray_DIM(obj, N) == M && 
00863                    strides[N] == sizeof(T) &&
00864                    strides[0] == sizeof(TinyVector<T, M>);
00865         }
00866     }
00867 
00868     static bool isPropertyCompatible(PyArrayObject * obj) /* obj must not be NULL */
00869     {
00870         return isShapeCompatible(obj) && BaseType::isValuetypeCompatible(obj);
00871     }
00872 };
00873 
00874 /********************************************************/
00875 
00876 template<unsigned int N, class T>
00877 struct NumpyArrayTraits<N, RGBValue<T>, StridedArrayTag>
00878 : public NumpyArrayTraits<N, TinyVector<T, 3>, StridedArrayTag>
00879 {
00880     typedef T dtype;
00881     typedef RGBValue<T> value_type;
00882     typedef NumpyArrayValuetypeTraits<T> ValuetypeTraits;
00883 };
00884 
00885 /********************************************************/
00886 
00887 template<unsigned int N, class T>
00888 struct NumpyArrayTraits<N, RGBValue<T>, UnstridedArrayTag>
00889 : public NumpyArrayTraits<N, RGBValue<T>, StridedArrayTag>
00890 {
00891     typedef NumpyArrayTraits<N, TinyVector<T, 3>, UnstridedArrayTag> UnstridedTraits;
00892     typedef NumpyArrayTraits<N, RGBValue<T>, StridedArrayTag> BaseType;
00893     typedef typename BaseType::value_type value_type;
00894     typedef typename BaseType::ValuetypeTraits ValuetypeTraits;
00895 
00896     static bool isShapeCompatible(PyArrayObject * obj) /* obj must not be NULL */
00897     {
00898         return UnstridedTraits::isShapeCompatible(obj);
00899     }
00900 
00901     static bool isPropertyCompatible(PyArrayObject * obj) /* obj must not be NULL */
00902     {
00903         return UnstridedTraits::isPropertyCompatible(obj);
00904     }
00905 };
00906 
00907 } // namespace vigra
00908 
00909 #endif // VIGRA_NUMPY_ARRAY_TRAITS_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.9.0 (Tue Nov 6 2012)