diff --git a/src/grung/examples.py b/src/grung/examples.py index 377e271..f93905c 100644 --- a/src/grung/examples.py +++ b/src/grung/examples.py @@ -1,4 +1,4 @@ -from grung.types import BackReference, Collection, Field, Integer, Record +from grung.types import BackReference, Collection, Field, Integer, Password, Record class User(Record): @@ -9,6 +9,7 @@ class User(Record): Field("name"), Integer("number", default=0), Field("email", unique=True), + Password("password"), BackReference("groups", Group), ] diff --git a/src/grung/types.py b/src/grung/types.py index 48a5004..cdba6ec 100644 --- a/src/grung/types.py +++ b/src/grung/types.py @@ -1,5 +1,7 @@ from __future__ import annotations +import hashlib +import os from collections import namedtuple from dataclasses import dataclass, field from typing import Dict, List @@ -46,6 +48,34 @@ class Integer(Field): return int(value) +@dataclass +class Password(Field): + value_type = str + default: str = None + + # Relatively weak. Consider using stronger initial values in production applications. + salt_size = 4 + digest_size = 16 + + @classmethod + def get_digest(cls, passwd: str, salt: bytes = None): + if not salt: + salt = os.urandom(cls.salt_size) + digest = hashlib.blake2b(passwd.encode(), digest_size=cls.digest_size, salt=salt).hexdigest() + return digest, salt.hex() + + @classmethod + def compare(cls, passwd: value_type, stored: value_type): + stored_salt, stored_digest = stored.split(":") + input_digest, input_salt = cls.get_digest(passwd, bytes.fromhex(stored_salt)) + return input_digest == stored_digest + + def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: + if value: + digest, salt = self.__class__.get_digest(value) + record[self.name] = f"{salt}:{digest}" + + class Record(Dict[(str, Field)]): """ Base type for a single database record. @@ -72,7 +102,7 @@ class Record(Dict[(str, Field)]): def serialize(self): """ - Serialie every field on the record + Serialize every field on the record """ rec = {} for name, _field in self._metadata.fields.items(): diff --git a/test/test_db.py b/test/test_db.py index 2675092..eb06005 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -118,3 +118,15 @@ def test_search(db): Group = Query() crew = db.Group.search(Group.name == "Crew", recurse=False) assert kirk.reference in crew[0].members + + +def test_password(db): + user = db.save(examples.User(name="john", email="john@foo", password="fnord")) + + assert ":" in user.password + assert user.password != "fnord" + + check = user._metadata.fields["password"].compare + assert check("fnord", user.password) + assert not check("wrong password", user.password) + assert not check("", user.password)