Skip to content

Commit

Permalink
v3.4.8 fix bug when classes >44
Browse files Browse the repository at this point in the history
  • Loading branch information
nkarasiak committed Jun 12, 2019
1 parent 71fc144 commit 98eb504
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 20 deletions.
4 changes: 3 additions & 1 deletion metadata.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
name=dzetsaka : Classification tool
qgisMinimumVersion=3.0
description=Fast and Easy Classification plugin for Qgis
version=3.4.7
version=3.4.8
author=Nicolas Karasiak
[email protected]

Expand All @@ -29,6 +29,8 @@ repository=http://www.github.com/lennepkade/dzetsaka

# Uncomment the following line and add your changelog:
changelog=
3.4.8
* Fix errors when number of classes > 44 (problem of datatype in sample extraction).
3.4.7
* Support more than 255 classes to predict (if n > 255, raster datatype will be set to uint16)
3.4.6
Expand Down
21 changes: 11 additions & 10 deletions scripts/function_dataraster.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def pixel2coord(coord):

# Read block data
X = np.array([]).reshape(0, d)
Y = np.array([]).reshape(0, 1)
STD = np.array([]).reshape(0, 1)
Y = np.array([],dtype=np.uint16).reshape(0, 1)
STD = np.array([],dtype=np.uint16).reshape(0, 1)

for i in range(0, nl, y_block_size):
if i + y_block_size < nl: # Check for size consistency in Y
Expand All @@ -289,11 +289,11 @@ def pixel2coord(coord):
if t[0].size > 0:
Y = np.concatenate(
(Y, ROI[t].reshape(
(t[0].shape[0], 1)).astype('uint8')))
(t[0].shape[0], 1))))
if stand_name:
STD = np.concatenate(
(STD, STAND[t].reshape(
(t[0].shape[0], 1)).astype('uint8')))
(t[0].shape[0], 1))))
if getCoords:
#coords = sp.append(coords,(i,j))
#coordsTp = sp.array(([[cols,lines]]))
Expand Down Expand Up @@ -609,7 +609,7 @@ def rasterize(data, vectorSrc, field, outFile):
dataSrc.RasterXSize,
dataSrc.RasterYSize,
1,
gdal.GDT_Byte)
gdal.GDT_UInt16)
dst_ds.SetGeoTransform(dataSrc.GetGeoTransform())
dst_ds.SetProjection(dataSrc.GetProjection())
if field is None:
Expand Down Expand Up @@ -660,12 +660,13 @@ def scale(x, M=None, m=None): # TODO: DO IN PLACE SCALING


if __name__ == "__main__":
Raster = "/mnt/DATA/Test/DA/SENTINEL_20170516.tif"
ROI = '/tmp/testroi.tif'
rasterize(Raster, '/mnt/DATA/Test/DA/ROI_2154.sqlite', 'level2', ROI)

X, Y, coords = get_samples_from_roi(Raster, ROI, getCoords=True)
Raster = "/mnt/DATA/Test/dzetsaka/map.tif"
ROI = '/home/nicolas/Bureau/train_300class.gpkg'
rasterize(Raster, ROI, 'Class', '/tmp/roi.tif')

# X, Y, coords = get_samples_from_roi(Raster, '/tmp/roi.tif', getCoords=True)
X, Y = get_samples_from_roi(Raster, '/tmp/roi.tif')
print(np.amax(Y))
"""
import accuracy_index as ai
print(X.shape)
Expand Down
2 changes: 1 addition & 1 deletion scripts/gmm_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def learn(self, x, y):
self.cov = np.empty((C, d, d)) # Matrix of covariance
self.Q = np.empty((C, d, d)) # Matrix of eigenvectors
self.L = np.empty((C, d)) # Vector of eigenvalues
self.classnum = np.empty(C).astype('uint8')
self.classnum = np.empty(C).astype('uint16')
self.classes_ = self.classnum
# Learn the parameter of the model for each class
for c, cR in enumerate(np.unique(y)):
Expand Down
19 changes: 11 additions & 8 deletions scripts/mainfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(self, inRaster, inVector, inField='Class', outModel=None, inSplit=1
xt, N, n = self.scale(Xt)
#x,y = dataraster.get_samples_from_roi(inRaster,ROI,getCoords=True,convertTo4326=True)
y = Y

# Create temporary data set
if SPLIT == 'SLOO':

Expand Down Expand Up @@ -162,7 +162,7 @@ def __init__(self, inRaster, inVector, inField='Class', outModel=None, inSplit=1
except BaseException:
X, Y, coords = dataraster.get_samples_from_roi(
inRaster, ROI, getCoords=True)

distanceArray = distMatrix(coords)
# np.save(os.path.splitext(distanceFile)[0],distanceArray)

Expand Down Expand Up @@ -193,17 +193,19 @@ def __init__(self, inRaster, inVector, inField='Class', outModel=None, inSplit=1

elif needXY:
X, Y = dataraster.get_samples_from_roi(inRaster, ROI)


except BaseException:
msg = "Problem with getting samples from ROI \n \
Are you sure to have only integer values in your " + str(inField) + " field ?\n "
pushFeedback(msg, feedback=feedback)

[n, d] = X.shape
C = int(Y.max())
SPLIT = inSplit

try:
#pushFeedback(str(ROI),feedback=feedback)
os.remove(ROI)
except BaseException:
pass
Expand Down Expand Up @@ -261,7 +263,7 @@ def __init__(self, inRaster, inVector, inField='Class', outModel=None, inSplit=1
pushFeedback(
'This step could take a lot of time... So be patient, even if the progress bar stucks at 20% :)',
feedback=feedback)

if feedback == 'gui':
progress.addStep() # Add Step to ProgressBar
# Train Classifier
Expand All @@ -275,7 +277,6 @@ def __init__(self, inRaster, inVector, inField='Class', outModel=None, inSplit=1
# tau=10.0**sp.arange(-8,8,0.5)
model = gmmr.GMMR()
model.learn(x, y)

# htau,err = model.cross_validation(x,y,tau)
# model.tau = htau
except BaseException:
Expand Down Expand Up @@ -812,10 +813,12 @@ def predict_image(self, inRaster, outRaster, model=None, inMask=None, confidence

driver = gdal.GetDriverByName('GTiff')

if len(model.classes_)>255:
dtype = gdal.GDT_Uint16
if np.amax(model.classes_)>255:
dtype = gdal.GDT_UInt16
else:
dtype = gdal.GDT_Byte


dst_ds = driver.Create(outRaster, nc, nl, 1, dtype)
dst_ds.SetGeoTransform(GeoTransform)
dst_ds.SetProjection(Projection)
Expand Down Expand Up @@ -992,7 +995,7 @@ def rasterize(inRaster, inShape, inField):
data.RasterXSize,
data.RasterYSize,
1,
gdal.GDT_Byte)
gdal.GDT_UInt16)
dst_ds.SetGeoTransform(data.GetGeoTransform())
dst_ds.SetProjection(data.GetProjection())
OPTIONS = 'ATTRIBUTE=' + inField
Expand Down

0 comments on commit 98eb504

Please sign in to comment.