diff --git a/groove/db/manager.py b/groove/db/manager.py deleted file mode 100644 index 0a50029..0000000 --- a/groove/db/manager.py +++ /dev/null @@ -1,43 +0,0 @@ -import os - -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -from . import metadata - - -class DatabaseManager: - """ - A context manager for working with sqllite database. - """ - - def __init__(self): - self._engine = None - self._session = None - - @property - def engine(self): - if not self._engine: - self._engine = create_engine(f"sqlite:///{os.environ.get('DATABASE_PATH')}", future=True) - return self._engine - - @property - def session(self): - if not self._session: - Session = sessionmaker(bind=self.engine, future=True) - self._session = Session() - return self._session - - def import_from_filesystem(self): - pass - - def __enter__(self): - metadata.create_all(bind=self.engine) - return self - - def __exit__(self, exc_type, exc_value, traceback): - if self.session: - self.session.close() - - -database_manager = DatabaseManager diff --git a/groove/playlist.py b/groove/playlist.py index 6d62bdc..d7cec36 100644 --- a/groove/playlist.py +++ b/groove/playlist.py @@ -1,6 +1,9 @@ from groove import db from sqlalchemy import func, delete -from sqlalchemy.exc import NoResultFound +from sqlalchemy.orm.session import Session +from sqlalchemy.engine.row import Row +from sqlalchemy.exc import NoResultFound, MultipleResultsFound +from typing import Union, List import logging @@ -8,43 +11,52 @@ class Playlist: """ CRUD operations and convenience methods for playlists. """ - def __init__(self, slug, connection, create_if_not_exists=False): - self._conn = connection + def __init__(self, slug: str, session: Session, create_if_not_exists: bool = False): + self._session = session self._slug = slug self._record = None self._entries = None self._create_if_not_exists = create_if_not_exists @property - def exists(self): + def exists(self) -> bool: + """ + True if the playlist exists in the database. + """ return self.record is not None @property - def slug(self): + def slug(self) -> Union[str, None]: return self._slug @property - def conn(self): - return self._conn + def session(self) -> Union[Session, None]: + return self._session @property - def record(self): + def record(self) -> Union[Row, None]: + """ + Cache the playlist row from the database and return it. Optionally create it if it doesn't exist. + """ if not self._record: try: - self._record = self.conn.query(db.playlist).filter(db.playlist.c.slug == self.slug).one() + self._record = self.session.query(db.playlist).filter(db.playlist.c.slug == self.slug).one() logging.debug(f"Retrieved playlist {self._record.id}") except NoResultFound: pass if self._create_if_not_exists: self._record = self._create() - if not self._record: + if not self._record: # pragma: no cover raise RuntimeError(f"Tried to create a playlist but couldn't read it back using slug {self.slug}") return self._record @property - def entries(self): - if not self._entries: - self._entries = self.conn.query( + def entries(self) -> Union[List, None]: + """ + Cache the list of entries on this playlist and return it. + """ + if not self._entries and self.record: + self._entries = self.session.query( db.entry, db.track ).filter( @@ -59,46 +71,87 @@ class Playlist: @property def as_dict(self) -> dict: """ - Retrieve a playlist and its entries by its slug. + Return a dictionary of the playlist and its entries. """ - playlist = {} + if not self.record: + return {} playlist = dict(self.record) playlist['entries'] = [dict(entry) for entry in self.entries] return playlist - def add(self, paths) -> int: - return self._create_entries(self._get_tracks_by_path(paths)) + def add(self, paths: List[str]) -> int: + """ + Add entries to the playlist. Each path should match one and only one track in the database (case-insensitive). + If a path doesn't match any track, or if a path matches multiple tracks, nothing is added to the playlist. - def delete(self): + Args: + paths (list): A list of partial paths to add. + + Returns: + int: The number of tracks added. + """ + try: + return self._create_entries(self._get_tracks_by_path(paths)) + except NoResultFound: + logging.error("One or more of the specified paths do not match any tracks in the database.") + return 0 + except MultipleResultsFound: + logging.error("One or more of the specified paths matches multiple tracks in the database.") + return 0 + + def delete(self) -> Union[int, None]: + """ + Delete a playlist and its entries from the database, then clear the cached values. + """ + if not self.record: + return None plid = self.record.id stmt = delete(db.entry).where(db.entry.c.playlist_id == plid) logging.debug(f"Deleting entries associated with playlist {plid}: {stmt}") - self.conn.execute(stmt) + self.session.execute(stmt) stmt = delete(db.playlist).where(db.playlist.c.id == plid) logging.debug(f"Deleting playlist {plid}: {stmt}") - self.conn.execute(stmt) - self.conn.commit() + self.session.execute(stmt) + self.session.commit() + self._record = None + self._entries = None return plid - def _get_tracks_by_path(self, paths): - return [self.conn.query(db.track).filter(db.track.c.relpath.ilike(f"%{path}%")).one() for path in paths] + def _get_tracks_by_path(self, paths: List[str]) -> List: + """ + Retrieve tracks from the database that match the specified path fragments. The exceptions NoResultFound and + MultipleResultsFound are expected in the case of no matches and multiple matches, respectively. + """ + return [self.session.query(db.track).filter(db.track.c.relpath.ilike(f"%{path}%")).one() for path in paths] - def _create_entries(self, tracks): + def _create_entries(self, tracks: List[Row]) -> int: + """ + Append a list of tracks to a playlist by populating the entries table with records referencing the playlist and + the specified tracks. - maxtrack = self.conn.query(func.max(db.entry.c.track)).filter_by(playlist_id=self.record.id).one()[0] - self.conn.execute( + Args: + tracks (list): A list of Row objects from the track table. + + Returns: + int: The number of tracks added. + """ + maxtrack = self.session.query(func.max(db.entry.c.track)).filter_by(playlist_id=self.record.id).one()[0] or 0 + self.session.execute( db.entry.insert(), [ {'playlist_id': self.record.id, 'track_id': obj.id, 'track': idx} for (idx, obj) in enumerate(tracks, start=maxtrack+1) ] ) - self.conn.commit() + self.session.commit() return len(tracks) - def _create(self): + def _create(self) -> Row: + """ + Insert a new playlist record into the database. + """ stmt = db.playlist.insert({'slug': self.slug}) - results = self.conn.execute(stmt) - self.conn.commit() - logging.debug(f"Created new playlist {results.inserted_primary_key} with slug {self.slug}") - return self.conn.query(db.playlist).filter(db.playlist.c.id == results.inserted_primary_key).one() + results = self.session.execute(stmt) + self.session.commit() + logging.debug(f"Created new playlist {results.inserted_primary_key[0]} with slug {self.slug}") + return self.session.query(db.playlist).filter(db.playlist.c.id == results.inserted_primary_key[0]).one() diff --git a/groove/webserver.py b/groove/webserver.py index 3867e27..deb54a4 100644 --- a/groove/webserver.py +++ b/groove/webserver.py @@ -45,7 +45,7 @@ def get_playlist(slug, db): Retrieve a playlist and its entries by a slug. """ logging.debug(f"Looking up playlist: {slug}...") - playlist = Playlist(slug=slug, conn=db) + playlist = Playlist(slug=slug, session=db) if not playlist.exists: return HTTPResponse(status=404, body="Not found") response = json.dumps(playlist.as_dict) diff --git a/test/conftest.py b/test/conftest.py index a075d1d..5a7fa7a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -38,7 +38,8 @@ def db(in_memory_db): in_memory_db.execute(query, [ {'id': 1, 'name': 'playlist one', 'description': 'the first one', 'slug': 'playlist-one'}, {'id': 2, 'name': 'playlist two', 'description': 'the second one', 'slug': 'playlist-two'}, - {'id': 3, 'name': 'playlist three', 'description': 'the threerd one', 'slug': 'playlist-three'} + {'id': 3, 'name': 'playlist three', 'description': 'the threerd one', 'slug': 'playlist-three'}, + {'id': 4, 'name': 'empty playlist', 'description': 'no tracks', 'slug': 'empty-playlist'} ]) # populate the playlists diff --git a/test/test_playlists.py b/test/test_playlists.py index 3b047bb..f3ae1ed 100644 --- a/test/test_playlists.py +++ b/test/test_playlists.py @@ -1,12 +1,48 @@ +# 70, 73-81, 84, 88-97, 100-104 +import pytest +from groove import playlist -def test_create_playlist(): - pass +def test_create(db): + pl = playlist.Playlist(slug='test-create-playlist', session=db, create_if_not_exists=True) + assert pl.exists -def test_update_playlist(): - pass +@pytest.mark.parametrize('tracks', [ + ('01 Guns Blazing', ), + ('01 Guns Blazing', '02 UNKLE'), +]) +def test_add(db, tracks): + pl = playlist.Playlist(slug='test-create-playlist', session=db, create_if_not_exists=True) + count = pl.add(tracks) + assert count == len(tracks) -def delete_playlist(): - pass +def test_add_no_matches(db): + pl = playlist.Playlist(slug='playlist-one', session=db) + assert pl.add(('no match', )) == 0 + + +def test_add_multiple_matches(db): + pl = playlist.Playlist(slug='playlist-one', session=db) + assert pl.add('UNKLE',) == 0 + + +def test_delete(db): + pl = playlist.Playlist(slug='playlist-one', session=db) + expected = pl.record.id + assert pl.delete() == expected + assert not pl.as_dict + + +def test_delete_playlist_not_exist(db): + pl = playlist.Playlist(slug='playlist-doesnt-exist', session=db) + assert not pl.delete() + assert not pl.as_dict + + +def test_entries(db): + pl = playlist.Playlist(slug='playlist-one', session=db) + # assert twice for branch coverage of cached values + assert pl.entries + assert pl.entries diff --git a/test/test_scanner.py b/test/test_scanner.py index 227e9fa..7e45274 100644 --- a/test/test_scanner.py +++ b/test/test_scanner.py @@ -1,4 +1,5 @@ import pytest +import os from pathlib import Path from unittest.mock import MagicMock from sqlalchemy import func @@ -47,3 +48,9 @@ def test_scanner(monkeypatch, in_memory_db, media): # verify idempotency assert test_scanner.scan() == 0 + + +def test_scanner_no_media_root(in_memory_db): + del os.environ['MEDIA_ROOT'] + with pytest.raises(SystemExit): + assert scanner.media_scanner(root=None, db=in_memory_db)