Spaces:
Sleeping
Sleeping
from cryptography.hazmat.primitives import serialization, hashes | |
from cryptography import x509 | |
from cryptography.hazmat.primitives.asymmetric import ec, rsa | |
from cryptography.hazmat.backends import default_backend | |
from cryptography.x509.oid import NameOID | |
from typing import List, Tuple | |
def gen_pvt(key_type: str, key_size: int = None, key_curve: str = None) -> bytes: | |
if key_type.lower() == "ec": | |
if key_curve == 'SECP256R1' or key_curve == 'ec256': | |
key = ec.generate_private_key(ec.SECP256R1(), default_backend()) | |
elif key_curve == 'SECP384R1' or key_curve == 'ec384': | |
key = ec.generate_private_key(ec.SECP384R1(), default_backend()) | |
else: | |
key = ec.generate_private_key(ec.SECP256R1(), default_backend()) | |
private_key = key.private_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PrivateFormat.TraditionalOpenSSL, | |
encryption_algorithm=serialization.NoEncryption() | |
) | |
elif key_type.lower() == "rsa": | |
if key_size not in [2048, 4096]: | |
key_size = 4096 | |
key = rsa.generate_private_key( | |
public_exponent=65537, | |
key_size=key_size, | |
backend=default_backend() | |
) | |
private_key = key.private_bytes( | |
encoding=serialization.Encoding.PEM, | |
format=serialization.PrivateFormat.TraditionalOpenSSL, | |
encryption_algorithm=serialization.NoEncryption() | |
) | |
else: | |
raise ValueError("Unsupported key type or parameters") | |
return private_key | |
def gen_csr(private_key: bytes, domains: List[str], email: str, common_name: str = None, country: str = None, | |
state: str = None, locality: str = None, organization: str = None, organization_unit: str = None) -> bytes: | |
ssl_domains = [x509.DNSName(domain.strip()) for domain in domains] | |
private_key_obj = serialization.load_pem_private_key(private_key, password=None, backend=default_backend()) | |
try: | |
if email.split("@")[1] in ["demo.com", "example.com"] or email.count("@") > 1 or email.count(".") < 1 or email is None: | |
print("Invalid email address") | |
email = f"admin@{domains[0]}" | |
except Exception as e: | |
print(f"Error in email address: {e}") | |
email = f"admin@{domains[0]}" | |
country: str = country or "IN" | |
state: str = state or "Maharashtra" | |
locality: str = locality or "Mumbai" | |
organization_unit: str = organization_unit or "IT Department" | |
common_name: str = common_name or domains[0] | |
organization: str = organization or common_name.split(".")[0] | |
subject = x509.Name([ | |
x509.NameAttribute(NameOID.COUNTRY_NAME, country), | |
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state), | |
x509.NameAttribute(NameOID.LOCALITY_NAME, locality), | |
x509.NameAttribute(NameOID.EMAIL_ADDRESS, email), | |
x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization), | |
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, organization_unit), | |
x509.NameAttribute(NameOID.COMMON_NAME, common_name), | |
]) | |
builder = x509.CertificateSigningRequestBuilder() | |
builder = builder.subject_name(subject) | |
builder = builder.add_extension( | |
x509.SubjectAlternativeName(ssl_domains), | |
critical=False, | |
) | |
csr = builder.sign(private_key_obj, hashes.SHA256(), default_backend()) | |
return csr.public_bytes(serialization.Encoding.PEM) | |
def gen_pvt_csr(domains: List[str], key_type: str, key_size: int = None, key_curve: str = None, email: str = None, | |
common_name: str = None, country: str = None, state: str = None, locality: str = None, | |
organization: str = None, organization_unit: str = None) -> Tuple[bytes, bytes]: | |
if key_type.lower() == "rsa": | |
private_key = gen_pvt(key_type, key_size) | |
else: | |
private_key = gen_pvt(key_type, key_curve) | |
csr = gen_csr(private_key, domains, email, common_name, country, state, locality, organization, organization_unit) | |
return private_key, csr | |