diff --git a/cardvault/application.py b/cardvault/application.py index d37f664..4ce60ab 100644 --- a/cardvault/application.py +++ b/cardvault/application.py @@ -31,7 +31,6 @@ class Application: self.ui.add_from_file(util.get_ui_filename("dialogs.glade")) self.current_page = None - self.unsaved_changes = False self.current_lib_tag = "Untagged" self.db = database.CardVaultDB(util.get_root_filename(util.DB_NAME)) @@ -200,21 +199,22 @@ class Application: else: return value + def unsaved_changes(self) -> bool: + """Check if database is in transaction""" + return self.db.db_unsaved_changes() + def save_config(self): cf = util.get_root_filename("config.json") util.save_config(self.config, cf) util.log("Config saved to '{}'".format(cf), util.LogLevel.Info) def save_data(self): - # util.log("Saving Data to database", util.LogLevel.Info) - # start = time.time() - # self.db.save_library(self.library) - # self.db.save_tags(self.tags) - # self.db.save_wants(self.wants) - # end = time.time() - # util.log("Finished in {}s".format(str(round(end - start, 3))), util.LogLevel.Info) - # self.unsaved_changes = False - # self.push_status("All data saved.") + util.log("Saving Data to database", util.LogLevel.Info) + start = time.time() + self.db.db_save_changes() + end = time.time() + util.log("Finished in {}s".format(str(round(end - start, 3))), util.LogLevel.Info) + self.push_status("All data saved.") pass def load_user_data(self): @@ -287,7 +287,6 @@ class Application: del self.tags[old] self.db.tag_rename(old, new) util.log("Tag '" + old + "' renamed to '" + new + "'", util.LogLevel.Info) - self.unsaved_changes = True def get_wanted_card_ids(self) -> List[str]: all_ids = [] diff --git a/cardvault/database.py b/cardvault/database.py index bda845a..7e972fd 100644 --- a/cardvault/database.py +++ b/cardvault/database.py @@ -9,6 +9,7 @@ class CardVaultDB: """Data access class for sqlite3""" def __init__(self, db_file: str): self.db_file = db_file + self.connection = sqlite3.connect(self.db_file) # Database operations ############################################################################################## @@ -37,7 +38,8 @@ class CardVaultDB: "booster TEXT, oldcode TEXT)") def db_card_insert(self, card: Card): - # Connect to database + """Insert single card data into database""" + # Use own connection so that inserts are commited directly con = sqlite3.connect(self.db_file) try: with con: @@ -54,13 +56,12 @@ class CardVaultDB: pass def db_get_all(self): + """Return data of all cards in database""" sql = 'SELECT * FROM cards' - con = sqlite3.connect(self.db_file) - cur = con.cursor() + cur = self.connection.cursor() cur.row_factory = sqlite3.Row cur.execute(sql) rows = cur.fetchall() - con.close() output = [] for row in rows: card = self.table_to_card_mapping(row) @@ -85,7 +86,7 @@ class CardVaultDB: set = Set(data) s_rows.append(self.set_to_table_mapping(set)) - # Connect to database + # Use separate connection to commit changes immediately con = sqlite3.connect(self.db_file) try: with con: @@ -117,12 +118,10 @@ class CardVaultDB: def lib_get_all(self) -> dict: """Load library from database""" - con = sqlite3.connect(self.db_file) - cur = con.cursor() + cur = self.connection.cursor() cur.row_factory = sqlite3.Row cur.execute('SELECT * FROM `library` INNER JOIN `cards` ON library.multiverseid = cards.multiverseid') rows = cur.fetchall() - con.close() return self.rows_to_card_dict(rows) @@ -138,8 +137,7 @@ class CardVaultDB: def tag_get_all(self) -> dict: """Loads a dict from database with all tags and the card ids tagged""" - con = sqlite3.connect(self.db_file) - cur = con.cursor() + cur = self.connection.cursor() cur.row_factory = sqlite3.Row # First load all tags @@ -179,8 +177,7 @@ class CardVaultDB: def tag_card_check_tagged(self, card) -> tuple: """Check if a card is tagged. Return True/False and a list of tags.""" - con = sqlite3.connect(self.db_file) - cur = con.cursor() + cur = self.connection.cursor() cur.row_factory = sqlite3.Row cur.execute('SELECT `tag` FROM `tags` WHERE tags.multiverseid = ? ', (card.multiverse_id,)) rows = cur.fetchall() @@ -197,8 +194,7 @@ class CardVaultDB: def wants_get_all(self) -> dict: """Load all wants lists from database""" - con = sqlite3.connect(self.db_file) - cur = con.cursor() + cur = self.connection.cursor() cur.row_factory = sqlite3.Row # First load all lists @@ -294,7 +290,6 @@ class CardVaultDB: cur = con.cursor() cur.row_factory = sqlite3.Row - # First load all tags cur.execute("SELECT * FROM sets") rows = cur.fetchall() sets = [] @@ -313,16 +308,27 @@ class CardVaultDB: output[card.multiverse_id] = card return output - def db_operation(self, sql: str, parms: tuple=()): + def db_operation(self, sql: str, args: tuple=()): """Perform an arbitrary sql operation on the database""" - con = sqlite3.connect(self.db_file) + cur = self.connection.cursor() try: - with con: - con.execute(sql, parms) + cur.execute(sql, args) except sqlite3.OperationalError as err: util.log("Database Error", util.LogLevel.Error) util.log(str(err), util.LogLevel.Error) + def db_save_changes(self): + try: + self.connection.commit() + except sqlite3.Error as err: + self.connection.rollback() + util.log("Database Error", util.LogLevel.Error) + util.log(str(err), util.LogLevel.Error) + + def db_unsaved_changes(self) -> bool: + """Checks if database is currently in transaction""" + return self.connection.in_transaction + @staticmethod def filter_colors_list(mana: list) -> str: symbols = util.unique_list(mana) diff --git a/cardvault/handlers.py b/cardvault/handlers.py index 23c561f..6475635 100644 --- a/cardvault/handlers.py +++ b/cardvault/handlers.py @@ -138,10 +138,10 @@ class Handlers(SearchHandlers, LibraryHandlers, WantsHandlers): self.app.ui.get_object("mainWindow").set_title(app_title) def do_delete_event(self, arg1, arg2): - if self.app.unsaved_changes: + if self.app.unsaved_changes(): response = self.app.show_dialog_ync("Unsaved Changes", - "You have unsaved changes in your library. " - "Save before exiting?") + "You have unsaved changes in your library. " + "Save before exiting?") if response == Gtk.ResponseType.YES: self.app.save_data() return False @@ -177,7 +177,7 @@ class Handlers(SearchHandlers, LibraryHandlers, WantsHandlers): def do_download_card_data(self, item: Gtk.MenuItem): """Download button was pressed in the menu bar. Starts a thread to load data from the internet""" - info_string = "Start downloading card information from the internet?\n" \ + info_string = "Start downloading card information from the internet?\n" \ "You can cancel the download at any point." response = self.app.show_dialog_yn("Download Card Data", info_string) if response == Gtk.ResponseType.NO: