diff --git a/groove/editor.py b/groove/editor.py new file mode 100644 index 0000000..93788ce --- /dev/null +++ b/groove/editor.py @@ -0,0 +1,63 @@ +import logging +import os +import subprocess +import yaml + +from tempfile import NamedTemporaryFile + + +EDITOR_TEMPLATE = """ +{name}: + description: {description} + entries: + {entries} + +# ------------------------------------------------------------------------------ +# +# Groove On Demand Playlist Editor +# +# This file is in YAML format. Blank lines and lines beginning with # are +# ignored. Here's a complete example: +# +# My Awesome Jams, Vol. 2: +# description: | +# These jams are totally awesome, yo. +# Totally. +# +# yo. +# entries: +# - Beastie Boys - Help Me, Ronda +# - Bob and Doug McKenzie - Messiah (Hallelujah Eh) +# +""" + + +class PlaylistEditor: + """ + A custom ConfigParser that only supports specific headers and ignores all other square brackets. + """ + def __init__(self): + self._path = None + + @property + def path(self): + if not self._path: + self._path = NamedTemporaryFile(prefix='groove_on_demand-', delete=False) + return self._path + + def edit(self, playlist): + with self.path as fh: + fh.write(playlist.as_yaml.encode()) + subprocess.check_call([os.environ['EDITOR'], self.path.name]) + edits = self.read() + self.cleanup() + return edits + + def read(self): + with open(self.path.name, 'rb') as fh: + return yaml.safe_load(fh) + + def cleanup(self): + if self._path: + os.unlink(self._path.name) + self._path = None diff --git a/groove/exceptions.py b/groove/exceptions.py index ff514b6..1cff7df 100644 --- a/groove/exceptions.py +++ b/groove/exceptions.py @@ -21,3 +21,9 @@ class ConfigurationError(Exception): """ An error was discovered with the Groove on Demand configuration. """ + + +class PlaylistImportError(Exception): + """ + An error was discovered in a playlist template. + """ diff --git a/groove/playlist.py b/groove/playlist.py index 6114a09..dddc409 100644 --- a/groove/playlist.py +++ b/groove/playlist.py @@ -1,12 +1,17 @@ +import logging +import os + +from typing import Union, List + from groove import db +from groove.editor import PlaylistEditor, EDITOR_TEMPLATE +from groove.exceptions import PlaylistImportError + +from slugify import slugify from sqlalchemy import func, delete from sqlalchemy.orm.session import Session from sqlalchemy.engine.row import Row from sqlalchemy.exc import NoResultFound, MultipleResultsFound -from typing import Union, List - -import logging -import os class Playlist: @@ -27,6 +32,7 @@ class Playlist: self._record = None self._create_ok = create_ok self._deleted = False + self._editor = PlaylistEditor() @property def deleted(self) -> bool: @@ -43,6 +49,18 @@ class Playlist: return False return True + @property + def editor(self): + return self._editor + + @property + def name(self): + return self._name + + @property + def description(self): + return self._description + @property def summary(self): return ' :: '.join([ @@ -109,6 +127,14 @@ class Playlist: text += f" - {entry.track} {entry.artist} - {entry.title}\n" return text + @property + def as_yaml(self) -> str: + template_vars = self.as_dict + template_vars['entries'] = '' + for entry in self.entries: + template_vars['entries'] += f" - {entry.artist} - {entry.title}\n" + return EDITOR_TEMPLATE.format(**template_vars) + def _get_tracks_by_path(self, paths: List[str]) -> List: """ Retrieve tracks from the database that match the specified path fragments. The exceptions NoResultFound and @@ -116,6 +142,20 @@ class Playlist: """ return [self.session.query(db.track).filter(db.track.c.relpath.ilike(f"%{path}%")).one() for path in paths] + def edit(self): + edits = self.editor.edit(self) + if not edits: + return + new = Playlist.from_yaml(edits, self.session) + if new == self: + logging.debug("No changes detected.") + return + logging.debug(f"Updating {self.slug} with new edits.") + self._slug = new.slug + self._name = new.name.strip() + self._description = new.description.strip() + self._record = self.save() + def add(self, paths: List[str]) -> int: """ Add entries to the playlist. Each path should match one and only one track in the database (case-insensitive). @@ -158,7 +198,7 @@ class Playlist: def get_or_create(self, create_ok: bool = False) -> Row: try: - return self.session.query(db.playlist).filter(db.playlist.c.slug == self.slug).one() + return self._get() except NoResultFound: logging.debug(f"Could not find a playlist with slug {self.slug}.") if self.deleted: @@ -166,18 +206,45 @@ class Playlist: if self._create_ok or create_ok: return self.save() + def _get(self): + return self.session.query(db.playlist).filter( + db.playlist.c.slug == self.slug + ).one() + + def _insert(self, values): + stmt = db.playlist.insert(values) + results = self.session.execute(stmt) + self.session.commit() + logging.debug(f"Saved playlist with slug {self.slug}") + return self.session.query(db.playlist).filter( + db.playlist.c.id == results.inserted_primary_key[0] + ).one() + + def _update(self, values): + stmt = db.playlist.update().where( + db.playlist.c.id == self._record.id + ).values(values) + self.session.execute(stmt) + self.session.commit() + return self.session.query(db.playlist).filter( + db.playlist.c.id == self._record.id + ).one() + + def save(self) -> Row: + values = { + 'slug': self.slug, + 'name': self.name, + 'description': self.description + } + logging.debug(f"Saving values: {values}") + obj = self._update(values) if self._record else self._insert(values) + logging.debug(f"Saved playlist {obj.id} with slug {obj.slug}") + return obj + def load(self): self.get_or_create(create_ok=False) return self - def save(self) -> Row: - keys = {'slug': self.slug, 'name': self._name, 'description': self._description} - stmt = db.playlist.update(keys) if self._record else db.playlist.insert(keys) - results = self.session.execute(stmt) - self.session.commit() - logging.debug(f"Saved playlist {results.inserted_primary_key[0]} with slug {self.slug}") - return self.session.query(db.playlist).filter(db.playlist.c.id == results.inserted_primary_key[0]).one() - def create_entries(self, tracks: List[Row]) -> int: """ Append a list of tracks to a playlist by populating the entries table with records referencing the playlist and @@ -210,5 +277,26 @@ class Playlist: pl._record = row return pl + @classmethod + def from_yaml(cls, source, session): + try: + name = list(source.keys())[0].strip() + description = (source[name]['description'] or '').strip() + return Playlist( + slug=slugify(name), + name=name, + description=description, + session=session, + ) + except (IndexError, KeyError): + PlaylistImportError("The specified source was not a valid playlist.") + + def __eq__(self, obj): + for key in ('slug', 'name', 'description'): + if getattr(obj, key) != getattr(self, key): + logging.debug(f"{key}: {getattr(obj, key)} != {getattr(self, key)}") + return False + return True + def __repr__(self): return self.as_string diff --git a/groove/shell/interactive_shell.py b/groove/shell/interactive_shell.py index e10103c..4129764 100644 --- a/groove/shell/interactive_shell.py +++ b/groove/shell/interactive_shell.py @@ -48,9 +48,8 @@ class CommandPrompt(BasePrompt): if cmd in self.commands: self.commands[cmd].start(name) else: - slug = slugify(name) self._playlist = Playlist( - slug=slug, + slug=slugify(name), name=name, session=self.manager.session, create_ok=True diff --git a/groove/shell/playlist.py b/groove/shell/playlist.py index 1907e96..11e495d 100644 --- a/groove/shell/playlist.py +++ b/groove/shell/playlist.py @@ -30,6 +30,7 @@ class _playlist(BasePrompt): 'show': self.show, 'delete': self.delete, 'add': self.add, + 'edit': self.edit, } return self._commands @@ -45,11 +46,15 @@ class _playlist(BasePrompt): print(self.parent.playlist) return True + def edit(self, parts): + self.parent.playlist.edit() + return True + def add(self, parts): print("Add tracks one at a time by title. ENTER to finish.") while True: text = prompt( - ' ? ', + ' ?', completer=self.manager.fuzzy_table_completer( db.track, db.track.c.relpath, diff --git a/pyproject.toml b/pyproject.toml index 4d79a40..2811741 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ rich = "^12.6.0" bottle-sqlalchemy = "^0.4.3" music-tag = "^0.4.3" prompt-toolkit = "^3.0.33" +PyYAML = "^6.0" [tool.poetry.dev-dependencies] pytest = "^7.2.0"