add CLI helpers, switch to bottle-sqlalchemy

This commit is contained in:
evilchili 2022-11-25 15:40:24 -08:00
parent 267af75cb4
commit 7ca1f69100
7 changed files with 159 additions and 21 deletions

View File

@ -6,10 +6,12 @@ from pathlib import Path
from typing import Optional, List
from dotenv import load_dotenv
from slugify import slugify
from pprint import pprint
from rich import print
import rich.table
from groove import webserver
from groove.playlist import Playlist
from groove import db
from groove.db.manager import database_manager
from groove.db.scanner import media_scanner
@ -26,6 +28,25 @@ def initialize():
level=logging.DEBUG if debug else logging.INFO)
@playlist_app.command()
def list():
"""
List all Playlists
"""
initialize()
with database_manager() as manager:
query = manager.session.query(db.playlist)
table = rich.table.Table(
*[rich.table.Column(k.name.title()) for k in db.playlist.columns]
)
for row in db.windowed_query(query, db.playlist.c.id, 1000):
columns = tuple(Playlist.from_row(row, manager.session).record)[0:-1]
table.add_row(*[str(col) for col in columns])
print()
print(table)
print()
@playlist_app.command()
def delete(
name: str = typer.Argument(
@ -42,24 +63,42 @@ def delete(
"""
initialize()
with database_manager() as manager:
pl = Playlist(slug=slugify(name), connection=manager.session, create_if_not_exists=False)
pl = Playlist(slug=slugify(name), session=manager.session, create_if_not_exists=False)
if not pl.exists:
logging.info(f"No playlist named '{name}' could be found.")
return
if no_dry_run is False:
print(f"Would delete playlist {pl.record.id}, which contains {len(pl.entries)} tracks.")
entry_count = 0 if not pl.entries else len([e for e in pl.entries])
print(f"Would delete playlist {pl.record.id}, which contains {entry_count} tracks.")
return
deleted_playlist = pl.delete()
print(f"Playlist {deleted_playlist} deleted.")
@playlist_app.command()
def get(
slug: str = typer.Argument(
...,
help="The slug of the playlist to retrieve."
),
):
initialize()
with database_manager() as manager:
pl = Playlist(slug=slug, session=manager.session)
print(pl.as_dict)
@playlist_app.command()
def add(
name: str = typer.Argument(
...,
help="The name of the playlist to create."
),
description: str = typer.Option(
None,
help="The description of the playlist."
),
tracks: List[str] = typer.Option(
None,
help="A list of tracks to add to the playlist."
@ -78,14 +117,19 @@ def add(
"""
initialize()
with database_manager() as manager:
pl = Playlist(slug=slugify(name), connection=manager.session, create_if_not_exists=True)
pl = Playlist(
slug=slugify(name),
session=manager.session,
name=name,
description=description,
create_if_not_exists=True)
if pl.exists:
if not exists_ok:
raise RuntimeError(f"Playlist with slug {pl.slug} already exists!")
logging.debug(pl.as_dict)
if tracks:
pl.add(tracks)
pprint(pl.as_dict)
print(pl.as_dict)
@app.command()

View File

@ -1 +1,2 @@
from groove.db.schema import metadata, track, playlist, entry
from groove.db.helpers import windowed_query

20
groove/db/helpers.py Normal file
View File

@ -0,0 +1,20 @@
def windowed_query(query, column, window_size):
""""
Break a Query into chunks on a given column.
see: https://github.com/sqlalchemy/sqlalchemy/wiki/RangeQuery-and-WindowedRangeQuery
"""
single_entity = query.is_single_entity
query = query.add_columns(column).order_by(column)
last_id = None
while True:
sub_query = query
if last_id is not None:
sub_query = sub_query.filter(column > last_id)
chunk = sub_query.limit(window_size).all()
if not chunk:
break
last_id = chunk[-1][-1]
for row in chunk:
yield row

43
groove/db/manager.py Normal file
View File

@ -0,0 +1,43 @@
import os
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from . import metadata
class DatabaseManager:
"""
A context manager for working with sqllite database.
"""
def __init__(self):
self._engine = None
self._session = None
@property
def engine(self):
if not self._engine:
self._engine = create_engine(f"sqlite:///{os.environ.get('DATABASE_PATH')}", future=True)
return self._engine
@property
def session(self):
if not self._session:
Session = sessionmaker(bind=self.engine, future=True)
self._session = Session()
return self._session
def import_from_filesystem(self):
pass
def __enter__(self):
metadata.create_all(bind=self.engine)
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.session:
self.session.close()
database_manager = DatabaseManager

View File

@ -11,9 +11,16 @@ class Playlist:
"""
CRUD operations and convenience methods for playlists.
"""
def __init__(self, slug: str, session: Session, create_if_not_exists: bool = False):
def __init__(self,
slug: str,
session: Session,
name: str = '',
description: str = '',
create_if_not_exists: bool = False):
self._session = session
self._slug = slug
self._name = name
self._description = description
self._record = None
self._entries = None
self._create_if_not_exists = create_if_not_exists
@ -44,7 +51,7 @@ class Playlist:
logging.debug(f"Retrieved playlist {self._record.id}")
except NoResultFound:
pass
if self._create_if_not_exists:
if not self._record and self._create_if_not_exists:
self._record = self._create()
if not self._record: # pragma: no cover
raise RuntimeError(f"Tried to create a playlist but couldn't read it back using slug {self.slug}")
@ -56,7 +63,7 @@ class Playlist:
Cache the list of entries on this playlist and return it.
"""
if not self._entries and self.record:
self._entries = self.session.query(
query = self.session.query(
db.entry,
db.track
).filter(
@ -65,7 +72,8 @@ class Playlist:
db.entry.c.playlist_id == db.playlist.c.id
).filter(
db.entry.c.track_id == db.track.c.id
).all()
)
self._entries = db.windowed_query(query, db.entry.c.track_id, 1000)
return self._entries
@property
@ -90,6 +98,7 @@ class Playlist:
Returns:
int: The number of tracks added.
"""
logging.debug(f"Attempting to add tracks matching: {paths}")
try:
return self._create_entries(self._get_tracks_by_path(paths))
except NoResultFound:
@ -144,14 +153,24 @@ class Playlist:
]
)
self.session.commit()
self._entries = None
return len(tracks)
def _create(self) -> Row:
"""
Insert a new playlist record into the database.
"""
stmt = db.playlist.insert({'slug': self.slug})
stmt = db.playlist.insert({'slug': self.slug, 'name': self._name, 'description': self._description})
results = self.session.execute(stmt)
self.session.commit()
logging.debug(f"Created new playlist {results.inserted_primary_key[0]} with slug {self.slug}")
return self.session.query(db.playlist).filter(db.playlist.c.id == results.inserted_primary_key[0]).one()
@classmethod
def from_row(cls, row, session):
pl = Playlist(slug=row.slug, session=session)
pl._record = row
return pl
def __repr__(self):
return str(self.as_dict)

View File

@ -4,9 +4,11 @@ import os
import bottle
from bottle import HTTPResponse
from bottle.ext import sqlite
from bottle.ext import sqlalchemy
from groove.auth import is_authenticated
from groove.db.manager import database_manager
from groove.db import metadata
from groove.playlist import Playlist
server = bottle.Bottle()
@ -17,15 +19,23 @@ def start(host: str, port: int, debug: bool) -> None: # pragma: no cover
Start the Bottle app.
"""
logging.debug(f"Configuring sqllite using {os.environ.get('DATABASE_PATH')}")
server.install(sqlite.Plugin(dbfile=os.environ.get('DATABASE_PATH')))
logging.debug(f"Configuring webserver with host={host}, port={port}, debug={debug}")
server.run(
host=os.getenv('HOST', host),
port=os.getenv('PORT', port),
debug=debug,
server='paste',
quiet=True
)
with database_manager() as manager:
server.install(sqlalchemy.Plugin(
manager.engine,
metadata,
keyword='db',
create=True,
commit=True,
))
logging.debug(f"Configuring webserver with host={host}, port={port}, debug={debug}")
server.run(
host=os.getenv('HOST', host),
port=os.getenv('PORT', port),
debug=debug,
server='paste',
quiet=True
)
@server.route('/')

View File

@ -14,9 +14,10 @@ bottle = "^0.12.23"
typer = "^0.7.0"
python-dotenv = "^0.21.0"
Paste = "^3.5.2"
bottle-sqlite = "^0.2.0"
SQLAlchemy = "^1.4.44"
python-slugify = "^7.0.0"
rich = "^12.6.0"
bottle-sqlalchemy = "^0.4.3"
[tool.poetry.dev-dependencies]
pytest = "^7.2.0"