cleanup, more tests

This commit is contained in:
evilchili 2022-12-02 00:21:19 -08:00
parent a0de0eef48
commit ed16ebdd0e
7 changed files with 103 additions and 57 deletions

View File

@ -18,21 +18,27 @@ class Playlist:
session: Session, session: Session,
name: str = '', name: str = '',
description: str = '', description: str = '',
create_if_not_exists: bool = False): create_ok=True):
self._session = session self._session = session
self._slug = slug self._slug = slug
self._name = name self._name = name
self._description = description self._description = description
self._record = None
self._entries = None self._entries = None
self._create_if_not_exists = create_if_not_exists self._record = None
self._create_ok = create_ok
self._deleted = False
@property
def deleted(self) -> bool:
return self._deleted
@property @property
def exists(self) -> bool: def exists(self) -> bool:
""" if self.deleted:
True if the playlist exists in the database. return False
""" if not self._record:
return self.record is not None return (self._create_ok and self.record)
return True
@property @property
def summary(self): def summary(self):
@ -43,11 +49,11 @@ class Playlist:
]) ])
@property @property
def slug(self) -> Union[str, None]: def slug(self) -> str:
return self._slug return self._slug
@property @property
def session(self) -> Union[Session, None]: def session(self) -> Session:
return self._session return self._session
@property @property
@ -56,16 +62,7 @@ class Playlist:
Cache the playlist row from the database and return it. Optionally create it if it doesn't exist. Cache the playlist row from the database and return it. Optionally create it if it doesn't exist.
""" """
if not self._record: if not self._record:
try: self._record = self.get_or_create()
self._record = self.session.query(db.playlist).filter(db.playlist.c.slug == self.slug).one()
logging.debug(f"Retrieved playlist {self._record.id}")
except NoResultFound:
logging.debug(f"Could not find a playlist with slug {self.slug}.")
pass
if not self._record and self._create_if_not_exists:
self._record = self._create()
if not self._record: # pragma: no cover
raise RuntimeError(f"Tried to create a playlist but couldn't read it back using slug {self.slug}")
return self._record return self._record
@property @property
@ -86,7 +83,6 @@ class Playlist:
).order_by( ).order_by(
db.entry.c.track db.entry.c.track
) )
# self._entries = list(db.windowed_query(query, db.entry.c.track_id, 1000))
self._entries = query.all() self._entries = query.all()
return self._entries return self._entries
@ -95,7 +91,7 @@ class Playlist:
""" """
Return a dictionary of the playlist and its entries. Return a dictionary of the playlist and its entries.
""" """
if not self.record: if not self.exists:
return {} return {}
playlist = dict(self.record) playlist = dict(self.record)
playlist['entries'] = [dict(entry) for entry in self.entries] playlist['entries'] = [dict(entry) for entry in self.entries]
@ -103,11 +99,20 @@ class Playlist:
@property @property
def as_string(self) -> str: def as_string(self) -> str:
if not self.exists:
return ''
text = f"{self.summary}\n" text = f"{self.summary}\n"
for entry in self.entries: for entry in self.entries:
text += f" - {entry.track} {entry.artist} - {entry.title}\n" text += f" - {entry.track} {entry.artist} - {entry.title}\n"
return text return text
def _get_tracks_by_path(self, paths: List[str]) -> List:
"""
Retrieve tracks from the database that match the specified path fragments. The exceptions NoResultFound and
MultipleResultsFound are expected in the case of no matches and multiple matches, respectively.
"""
return [self.session.query(db.track).filter(db.track.c.relpath.ilike(f"%{path}%")).one() for path in paths]
def add(self, paths: List[str]) -> int: def add(self, paths: List[str]) -> int:
""" """
Add entries to the playlist. Each path should match one and only one track in the database (case-insensitive). Add entries to the playlist. Each path should match one and only one track in the database (case-insensitive).
@ -145,14 +150,26 @@ class Playlist:
self.session.commit() self.session.commit()
self._record = None self._record = None
self._entries = None self._entries = None
self._deleted = True
return plid return plid
def _get_tracks_by_path(self, paths: List[str]) -> List: def get_or_create(self, create_ok: bool = False) -> Row:
""" try:
Retrieve tracks from the database that match the specified path fragments. The exceptions NoResultFound and return self.session.query(db.playlist).filter(db.playlist.c.slug == self.slug).one()
MultipleResultsFound are expected in the case of no matches and multiple matches, respectively. except NoResultFound:
""" logging.debug(f"Could not find a playlist with slug {self.slug}.")
return [self.session.query(db.track).filter(db.track.c.relpath.ilike(f"%{path}%")).one() for path in paths] if self.deleted:
raise RuntimeError("Object has been deleted.")
if self._create_ok or create_ok:
return self.save()
def save(self) -> Row:
keys = {'slug': self.slug, 'name': self._name, 'description': self._description}
stmt = db.playlist.update(keys) if self._record else db.playlist.insert(keys)
results = self.session.execute(stmt)
self.session.commit()
logging.debug(f"Saved playlist {results.inserted_primary_key[0]} with slug {self.slug}")
return self.session.query(db.playlist).filter(db.playlist.c.id == results.inserted_primary_key[0]).one()
def create_entries(self, tracks: List[Row]) -> int: def create_entries(self, tracks: List[Row]) -> int:
""" """
@ -165,7 +182,10 @@ class Playlist:
Returns: Returns:
int: The number of tracks added. int: The number of tracks added.
""" """
maxtrack = self.session.query(func.max(db.entry.c.track)).filter_by(playlist_id=self.record.id).one()[0] or 0 maxtrack = self.session.query(func.max(db.entry.c.track)).filter_by(
playlist_id=self.record.id
).one()[0] or 0
self.session.execute( self.session.execute(
db.entry.insert(), db.entry.insert(),
[ [
@ -177,16 +197,6 @@ class Playlist:
self._entries = None self._entries = None
return len(tracks) return len(tracks)
def _create(self) -> Row:
"""
Insert a new playlist record into the database.
"""
stmt = db.playlist.insert({'slug': self.slug, 'name': self._name, 'description': self._description})
results = self.session.execute(stmt)
self.session.commit()
logging.debug(f"Created new playlist {results.inserted_primary_key[0]} with slug {self.slug}")
return self.session.query(db.playlist).filter(db.playlist.c.id == results.inserted_primary_key[0]).one()
@classmethod @classmethod
def from_row(cls, row, session): def from_row(cls, row, session):
pl = Playlist(slug=row.slug, session=session) pl = Playlist(slug=row.slug, session=session)

View File

@ -4,4 +4,4 @@ from .help import help
from .browse import browse from .browse import browse
from .stats import stats from .stats import stats
from .playlist import _playlist from .playlist import _playlist
from .create import create from .load import load

View File

@ -1,4 +1,3 @@
from rich import print
from slugify import slugify from slugify import slugify
from groove.db.manager import database_manager from groove.db.manager import database_manager
@ -36,7 +35,7 @@ class CommandPrompt(BasePrompt):
def default_completer(self, document, complete_event): def default_completer(self, document, complete_event):
def _formatter(row): def _formatter(row):
self._playlist = Playlist.from_row(row, self.manager) self._playlist = Playlist.from_row(row, self.manager.session)
return self.playlist.record.name return self.playlist.record.name
return self.manager.fuzzy_table_completer( return self.manager.fuzzy_table_completer(
db.playlist, db.playlist,
@ -48,18 +47,17 @@ class CommandPrompt(BasePrompt):
name = cmd + ' ' + ' '.join(parts) name = cmd + ' ' + ' '.join(parts)
if cmd in self.commands: if cmd in self.commands:
self.commands[cmd].start(name) self.commands[cmd].start(name)
elif not parts:
print(f"Command not understood: {cmd}")
else: else:
slug = slugify(name) slug = slugify(name)
self._playlist = Playlist( self._playlist = Playlist(
slug=slug, slug=slug,
name=name, name=name,
session=self.manager.session, session=self.manager.session,
create_if_not_exists=False create_ok=True
) )
self.commands['_playlist'].start() res = self.commands['_playlist'].start()
self._playlist = None if res is False:
return res
return True return True

View File

@ -5,12 +5,12 @@ from slugify import slugify
from groove.playlist import Playlist from groove.playlist import Playlist
class create(BasePrompt): class load(BasePrompt):
"""Create a new playlist.""" """Create a new playlist."""
@property @property
def usage(self): def usage(self):
return "create PLAYLIST_NAME" return "load PLAYLIST_NAME"
def process(self, cmd, *parts): def process(self, cmd, *parts):
name = ' '.join(parts) name = ' '.join(parts)
@ -22,6 +22,7 @@ class create(BasePrompt):
slug=slug, slug=slug,
name=name, name=name,
session=self.manager.session, session=self.manager.session,
create_if_not_exists=True create_ok=True
) )
print(self.parent.playlist.summary)
return self.parent.commands['_playlist'].start() return self.parent.commands['_playlist'].start()

View File

@ -55,7 +55,8 @@ def get_playlist(slug, db):
Retrieve a playlist and its entries by a slug. Retrieve a playlist and its entries by a slug.
""" """
logging.debug(f"Looking up playlist: {slug}...") logging.debug(f"Looking up playlist: {slug}...")
playlist = Playlist(slug=slug, session=db) playlist = Playlist(slug=slug, session=db, create_ok=False)
print(playlist.record)
if not playlist.exists: if not playlist.exists:
return HTTPResponse(status=404, body="Not found") return HTTPResponse(status=404, body="Not found")
response = json.dumps(playlist.as_dict) response = json.dumps(playlist.as_dict)

View File

@ -4,8 +4,8 @@ from groove import playlist
def test_create(db): def test_create(db):
pl = playlist.Playlist(slug='test-create-playlist', session=db, create_if_not_exists=True) pl = playlist.Playlist(slug='test-create-playlist', session=db, create_ok=True)
assert pl.exists assert pl.record.id
@pytest.mark.parametrize('tracks', [ @pytest.mark.parametrize('tracks', [
@ -13,7 +13,7 @@ def test_create(db):
('01 Guns Blazing', '02 UNKLE'), ('01 Guns Blazing', '02 UNKLE'),
]) ])
def test_add(db, tracks): def test_add(db, tracks):
pl = playlist.Playlist(slug='test-create-playlist', session=db, create_if_not_exists=True) pl = playlist.Playlist(slug='test-create-playlist', session=db)
count = pl.add(tracks) count = pl.add(tracks)
assert count == len(tracks) assert count == len(tracks)
@ -32,13 +32,23 @@ def test_delete(db):
pl = playlist.Playlist(slug='playlist-one', session=db) pl = playlist.Playlist(slug='playlist-one', session=db)
expected = pl.record.id expected = pl.record.id
assert pl.delete() == expected assert pl.delete() == expected
assert not pl.as_dict assert not pl.exists
assert pl.deleted
def test_delete_playlist_not_exist(db): def test_delete_playlist_not_exist(db):
pl = playlist.Playlist(slug='playlist-doesnt-exist', session=db) pl = playlist.Playlist(slug='playlist-doesnt-exist', session=db, create_ok=False)
assert not pl.delete() assert not pl.delete()
assert not pl.as_dict assert not pl.exists
assert not pl.deleted
def test_cannot_create_after_delete(db):
pl = playlist.Playlist(slug='playlist-one', session=db)
pl.delete()
with pytest.raises(RuntimeError):
assert pl.record
assert not pl.exists
def test_entries(db): def test_entries(db):
@ -46,3 +56,18 @@ def test_entries(db):
# assert twice for branch coverage of cached values # assert twice for branch coverage of cached values
assert pl.entries assert pl.entries
assert pl.entries assert pl.entries
def test_playlist_not_exist_formatted(db):
pl = playlist.Playlist(slug='fnord', session=db, create_ok=False)
assert not repr(pl)
assert not pl.as_dict
def test_playlist_formatted(db):
pl = playlist.Playlist(slug='playlist-one', session=db)
assert repr(pl)
assert pl.as_string
assert pl.as_dict

View File

@ -13,7 +13,7 @@ def cmd_prompt(in_memory_engine, db):
def response_factory(responses): def response_factory(responses):
return MagicMock(side_effect=responses + ['']) return MagicMock(side_effect=responses + ([''] * 10))
@pytest.mark.parametrize('inputs, expected', [ @pytest.mark.parametrize('inputs, expected', [
@ -58,3 +58,14 @@ def test_help(monkeypatch, capsys, cmd_prompt, inputs, expected):
output = capsys.readouterr() output = capsys.readouterr()
for txt in expected: for txt in expected:
assert txt in output.out assert txt in output.out
@pytest.mark.parametrize('inputs, expected', [
('load A New Playlist', 'a-new-playlist'),
('new playlist', 'new-playlist'),
('load', '')
])
def test_load(monkeypatch, caplog, cmd_prompt, inputs, expected):
monkeypatch.setattr('groove.shell.base.prompt', response_factory([inputs]))
cmd_prompt.start()
assert expected in caplog.text