diff --git a/alsNet/dataset.py b/alsNet/dataset.py index cce70b0..bca64e9 100644 --- a/alsNet/dataset.py +++ b/alsNet/dataset.py @@ -232,8 +232,8 @@ def getBatches(self, batch_size=1): for i in range(batch_size): if self.currIdx >= self.num_batches: break - centers.append([self.xmin + self.spacing/2 + (self.currIdx // self.num_cols) * self.spacing, - self.ymin + self.spacing/2 + (self.currIdx % self.num_cols) * self.spacing]) + centers.append([self.xmin + self.spacing/2 + (self.currIdx // self.num_rows) * self.spacing, + self.ymin + self.spacing/2 + (self.currIdx % self.num_rows) * self.spacing]) self.currIdx += 1 if centers: _, idx = self.tree.query(centers, k=self.k) @@ -242,7 +242,7 @@ def getBatches(self, batch_size=1): return None, None def getBatchByIdx(self, batch_idx): - centers = [[self.xmin + self.spacing / 2 + (batch_idx // self.num_cols) * self.spacing, - self.ymin + self.spacing / 2 + (batch_idx % self.num_cols) * self.spacing]] + centers = [[self.xmin + self.spacing / 2 + (batch_idx // self.num_rows) * self.spacing, + self.ymin + self.spacing / 2 + (batch_idx % self.num_rows) * self.spacing]] _, idx = self.tree.query(centers, k=self.k) return self.points_and_features[idx, :], self.labels[idx]