diff --git a/src/grung/db.py b/src/grung/db.py index abcb4b9..452433f 100644 --- a/src/grung/db.py +++ b/src/grung/db.py @@ -34,10 +34,25 @@ class RecordTable(table.Table): doc.after_insert(self._db) return doc.deserialize(self._db) - def get(self, doc_id: int, recurse: bool = False): - document = super().get(doc_id=doc_id) - if document: - return document.deserialize(self._db, recurse=recurse) + def get(self, *args, doc_id: int = None, recurse: bool = False, **kwargs): + """ + Return exactly zero or one records from the database matching the supplied criteria. + If more than one records match the criteria, return the first one. Criteria are ignored + if doc_id is specified. + + Usage: + Table.get(doc_id=1) + Table.get(where("uid") == "abcdef") + + """ + if doc_id: + document = super().get(doc_id=doc_id) + if document: + return document.deserialize(self._db, recurse=recurse) + + matches = self.search(*args, recurse=recurse, **kwargs) + if matches: + return matches[0] def search(self, *args, recurse: bool = False, **kwargs) -> List[Record]: results = super().search(*args, **kwargs) diff --git a/test/test_db.py b/test/test_db.py index ddfc49b..efc538d 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -116,8 +116,8 @@ def test_search(db): assert ricky in crew.members Group = Query() - crew = db.Group.search(Group.name == "Crew", recurse=False) - assert kirk.reference in crew[0].members + crew = db.Group.get(Group.name == "Crew", recurse=False) + assert kirk.reference in crew.members def test_password(db):