diff --git a/pythondp3/PyStep.cc b/pythondp3/PyStep.cc index e4ea295b84329bca8cac7d7d4e70abedf548262a..ce8b1941d5106e084468098582af7247171e8df5 100644 --- a/pythondp3/PyStep.cc +++ b/pythondp3/PyStep.cc @@ -42,8 +42,6 @@ void PyStepImpl::finish() { } bool PyStepImpl::process(const DPBuffer& bufin) { - m_count++; - // Make a deep copy of the buffer to make the data // persistent across multiple process calls // This is not always necessary, but for python Steps @@ -52,26 +50,26 @@ bool PyStepImpl::process(const DPBuffer& bufin) { dpbuffer->copy(bufin); PYBIND11_OVERRIDE_PURE( - bool, /* Return type */ - StepWrapper, /* Parent class */ - process, /* Name of function in C++ (must match Python name) */ - dpbuffer /* Argument(s) */ + bool, /* Return type */ + Step, /* Parent class */ + process, /* Name of function in C++ (must match Python name) */ + dpbuffer /* Argument(s) */ ); } common::Fields PyStepImpl::getRequiredFields() const { - PYBIND11_OVERRIDE_PURE_NAME(common::Fields, StepWrapper, - "get_required_fields", getRequiredFields, ); + PYBIND11_OVERRIDE_PURE_NAME(common::Fields, Step, "get_required_fields", + getRequiredFields, ); } common::Fields PyStepImpl::getProvidedFields() const { - PYBIND11_OVERRIDE_PURE_NAME(common::Fields, StepWrapper, - "get_provided_fields", getProvidedFields, ); + PYBIND11_OVERRIDE_PURE_NAME(common::Fields, Step, "get_provided_fields", + getProvidedFields, ); } void PyStepImpl::updateInfo(const DPInfo& dpinfo) { PYBIND11_OVERRIDE_NAME(void, /* Return type */ - StepWrapper, /* Parent class */ + Step, /* Parent class */ "update_info", /* Name of function in Python */ updateInfo, /* Name of function in C++ */ dpinfo /* Argument(s) */ diff --git a/pythondp3/PyStepImpl.h b/pythondp3/PyStepImpl.h index 6d921c1c8e378d9cd0ba8896f7c4a081e4e71f3f..ceb8d57f375e83001dc65761e484518b53d1ce07 100644 --- a/pythondp3/PyStepImpl.h +++ b/pythondp3/PyStepImpl.h @@ -30,33 +30,18 @@ class ostream_wrapper { std::ostream& os_; }; +/** + * Wrapper class that exposes protected members, so pybind11 can use those. + */ class StepWrapper : public steps::Step { public: using Step::info; - using Step::Step; using Step::updateInfo; - - using Step::getNextStep; - - Step::ShPtr get_next_step() { return Step::ShPtr(getNextStep()); } - - bool process_next_step(const base::DPBuffer& dpbuffer) { - return get_next_step()->process(dpbuffer); - } - - int get_count() { return m_count; } - void set_parset(const common::ParameterSet& parset) { m_parset = parset; }; - void set_name(const string& name) { m_name = name; }; - - protected: - int m_count = 0; - steps::InputStep* m_input; - common::ParameterSet m_parset; - string m_name; - const base::DPBuffer* m_dpbuffer_in; - common::NSTimer m_timer; }; +/** + * Trampoline class for python steps. + */ class PyStepImpl final : public StepWrapper { public: using StepWrapper::StepWrapper; diff --git a/pythondp3/pydp3.cc b/pythondp3/pydp3.cc index 4568cd2003d846d4cea334e88cf8d9c90abee72e..545042a6bdb006b00eaa869ae22b47b49d164c96 100644 --- a/pythondp3/pydp3.cc +++ b/pythondp3/pydp3.cc @@ -16,6 +16,7 @@ using dp3::base::DPBuffer; using dp3::base::DPInfo; using dp3::common::Fields; +using dp3::steps::Step; namespace py = pybind11; @@ -23,9 +24,9 @@ namespace dp3 { namespace pythondp3 { template <typename T> -void register_cube(py::module &m, const char *name) { +void register_cube(py::module& m, const char* name) { py::class_<casacore::Cube<T>>(m, name, py::buffer_protocol()) - .def_buffer([](casacore::Cube<T> &cube) -> py::buffer_info { + .def_buffer([](casacore::Cube<T>& cube) -> py::buffer_info { return py::buffer_info( cube.data(), /* Pointer to buffer */ sizeof(T), /* Size of one scalar */ @@ -41,9 +42,9 @@ void register_cube(py::module &m, const char *name) { } template <typename T> -void register_matrix(py::module &m, const char *name) { +void register_matrix(py::module& m, const char* name) { py::class_<casacore::Matrix<T>>(m, name, py::buffer_protocol()) - .def_buffer([](casacore::Matrix<T> &matrix) -> py::buffer_info { + .def_buffer([](casacore::Matrix<T>& matrix) -> py::buffer_info { return py::buffer_info( matrix.data(), /* Pointer to buffer */ sizeof(T), /* Size of one scalar */ @@ -58,9 +59,9 @@ void register_matrix(py::module &m, const char *name) { } template <typename T> -void register_vector(py::module &m, const char *name) { +void register_vector(py::module& m, const char* name) { py::class_<casacore::Vector<T>>(m, name, py::buffer_protocol()) - .def_buffer([](casacore::Vector<T> &vector) -> py::buffer_info { + .def_buffer([](casacore::Vector<T>& vector) -> py::buffer_info { return py::buffer_info( vector.data(), /* Pointer to buffer */ sizeof(T), /* Size of one scalar */ @@ -86,26 +87,28 @@ PYBIND11_MODULE(pydp3, m) { py::class_<ostream_wrapper>(m, "ostream") .def("write", &ostream_wrapper::write); - py::class_<StepWrapper, PyStepImpl /* <--- trampoline*/ - >(m, "Step") + py::class_<Step, PyStepImpl, std::shared_ptr<Step>>(m, "Step") .def(py::init<>()) - .def("show", &StepWrapper::show, + .def("show", &Step::show, "Show step summary (stdout will be redirected to DPPP's output " "stream during this step)") .def("update_info", &StepWrapper::updateInfo, "Handle metadata") .def("info", &StepWrapper::info, py::return_value_policy::reference, "Get info object (read/write) with metadata") - .def("finish", &StepWrapper::finish, + .def( + "process", + // Use a lambda, since process is overloaded. + [](Step& self, const DPBuffer& buffer) { + return self.process(buffer); + }, + "Process a single buffer") + .def("finish", &Step::finish, "Finish processing (nextstep->finish will be called automatically") - .def("get_next_step", &StepWrapper::get_next_step, - py::return_value_policy::copy, "Get a reference to the next step") - .def("process_next_step", &StepWrapper::process_next_step, - "Process the next step") - .def("get_count", &StepWrapper::get_count, - "Get the number of time slots processed") - .def("get_required_fields", &StepWrapper::getRequiredFields, + .def("get_next_step", &Step::getNextStep, + "Get a reference to the next step") + .def("get_required_fields", &Step::getRequiredFields, "Get the fields required by current step") - .def("get_provided_fields", &StepWrapper::getProvidedFields, + .def("get_provided_fields", &Step::getProvidedFields, "Get the fields provided by current step"); py::class_<DPBuffer, std::shared_ptr<DPBuffer>>(m, "DPBuffer") @@ -132,7 +135,7 @@ PYBIND11_MODULE(pydp3, m) { py::class_<DPInfo>(m, "DPInfo") .def( "antenna_names", - [](DPInfo &self) -> py::array { + [](DPInfo& self) -> py::array { // Convert casa vector of casa strings to std::vector of strings std::vector<std::string> names_casa = self.antennaNames(); std::vector<std::string> names; @@ -145,7 +148,7 @@ PYBIND11_MODULE(pydp3, m) { "Get numpy array with antenna names (read only)") .def( "antenna_positions", - [](DPInfo &self) -> py::array { + [](DPInfo& self) -> py::array { // Convert vector of casa MPositions to std::vector of positions std::vector<casacore::MPosition> positions_casa = self.antennaPos(); std::vector<std::array<double, 3>> positions; @@ -186,21 +189,21 @@ PYBIND11_MODULE(pydp3, m) { .def_property_readonly("fullresflags", &Fields::FullResFlags) .def( "update_requirements", - [](Fields &self, const Fields &a, const Fields &b) { + [](Fields& self, const Fields& a, const Fields& b) { return self.UpdateRequirements(a, b); }, "Updates the current object's Fields based on a step's required and " "provided Fields") .def("__str__", - [](const Fields &a) { + [](const Fields& a) { std::stringstream ss; ss << a; return ss.str(); }) - .def("__eq__", [](const Fields &a, const Fields &b) { return a == b; }) - .def("__neq__", [](const Fields &a, const Fields &b) { return a != b; }) - .def("__or__", [](const Fields &a, const Fields &b) { return a | b; }) - .def("__ior__", [](Fields &a, const Fields &b) { + .def("__eq__", [](const Fields& a, const Fields& b) { return a == b; }) + .def("__neq__", [](const Fields& a, const Fields& b) { return a != b; }) + .def("__or__", [](const Fields& a, const Fields& b) { return a | b; }) + .def("__ior__", [](Fields& a, const Fields& b) { a |= b; return a; }); diff --git a/steps/test/unit/mock/mockpystep.py b/steps/test/unit/mock/mockpystep.py index 18274049e88e2016a8d9021a113c9d09e8c0d99c..feb88a09373867b9e8b6a732bb85a7fcaa65493d 100644 --- a/steps/test/unit/mock/mockpystep.py +++ b/steps/test/unit/mock/mockpystep.py @@ -51,7 +51,7 @@ class MockPyStep(Step): def process(self, dpbuffer): """ - Process one time slot of data. This function MUST call process_next_step. + Process one time slot of data. This function MUST call get_next_step().process(). Args: dpbuffer: DPBuffer object which can contain data, flags and weights @@ -66,7 +66,7 @@ class MockPyStep(Step): weights *= self.weightsfactor # Send processed data to the next step - self.process_next_step(dpbuffer) + self.get_next_step().process(dpbuffer) def finish(self): """