From ed16ebdd0e328c707b6d925386e48fbe86a27e68 Mon Sep 17 00:00:00 2001 From: evilchili Date: Fri, 2 Dec 2022 00:21:19 -0800 Subject: [PATCH] cleanup, more tests --- groove/playlist.py | 86 ++++++++++++++++------------- groove/shell/__init__.py | 2 +- groove/shell/interactive_shell.py | 12 ++-- groove/shell/{create.py => load.py} | 7 ++- groove/webserver.py | 3 +- test/test_playlists.py | 37 +++++++++++-- test/test_shell.py | 13 ++++- 7 files changed, 103 insertions(+), 57 deletions(-) rename groove/shell/{create.py => load.py} (80%) diff --git a/groove/playlist.py b/groove/playlist.py index a3e03a3..84ab4ff 100644 --- a/groove/playlist.py +++ b/groove/playlist.py @@ -18,21 +18,27 @@ class Playlist: session: Session, name: str = '', description: str = '', - create_if_not_exists: bool = False): + create_ok=True): self._session = session self._slug = slug self._name = name self._description = description - self._record = None self._entries = None - self._create_if_not_exists = create_if_not_exists + self._record = None + self._create_ok = create_ok + self._deleted = False + + @property + def deleted(self) -> bool: + return self._deleted @property def exists(self) -> bool: - """ - True if the playlist exists in the database. - """ - return self.record is not None + if self.deleted: + return False + if not self._record: + return (self._create_ok and self.record) + return True @property def summary(self): @@ -43,11 +49,11 @@ class Playlist: ]) @property - def slug(self) -> Union[str, None]: + def slug(self) -> str: return self._slug @property - def session(self) -> Union[Session, None]: + def session(self) -> Session: return self._session @property @@ -56,16 +62,7 @@ class Playlist: Cache the playlist row from the database and return it. Optionally create it if it doesn't exist. """ if not self._record: - try: - self._record = self.session.query(db.playlist).filter(db.playlist.c.slug == self.slug).one() - logging.debug(f"Retrieved playlist {self._record.id}") - except NoResultFound: - logging.debug(f"Could not find a playlist with slug {self.slug}.") - pass - if not self._record and self._create_if_not_exists: - self._record = self._create() - if not self._record: # pragma: no cover - raise RuntimeError(f"Tried to create a playlist but couldn't read it back using slug {self.slug}") + self._record = self.get_or_create() return self._record @property @@ -86,7 +83,6 @@ class Playlist: ).order_by( db.entry.c.track ) - # self._entries = list(db.windowed_query(query, db.entry.c.track_id, 1000)) self._entries = query.all() return self._entries @@ -95,7 +91,7 @@ class Playlist: """ Return a dictionary of the playlist and its entries. """ - if not self.record: + if not self.exists: return {} playlist = dict(self.record) playlist['entries'] = [dict(entry) for entry in self.entries] @@ -103,11 +99,20 @@ class Playlist: @property def as_string(self) -> str: + if not self.exists: + return '' text = f"{self.summary}\n" for entry in self.entries: text += f" - {entry.track} {entry.artist} - {entry.title}\n" return text + def _get_tracks_by_path(self, paths: List[str]) -> List: + """ + Retrieve tracks from the database that match the specified path fragments. The exceptions NoResultFound and + MultipleResultsFound are expected in the case of no matches and multiple matches, respectively. + """ + return [self.session.query(db.track).filter(db.track.c.relpath.ilike(f"%{path}%")).one() for path in paths] + def add(self, paths: List[str]) -> int: """ Add entries to the playlist. Each path should match one and only one track in the database (case-insensitive). @@ -145,14 +150,26 @@ class Playlist: self.session.commit() self._record = None self._entries = None + self._deleted = True return plid - def _get_tracks_by_path(self, paths: List[str]) -> List: - """ - Retrieve tracks from the database that match the specified path fragments. The exceptions NoResultFound and - MultipleResultsFound are expected in the case of no matches and multiple matches, respectively. - """ - return [self.session.query(db.track).filter(db.track.c.relpath.ilike(f"%{path}%")).one() for path in paths] + def get_or_create(self, create_ok: bool = False) -> Row: + try: + return self.session.query(db.playlist).filter(db.playlist.c.slug == self.slug).one() + except NoResultFound: + logging.debug(f"Could not find a playlist with slug {self.slug}.") + if self.deleted: + raise RuntimeError("Object has been deleted.") + if self._create_ok or create_ok: + return self.save() + + def save(self) -> Row: + keys = {'slug': self.slug, 'name': self._name, 'description': self._description} + stmt = db.playlist.update(keys) if self._record else db.playlist.insert(keys) + results = self.session.execute(stmt) + self.session.commit() + logging.debug(f"Saved playlist {results.inserted_primary_key[0]} with slug {self.slug}") + return self.session.query(db.playlist).filter(db.playlist.c.id == results.inserted_primary_key[0]).one() def create_entries(self, tracks: List[Row]) -> int: """ @@ -165,7 +182,10 @@ class Playlist: Returns: int: The number of tracks added. """ - maxtrack = self.session.query(func.max(db.entry.c.track)).filter_by(playlist_id=self.record.id).one()[0] or 0 + maxtrack = self.session.query(func.max(db.entry.c.track)).filter_by( + playlist_id=self.record.id + ).one()[0] or 0 + self.session.execute( db.entry.insert(), [ @@ -177,16 +197,6 @@ class Playlist: self._entries = None return len(tracks) - def _create(self) -> Row: - """ - Insert a new playlist record into the database. - """ - stmt = db.playlist.insert({'slug': self.slug, 'name': self._name, 'description': self._description}) - results = self.session.execute(stmt) - self.session.commit() - logging.debug(f"Created new playlist {results.inserted_primary_key[0]} with slug {self.slug}") - return self.session.query(db.playlist).filter(db.playlist.c.id == results.inserted_primary_key[0]).one() - @classmethod def from_row(cls, row, session): pl = Playlist(slug=row.slug, session=session) diff --git a/groove/shell/__init__.py b/groove/shell/__init__.py index ef6e387..5542b20 100644 --- a/groove/shell/__init__.py +++ b/groove/shell/__init__.py @@ -4,4 +4,4 @@ from .help import help from .browse import browse from .stats import stats from .playlist import _playlist -from .create import create +from .load import load diff --git a/groove/shell/interactive_shell.py b/groove/shell/interactive_shell.py index 0181b95..e10103c 100644 --- a/groove/shell/interactive_shell.py +++ b/groove/shell/interactive_shell.py @@ -1,4 +1,3 @@ -from rich import print from slugify import slugify from groove.db.manager import database_manager @@ -36,7 +35,7 @@ class CommandPrompt(BasePrompt): def default_completer(self, document, complete_event): def _formatter(row): - self._playlist = Playlist.from_row(row, self.manager) + self._playlist = Playlist.from_row(row, self.manager.session) return self.playlist.record.name return self.manager.fuzzy_table_completer( db.playlist, @@ -48,18 +47,17 @@ class CommandPrompt(BasePrompt): name = cmd + ' ' + ' '.join(parts) if cmd in self.commands: self.commands[cmd].start(name) - elif not parts: - print(f"Command not understood: {cmd}") else: slug = slugify(name) self._playlist = Playlist( slug=slug, name=name, session=self.manager.session, - create_if_not_exists=False + create_ok=True ) - self.commands['_playlist'].start() - self._playlist = None + res = self.commands['_playlist'].start() + if res is False: + return res return True diff --git a/groove/shell/create.py b/groove/shell/load.py similarity index 80% rename from groove/shell/create.py rename to groove/shell/load.py index bb081ca..b749254 100644 --- a/groove/shell/create.py +++ b/groove/shell/load.py @@ -5,12 +5,12 @@ from slugify import slugify from groove.playlist import Playlist -class create(BasePrompt): +class load(BasePrompt): """Create a new playlist.""" @property def usage(self): - return "create PLAYLIST_NAME" + return "load PLAYLIST_NAME" def process(self, cmd, *parts): name = ' '.join(parts) @@ -22,6 +22,7 @@ class create(BasePrompt): slug=slug, name=name, session=self.manager.session, - create_if_not_exists=True + create_ok=True ) + print(self.parent.playlist.summary) return self.parent.commands['_playlist'].start() diff --git a/groove/webserver.py b/groove/webserver.py index d26f50c..7fbc197 100644 --- a/groove/webserver.py +++ b/groove/webserver.py @@ -55,7 +55,8 @@ def get_playlist(slug, db): Retrieve a playlist and its entries by a slug. """ logging.debug(f"Looking up playlist: {slug}...") - playlist = Playlist(slug=slug, session=db) + playlist = Playlist(slug=slug, session=db, create_ok=False) + print(playlist.record) if not playlist.exists: return HTTPResponse(status=404, body="Not found") response = json.dumps(playlist.as_dict) diff --git a/test/test_playlists.py b/test/test_playlists.py index f3ae1ed..e7b2cf2 100644 --- a/test/test_playlists.py +++ b/test/test_playlists.py @@ -4,8 +4,8 @@ from groove import playlist def test_create(db): - pl = playlist.Playlist(slug='test-create-playlist', session=db, create_if_not_exists=True) - assert pl.exists + pl = playlist.Playlist(slug='test-create-playlist', session=db, create_ok=True) + assert pl.record.id @pytest.mark.parametrize('tracks', [ @@ -13,7 +13,7 @@ def test_create(db): ('01 Guns Blazing', '02 UNKLE'), ]) def test_add(db, tracks): - pl = playlist.Playlist(slug='test-create-playlist', session=db, create_if_not_exists=True) + pl = playlist.Playlist(slug='test-create-playlist', session=db) count = pl.add(tracks) assert count == len(tracks) @@ -32,13 +32,23 @@ def test_delete(db): pl = playlist.Playlist(slug='playlist-one', session=db) expected = pl.record.id assert pl.delete() == expected - assert not pl.as_dict + assert not pl.exists + assert pl.deleted def test_delete_playlist_not_exist(db): - pl = playlist.Playlist(slug='playlist-doesnt-exist', session=db) + pl = playlist.Playlist(slug='playlist-doesnt-exist', session=db, create_ok=False) assert not pl.delete() - assert not pl.as_dict + assert not pl.exists + assert not pl.deleted + + +def test_cannot_create_after_delete(db): + pl = playlist.Playlist(slug='playlist-one', session=db) + pl.delete() + with pytest.raises(RuntimeError): + assert pl.record + assert not pl.exists def test_entries(db): @@ -46,3 +56,18 @@ def test_entries(db): # assert twice for branch coverage of cached values assert pl.entries assert pl.entries + + +def test_playlist_not_exist_formatted(db): + pl = playlist.Playlist(slug='fnord', session=db, create_ok=False) + assert not repr(pl) + assert not pl.as_dict + + +def test_playlist_formatted(db): + pl = playlist.Playlist(slug='playlist-one', session=db) + assert repr(pl) + assert pl.as_string + assert pl.as_dict + + diff --git a/test/test_shell.py b/test/test_shell.py index d808a96..7987cca 100644 --- a/test/test_shell.py +++ b/test/test_shell.py @@ -13,7 +13,7 @@ def cmd_prompt(in_memory_engine, db): def response_factory(responses): - return MagicMock(side_effect=responses + ['']) + return MagicMock(side_effect=responses + ([''] * 10)) @pytest.mark.parametrize('inputs, expected', [ @@ -58,3 +58,14 @@ def test_help(monkeypatch, capsys, cmd_prompt, inputs, expected): output = capsys.readouterr() for txt in expected: assert txt in output.out + + +@pytest.mark.parametrize('inputs, expected', [ + ('load A New Playlist', 'a-new-playlist'), + ('new playlist', 'new-playlist'), + ('load', '') +]) +def test_load(monkeypatch, caplog, cmd_prompt, inputs, expected): + monkeypatch.setattr('groove.shell.base.prompt', response_factory([inputs])) + cmd_prompt.start() + assert expected in caplog.text