from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.primitives.asymmetric import rsa, padding def rsa_create_key_pair() -> tuple[bytes, bytes]: """ Create a pair of private and public RSA key. :return: a pair of private and public RSA key. """ # create a private key private_key = rsa.generate_private_key( public_exponent=65537, key_size=2048 ) # serialize the private key private_key_data = private_key.private_bytes( encoding=serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption() ) # get the public key from the private key public_key = private_key.public_key() # serialize the public key public_key_data = public_key.public_bytes( encoding=serialization.Encoding.DER, format=serialization.PublicFormat.PKCS1 ) return private_key_data, public_key_data def rsa_encrypt(data: bytes, public_key_data: bytes) -> bytes: """ Encrypt data with RSA using a public key :param data: the data to encrypt :param public_key_data: the public key to encrypt with :return: the encrypted data """ # load the public key public_key = serialization.load_der_public_key(public_key_data) # verify if the key is loaded if not isinstance(public_key, rsa.RSAPublicKey): raise ValueError("Could not load the public key.") # encrypt the data with the key return public_key.encrypt( data, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None ) ) def rsa_decrypt(encrypted_data: bytes, private_key_data: bytes) -> bytes: """ Decrypt the data with the RSA private key :param encrypted_data: the data to decrypt :param private_key_data: the private key data :return: the decrypted data """ # load the private key private_key = serialization.load_der_private_key(private_key_data, None) # verify if the key is loaded if not isinstance(private_key, rsa.RSAPrivateKey): raise ValueError("Could not load the private key.") # decrypt the data return private_key.decrypt( encrypted_data, padding.OAEP( mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None ) )