diff --git a/pyproject.toml b/pyproject.toml index e664103..aa1669c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = 'dnd-rolltable' -version = '1.1.5' +version = '1.1.8' license = 'The Unlicense' authors = ['Greg Boyington '] description = 'Generate roll tables using weighted random distributions' @@ -14,6 +14,7 @@ typer = "latest" rich = "latest" pyyaml = "latest" pytest = "latest" +csv2md = "latest" [tool.poetry.scripts] diff --git a/rolltable/cli.py b/rolltable/cli.py index 039e535..2fea9ad 100644 --- a/rolltable/cli.py +++ b/rolltable/cli.py @@ -1,5 +1,6 @@ from rolltable import tables import typer +from enum import Enum from rich import print from rich.table import Table from pathlib import Path @@ -9,6 +10,12 @@ from typing import List app = typer.Typer() +class OUTPUT_FORMATS(Enum): + text = 'text' + yaml = 'yaml' + markdown = 'markdown' + + @app.command("roll-table") def create( sources: List[Path] = typer.Argument( @@ -20,28 +27,37 @@ def create( die: int = typer.Option( 20, help='The size of the die for which to create a table'), + hide_rolls: bool = typer.Option( + False, + help='If True, do not show the Roll column.', + ), collapsed: bool = typer.Option( True, help='If True, collapse multiple die values with the same option.'), - yaml: bool = typer.Option( - False, - help='Render output as yaml.') + width: int = typer.Option( + 120, + help='Width of the table.'), + output: OUTPUT_FORMATS = typer.Option( + 'text', + help='The output format to use.', + ) ): """ CLI for creating roll tables. """ - rt = tables.RollTable([Path(s).read_text() for s in sources], frequency=frequency, die=die) + rt = tables.RollTable([Path(s).read_text() for s in sources], frequency=frequency, die=die, hide_rolls=hide_rolls) - if yaml: + if output == OUTPUT_FORMATS.yaml: print(rt.as_yaml()) - return - - rows = rt.rows if collapsed else rt.expanded_rows - table = Table(*rows[0]) - for row in rows[1:]: - table.add_row(*row) - print(table) + elif output == OUTPUT_FORMATS.markdown: + print(rt.as_markdown) + else: + rows = rt.rows if collapsed else rt.expanded_rows + table = Table(*rows[0], width=width) + for row in rows[1:]: + table.add_row(*row) + print(table) if __name__ == '__main__': diff --git a/rolltable/tables.py b/rolltable/tables.py index e98f3e9..77de980 100644 --- a/rolltable/tables.py +++ b/rolltable/tables.py @@ -1,5 +1,6 @@ import yaml import random +from csv2md.table import Table from collections.abc import Iterable from typing import Optional, List, IO @@ -83,10 +84,11 @@ class RollTable: """ def __init__(self, sources: List[str], frequency: str = 'default', - die: Optional[int] = 20) -> None: + die: Optional[int] = 20, hide_rolls: bool = False) -> None: self._sources = sources self._frequency = frequency self._die = die + self._hide_rolls = hide_rolls self._data = None self._rows = None self._headers = None @@ -157,6 +159,7 @@ class RollTable: lastrow = None offset = 0 self._rows = [self._column_filter(['Roll'] + self.headers)] + for face in range(self._die): row = self._values[face] if not lastrow: @@ -180,7 +183,7 @@ class RollTable: @property def as_markdown(self) -> str: - return '' + return Table(self.rows).markdown() def _config(self): """ @@ -204,12 +207,14 @@ class RollTable: self._header_excludes = [] for i in range(len(self._headers)): if self.headers[i] is None: - self._header_excludes.append(i+1) # +1 to account for the 'Roll' column + self._header_excludes.append(i) def _column_filter(self, row): - cols = [col for (pos, col) in enumerate(row) if pos not in self._header_excludes] + cols = [col or '' for (pos, col) in enumerate(row) if pos not in self._header_excludes] # pad the row with empty columns if there are more headers than columns - return cols + [''] * (1 + len(self.headers) - len(row)) + cols = cols + [''] * (1 + len(self.headers) - len(row)) + # strip the leading column if we're hiding the dice rolls + return cols[1:] if self._hide_rolls else cols def _flatten(self, obj: List) -> List: for member in obj: @@ -220,6 +225,5 @@ class RollTable: def __repr__(self) -> str: rows = list(self.rows) - print(rows) str_format = '\t'.join(['{:10s}'] * len(rows[0])) - return "\n".join([str_format.format(*row) for row in rows]) + return "\n".join([str_format.format(*[r or '' for r in row]) for row in rows]) diff --git a/tests/test_tables.py b/tests/test_tables.py index 3581423..dc1b9b1 100644 --- a/tests/test_tables.py +++ b/tests/test_tables.py @@ -166,3 +166,10 @@ def test_yaml(): assert tables.RollTable([fixture_one_choice]).as_yaml() assert tables.RollTable([fixture_metadata + fixture_source]).as_yaml() assert tables.RollTable([fixture_source]).as_yaml() + + +def test_text(): + assert repr(tables.RollTable([fixture_no_options])) + assert repr(tables.RollTable([fixture_one_choice])) + assert repr(tables.RollTable([fixture_metadata + fixture_source])) + assert repr(tables.RollTable([fixture_source]))