diff --git a/lsmtool/operations/group.py b/lsmtool/operations/group.py index 9090aa8c3c15992c0d594c7fb40a92390e0a018f..135424eea1cfa1720496254926be19dbc346782e 100644 --- a/lsmtool/operations/group.py +++ b/lsmtool/operations/group.py @@ -51,7 +51,7 @@ def run(step, parset, LSM): def group(LSM, algorithm, targetFlux=None, numClusters=100, FWHM=None, threshold=0.1, applyBeam=False, root='Patch', pad_index=False, method='mid', - facet=""): + facet="", byPatch=False): """ Groups sources into patches. @@ -110,6 +110,9 @@ def group(LSM, algorithm, targetFlux=None, numClusters=100, FWHM=None, - 'zero' => set all positions to [0.0, 0.0] facet : str, optional Facet fits file used with the algorithm 'facet' + byPatch : bool, optional + For the 'tessellate' algorithm, use patches instead of by sources + Examples -------- @@ -157,15 +160,30 @@ def group(LSM, algorithm, targetFlux=None, numClusters=100, FWHM=None, targetFlux = float(parts[0]) if len(parts) == 2: units = parts[1] - LSM.ungroup() - x, y, midRA, midDec = LSM._getXY() - f = LSM.getColValues('I', units=units, applyBeam=applyBeam) + if byPatch: + if not 'Patch' in LSM.table.keys(): + raise ValueError('Sky model must be grouped before "byPatch" can be used.') + x, y, midRA, midDec = LSM._getXY(byPatch=True) + f = LSM.getColValues('I', units=units, applyBeam=applyBeam, aggregate='sum') + else: + LSM.ungroup() + x, y, midRA, midDec = LSM._getXY() + f = LSM.getColValues('I', units=units, applyBeam=applyBeam) vobin = _tessellate.bin2D(np.array(x), np.array(y), f, target_flux=targetFlux) try: vobin.bin_voronoi() patchCol = _tessellate.bins2Patches(vobin, root=root, pad_index=pad_index) - LSM.setColValues('Patch', patchCol, index=2) + if byPatch: + newPatchNames = patchCol.copy() + origPatchNames = LSM.getPatchNames() + patchCol = np.zeros(len(LSM), dtype='S100') + for newPatchName, origPatchName in zip(newPatchNames, origPatchNames): + ind = np.array(LSM.getRowIndex(origPatchName)) + patchCol[ind] = newPatchName + LSM.setColValues('Patch', patchCol, index=2) + else: + LSM.setColValues('Patch', patchCol, index=2) except ValueError: # Catch error in some cases with high target flux relative to # total model flux diff --git a/lsmtool/skymodel.py b/lsmtool/skymodel.py index df5c850d0b62496563dadc0e71f138aac57f15cb..2a272d01e139873f9a3001e4d684815e98bc9684 100644 --- a/lsmtool/skymodel.py +++ b/lsmtool/skymodel.py @@ -619,7 +619,7 @@ class SkyModel(object): raise RuntimeError('Sky model does not have patches.') - def _getXY(self, patchName=None, crdelt=None): + def _getXY(self, patchName=None, crdelt=None, byPatch=False): """ Returns lists of projected x and y values for all sources. @@ -629,6 +629,8 @@ class SkyModel(object): If given, return x and y for specified patch only crdelt: float, optional Delta in degrees for sky grid + byPatch : bool, optional + Use patches instead of by sources Returns ------- @@ -643,12 +645,18 @@ class SkyModel(object): if len(self.table) == 0: return [0], [0], 0, 0 - RA = self.getColValues('Ra') - Dec = self.getColValues('Dec') - if patchName is not None: - ind = self.getRowIndex(patchName) - RA = RA[ind] - Dec = Dec[ind] + if byPatch: + if not 'Patch' in self.table.keys(): + raise ValueError('Sky model must be grouped before "byPatch" can be used.') + RA = self.getColValues('Ra', aggregate='wmean') + Dec = self.getColValues('Dec', aggregate='wmean') + else: + RA = self.getColValues('Ra') + Dec = self.getColValues('Dec') + if patchName is not None: + ind = self.getRowIndex(patchName) + RA = RA[ind] + Dec = Dec[ind] x, y = radec2xy(RA, Dec, crdelt=crdelt) # Refine x and y using midpoint @@ -1974,7 +1982,7 @@ class SkyModel(object): def group(self, algorithm, targetFlux=None, numClusters=100, FWHM=None, threshold=0.1, applyBeam=False, root='Patch', pad_index=False, - method='mid', facet=""): + method='mid', facet="", byPatch=False): """ Groups sources into patches. @@ -2033,6 +2041,8 @@ class SkyModel(object): - 'zero' => set all positions to [0.0, 0.0] facet : str, optional Facet fits file used with the algorithm 'facet' + byPatch : bool, optional + For the 'tessellate' algorithm, use patches instead of by sources Examples -------- @@ -2045,7 +2055,7 @@ class SkyModel(object): operations.group.group(self, algorithm, targetFlux=targetFlux, numClusters=numClusters, FWHM=FWHM, threshold=threshold, applyBeam=applyBeam, root=root, pad_index=pad_index, method=method, - facet=facet) + facet=facet, byPatch=byPatch) def transfer(self, patchSkyModel, matchBy='name', radius=0.1):