project-gatekeeper / genPVTCSR.py
raannakasturi's picture
Upload 12 files
d93884d verified
from cryptography.hazmat.primitives import serialization, hashes
from cryptography import x509
from cryptography.hazmat.primitives.asymmetric import ec, rsa
from cryptography.hazmat.backends import default_backend
from cryptography.x509.oid import NameOID
from typing import List, Tuple
def gen_pvt(key_type: str, key_size: int = None, key_curve: str = None) -> bytes:
if key_type.lower() == "ec":
if key_curve == 'SECP256R1' or key_curve == 'ec256':
key = ec.generate_private_key(ec.SECP256R1(), default_backend())
elif key_curve == 'SECP384R1' or key_curve == 'ec384':
key = ec.generate_private_key(ec.SECP384R1(), default_backend())
else:
key = ec.generate_private_key(ec.SECP256R1(), default_backend())
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
)
elif key_type.lower() == "rsa":
if key_size not in [2048, 4096]:
key_size = 4096
key = rsa.generate_private_key(
public_exponent=65537,
key_size=key_size,
backend=default_backend()
)
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
)
else:
raise ValueError("Unsupported key type or parameters")
return private_key
def gen_csr(private_key: bytes, domains: List[str], email: str, common_name: str = None, country: str = None,
state: str = None, locality: str = None, organization: str = None, organization_unit: str = None) -> bytes:
ssl_domains = [x509.DNSName(domain.strip()) for domain in domains]
private_key_obj = serialization.load_pem_private_key(private_key, password=None, backend=default_backend())
try:
if email.split("@")[1] in ["demo.com", "example.com"] or email.count("@") > 1 or email.count(".") < 1 or email is None:
print("Invalid email address")
email = f"admin@{domains[0]}"
except Exception as e:
print(f"Error in email address: {e}")
email = f"admin@{domains[0]}"
country: str = country or "IN"
state: str = state or "Maharashtra"
locality: str = locality or "Mumbai"
organization_unit: str = organization_unit or "IT Department"
common_name: str = common_name or domains[0]
organization: str = organization or common_name.split(".")[0]
subject = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, country),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state),
x509.NameAttribute(NameOID.LOCALITY_NAME, locality),
x509.NameAttribute(NameOID.EMAIL_ADDRESS, email),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, organization_unit),
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
])
builder = x509.CertificateSigningRequestBuilder()
builder = builder.subject_name(subject)
builder = builder.add_extension(
x509.SubjectAlternativeName(ssl_domains),
critical=False,
)
csr = builder.sign(private_key_obj, hashes.SHA256(), default_backend())
return csr.public_bytes(serialization.Encoding.PEM)
def gen_pvt_csr(domains: List[str], key_type: str, key_size: int = None, key_curve: str = None, email: str = None,
common_name: str = None, country: str = None, state: str = None, locality: str = None,
organization: str = None, organization_unit: str = None) -> Tuple[bytes, bytes]:
if key_type.lower() == "rsa":
private_key = gen_pvt(key_type, key_size)
else:
private_key = gen_pvt(key_type, key_curve)
csr = gen_csr(private_key, domains, email, common_name, country, state, locality, organization, organization_unit)
return private_key, csr