refactor scanner, add progress bar

This commit is contained in:
evilchili 2022-12-21 15:17:13 -08:00
parent fe671194a0
commit 7c82226ff9
13 changed files with 339 additions and 173 deletions

View File

@ -5,99 +5,45 @@ import typer
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from slugify import slugify
from rich import print
import rich.table
from rich.logging import RichHandler
from groove.shell import interactive_shell
from groove.playlist import Playlist
from groove import db
from groove.db.manager import database_manager
from groove.db.scanner import media_scanner
from groove.webserver import webserver
playlist_app = typer.Typer()
app = typer.Typer()
app.add_typer(playlist_app, name='playlist', help='Manage playlists.')
def initialize():
load_dotenv()
debug = os.getenv('DEBUG', None)
logging.basicConfig(format='%(asctime)s - %(message)s',
level=logging.DEBUG if debug else logging.INFO)
logging.basicConfig(
format='%(message)s',
level=logging.DEBUG if debug else logging.INFO,
handlers=[
RichHandler(rich_tracebacks=True, tracebacks_suppress=[typer])
]
)
logging.getLogger('asyncio').setLevel(logging.ERROR)
@playlist_app.command()
@app.command()
def list():
"""
List all Playlists
"""
initialize()
with database_manager() as manager:
query = manager.session.query(db.playlist)
table = rich.table.Table(
*[rich.table.Column(k.name.title()) for k in db.playlist.columns]
)
for row in db.windowed_query(query, db.playlist.c.id, 1000):
columns = tuple(Playlist.from_row(row, manager.session).record)[0:-1]
table.add_row(*[str(col) for col in columns])
print()
print(table)
print()
@playlist_app.command()
def delete(
name: str = typer.Argument(
...,
help="The name of the playlist to create."
),
no_dry_run: bool = typer.Option(
False,
help="If True, actually delete the playlist, Otherwise, show what would be deleted."
)
):
"""
Delete a playlist
"""
initialize()
with database_manager() as manager:
pl = Playlist(slug=slugify(name), session=manager.session, create_if_not_exists=False)
if not pl.exists:
logging.info(f"No playlist named '{name}' could be found.")
return
if no_dry_run is False:
entry_count = 0 if not pl.entries else len([e for e in pl.entries])
print(f"Would delete playlist {pl.record.id}, which contains {entry_count} tracks.")
return
deleted_playlist = pl.delete()
print(f"Playlist {deleted_playlist} deleted.")
@playlist_app.command()
def get(
slug: str = typer.Argument(
...,
help="The slug of the playlist to retrieve."
),
):
initialize()
with database_manager() as manager:
pl = Playlist(slug=slug, session=manager.session)
print(pl.as_dict)
shell = interactive_shell.InteractiveShell(manager)
shell.list(None)
@app.command()
def scan(
root: Optional[Path] = typer.Option(
None,
help="The path to the root of your media."
),
debug: bool = typer.Option(
False,
help='Enable debugging output'
path: Optional[Path] = typer.Option(
'',
help="A path to scan, relative to your MEDIA_ROOT. "
"If not specified, the entire MEDIA_ROOT will be scanned."
),
):
"""
@ -105,9 +51,9 @@ def scan(
"""
initialize()
with database_manager() as manager:
scanner = media_scanner(root=root, db=manager.session)
count = scanner.scan()
logging.info(f"Imported {count} new tracks.")
shell = interactive_shell.InteractiveShell(manager)
shell.console.print("Starting the Groove on Demand scanner...")
shell.scan([str(path)])
@app.command()
@ -135,7 +81,6 @@ def server(
Start the Groove on Demand playlsit server.
"""
initialize()
print("Starting Groove On Demand...")
with database_manager() as manager:
manager.import_from_filesystem()
webserver.start(host=host, port=port, debug=debug)

View File

@ -3,11 +3,14 @@ import os
from configparser import ConfigParser
from pathlib import Path
from textwrap import dedent
from typing import Union, List
import rich.repr
from rich.console import Console as _Console
from rich.markdown import Markdown
from rich.theme import Theme
from rich.table import Table
from rich.table import Table, Column
from prompt_toolkit import prompt as _toolkit_prompt
from prompt_toolkit.formatted_text import ANSI
@ -23,7 +26,13 @@ BASE_STYLE = {
}
def console_theme(theme_name=None):
def console_theme(theme_name: Union[str, None] = None) -> dict:
"""
Return a console theme as a dictionary.
Args:
theme_name (str):
"""
cfg = ConfigParser()
cfg.read_dict({'styles': BASE_STYLE})
cfg.read(theme(
@ -32,18 +41,54 @@ def console_theme(theme_name=None):
return cfg['styles']
@rich.repr.auto
class Console(_Console):
"""
SYNOPSIS
Subclasses a rich.console.Console to provide an instance with a
reconfigured themes, and convenience methods and attributes.
USAGE
Console([ARGS])
ARGS
theme The name of a theme to load. Defaults to DEFAULT_THEME.
EXAMPLES
Console().print("Can I kick it?")
>>> Can I kick it?
INSTANCE ATTRIBUTES
theme The current theme
"""
def __init__(self, *args, **kwargs):
self._console_theme = console_theme(kwargs.get('theme', None))
self._overflow = 'ellipsis'
kwargs['theme'] = Theme(self._console_theme, inherit=False)
super().__init__(*args, **kwargs)
@property
def theme(self):
def theme(self) -> Theme:
return self._console_theme
def prompt(self, lines, **kwargs):
def prompt(self, lines: List, **kwargs) -> str:
"""
Print a list of lines, using the final line as a prompt.
Example:
Console().prompt(["Can I kick it?", "[Y/n] ")
>>> Can I kick it?
[Y/n]>
"""
for line in lines[:-1]:
super().print(line)
with self.capture() as capture:
@ -51,17 +96,29 @@ class Console(_Console):
rendered = ANSI(capture.get())
return _toolkit_prompt(rendered, **kwargs)
def mdprint(self, txt, **kwargs):
def mdprint(self, txt: str, **kwargs) -> None:
"""
Like print(), but support markdown. Text will be dedented.
"""
self.print(Markdown(dedent(txt), justify='left'), **kwargs)
def print(self, txt, **kwargs):
super().print(txt, **kwargs)
def print(self, txt: str, **kwargs) -> None:
"""
Print text to the console, possibly truncated with an ellipsis.
"""
super().print(txt, overflow=self._overflow, **kwargs)
def error(self, txt, **kwargs):
super().print(dedent(txt), style='error')
def error(self, txt: str, **kwargs) -> None:
"""
Print text to the console with the current theme's error style applied.
"""
self.print(dedent(txt), style='error')
def table(self, *cols, **params):
if os.environ['CONSOLE_THEMES']:
def table(self, *cols: List[Column], **params) -> None:
"""
Print a rich table to the console with theme elements and styles applied.
parameters and keyword arguments are passed to rich.table.Table.
"""
background_style = f"on {self.theme['background']}"
params.update(
header_style=background_style,

View File

@ -1,9 +1,9 @@
import os
from prompt_toolkit.completion import Completion, FuzzyCompleter
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
import groove.path
from . import metadata
@ -37,7 +37,8 @@ class DatabaseManager:
@property
def engine(self):
if not self._engine:
self._engine = create_engine(f"sqlite:///{os.environ.get('DATABASE_PATH')}?check_same_thread=False", future=True)
path = groove.path.database()
self._engine = create_engine(f"sqlite:///{path}?check_same_thread=False", future=True)
return self._engine
@property

View File

@ -1,48 +1,111 @@
import asyncio
import logging
import os
import music_tag
from itertools import chain
from pathlib import Path
from typing import Callable, Union, Iterable
import music_tag
import rich.repr
from rich.console import Console
from rich.progress import (
Progress,
TextColumn,
BarColumn,
SpinnerColumn,
TimeRemainingColumn
)
from sqlalchemy import func
from sqlalchemy.exc import NoResultFound
import groove.db
import groove.path
from groove.exceptions import InvalidPathError
@rich.repr.auto(angular=True)
class MediaScanner:
"""
Scan a directory structure containing audio files and import them into the database.
SYNOPSIS
Scan a directory structure containing audio files and import track entries
into the Groove on Demand database. Existing tracks will be ignored.
USAGE
MediaScanner(db=DB, [ARGS])
ARGS
db An sqlalchemy databse session
console A rich console instance
glob A pattern to search for. Defaults to MEDIA_GLOB. Multiple
patterns can be specifed as a comma-separated-list.
path The path to scan. Defaults to MEDIA_ROOT.
root The media root, as specified by MEDIA_ROOT
EXAMPLES
MediaScanner(db=DB, path='Kid Koala', glob='*.mp3').scan()
>>> 15
INSTANCE ATTRIBUTES
db The databse session
console The rich console instance
glob The globs to search for
path The path to be scanned
root The media root
"""
def __init__(self, root: Union[Path, None], db: Callable, glob: Union[str, None] = None) -> None:
def __init__(
self,
db: Callable,
path: Union[Path, None] = None,
glob: Union[str, None] = None,
console: Union[Console, None] = None,
) -> None:
self._db = db
self._glob = tuple((glob or os.environ.get('MEDIA_GLOB')).split(','))
self._root = root or groove.path.media_root()
logging.debug(f"Configured media scanner for root: {self._root}")
self._root = groove.path.media_root()
self._console = console or Console()
self._scanned = 0
self._imported = 0
self._total = 0
self._path = self._configure_path(path)
@property
def db(self) -> Callable:
return self._db
@property
def console(self) -> Console:
return self._console
@property
def root(self) -> Path:
return self._root
@property
def path(self) -> Path:
return self._path
@property
def glob(self) -> tuple:
return self._glob
def find_sources(self, pattern):
return self.root.rglob(pattern) # pragma: no cover
def import_tracks(self, sources: Iterable) -> None:
async def _do_import():
logging.debug("Scanning filesystem (this may take a minute)...")
for path in sources:
asyncio.create_task(self._import_one_track(path))
asyncio.run(_do_import())
self.db.commit()
def _configure_path(self, path):
if not path: # pragma: no cover
return self._root
fullpath = Path(self._root) / Path(path)
if not (fullpath.exists() and fullpath.is_dir()):
raise InvalidPathError( # pragma: no cover
f"[b]{fullpath}[/b] does not exist or is not a directory."
)
return fullpath
def _get_tags(self, path): # pragma: no cover
tags = music_tag.load_file(path)
@ -51,12 +114,83 @@ class MediaScanner:
'title': str(tags['title']),
}
async def _import_one_track(self, path):
tags = self._get_tags(path)
tags['relpath'] = str(path.relative_to(self.root))
stmt = groove.db.track.insert(tags).prefix_with('OR IGNORE')
logging.debug(f"{tags['artist']} - {tags['title']}")
self.db.execute(stmt)
def find_sources(self, pattern):
"""
Recursively search the instance path for files matching the pattern.
"""
entrypoint = self._path if self._path else self._root
for path in entrypoint.rglob(pattern): # pragma: no cover
if not path.is_dir():
yield path
def import_tracks(self, sources: Iterable) -> None:
"""
Step through the specified source files and schedule async tasks to
import them, reporting progress via a rich progress bar.
"""
async def _do_import(progress, scanner):
tasks = set()
for path in sources:
self._total += 1
progress.update(scanner, total=self._total)
tasks.add(asyncio.create_task(
self._import_one_track(path, progress, scanner)))
progress.start_task(scanner)
progress = Progress(
TimeRemainingColumn(compact=True, elapsed_when_finished=True),
BarColumn(bar_width=15),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%", justify="left"),
TextColumn("[dim]|"),
TextColumn("[title]{task.total:-6d}[/title] [b]total", justify="right"),
TextColumn("[dim]|"),
TextColumn("[title]{task.fields[imported]:-6d}[/title] [b]new", justify="right"),
TextColumn("[dim]|"),
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=self.console,
)
with progress:
scanner = progress.add_task(
f"[bright]Scanning [link]{self.path}[/link] (this may take some time)...",
imported=0,
total=0,
start=False
)
asyncio.run(_do_import(progress, scanner))
progress.update(
scanner,
completed=self._total,
description=f"[bright]Scan of [link]{self.path}[/link] complete!",
)
async def _import_one_track(self, path, progress, scanner):
"""
Import a single audo file into the databse, unless it already exists.
"""
self._scanned += 1
relpath = str(path.relative_to(self.root))
try:
self.db.query(groove.db.track).filter(
groove.db.track.c.relpath == relpath).one()
return
except NoResultFound:
pass
columns = self._get_tags(path)
columns['relpath'] = relpath
logging.debug(f"Importing: {columns}")
self.db.execute(groove.db.track.insert(columns))
self.db.commit()
self._imported += 1
progress.update(
scanner,
imported=self._imported,
completed=self._scanned,
description=f"[bright]Imported [artist]{columns['artist']}[/artist]: [title]{columns['title']}[/title]",
)
def scan(self) -> int:
"""
@ -64,12 +198,9 @@ class MediaScanner:
found. Existing entries will be ignored.
"""
count = self.db.query(func.count(groove.db.track.c.relpath)).scalar()
logging.debug(f"Track table currently contains {count} entries.")
for pattern in self.glob:
self.import_tracks(self.find_sources(pattern))
combined_sources = chain.from_iterable(
self.find_sources(pattern) for pattern in self.glob
)
self.import_tracks(combined_sources)
newcount = self.db.query(func.count(groove.db.track.c.relpath)).scalar() - count
logging.debug(f"Inserted {newcount} new tracks so far this run...")
return newcount
media_scanner = MediaScanner

View File

@ -33,3 +33,9 @@ class TrackNotFoundError(Exception):
"""
The specified track doesn't exist.
"""
class InvalidPathError(Exception):
"""
The specified path was invalid -- either it was not the expected type or wasn't accessible.
"""

View File

@ -102,4 +102,13 @@ def available_themes():
def database():
return root() / Path(os.environ.get('DATABASE_PATH', 'groove_on_demand.db'))
path = os.environ.get('DATABASE_PATH', None)
if not path:
path = root()
else: # pragma: no cover
path = Path(path).expanduser()
if not path.exists() or not path.is_dir():
raise ConfigurationError(
"DATABASE_PATH doesn't exist or isn't a directory.\n\n{_setup_hint}"
)
return path / Path('groove_on_demand.db')

View File

@ -95,7 +95,7 @@ class BasePrompt(Completer):
def autocomplete_values(self):
return self._autocomplete_values
def get_completions(self, document, complete_event):
def get_completions(self, document, complete_event): # pragma: no cover
word = document.get_word_before_cursor()
found = False
for value in self.autocomplete_values:

View File

@ -1,11 +1,13 @@
from slugify import slugify
from groove.db.manager import database_manager
from groove.shell.base import BasePrompt, command, register_command
from groove.db.scanner import MediaScanner
from groove.shell.base import BasePrompt, command
from groove.exceptions import InvalidPathError
from groove import db
from groove.playlist import Playlist
from rich.table import Table, Column
from rich.table import Column
from rich import box
from sqlalchemy import func
@ -60,6 +62,34 @@ class InteractiveShell(BasePrompt):
name = cmd + ' ' + ' '.join(parts)
self.load([name.strip()])
@command("""
[title]SCANNING YOUR MEDIA[/title]
Use the [b]scan[/b] function to scan your media root for new, changed, and
deleted audio files. This process may take some time if you have a large
library!
Instead of scanning the entire MEDIA_ROOT, you can specify a PATH, which
must be a subdirectory of your MEDIA_ROOT. This is useful to import that
new new.
[title]USAGE[/title]
[link]> scan [PATH][/link]
""")
def scan(self, parts):
"""
Scan your MEDIA_ROOT for changes.
"""
path = ' '.join(parts) if parts else None
try:
scanner = MediaScanner(path=path, db=self.manager.session, console=self.console)
except InvalidPathError as e:
self.console.error(str(e))
return True
scanner.scan()
@command("""
[title]LISTS FOR THE LIST LOVER[/title]
@ -75,7 +105,6 @@ class InteractiveShell(BasePrompt):
"""
List all playlists.
"""
count = self.manager.session.query(func.count(db.playlist.c.id)).scalar()
table = self.console.table(
Column('#', justify='right', width=4),
Column('Name'),
@ -182,6 +211,7 @@ class InteractiveShell(BasePrompt):
super().help(parts)
return True
def start(): # pragma: no cover
with database_manager() as manager:
InteractiveShell(manager).start()

View File

@ -19,6 +19,7 @@ def env():
load_dotenv(Path('test/fixtures/env'))
os.environ['GROOVE_ON_DEMAND_ROOT'] = str(root)
os.environ['MEDIA_ROOT'] = str(root / Path('media'))
os.environ['DATABASE_PATH'] = ''
return os.environ

3
test/fixtures/env vendored
View File

@ -2,9 +2,6 @@
GROOVE_ON_DEMAND_ROOT=.
MEDIA_ROOT=.
# where to store the database
DATABASE_PATH=test.db
# Admin user credentials
USERNAME=test_username
PASSWORD=test_password

View File

@ -14,10 +14,12 @@ def test_missing_media_root(monkeypatch, root):
with pytest.raises(ConfigurationError):
path.media_root()
def test_static(monkeypatch):
assert path.static('foo')
assert path.static('foo', theme=themes.load_theme('default_theme'))
@pytest.mark.parametrize('root', ['/dev/null/missing', None])
def test_missing_theme_root(monkeypatch, root):
broken_env = {k: v for (k, v) in os.environ.items()}
@ -32,5 +34,9 @@ def test_theme_no_path():
path.theme('nope')
def test_database_default(env):
assert path.database().relative_to(path.root())
def test_database(env):
assert env['DATABASE_PATH'] in path.database().name
assert env['DATABASE_PATH'] in str(path.database().absolute())

View File

@ -8,55 +8,23 @@ import groove.exceptions
from groove.db import scanner, track
fixture_tracks = [
"/test/Spookey Ruben/Modes of Transportation, Volume 1/Spookey Ruben - Modes of Transportation, Volume 1 - 01 Terra Magnifica.flac",
"/test/Spookey Ruben/Modes of Transportation, Volume 1/Spookey Ruben - Modes of Transportation, Volume 1 - 02 These Days Are Old.flac",
"/test/Spookey Ruben/Modes of Transportation, Volume 1/Spookey Ruben - Modes of Transportation, Volume 1 - 03 Crystal Cradle.flac",
"/test/Spookey Ruben/Modes of Transportation, Volume 1/Spookey Ruben - Modes of Transportation, Volume 1 - 04 Running Away.flac",
"/test/Spookey Ruben/Modes of Transportation, Volume 1/Spookey Ruben - Modes of Transportation, Volume 1 - 05 Welcome to the House of Food.flac",
"/test/Spookey Ruben/Modes of Transportation, Volume 1/Spookey Ruben - Modes of Transportation, Volume 1 - 06 Wendy McDonald.flac",
"/test/Spookey Ruben/Modes of Transportation, Volume 1/Spookey Ruben - Modes of Transportation, Volume 1 - 07 The Size of You.flac",
"/test/Spookey Ruben/Modes of Transportation, Volume 1/Spookey Ruben - Modes of Transportation, Volume 1 - 08 Its Not What You Do Its You.flac",
"/test/Spookey Ruben/Modes of Transportation, Volume 1/Spookey Ruben - Modes of Transportation, Volume 1 - 09 Mars.flac",
"/test/Spookey Ruben/Modes of Transportation, Volume 1/Spookey Ruben - Modes of Transportation, Volume 1 - 10 Leave the City.flac",
"/test/Spookey Ruben/Modes of Transportation, Volume 1/Spookey Ruben - Modes of Transportation, Volume 1 - 11 Growing Up is Over.flac",
"/test/Spookey Ruben/Modes of Transportation, Volume 1/Spookey Ruben - Modes of Transportation, Volume 1 - 12 Donate Your Heart to a Stranger....flac",
"/test/Spookey Ruben/Modes of Transportation, Volume 1/Spookey Ruben - Modes of Transportation, Volume 1 - 13 Life Insurance.flac",
]
@pytest.fixture
def media():
def fixture():
for t in fixture_tracks:
yield Path(t)
return fixture
def test_scanner(monkeypatch, in_memory_db, media):
# replace the filesystem glob with the test fixture generator
monkeypatch.setattr(scanner.MediaScanner, 'find_sources', MagicMock(return_value=media()))
def test_scanner(monkeypatch, in_memory_db):
def mock_loader(path):
return {
'artist': 'foo',
'title': 'bar',
}
# replace music_tag so it doesn't try to read things
monkeypatch.setattr(scanner.MediaScanner, '_get_tags', MagicMock(side_effect=mock_loader))
test_scanner = scanner.media_scanner(root=Path('/test'), db=in_memory_db)
expected = len(fixture_tracks)
test_scanner = scanner.MediaScanner(path=Path('UNKLE'), db=in_memory_db)
# verify all entries are scanned
assert test_scanner.scan() == expected
assert test_scanner.scan() == 1
# readback; verify entries are in the db
query = func.count(track.c.relpath)
query = query.filter(track.c.relpath.ilike('%Spookey%'))
assert in_memory_db.query(query).scalar() == expected
query = query.filter(track.c.relpath.ilike('%UNKLE%'))
assert in_memory_db.query(query).scalar() == 1
# verify idempotency
assert test_scanner.scan() == 0
@ -65,4 +33,4 @@ def test_scanner(monkeypatch, in_memory_db, media):
def test_scanner_no_media_root(in_memory_db):
del os.environ['MEDIA_ROOT']
with pytest.raises(groove.exceptions.ConfigurationError):
assert scanner.media_scanner(root=None, db=in_memory_db)
assert scanner.MediaScanner(path=None, db=in_memory_db)

View File

@ -17,5 +17,20 @@ help = #999999
background = #001321
info = #88FF88
error = #FF8888
danger = #FF8888
log.time = #9999FF
log.message = #f1f2f6
log.path = #9999FF
bar.back = #555555
bar.finished = #70bc45
bar.complete = #70bc45
bar.pulse = #f1f2f6
progress.description = #999999
progress.percentage = #70bc45
progress.spinner = #70bc45