Skip to content
Snippets Groups Projects
Commit 592debb0 authored by David Rafferty's avatar David Rafferty
Browse files

Add option to build C++ extensions

parent 95eba010
No related branches found
No related tags found
1 merge request!18Add c++ version of mean shift algorithm
...@@ -64,6 +64,12 @@ Then install with: ...@@ -64,6 +64,12 @@ Then install with:
cd LSMTool cd LSMTool
python setup.py install python setup.py install
If you have a C++11-compliant compiler, you can build a faster
version of the mean shift grouping algorithm with:
cd LSMTool
python setup.py install --build_c_extentions
### Testing ### Testing
You can test that the installation worked with: You can test that the installation worked with:
......
...@@ -158,7 +158,7 @@ void Grouper::group(py::list l){ ...@@ -158,7 +158,7 @@ void Grouper::group(py::list l){
} }
} }
PYBIND11_MODULE(grouper, m) PYBIND11_MODULE(_grouper, m)
{ {
py::class_<Grouper>(m, "Grouper") py::class_<Grouper>(m, "Grouper")
.def(py::init<>()) .def(py::init<>())
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
import numpy as np import numpy as np
import grouper from . import _grouper
import logging import logging
log = logging.getLogger('LSMTool.Group') log = logging.getLogger('LSMTool.Group')
...@@ -52,7 +52,7 @@ class Grouper(object): ...@@ -52,7 +52,7 @@ class Grouper(object):
self.grouping_distance = grouping_distance self.grouping_distance = grouping_distance
self.past_coords = [np.copy(self.coords)] self.past_coords = [np.copy(self.coords)]
self.clusters = [] self.clusters = []
self.g = grouper.Grouper() self.g = _grouper.Grouper()
self.g.readCoordinates(self.coords, self.fluxes) self.g.readCoordinates(self.coords, self.fluxes)
self.g.setKernelSize(self.kernel_size) self.g.setKernelSize(self.kernel_size)
self.g.setNumberOfIterations(self.n_iterations) self.g.setNumberOfIterations(self.n_iterations)
......
...@@ -150,6 +150,9 @@ def group(LSM, algorithm, targetFlux=None, weightBySize=False, numClusters=100, ...@@ -150,6 +150,9 @@ def group(LSM, algorithm, targetFlux=None, weightBySize=False, numClusters=100,
from . import _tessellate from . import _tessellate
from . import _cluster from . import _cluster
from . import _threshold from . import _threshold
try:
from . import _meanshiftc as _meanshift
except ImportError:
from . import _meanshift from . import _meanshift
import numpy as np import numpy as np
import os import os
......
from __future__ import print_function from __future__ import print_function
from setuptools import setup, Command from setuptools import setup, Command, Extension, Distribution
import os from setuptools.command.build_ext import build_ext
import sys import sys
import lsmtool._version import lsmtool._version
# Flag that determines whether to build the optional (but faster) C++
# extensions. Set to False to install only the pure Python versions.
if "--build_c_extentions" in sys.argv:
build_c_extentions = True
sys.argv.remove("--build_c_extentions")
else:
build_c_extentions = False
# Handle Python 3-only dependencies
if sys.version_info < (3, 0):
reqlist = ['numpy', 'astropy >= 0.4, <3.0']
else:
reqlist = ['numpy', 'astropy >= 0.4']
if build_c_extentions:
reqlist.append('pybind11>=2.2.0')
ext_modules = [Extension('lsmtool.operations._grouper',
['lsmtool/operations/_grouper.cpp'],
language='c++')]
else:
ext_modules = []
class PyTest(Command): class PyTest(Command):
user_options = [] user_options = []
def initialize_options(self): def initialize_options(self):
pass pass
...@@ -14,15 +37,36 @@ class PyTest(Command): ...@@ -14,15 +37,36 @@ class PyTest(Command):
pass pass
def run(self): def run(self):
import sys,subprocess import sys
import subprocess
errno = subprocess.call([sys.executable, 'runtests.py']) errno = subprocess.call([sys.executable, 'runtests.py'])
raise SystemExit(errno) raise SystemExit(errno)
# Handle Python 3-only dependencies
if sys.version_info < (3, 0): class LSMToolDistribution(Distribution):
reqlist = ['numpy','astropy >= 0.4, <3.0']
else: def is_pure(self):
reqlist = ['numpy','astropy >= 0.4'] if self.pure:
return True
def has_ext_modules(self):
return not self.pure
global_options = Distribution.global_options + [
('pure', None, "use pure Python code instead of C++ extensions")]
pure = False
class BuildExt(build_ext):
def build_extensions(self):
opts = ['-std=c++11']
if sys.platform == 'darwin':
opts += ['-stdlib=libc++']
for ext in self.extensions:
ext.extra_compile_args = opts
build_ext.build_extensions(self)
setup( setup(
name='lsmtool', name='lsmtool',
...@@ -41,6 +85,9 @@ setup( ...@@ -41,6 +85,9 @@ setup(
], ],
install_requires=reqlist, install_requires=reqlist,
scripts=['bin/lsmtool'], scripts=['bin/lsmtool'],
ext_modules=ext_modules,
cmdclass={'build_ext': BuildExt},
distclass=LSMToolDistribution,
packages=['lsmtool', 'lsmtool.operations'], packages=['lsmtool', 'lsmtool.operations'],
setup_requires=['pytest-runner'], setup_requires=['pytest-runner'],
tests_require=['pytest'] tests_require=['pytest']
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment