12

I am using Cython to wrap a set of C++ classes, allowing a Python interface to them. Example Code is provided below:

BaseClass.h:

#ifndef __BaseClass__
#define __BaseClass__
#include <stdio.h>
#include <stdlib.h>
#include <string>
using namespace std;
class BaseClass
{
    public:
        BaseClass(){};
        virtual ~BaseClass(){};
        virtual void SetName(string name){printf("in base set name\n");}
        virtual float Evaluate(float time){printf("in base Evaluate\n");return 0;}
        virtual bool DataExists(){printf("in base data exists\n");return false;}
};
#endif /* defined(__BaseClass__) */

DerivedClass.h:

#ifndef __DerivedClass__
#define __DerivedClass__

#include "BaseClass.h"

class DerivedClass:public BaseClass
{
    public:
        DerivedClass(){};
        virtual ~DerivedClass(){};
        virtual float Evaluate(float time){printf("in derived Evaluate\n");return 1;}
        virtual bool DataExists(){printf("in derived data exists\n");return true;}
        virtual void MyFunction(){printf("in my function\n");}
        virtual void SetObject(BaseClass *input){printf("in set object\n");}
};
#endif /* defined(__DerivedClass__) */

NextDerivedClass.h:

#ifndef __NextDerivedClass__
#define __NextDerivedClass__

#include "DerivedClass.h"

class NextDerivedClass:public DerivedClass
{
    public:
        NextDerivedClass(){};
        virtual ~NextDerivedClass(){};
        virtual void SetObject(BaseClass *input){printf("in set object of next derived class\n");}
};
#endif /* defined(__NextDerivedClass__) */

inheritTest.pyx:

cdef extern from "BaseClass.h":
cdef cppclass BaseClass:
    BaseClass() except +
    void SetName(string)
    float Evaluate(float)
    bool DataExists()

cdef extern from "DerivedClass.h":
    cdef cppclass DerivedClass(BaseClass):
        DerivedClass() except +
        void MyFunction()
        float Evaluate(float)
        bool DataExists()
        void SetObject(BaseClass *)

cdef extern from "NextDerivedClass.h":
    cdef cppclass NextDerivedClass(DerivedClass):
        NextDerivedClass() except +
        # ***  The issue is right here ***
        void SetObject(BaseClass *)

cdef class PyBaseClass:
    cdef BaseClass *thisptr
    def __cinit__(self):
        if type(self) is PyBaseClass:
            self.thisptr = new BaseClass()
    def __dealloc__(self):
        if type(self) is PyBaseClass:
            del self.thisptr

cdef class PyDerivedClass(PyBaseClass):
    cdef DerivedClass *derivedptr
    def __cinit__(self):
        self.derivedptr = self.thisptr = new DerivedClass()
    def __dealloc__(self):
        del self.derivedptr
    # def Evaluate(self, time):
    #     return self.derivedptr.Evaluate(time)
    def SetObject(self, PyBaseClass inputObject):
         self.derivedptr.SetObject(<BaseClass *>inputObject.thisptr)

cdef class PyNextDerivedClass(PyDerivedClass):
    cdef NextDerivedClass *nextDerivedptr
    def __cinit__(self):
        self.nextDerivedptr = self.thisptr = new NextDerivedClass()
    def __dealloc__(self):
        del self.nextDerivedptr
    def SetObject(self, PyBaseClass input):
        self.nextDerivedptr.SetObject(<BaseClass *>input.thisptr)

I want to be able to call SetObject in Python similar to as shown below:

main.py:

from inheritTest import PyBaseClass as base
from inheritTest import PyDerivedClass as der
from inheritTest import PyNextDerivedClass as nextDer

#This works now!
a = der()
b = der()
a.SetObject(b)

#This doesn't work -- keeping the function declaration causes a overloaded error, not keeping it means the call below works, but it calls the inherited implementation (From derived class)
c = nextDer()
c.SetObject(b)

