updated CRUD tests, added type hinting.

This commit is contained in:
evilchili 2022-11-25 12:20:43 -08:00
parent 5f8ab8fe25
commit 267af75cb4
6 changed files with 137 additions and 83 deletions

View File

@ -1,43 +0,0 @@
import os
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from . import metadata
class DatabaseManager:
"""
A context manager for working with sqllite database.
"""
def __init__(self):
self._engine = None
self._session = None
@property
def engine(self):
if not self._engine:
self._engine = create_engine(f"sqlite:///{os.environ.get('DATABASE_PATH')}", future=True)
return self._engine
@property
def session(self):
if not self._session:
Session = sessionmaker(bind=self.engine, future=True)
self._session = Session()
return self._session
def import_from_filesystem(self):
pass
def __enter__(self):
metadata.create_all(bind=self.engine)
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.session:
self.session.close()
database_manager = DatabaseManager

View File

@ -1,6 +1,9 @@
from groove import db
from sqlalchemy import func, delete
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm.session import Session
from sqlalchemy.engine.row import Row
from sqlalchemy.exc import NoResultFound, MultipleResultsFound
from typing import Union, List
import logging
@ -8,43 +11,52 @@ class Playlist:
"""
CRUD operations and convenience methods for playlists.
"""
def __init__(self, slug, connection, create_if_not_exists=False):
self._conn = connection
def __init__(self, slug: str, session: Session, create_if_not_exists: bool = False):
self._session = session
self._slug = slug
self._record = None
self._entries = None
self._create_if_not_exists = create_if_not_exists
@property
def exists(self):
def exists(self) -> bool:
"""
True if the playlist exists in the database.
"""
return self.record is not None
@property
def slug(self):
def slug(self) -> Union[str, None]:
return self._slug
@property
def conn(self):
return self._conn
def session(self) -> Union[Session, None]:
return self._session
@property
def record(self):
def record(self) -> Union[Row, None]:
"""
Cache the playlist row from the database and return it. Optionally create it if it doesn't exist.
"""
if not self._record:
try:
self._record = self.conn.query(db.playlist).filter(db.playlist.c.slug == self.slug).one()
self._record = self.session.query(db.playlist).filter(db.playlist.c.slug == self.slug).one()
logging.debug(f"Retrieved playlist {self._record.id}")
except NoResultFound:
pass
if self._create_if_not_exists:
self._record = self._create()
if not self._record:
if not self._record: # pragma: no cover
raise RuntimeError(f"Tried to create a playlist but couldn't read it back using slug {self.slug}")
return self._record
@property
def entries(self):
if not self._entries:
self._entries = self.conn.query(
def entries(self) -> Union[List, None]:
"""
Cache the list of entries on this playlist and return it.
"""
if not self._entries and self.record:
self._entries = self.session.query(
db.entry,
db.track
).filter(
@ -59,46 +71,87 @@ class Playlist:
@property
def as_dict(self) -> dict:
"""
Retrieve a playlist and its entries by its slug.
Return a dictionary of the playlist and its entries.
"""
playlist = {}
if not self.record:
return {}
playlist = dict(self.record)
playlist['entries'] = [dict(entry) for entry in self.entries]
return playlist
def add(self, paths) -> int:
return self._create_entries(self._get_tracks_by_path(paths))
def add(self, paths: List[str]) -> int:
"""
Add entries to the playlist. Each path should match one and only one track in the database (case-insensitive).
If a path doesn't match any track, or if a path matches multiple tracks, nothing is added to the playlist.
def delete(self):
Args:
paths (list): A list of partial paths to add.
Returns:
int: The number of tracks added.
"""
try:
return self._create_entries(self._get_tracks_by_path(paths))
except NoResultFound:
logging.error("One or more of the specified paths do not match any tracks in the database.")
return 0
except MultipleResultsFound:
logging.error("One or more of the specified paths matches multiple tracks in the database.")
return 0
def delete(self) -> Union[int, None]:
"""
Delete a playlist and its entries from the database, then clear the cached values.
"""
if not self.record:
return None
plid = self.record.id
stmt = delete(db.entry).where(db.entry.c.playlist_id == plid)
logging.debug(f"Deleting entries associated with playlist {plid}: {stmt}")
self.conn.execute(stmt)
self.session.execute(stmt)
stmt = delete(db.playlist).where(db.playlist.c.id == plid)
logging.debug(f"Deleting playlist {plid}: {stmt}")
self.conn.execute(stmt)
self.conn.commit()
self.session.execute(stmt)
self.session.commit()
self._record = None
self._entries = None
return plid
def _get_tracks_by_path(self, paths):
return [self.conn.query(db.track).filter(db.track.c.relpath.ilike(f"%{path}%")).one() for path in paths]
def _get_tracks_by_path(self, paths: List[str]) -> List:
"""
Retrieve tracks from the database that match the specified path fragments. The exceptions NoResultFound and
MultipleResultsFound are expected in the case of no matches and multiple matches, respectively.
"""
return [self.session.query(db.track).filter(db.track.c.relpath.ilike(f"%{path}%")).one() for path in paths]
def _create_entries(self, tracks):
def _create_entries(self, tracks: List[Row]) -> int:
"""
Append a list of tracks to a playlist by populating the entries table with records referencing the playlist and
the specified tracks.
maxtrack = self.conn.query(func.max(db.entry.c.track)).filter_by(playlist_id=self.record.id).one()[0]
self.conn.execute(
Args:
tracks (list): A list of Row objects from the track table.
Returns:
int: The number of tracks added.
"""
maxtrack = self.session.query(func.max(db.entry.c.track)).filter_by(playlist_id=self.record.id).one()[0] or 0
self.session.execute(
db.entry.insert(),
[
{'playlist_id': self.record.id, 'track_id': obj.id, 'track': idx}
for (idx, obj) in enumerate(tracks, start=maxtrack+1)
]
)
self.conn.commit()
self.session.commit()
return len(tracks)
def _create(self):
def _create(self) -> Row:
"""
Insert a new playlist record into the database.
"""
stmt = db.playlist.insert({'slug': self.slug})
results = self.conn.execute(stmt)
self.conn.commit()
logging.debug(f"Created new playlist {results.inserted_primary_key} with slug {self.slug}")
return self.conn.query(db.playlist).filter(db.playlist.c.id == results.inserted_primary_key).one()
results = self.session.execute(stmt)
self.session.commit()
logging.debug(f"Created new playlist {results.inserted_primary_key[0]} with slug {self.slug}")
return self.session.query(db.playlist).filter(db.playlist.c.id == results.inserted_primary_key[0]).one()

View File

@ -45,7 +45,7 @@ def get_playlist(slug, db):
Retrieve a playlist and its entries by a slug.
"""
logging.debug(f"Looking up playlist: {slug}...")
playlist = Playlist(slug=slug, conn=db)
playlist = Playlist(slug=slug, session=db)
if not playlist.exists:
return HTTPResponse(status=404, body="Not found")
response = json.dumps(playlist.as_dict)

View File

@ -38,7 +38,8 @@ def db(in_memory_db):
in_memory_db.execute(query, [
{'id': 1, 'name': 'playlist one', 'description': 'the first one', 'slug': 'playlist-one'},
{'id': 2, 'name': 'playlist two', 'description': 'the second one', 'slug': 'playlist-two'},
{'id': 3, 'name': 'playlist three', 'description': 'the threerd one', 'slug': 'playlist-three'}
{'id': 3, 'name': 'playlist three', 'description': 'the threerd one', 'slug': 'playlist-three'},
{'id': 4, 'name': 'empty playlist', 'description': 'no tracks', 'slug': 'empty-playlist'}
])
# populate the playlists

View File

@ -1,12 +1,48 @@
# 70, 73-81, 84, 88-97, 100-104
import pytest
from groove import playlist
def test_create_playlist():
pass
def test_create(db):
pl = playlist.Playlist(slug='test-create-playlist', session=db, create_if_not_exists=True)
assert pl.exists
def test_update_playlist():
pass
@pytest.mark.parametrize('tracks', [
('01 Guns Blazing', ),
('01 Guns Blazing', '02 UNKLE'),
])
def test_add(db, tracks):
pl = playlist.Playlist(slug='test-create-playlist', session=db, create_if_not_exists=True)
count = pl.add(tracks)
assert count == len(tracks)
def delete_playlist():
pass
def test_add_no_matches(db):
pl = playlist.Playlist(slug='playlist-one', session=db)
assert pl.add(('no match', )) == 0
def test_add_multiple_matches(db):
pl = playlist.Playlist(slug='playlist-one', session=db)
assert pl.add('UNKLE',) == 0
def test_delete(db):
pl = playlist.Playlist(slug='playlist-one', session=db)
expected = pl.record.id
assert pl.delete() == expected
assert not pl.as_dict
def test_delete_playlist_not_exist(db):
pl = playlist.Playlist(slug='playlist-doesnt-exist', session=db)
assert not pl.delete()
assert not pl.as_dict
def test_entries(db):
pl = playlist.Playlist(slug='playlist-one', session=db)
# assert twice for branch coverage of cached values
assert pl.entries
assert pl.entries

View File

@ -1,4 +1,5 @@
import pytest
import os
from pathlib import Path
from unittest.mock import MagicMock
from sqlalchemy import func
@ -47,3 +48,9 @@ def test_scanner(monkeypatch, in_memory_db, media):
# verify idempotency
assert test_scanner.scan() == 0
def test_scanner_no_media_root(in_memory_db):
del os.environ['MEDIA_ROOT']
with pytest.raises(SystemExit):
assert scanner.media_scanner(root=None, db=in_memory_db)