Skip to content

Commit

Permalink
Face Recognition Implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
rafay-pk committed May 6, 2023
1 parent d2e4c60 commit 6efb6c2
Showing 1 changed file with 113 additions and 8 deletions.
121 changes: 113 additions & 8 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import sys, os, pathlib, sqlite3, face_recognition, threading
import sys, os, pathlib, sqlite3, face_recognition, threading, random, numpy as np, base64
from PyQt6.QtCore import Qt, QSize, QDir, QSize, QUrl, QStringListModel, QItemSelectionModel, QSortFilterProxyModel, QModelIndex
from PyQt6.QtGui import QFileSystemModel, QPixmap, QMovie, QIcon, QAction, QImage, QStandardItemModel, QStandardItem
from PyQt6.QtMultimedia import QMediaPlayer
Expand All @@ -23,7 +23,8 @@
QSplitter,
QAbstractItemView,
QListView,
QMenu
QMenu,
QListWidgetItem
)


Expand Down Expand Up @@ -56,6 +57,11 @@ def close_connection(self):
self.cursor.close()
self.conn.close()

class QDeselectableListWidget(QListWidget):
def mousePressEvent(self, event):
self.clearSelection()
QListWidget.mousePressEvent(self, event)

class QDeselectableTreeView(QTreeView):
# def __init__(self, parent=None):
# super().__init__(parent)
Expand All @@ -82,7 +88,8 @@ def __init__(self, parent=None):
super().__init__(parent)

self.setSelectionMode(QListView.SelectionMode.SingleSelection)

self.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers)

self.model = QStandardItemModel()
self.setModel(self.model)

Expand Down Expand Up @@ -302,6 +309,9 @@ def __init__(self):
# region Tags View
self.dock_tags = QDockWidget("Tags")
self.tags = QListWidget()
self.tags.setEditTriggers(QAbstractItemView.EditTrigger.DoubleClicked)
self.tags.itemDoubleClicked.connect(self.tag_double_clicked)
self.tags.itemChanged.connect(self.tag_update)
self.tag_bar = QLineEdit()
self.tag_bar.setPlaceholderText("Press Ctrl+T to focus")
btn_add_tag = QPushButton("Add Tag")
Expand Down Expand Up @@ -329,23 +339,68 @@ def __init__(self):
self.dock_tags.setWidget(tag_view)
# endregion

# region People View
self.dock_people = QDockWidget("People")
self.people = QListWidget()
self.people.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
# self.people.selectionModel().selectionChanged.connect(self.people_selected)
self.people.itemSelectionChanged.connect(self.people_selected)
people_view = QWidget()
peopleLayout = QVBoxLayout()
people_view.setLayout(peopleLayout)
peopleLayout.addWidget(self.people)
peopleLayout.setContentsMargins(0, 0, 0, 0)
peopleLayout.setSpacing(0)
self.dock_people.setWidget(people_view)
self.people.addItems(self.sql_get_all_people())
# endregion

# region Docking
self.addDockWidget(Qt.DockWidgetArea.TopDockWidgetArea, self.dock_search)
self.addDockWidget(Qt.DockWidgetArea.BottomDockWidgetArea, self.dock_folders)
self.splitDockWidget(self.dock_folders, self.dock_browser, Qt.Orientation.Horizontal)
self.splitDockWidget(self.dock_browser, self.dock_media, Qt.Orientation.Horizontal)
self.splitDockWidget(self.dock_media, self.dock_tags, Qt.Orientation.Vertical)
self.splitDockWidget(self.dock_folders, self.dock_people, Qt.Orientation.Vertical)
self.dock_folders.setMinimumWidth(150)
self.dock_browser.setMinimumWidth(300)
# self.dock_media.setMinimumWidth(200)
# endregion

# region Setup
self.folder_selected()
# self.people_selected()
self.editing = False
self.old_tag = ""
self.show()
# endregion

# region Mechanics
def tag_double_clicked(self):
self.editing = True
self.old_tag = self.tags.selectedItems()[0].text()

def tag_update(self):
if self.editing:
self.editing = False
new_tag = self.tags.selectedItems()[0].text()
self.status.showMessage(f"Tag Updated from {self.old_tag} to {new_tag}")
self.sql_edit_tag(self.old_tag, new_tag)
if self.old_tag in [self.people.item(x).text() for x in range(self.people.count())]:
self.people.clear()
self.people.addItems(self.sql_get_all_people())

def people_selected(self):
self.browser_detailView.clear()
if len(self.people.selectedIndexes()) == 0:
self.browser_detailView.add_strings(self.sql_get_all_files())
self.update_title("Browser")
else:
person = self.people.currentItem().text()
self.status.showMessage(f"Searching for person - {person}")
self.browser_detailView.add_strings(self.sql_search_inclusive([person]))
self.update_title(person)

def folder_context_menu(self, pos):
menu = QMenu()
action1 = QAction("Process Folder", self)
Expand Down Expand Up @@ -422,8 +477,13 @@ def add_folder(self):
self.status.showMessage(f"Added - {path}")

