diff --git a/bemani/api/app.py b/bemani/api/app.py index d496dfb..3e8becd 100644 --- a/bemani/api/app.py +++ b/bemani/api/app.py @@ -286,10 +286,17 @@ def lookup(protoversion: str, requestgame: str, requestversion: str) -> Dict[str # Don't support this version! abort(404) - idtype = requestdata['type'] - ids = requestdata['ids'] - if idtype not in [APIConstants.ID_TYPE_CARD, APIConstants.ID_TYPE_SONG, APIConstants.ID_TYPE_INSTANCE, APIConstants.ID_TYPE_SERVER]: + # Attempt to coerce ID type. If we fail, provide the correct failure message. + idtype = None + try: + idtype = APIConstants(requestdata['type']) + except ValueError: + pass + if idtype is None: raise APIException('Invalid ID type provided!') + + # Validate the provided IDs given the ID type above. + ids = requestdata['ids'] if idtype == APIConstants.ID_TYPE_CARD and len(ids) == 0: raise APIException('Invalid number of IDs given!') if idtype == APIConstants.ID_TYPE_SONG and len(ids) not in [1, 2]: diff --git a/bemani/api/objects/base.py b/bemani/api/objects/base.py index 75e546a..c8144b6 100644 --- a/bemani/api/objects/base.py +++ b/bemani/api/objects/base.py @@ -1,7 +1,7 @@ from typing import List, Any, Dict from bemani.api.exceptions import APIException -from bemani.common import GameConstants +from bemani.common import APIConstants, GameConstants from bemani.data import Data @@ -20,5 +20,5 @@ class BaseObject: self.version = version self.omnimix = omnimix - def fetch_v1(self, idtype: str, ids: List[str], params: Dict[str, Any]) -> Any: + def fetch_v1(self, idtype: APIConstants, ids: List[str], params: Dict[str, Any]) -> Any: raise APIException('Object fetch not supported for this version!') diff --git a/bemani/api/objects/catalog.py b/bemani/api/objects/catalog.py index 74b4347..7738a6f 100644 --- a/bemani/api/objects/catalog.py +++ b/bemani/api/objects/catalog.py @@ -192,7 +192,7 @@ class CatalogObject(BaseObject): else: return self.version - def fetch_v1(self, idtype: str, ids: List[str], params: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]: + def fetch_v1(self, idtype: APIConstants, ids: List[str], params: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]: # Verify IDs if idtype != APIConstants.ID_TYPE_SERVER: raise APIException( diff --git a/bemani/api/objects/profile.py b/bemani/api/objects/profile.py index b1467a7..08ff330 100644 --- a/bemani/api/objects/profile.py +++ b/bemani/api/objects/profile.py @@ -73,7 +73,7 @@ class ProfileObject(BaseObject): return base - def fetch_v1(self, idtype: str, ids: List[str], params: Dict[str, Any]) -> List[Dict[str, Any]]: + def fetch_v1(self, idtype: APIConstants, ids: List[str], params: Dict[str, Any]) -> List[Dict[str, Any]]: # Fetch the profiles profiles: List[Tuple[UserID, ValidatedDict]] = [] if idtype == APIConstants.ID_TYPE_SERVER: diff --git a/bemani/api/objects/records.py b/bemani/api/objects/records.py index ddc90b6..7564deb 100644 --- a/bemani/api/objects/records.py +++ b/bemani/api/objects/records.py @@ -230,7 +230,7 @@ class RecordsObject(BaseObject): else: return self.version - def fetch_v1(self, idtype: str, ids: List[str], params: Dict[str, Any]) -> List[Dict[str, Any]]: + def fetch_v1(self, idtype: APIConstants, ids: List[str], params: Dict[str, Any]) -> List[Dict[str, Any]]: since = params.get('since') until = params.get('until') diff --git a/bemani/api/objects/statistics.py b/bemani/api/objects/statistics.py index 9d51da5..9b1fab8 100644 --- a/bemani/api/objects/statistics.py +++ b/bemani/api/objects/statistics.py @@ -180,7 +180,7 @@ class StatisticsObject(BaseObject): return retval - def fetch_v1(self, idtype: str, ids: List[str], params: Dict[str, Any]) -> List[Dict[str, Any]]: + def fetch_v1(self, idtype: APIConstants, ids: List[str], params: Dict[str, Any]) -> List[Dict[str, Any]]: retval: List[Dict[str, Any]] = [] # Fetch the attempts diff --git a/bemani/common/constants.py b/bemani/common/constants.py index 1a55f74..9fa865a 100644 --- a/bemani/common/constants.py +++ b/bemani/common/constants.py @@ -132,11 +132,9 @@ class VersionConstants: SDVX_HEAVENLY_HAVEN: Final[int] = 4 -class APIConstants: +class APIConstants(Enum): """ The four types of IDs found in a BEMAPI request or response. - - TODO: These should be an enum. """ ID_TYPE_SERVER: Final[str] = 'server' ID_TYPE_CARD: Final[str] = 'card' diff --git a/bemani/data/api/client.py b/bemani/data/api/client.py index 41bac8b..6a56331 100644 --- a/bemani/data/api/client.py +++ b/bemani/data/api/client.py @@ -2,7 +2,7 @@ import json import requests from typing import Tuple, Dict, List, Any, Optional -from bemani.common import GameConstants, VersionConstants, DBConstants, ValidatedDict +from bemani.common import APIConstants, GameConstants, VersionConstants, DBConstants, ValidatedDict class APIException(Exception): @@ -194,7 +194,7 @@ class APIClient: 'versions': resp['versions'], }) - def get_profiles(self, game: GameConstants, version: int, idtype: str, ids: List[str]) -> List[Dict[str, Any]]: + def get_profiles(self, game: GameConstants, version: int, idtype: APIConstants, ids: List[str]) -> List[Dict[str, Any]]: # Allow remote servers to be disabled if not self.allow_scores: return [] @@ -205,7 +205,7 @@ class APIClient: f'{self.API_VERSION}/{servergame}/{serverversion}', { 'ids': ids, - 'type': idtype, + 'type': idtype.value, 'objects': ['profile'], }, ) @@ -218,7 +218,7 @@ class APIClient: self, game: GameConstants, version: int, - idtype: str, + idtype: APIConstants, ids: List[str], since: Optional[int]=None, until: Optional[int]=None, @@ -231,7 +231,7 @@ class APIClient: servergame, serverversion = self.__translate(game, version) data: Dict[str, Any] = { 'ids': ids, - 'type': idtype, + 'type': idtype.value, 'objects': ['records'], } if since is not None: @@ -247,7 +247,7 @@ class APIClient: # Couldn't talk to server, assume empty records return [] - def get_statistics(self, game: GameConstants, version: int, idtype: str, ids: List[str]) -> List[Dict[str, Any]]: + def get_statistics(self, game: GameConstants, version: int, idtype: APIConstants, ids: List[str]) -> List[Dict[str, Any]]: # Allow remote servers to be disabled if not self.allow_stats: return [] @@ -258,7 +258,7 @@ class APIClient: f'{self.API_VERSION}/{servergame}/{serverversion}', { 'ids': ids, - 'type': idtype, + 'type': idtype.value, 'objects': ['statistics'], }, )