I thought it would work since the classes inherit from each other, but its giving me the following error:

Argument has incorrect type: expected PyBaseClass, got PyDerivedClass

Not specifying type in the function definition makes it think that the inputObject is a pure Python object (has no C-based attributes, which it does), in which case the error is:

*Cannot convert Python object to BaseClass *

A sort-of hacky workaround to this just to have Python functions with different names that expect different types of arguments (ex: SetObjectWithBase, SetObjectWithDerived), and then within their implementation, just call the same C-based function having type-casted the input. I know for a fact this works, but I would like to avoid having to do this as much as possible. Even if there is a way I can catch the Type Error within the function, and deal with it inside, I think that might work, but I wasn't sure exactly how to implement that.

Hope this question makes sense, let me know if you require additional information.

****EDIT****: Code has been edited such that basic inheritance works. After playing around with it a bit more, I realize that the problem is occurring for multiple levels of inheritance, for example, see edited code above. Basically, keeping the declaration for SetObject for the NextDerivedClass causes a "Ambiguous Overloaded Method" error, not keeping it allows me to call the function on the object, but it calls the inherited implementation (from DerivedClass). **

jeet.m
  • 553
  • 4
  • 15

3 Answers3

20

After a lot of help from the answers below, and experimentation, I think I understand how implementing basic inheritance within Cython works, I'm answering my own question to validate/improve my understanding, as well as hopefully help out anyone who in the future may encounter a related issue. If there is anything wrong with this explanation, feel free to correct me in the comments below, and I will edit it. I don't think this is the only way to do it, so I'm sure alternate methods work, but this is the way that worked for me.

Overview/Things Learnt:

So basically, from my understanding, Cython is smart enough (given the appropriate information) to traverse through the inheritance hiearchy/tree and call the appropriate implementation of a virtual function based on the type of the object that you are calling it on.

The important thing is to try and mirror the C++ inheritance structure which you are trying to wrap in your .pyx file. This means that ensuring:

1) Imported C++/Cython cppclasses (the ones which are declared as cdef extern from) inherit each other the same way the actual C++ classes do

2) Only unique methods/member variables are declared for each imported class (should not have a function declaration for both BaseClass and DerivedClass for a virtual function that is implemented differently in the two classes). As long as one inherits from the other, the function declaration only needs to be in the Base imported class.

3) Python wrapper classes (ie. PyBaseClass / PyDerivedClass) should also inherit from each other the same way the actual C++ classes do

4) Similar to above, the interface to a virtual function only needs to exist in the PyBase wrapper class (should not be putting in both classes, the correct implementation will be called when you actually run the code).

5) For each Python wrapper class that is subclassed or inherited from, you need a if type(self) is class-name: check in both the __cinit__() and the __dealloc__() functions. This will prevent seg-faults etc. You don't need this check for "leaf-nodes" in the hiearchy tree (classes which won't be inherited from or subclassed)

6) Make sure that in the __dealloc__() function, you only delete the current pointer (and not any inherited ones)

7) Again, in the __cinit__(), for inherited classes make sure you set the current pointer, as well as all derived pointers to an object of the type you are trying to create (ie. *self.nextDerivedptr = self.derivedptr = self.thisptr = new NextDerivedClass()*)

Hopefully the above points make a lot of sense when you see the code below, this compiles and runs/works as I need/intend it to work.

BaseClass.h:

#ifndef __BaseClass__
#define __BaseClass__

#include <stdio.h>
#include <stdlib.h>
#include <string>

using namespace std;

class BaseClass
{
    public:
        BaseClass(){};
        virtual ~BaseClass(){};
        virtual void SetName(string name){printf("BASE: in set name\n");}
        virtual float Evaluate(float time){printf("BASE: in Evaluate\n");return 0;}
        virtual bool DataExists(){printf("BASE: in data exists\n");return false;}
};
#endif /* defined(__BaseClass__) */ 

