1- # -*- coding: utf-8 -*-
2- #
31# Copyright 2017 Gehirn Inc.
42#
53# Licensed under the Apache License, Version 2.0 (the "License");
1816import hmac
1917from typing import (
2018 Any ,
21- Dict ,
2219 Callable ,
2320 Optional ,
2421)
3532
3633
3734def std_hash_by_alg (alg : str ) -> Callable [[bytes ], object ]:
38- if alg .endswith (' S256' ):
35+ if alg .endswith (" S256" ):
3936 return hashlib .sha256
40- if alg .endswith (' S384' ):
37+ if alg .endswith (" S384" ):
4138 return hashlib .sha384
42- if alg .endswith (' S512' ):
39+ if alg .endswith (" S512" ):
4340 return hashlib .sha512
44- raise ValueError ('{ } is not supported' . format ( alg ) )
41+ raise ValueError (f" { alg } is not supported" )
4542
4643
4744class AbstractSigningAlgorithm :
4845
4946 def sign (self , message : bytes , key : Optional [AbstractJWKBase ]) -> bytes :
5047 raise NotImplementedError () # pragma: no cover
5148
52- def verify (self , message : bytes , key : Optional [AbstractJWKBase ],
53- signature : bytes ) -> bool :
49+ def verify (
50+ self ,
51+ message : bytes ,
52+ key : Optional [AbstractJWKBase ],
53+ signature : bytes ,
54+ ) -> bool :
5455 raise NotImplementedError () # pragma: no cover
5556
5657
5758class NoneAlgorithm (AbstractSigningAlgorithm ):
5859
5960 def sign (self , message : bytes , key : Optional [AbstractJWKBase ]) -> bytes :
60- return b''
61+ return b""
6162
62- def verify (self , message : bytes , key : Optional [AbstractJWKBase ],
63- signature : bytes ) -> bool :
64- return hmac .compare_digest (signature , b'' )
63+ def verify (
64+ self ,
65+ message : bytes ,
66+ key : Optional [AbstractJWKBase ],
67+ signature : bytes ,
68+ ) -> bool :
69+ return hmac .compare_digest (signature , b"" )
6570
6671
6772none = NoneAlgorithm ()
@@ -73,8 +78,8 @@ def __init__(self, hash_fun: Callable[[], Any]) -> None:
7378 self .hash_fun = hash_fun
7479
7580 def _check_key (self , key : Optional [AbstractJWKBase ]) -> AbstractJWKBase :
76- if not key or key .get_kty () != ' oct' :
77- raise InvalidKeyTypeError (' Octet key is required' )
81+ if not key or key .get_kty () != " oct" :
82+ raise InvalidKeyTypeError (" Octet key is required" )
7883 return key
7984
8085 def _sign (self , message : bytes , key : bytes ) -> bytes :
@@ -84,8 +89,12 @@ def sign(self, message: bytes, key: Optional[AbstractJWKBase]) -> bytes:
8489 key = self ._check_key (key )
8590 return key .sign (message , signer = self ._sign )
8691
87- def verify (self , message : bytes , key : Optional [AbstractJWKBase ],
88- signature : bytes ) -> bool :
92+ def verify (
93+ self ,
94+ message : bytes ,
95+ key : Optional [AbstractJWKBase ],
96+ signature : bytes ,
97+ ) -> bool :
8998 key = self ._check_key (key )
9099 return key .verify (message , signature , signer = self ._sign )
91100
@@ -105,17 +114,19 @@ def _check_key(
105114 key : Optional [AbstractJWKBase ],
106115 must_sign_key : bool = False ,
107116 ) -> AbstractJWKBase :
108- if not key or key .get_kty () != ' RSA' :
109- raise InvalidKeyTypeError (' RSA key is required' )
117+ if not key or key .get_kty () != " RSA" :
118+ raise InvalidKeyTypeError (" RSA key is required" )
110119 if must_sign_key and not key .is_sign_key ():
111120 raise InvalidKeyTypeError (
112- 'a RSA private key is required, but passed is RSA public key' )
121+ "a RSA private key is required, but passed is RSA public key"
122+ )
113123 return key
114124
115125 def sign (self , message : bytes , key : Optional [AbstractJWKBase ]) -> bytes :
116126 key = self ._check_key (key , must_sign_key = True )
117- return key .sign (message , hash_fun = self .hash_fun ,
118- padding = padding .PKCS1v15 ())
127+ return key .sign (
128+ message , hash_fun = self .hash_fun , padding = padding .PKCS1v15 ()
129+ )
119130
120131 def verify (
121132 self ,
@@ -124,8 +135,12 @@ def verify(
124135 signature : bytes ,
125136 ) -> bool :
126137 key = self ._check_key (key )
127- return key .verify (message , signature , hash_fun = self .hash_fun ,
128- padding = padding .PKCS1v15 ())
138+ return key .verify (
139+ message ,
140+ signature ,
141+ hash_fun = self .hash_fun ,
142+ padding = padding .PKCS1v15 (),
143+ )
129144
130145
131146RS256 = RSAAlgorithm (SHA256 )
@@ -142,11 +157,12 @@ def _check_key(
142157 key : Optional [AbstractJWKBase ],
143158 must_sign_key : bool = False ,
144159 ) -> AbstractJWKBase :
145- if not key or key .get_kty () != ' RSA' :
146- raise InvalidKeyTypeError (' RSA key is required' )
160+ if not key or key .get_kty () != " RSA" :
161+ raise InvalidKeyTypeError (" RSA key is required" )
147162 if must_sign_key and not key .is_sign_key ():
148163 raise InvalidKeyTypeError (
149- 'a RSA private key is required, but passed is RSA public key' )
164+ "a RSA private key is required, but passed is RSA public key"
165+ )
150166 return key
151167
152168 def sign (self , message : bytes , key : Optional [AbstractJWKBase ]) -> bytes :
@@ -164,7 +180,7 @@ def verify(
164180 self ,
165181 message : bytes ,
166182 key : Optional [AbstractJWKBase ],
167- signature : bytes
183+ signature : bytes ,
168184 ) -> bool :
169185 key = self ._check_key (key )
170186 return key .verify (
@@ -183,16 +199,16 @@ def verify(
183199PS512 = PSSRSAAlgorithm (SHA512 )
184200
185201
186- def supported_signing_algorithms () -> Dict [str , AbstractSigningAlgorithm ]:
202+ def supported_signing_algorithms () -> dict [str , AbstractSigningAlgorithm ]:
187203 # NOTE(yosida95): exclude vulnerable 'none' algorithm by default.
188204 return {
189- ' HS256' : HS256 ,
190- ' HS384' : HS384 ,
191- ' HS512' : HS512 ,
192- ' RS256' : RS256 ,
193- ' RS384' : RS384 ,
194- ' RS512' : RS512 ,
195- ' PS256' : PS256 ,
196- ' PS384' : PS384 ,
197- ' PS512' : PS512 ,
205+ " HS256" : HS256 ,
206+ " HS384" : HS384 ,
207+ " HS512" : HS512 ,
208+ " RS256" : RS256 ,
209+ " RS384" : RS384 ,
210+ " RS512" : RS512 ,
211+ " PS256" : PS256 ,
212+ " PS384" : PS384 ,
213+ " PS512" : PS512 ,
198214 }
0 commit comments