def add_tag_to_file(self, tag, file):
if len(self.tags.findItems(tag, Qt.MatchFlag.MatchExactly)) > 0:
self.status.showMessage(f"Tag already exists on media file {file} - {tag}")
return
self.sql_add_tag_to_file(tag, file)
self.tags.addItem(tag)
item = QListWidgetItem(tag)
item.setFlags(item.flags() | Qt.ItemFlag.ItemIsEditable)
self.tags.addItem(item)
self.status.showMessage(f"Added tag - {tag} to file - {file}")

def search(self, tags):
Expand All @@ -446,11 +506,33 @@ def search(self, tags):
# print(face_locations)
# self.status.showMessage("AI Processing - Finished")

def generate_random_person_name(self):
return f'person_{random.randint(10000, 99999)}'

def add_person(self, encoding, file):
unique_people = self.sql_get_all_encodings()
result = face_recognition.compare_faces(unique_people, encoding)
check = np.count_nonzero(result)
if check == 0:
name = self.generate_random_person_name()
self.sql_add_face_encoding(name, encoding)
self.people.addItem(name)
self.new_people += 1
elif check == 1:
name = self.sql_get_person_name(unique_people[result.index(True)])
elif check > 1:
distances = face_recognition.face_distance(unique_people, encoding)
name = self.sql_get_person_name(unique_people[np.argmin(distances)])
self.sql_add_tag_to_file(name, file)
self.tags_applied += 1

def ai_process_folder_start(self):
thread = threading.Thread(target=self.ai_process_folder)
thread.start()

def ai_process_folder(self):
self.tags_applied = 0
self.new_people = 0
for index in self.folders.selectedIndexes():
sym_path = self.fileSystem.filePath(index)
org_path = self.sql_get_files_in_folder(sym_path)
Expand All @@ -460,8 +542,11 @@ def ai_process_folder(self):
if self.filetype_switcher[extension].__name__ != self.display_static.__name__:
continue
image = face_recognition.load_image_file(file)
print(face_recognition.face_encodings(image, num_jitters=5, model="large"))
self.status.showMessage(f"AI Processing - {org_path} - Finished")
face_encodings = face_recognition.face_encodings(image, num_jitters=5, model="large")
if len(face_encodings) > 0:
for encoding in face_encodings:
self.add_person(encoding, file)
self.status.showMessage(f"AI Processing - {org_path} - Finished - Applied {self.tags_applied} tags - Detected {self.new_people} new people")

def show_about(self):
self.about = AboutWindow()
Expand All @@ -474,8 +559,13 @@ def display_media(self):
_, extension = os.path.splitext(file_path)
if extension in self.filetype_switcher:
self.filetype_switcher[extension](file_path)
self.dock_media.setWindowTitle(f'Media - {file_path}')
self.tags.clear()
self.tags.addItems([x[0] for x in self.sql_get_file_tags(file_path)])
tags = self.sql_get_file_tags(file_path)
for tag in tags:
item = QListWidgetItem(tag, self.tags)
item.setFlags(item.flags() | Qt.ItemFlag.ItemIsEditable)
self.tags.addItem(item)
else:
self.status.showMessage(f"Unsupported file type - {extension}")

Expand Down Expand Up @@ -507,6 +597,7 @@ def sql_add_new_tag(self, tag):

def sql_edit_tag(self, tag, new_tag):
self.db.execute_query(f"UPDATE Tags SET name = '{new_tag}' WHERE name = '{tag}'")
self.db.execute_query(f"UPDATE People SET name = '{new_tag}' WHERE name = '{tag}'")

def sql_delete_tag(self, tag):
self.db.execute_query(f"DELETE FROM Tags WHERE name = '{tag}'")
Expand Down Expand Up @@ -546,7 +637,7 @@ def sql_get_file_tags(self, file):
JOIN Tags t
ON ft.tag_id = t.id
WHERE f.path = '{file}'"""
return self.db.fetch_data(query)
return [x[0] for x in self.db.fetch_data(query)]

def sql_search_inclusive(self, tags):
query = f"""SELECT DISTINCT f.path FROM Files f
Expand All @@ -558,6 +649,20 @@ def sql_search_inclusive(self, tags):
for tag in tags[1:]:
query += f" OR t.name = '{tag}'"
return [x[0] for x in self.db.fetch_data(query)]

def sql_add_face_encoding(self, name, encoding):
enc = base64.binascii.b2a_base64(encoding).decode("ascii")
self.db.execute_query(f"INSERT INTO People (name, encoding) VALUES ('{name}', '{enc}')")

def sql_get_all_encodings(self):
return [np.frombuffer(base64.binascii.a2b_base64(x[0].encode("ascii"))) for x in self.db.fetch_data("SELECT encoding FROM People")]

def sql_get_person_name(self, encoding):
enc = base64.binascii.b2a_base64(encoding).decode("ascii")
return self.db.fetch_data(f"SELECT name FROM People WHERE encoding = '{enc}'")[0][0]

def sql_get_all_people(self):
return [x[0] for x in self.db.fetch_data("SELECT name FROM People")]
# endregion

class AboutWindow(QWidget):
Expand Down

0 comments on commit 6efb6c2

Please sign in to comment.