Skip to content

Commit

Permalink
Hard training session has been setted
Browse files Browse the repository at this point in the history
  • Loading branch information
eliainnocenti committed Jul 10, 2024
1 parent 073cc45 commit 1faf5cc
Show file tree
Hide file tree
Showing 8 changed files with 876 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
__pycache__/
**/__pycache__/

# models
models/

# git # TODO: unignore when ready
.gitattributes

Expand Down
58 changes: 55 additions & 3 deletions inference/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
base_path = "../../../Data/"

# Load the TFLite model
interpreter = tf.lite.Interpreter(model_path='../models/model.tflite')
interpreter = tf.lite.Interpreter(model_path='../models/model2.tflite')
interpreter.allocate_tensors()

# Get input and output details
Expand Down Expand Up @@ -76,13 +76,54 @@ def visualize_detections(image_path, boxes, classes_scores, threshold=0.5):
plt.show()


def main():
def train_images():
"""
:return:
"""
train_path = '../data/rparis6k/sets/train/train.txt'

if not os.path.exists(train_path):
print(f"Error: Train file not found: {train_path}")
return

with open(train_path, 'r') as f:
train_images = [line.strip() for line in f]

for image_name in train_images[:5]:
image_path = os.path.join(base_path, 'datasets', 'rparis6k', 'images', image_name)
image_np = load_image_into_numpy_array(image_path)
boxes, classes_scores = run_inference(image_np)
visualize_detections(image_path, boxes, classes_scores)


def validation_images():
"""
:return:
"""
validation_path = '../data/rparis6k/sets/validation/val.txt'

if not os.path.exists(validation_path):
print(f"Error: Validation file not found: {validation_path}")
return

with open(validation_path, 'r') as f:
validation_images = [line.strip() for line in f]

for image_name in validation_images[:5]:
image_path = os.path.join(base_path, 'datasets', 'rparis6k', 'images', image_name)
image_np = load_image_into_numpy_array(image_path)
boxes, classes_scores = run_inference(image_np)
visualize_detections(image_path, boxes, classes_scores)


def test_images():
"""
test_path = '../data/rparis6k/sets/test.txt'
:return:
"""
test_path = '../data/rparis6k/sets/test/test.txt'

if not os.path.exists(test_path):
print(f"Error: Test file not found: {test_path}")
Expand All @@ -98,5 +139,16 @@ def main():
visualize_detections(image_path, boxes, classes_scores)


def main():
"""
:return:
"""

#train_images()
#validation_images()
#test_images()


if __name__ == '__main__':
main()
Binary file removed models/model.tflite
Binary file not shown.
26 changes: 26 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Data scripts

This directory contains scripts that are used to get, process and analyze
the data used in the project.

## Get the data

```bash
python get_data.py
```

## Process the data

```bash
python prepare_dataset.py
```

```mermaid
graph TD;
A-->B;
A-->C;
B-->D;
C-->D;
```


This file was deleted.

Loading

0 comments on commit 1faf5cc

Please sign in to comment.