DerivedClass.h:

#ifndef __DerivedClass__
#define __DerivedClass__

#include "BaseClass.h"
#include "string.h"

using namespace std;

class DerivedClass:public BaseClass
{
    public:
        DerivedClass(){};
        virtual ~DerivedClass(){};
        virtual void SetName(string name){printf("DERIVED CLASS: in Set name \n");}
        virtual float Evaluate(float time){printf("DERIVED CLASS: in Evaluate\n");return 1.0;}
        virtual bool DataExists(){printf("DERIVED CLASS:in data exists\n");return true;}
        virtual void MyFunction(){printf("DERIVED CLASS: in my function\n");}
        virtual void SetObject(BaseClass *input){printf("DERIVED CLASS: in set object\n");}
};
#endif /* defined(__DerivedClass__) */

NextDerivedClass.h:

    #ifndef __NextDerivedClass__
    #define __NextDerivedClass__

    #include "DerivedClass.h"

    class NextDerivedClass:public DerivedClass
    {
        public:
            NextDerivedClass(){};
            virtual ~NextDerivedClass(){};
            virtual void SetObject(BaseClass *input){printf("NEXT DERIVED CLASS: in set object\n");}
            virtual bool DataExists(){printf("NEXT DERIVED CLASS: in data exists \n");return true;}
    };
    #endif /* defined(__NextDerivedClass__) */

inheritTest.pyx:

#Necessary Compilation Options
#distutils: language = c++
#distutils: extra_compile_args = ["-std=c++11", "-g"]

#Import necessary modules
from libcpp cimport bool
from libcpp.string cimport string
from libcpp.map cimport map
from libcpp.pair cimport pair
from libcpp.vector cimport vector

cdef extern from "BaseClass.h":
    cdef cppclass BaseClass:
        BaseClass() except +
        void SetName(string)
        float Evaluate(float)
        bool DataExists()

cdef extern from "DerivedClass.h":
    cdef cppclass DerivedClass(BaseClass):
        DerivedClass() except +
        void MyFunction()
        void SetObject(BaseClass *)

cdef extern from "NextDerivedClass.h":
    cdef cppclass NextDerivedClass(DerivedClass):
        NextDerivedClass() except +

cdef class PyBaseClass:
    cdef BaseClass *thisptr
    def __cinit__(self):
        if type(self) is PyBaseClass:
            self.thisptr = new BaseClass()
    def __dealloc__(self):
        if type(self) is PyBaseClass:
            del self.thisptr
    def SetName(self, name):
        self.thisptr.SetName(name)
    def Evaluate(self, time):
        return self.thisptr.Evaluate(time)
    def DataExists(self):
        return self.thisptr.DataExists()

cdef class PyDerivedClass(PyBaseClass):
    cdef DerivedClass *derivedptr
    def __cinit__(self):
        if type(self) is PyDerivedClass:
            self.derivedptr = self.thisptr = new DerivedClass()
    def __dealloc__(self):
        if type(self) is PyBaseClass:
            del self.derivedptr
    def SetObject(self, PyBaseClass inputObject):
        self.derivedptr.SetObject(<BaseClass *>inputObject.thisptr)
    def MyFunction(self):
        self.derivedptr.MyFunction()

cdef class PyNextDerivedClass(PyDerivedClass):
    cdef NextDerivedClass *nextDerivedptr
    def __cinit__(self):
        self.nextDerivedptr = self.derivedptr = self.thisptr = new NextDerivedClass()
    def __dealloc__(self):
        del self.nextDerivedptr

test.py:

from inheritTest import PyBaseClass as base
from inheritTest import PyDerivedClass as der
from inheritTest import PyNextDerivedClass as nextDer

a = der()
b = der()
a.SetObject(b)
c = nextDer()
a.SetObject(c)
c.DataExists()
c.SetObject(b)
c.Evaluate(0.3)


