diff --git a/app.py b/app.py index e21095d..c501d3d 100644 --- a/app.py +++ b/app.py @@ -1,4 +1,4 @@ -import sys, os, pathlib, sqlite3, face_recognition, threading +import sys, os, pathlib, sqlite3, face_recognition, threading, random, numpy as np, base64 from PyQt6.QtCore import Qt, QSize, QDir, QSize, QUrl, QStringListModel, QItemSelectionModel, QSortFilterProxyModel, QModelIndex from PyQt6.QtGui import QFileSystemModel, QPixmap, QMovie, QIcon, QAction, QImage, QStandardItemModel, QStandardItem from PyQt6.QtMultimedia import QMediaPlayer @@ -23,7 +23,8 @@ QSplitter, QAbstractItemView, QListView, - QMenu + QMenu, + QListWidgetItem ) @@ -56,6 +57,11 @@ def close_connection(self): self.cursor.close() self.conn.close() +class QDeselectableListWidget(QListWidget): + def mousePressEvent(self, event): + self.clearSelection() + QListWidget.mousePressEvent(self, event) + class QDeselectableTreeView(QTreeView): # def __init__(self, parent=None): # super().__init__(parent) @@ -82,7 +88,8 @@ def __init__(self, parent=None): super().__init__(parent) self.setSelectionMode(QListView.SelectionMode.SingleSelection) - + self.setEditTriggers(QAbstractItemView.EditTrigger.NoEditTriggers) + self.model = QStandardItemModel() self.setModel(self.model) @@ -302,6 +309,9 @@ def __init__(self): # region Tags View self.dock_tags = QDockWidget("Tags") self.tags = QListWidget() + self.tags.setEditTriggers(QAbstractItemView.EditTrigger.DoubleClicked) + self.tags.itemDoubleClicked.connect(self.tag_double_clicked) + self.tags.itemChanged.connect(self.tag_update) self.tag_bar = QLineEdit() self.tag_bar.setPlaceholderText("Press Ctrl+T to focus") btn_add_tag = QPushButton("Add Tag") @@ -329,12 +339,29 @@ def __init__(self): self.dock_tags.setWidget(tag_view) # endregion + # region People View + self.dock_people = QDockWidget("People") + self.people = QListWidget() + self.people.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) + # self.people.selectionModel().selectionChanged.connect(self.people_selected) + self.people.itemSelectionChanged.connect(self.people_selected) + people_view = QWidget() + peopleLayout = QVBoxLayout() + people_view.setLayout(peopleLayout) + peopleLayout.addWidget(self.people) + peopleLayout.setContentsMargins(0, 0, 0, 0) + peopleLayout.setSpacing(0) + self.dock_people.setWidget(people_view) + self.people.addItems(self.sql_get_all_people()) + # endregion + # region Docking self.addDockWidget(Qt.DockWidgetArea.TopDockWidgetArea, self.dock_search) self.addDockWidget(Qt.DockWidgetArea.BottomDockWidgetArea, self.dock_folders) self.splitDockWidget(self.dock_folders, self.dock_browser, Qt.Orientation.Horizontal) self.splitDockWidget(self.dock_browser, self.dock_media, Qt.Orientation.Horizontal) self.splitDockWidget(self.dock_media, self.dock_tags, Qt.Orientation.Vertical) + self.splitDockWidget(self.dock_folders, self.dock_people, Qt.Orientation.Vertical) self.dock_folders.setMinimumWidth(150) self.dock_browser.setMinimumWidth(300) # self.dock_media.setMinimumWidth(200) @@ -342,10 +369,38 @@ def __init__(self): # region Setup self.folder_selected() + # self.people_selected() + self.editing = False + self.old_tag = "" self.show() # endregion # region Mechanics + def tag_double_clicked(self): + self.editing = True + self.old_tag = self.tags.selectedItems()[0].text() + + def tag_update(self): + if self.editing: + self.editing = False + new_tag = self.tags.selectedItems()[0].text() + self.status.showMessage(f"Tag Updated from {self.old_tag} to {new_tag}") + self.sql_edit_tag(self.old_tag, new_tag) + if self.old_tag in [self.people.item(x).text() for x in range(self.people.count())]: + self.people.clear() + self.people.addItems(self.sql_get_all_people()) + + def people_selected(self): + self.browser_detailView.clear() + if len(self.people.selectedIndexes()) == 0: + self.browser_detailView.add_strings(self.sql_get_all_files()) + self.update_title("Browser") + else: + person = self.people.currentItem().text() + self.status.showMessage(f"Searching for person - {person}") + self.browser_detailView.add_strings(self.sql_search_inclusive([person])) + self.update_title(person) + def folder_context_menu(self, pos): menu = QMenu() action1 = QAction("Process Folder", self) @@ -422,8 +477,13 @@ def add_folder(self): self.status.showMessage(f"Added - {path}") def add_tag_to_file(self, tag, file): + if len(self.tags.findItems(tag, Qt.MatchFlag.MatchExactly)) > 0: + self.status.showMessage(f"Tag already exists on media file {file} - {tag}") + return self.sql_add_tag_to_file(tag, file) - self.tags.addItem(tag) + item = QListWidgetItem(tag) + item.setFlags(item.flags() | Qt.ItemFlag.ItemIsEditable) + self.tags.addItem(item) self.status.showMessage(f"Added tag - {tag} to file - {file}") def search(self, tags): @@ -446,11 +506,33 @@ def search(self, tags): # print(face_locations) # self.status.showMessage("AI Processing - Finished") + def generate_random_person_name(self): + return f'person_{random.randint(10000, 99999)}' + + def add_person(self, encoding, file): + unique_people = self.sql_get_all_encodings() + result = face_recognition.compare_faces(unique_people, encoding) + check = np.count_nonzero(result) + if check == 0: + name = self.generate_random_person_name() + self.sql_add_face_encoding(name, encoding) + self.people.addItem(name) + self.new_people += 1 + elif check == 1: + name = self.sql_get_person_name(unique_people[result.index(True)]) + elif check > 1: + distances = face_recognition.face_distance(unique_people, encoding) + name = self.sql_get_person_name(unique_people[np.argmin(distances)]) + self.sql_add_tag_to_file(name, file) + self.tags_applied += 1 + def ai_process_folder_start(self): thread = threading.Thread(target=self.ai_process_folder) thread.start() def ai_process_folder(self): + self.tags_applied = 0 + self.new_people = 0 for index in self.folders.selectedIndexes(): sym_path = self.fileSystem.filePath(index) org_path = self.sql_get_files_in_folder(sym_path) @@ -460,8 +542,11 @@ def ai_process_folder(self): if self.filetype_switcher[extension].__name__ != self.display_static.__name__: continue image = face_recognition.load_image_file(file) - print(face_recognition.face_encodings(image, num_jitters=5, model="large")) - self.status.showMessage(f"AI Processing - {org_path} - Finished") + face_encodings = face_recognition.face_encodings(image, num_jitters=5, model="large") + if len(face_encodings) > 0: + for encoding in face_encodings: + self.add_person(encoding, file) + self.status.showMessage(f"AI Processing - {org_path} - Finished - Applied {self.tags_applied} tags - Detected {self.new_people} new people") def show_about(self): self.about = AboutWindow() @@ -474,8 +559,13 @@ def display_media(self): _, extension = os.path.splitext(file_path) if extension in self.filetype_switcher: self.filetype_switcher[extension](file_path) + self.dock_media.setWindowTitle(f'Media - {file_path}') self.tags.clear() - self.tags.addItems([x[0] for x in self.sql_get_file_tags(file_path)]) + tags = self.sql_get_file_tags(file_path) + for tag in tags: + item = QListWidgetItem(tag, self.tags) + item.setFlags(item.flags() | Qt.ItemFlag.ItemIsEditable) + self.tags.addItem(item) else: self.status.showMessage(f"Unsupported file type - {extension}") @@ -507,6 +597,7 @@ def sql_add_new_tag(self, tag): def sql_edit_tag(self, tag, new_tag): self.db.execute_query(f"UPDATE Tags SET name = '{new_tag}' WHERE name = '{tag}'") + self.db.execute_query(f"UPDATE People SET name = '{new_tag}' WHERE name = '{tag}'") def sql_delete_tag(self, tag): self.db.execute_query(f"DELETE FROM Tags WHERE name = '{tag}'") @@ -546,7 +637,7 @@ def sql_get_file_tags(self, file): JOIN Tags t ON ft.tag_id = t.id WHERE f.path = '{file}'""" - return self.db.fetch_data(query) + return [x[0] for x in self.db.fetch_data(query)] def sql_search_inclusive(self, tags): query = f"""SELECT DISTINCT f.path FROM Files f @@ -558,6 +649,20 @@ def sql_search_inclusive(self, tags): for tag in tags[1:]: query += f" OR t.name = '{tag}'" return [x[0] for x in self.db.fetch_data(query)] + + def sql_add_face_encoding(self, name, encoding): + enc = base64.binascii.b2a_base64(encoding).decode("ascii") + self.db.execute_query(f"INSERT INTO People (name, encoding) VALUES ('{name}', '{enc}')") + + def sql_get_all_encodings(self): + return [np.frombuffer(base64.binascii.a2b_base64(x[0].encode("ascii"))) for x in self.db.fetch_data("SELECT encoding FROM People")] + + def sql_get_person_name(self, encoding): + enc = base64.binascii.b2a_base64(encoding).decode("ascii") + return self.db.fetch_data(f"SELECT name FROM People WHERE encoding = '{enc}'")[0][0] + + def sql_get_all_people(self): + return [x[0] for x in self.db.fetch_data("SELECT name FROM People")] # endregion class AboutWindow(QWidget):