diff --git a/app.py b/app.py index 23c72b2..d3d5bc7 100644 --- a/app.py +++ b/app.py @@ -174,7 +174,7 @@ def season_overview(season_id: int): model = { "title": f"{season.code} {season.game}", "season_info": infos, - "episodes": episodes + "episodes": episodes, } return render_template("season_overview.html", model=model) @@ -187,9 +187,7 @@ def episode_list(season_id: int): sql, args = db.load_episodes(season_id) episodes = db.query_db(sql, args, cls=models.Episode) - model = { - "season_id": season_id, - "season_code": season.code} + model = {"season_id": season_id, "season_code": season.code} return render_template("episode_list.html", model=model) @@ -220,7 +218,18 @@ def episode_edit(season_id: int, episode_id: int): sql, args = db.load_episode(episode_id) episode: models.Episode = db.query_db(sql, args, one=True, cls=models.Episode) + sql, args = db.load_episode_players(episode_id) + ep_players = db.query_db(sql, args, cls=models.Player) + form = forms.EpisodeForm() + form.season_id.data = episode.season_id + form.episode_id.data = episode.id + form.code.data = episode.code + form.date.data = episode.date + form.start.data = episode.start + form.end.data = episode.end + form.title.data = episode.title + form.players.data = [p.id for p in ep_players] model.form_title = f"Edit Episode '{episode.code}: {episode.title}'" return render_template("generic_form.html", model=model, form=form) @@ -231,9 +240,31 @@ def episode_edit(season_id: int, episode_id: int): model.errors = form.errors return render_template("generic_form.html", model=model, form=form) + errors = False episode = models.Episode.from_form(form) sql, args = db.save_episode(episode) - errors = db.update_db(sql, args) + + last_key = db.update_db(sql, args, return_key=True) + + episode_id = episode.id if episode.id else last_key + + form_ids = form.players.data + + sql, args = db.load_episode_players(episode_id) + ep_players = db.query_db(sql, args, cls=models.Player) + pids = [p.id for p in ep_players] + + new_ids = [pid for pid in form_ids if pid not in pids] + removed_ids = [pid for pid in pids if pid not in form_ids] + + if removed_ids: + sql, args = db.remove_episode_player(episode_id, removed_ids) + errors = db.update_db(sql, args) + + if new_ids: + sql, args = db.save_episode_players(episode_id, new_ids) + errors = db.update_db(sql, args) + if errors: model.errors = {"Error saving episode": [errors]} return render_template("generic_form.html", model=model, form=form) diff --git a/db.py b/db.py index e1d1ba1..14c5166 100644 --- a/db.py +++ b/db.py @@ -8,6 +8,10 @@ import models from config import Config +class DataBaseError(Exception): + """General exception class for SQL errors""" + + def connect_db(): """Create a new sqlite3 connection and register it in 'g._database'""" db = getattr(g, "_database", None) @@ -30,20 +34,34 @@ def query_db(query, args=(), one=False, cls=None): return (rv[0] if rv else None) if one else rv -def update_db(query, args=()): +def update_db(query, args=(), return_key: bool = False): """ Runs an changing query on the database Returns either False if no error has occurred, or an sqlite3 Exception + :param query: An SQL query string + :param args: Tuple for inserting into a row + :param return_key: Changes return behavior of the function: + If used function will return last row id. + Exceptions will be raised instead of returned. """ log.debug(f"Running query ({query}) with arguments ({args})") with connect_db() as con: + cur = con.cursor() + + multi_args = any(isinstance(i, tuple) for i in args) + try: - con.cursor().execute(query, args) + if multi_args: + cur.executemany(query, args) + else: + cur.execute(query, args) except sqlite3.Error as err: - return err + if not return_key: + return err + raise else: con.commit() - return False + return cur.lastrowid if return_key else False def init_db(): @@ -189,9 +207,31 @@ def load_episodes(season_id: int = None): return sql, args +def load_episode_player_links(episode_id: int): + sql = "select * from episode_player where episode_id = ?" + args = (episode_id,) + return sql, args + + +def load_episode_players(episode_id: int): + sql = "select player.* " \ + "from player " \ + "left join episode_player ep on player.id = ep.player_id " \ + "where ep.episode_id = ?" + args = (episode_id,) + return sql, args + + def save_episode_players(episode_id: int, player_ids: Sequence[int]): - sql = "insert into episode_player values (?, ?)" - args = [(episode_id, i) for i in player_ids] + sql = "insert into episode_player values (?, ?, ?)" + args = tuple((None, episode_id, i) for i in player_ids) + return sql, args + + +def remove_episode_player(episode_id: int, player_ids: Sequence[int]): + sql = "delete from episode_player " \ + "where episode_id = ? and player_id = ?" + args = tuple((episode_id, pid) for pid in player_ids) return sql, args diff --git a/models.py b/models.py index 5169e33..d2f46f1 100644 --- a/models.py +++ b/models.py @@ -114,8 +114,8 @@ class Episode: code: str def __post_init__(self): - if isinstance(self.date, Rational): - self.date = datetime.date.fromtimestamp(self.date) + if isinstance(self.date, str): + self.date = datetime.datetime.strptime(self.date, util.DATE_FMT).date() if isinstance(self.start, Rational): self.start = datetime.datetime.fromtimestamp(self.start) if isinstance(self.end, Rational): diff --git a/schema.sql b/schema.sql index f671489..97d4711 100644 --- a/schema.sql +++ b/schema.sql @@ -78,11 +78,16 @@ create unique index if not exists episode_id_uindex create table if not exists episode_player ( + link_id integer not null + constraint episode_player_pk + primary key autoincrement, episode_id integer not null - constraint episode_player_episode_id_fk - references episode, + references episode, player_id integer not null - constraint episode_player_player_id_fk - references player + references player ); +create unique index if not exists episode_player_link_id_uindex + on episode_player (link_id); + +