baseSig = base()
signal = der()
baseSig.SetName('test')
signal.SetName('testingone')
baseSig.Evaluate(0.3)
signal.Evaluate(0.5)
signal.SetObject(b)
baseSig.DataExists()
signal.DataExists()

Notice that when I call:

c = nextDer()
c.Evaluate(0.3)

The way it works is Cython goes down the inheritance tree to look for the "latest" implementation of Evaluate. If it existed in NextDerivedClass.h, it would call that (I have tried that and it works), since it's not there however, it goes one step up and checks DerivedClass. The function is implemented there, thus the output is:

>> DERIVED CLASS: in Evaluate

I hope this helps someone in the future, again, if there are errors in my understanding or just grammar/syntax, feel free to comment below and I will try and address them. Again, big thanks to those who answered below, this is sort of a summary of their answers, just to help validate my understanding. Thanks!

the_drow
  • 18,571
  • 25
  • 126
  • 193
jeet.m
  • 553
  • 4
  • 15
  • Nice example. Thanks for posting. – IanH Feb 25 '15 at 19:11
  • Note that if your constructor (`__cinit__`) takes any parameters, you'll want all `__cinit__`'s in the hierarchy to swallow extra-parameters: `*a, **kw`. – LucasB Nov 09 '15 at 15:40
  • Very very very useful information. I tried to google around for several hours and this is the best answer. Thank you jeet.m! – Shi B. Nov 27 '15 at 08:40
  • This answer should make it to the Cython docs, it is extremely well described. – ibarrond Jul 02 '21 at 16:18
1

Your code, as written, doesn't compile. I suspect that your real PyDerivedClass doesn't really derive from PyBaseClass as if it did that last line would have to be

(<DerivedClass*>self.thisptr).SetObject(inputObject.thisptr)

This would also explain the type error you're getting, which is a bug I can't reproduce.

robertwb
  • 4,891
  • 18
  • 21
  • Yes, I have fixed that, but pleas see the edit above, I am getting a similar issue for multiple levels of inheritance – jeet.m Feb 20 '15 at 19:20
  • I still had trouble compiling your code as written. Fixed at https://gist.github.com/robertwb/1939f95f579cd8e80505 Once I got it to compile, it works as expected (no type error, calling the right SetObject, i.e. it prints " in set object in set object of next derived class " – robertwb Feb 25 '15 at 08:09
  • I would also note that having a new pointer for each level of inheritance might be dangerous, you don't set self.derivedptr in the last class, which could cause segfaults, and every __dealloc__ is called, so you'd have double (or triple) deletes. – robertwb Feb 25 '15 at 08:15
  • Thanks so much for your help, I have sort of compiled everything I learnt from the answers people have posted, as well from experimenting in an answer below, its to better my understanding as well as help anyone else in the future. Please look at it and let me know if there are any issues. Thanks again! – jeet.m Feb 25 '15 at 19:05
0

Honestly, this looks like a bug. The object you're passing in is an instance of the desired class, but it still throws an error. You may want to bring it up on the cython-users mailing list so the main developers can look at it.

A possible workaround would be to define a fused type that represents both types of arguments and use that inside the method. That seems like overkill though.

IanH
  • 10,250
  • 1
  • 28
  • 32
  • Yeah I just posted on the cython-users group, let's see what happens. I would really like to avoid using fused-types if possible. Thanks for the help, if anyone else doesn't reply with a solution in a few days, or if it turns out to be a bug, I'll just accept this answer. Thanks again – jeet.m Feb 19 '15 at 20:28
  • Please see my edit, your inheritance answer on my other question solved this issue for one level, but I seem to be getting a similar problem for multiple levels of inheritance. Any ideas? – jeet.m Feb 20 '15 at 19:19
  • 1
    Thanks for your help! I have compiled everything I have found out related to this issue in one answer, please let me know if there are any issues. Thanks again! – jeet.m Feb 25 '15 at 19:05