File size: 4,156 Bytes
d93884d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from cryptography.hazmat.primitives import serialization, hashes
from cryptography import x509
from cryptography.hazmat.primitives.asymmetric import ec, rsa
from cryptography.hazmat.backends import default_backend
from cryptography.x509.oid import NameOID
from typing import List, Tuple

def gen_pvt(key_type: str, key_size: int = None, key_curve: str = None) -> bytes:
    if key_type.lower() == "ec":
        if key_curve == 'SECP256R1' or key_curve == 'ec256':
            key = ec.generate_private_key(ec.SECP256R1(), default_backend())
        elif key_curve == 'SECP384R1' or key_curve == 'ec384':
            key = ec.generate_private_key(ec.SECP384R1(), default_backend())
        else:
            key = ec.generate_private_key(ec.SECP256R1(), default_backend())
        private_key = key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=serialization.NoEncryption()
        )
    elif key_type.lower() == "rsa":
        if key_size not in [2048, 4096]:
            key_size = 4096
        key = rsa.generate_private_key(
            public_exponent=65537,
            key_size=key_size,
            backend=default_backend()
        )
        private_key = key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=serialization.NoEncryption()
        )
    else:
        raise ValueError("Unsupported key type or parameters")
    return private_key

def gen_csr(private_key: bytes, domains: List[str], email: str, common_name: str = None, country: str = None,

           state: str = None, locality: str = None, organization: str = None, organization_unit: str = None) -> bytes:
    
    ssl_domains = [x509.DNSName(domain.strip()) for domain in domains]
    private_key_obj = serialization.load_pem_private_key(private_key, password=None, backend=default_backend())
    try:
        if email.split("@")[1] in ["demo.com", "example.com"] or email.count("@") > 1 or email.count(".") < 1 or email is None:
            print("Invalid email address")
            email = f"admin@{domains[0]}"
    except Exception as e:
        print(f"Error in email address: {e}")
        email = f"admin@{domains[0]}"
    country: str = country or "IN"
    state: str = state or "Maharashtra"
    locality: str = locality or "Mumbai"
    organization_unit: str = organization_unit or "IT Department"
    common_name: str = common_name or domains[0]
    organization: str = organization or common_name.split(".")[0]
    subject = x509.Name([
        x509.NameAttribute(NameOID.COUNTRY_NAME, country),
        x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, state),
        x509.NameAttribute(NameOID.LOCALITY_NAME, locality),
        x509.NameAttribute(NameOID.EMAIL_ADDRESS, email),
        x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization),
        x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, organization_unit),
        x509.NameAttribute(NameOID.COMMON_NAME, common_name),
    ])
    builder = x509.CertificateSigningRequestBuilder()
    builder = builder.subject_name(subject)
    builder = builder.add_extension(
        x509.SubjectAlternativeName(ssl_domains),
        critical=False,
    )
    csr = builder.sign(private_key_obj, hashes.SHA256(), default_backend())
    return csr.public_bytes(serialization.Encoding.PEM)

def gen_pvt_csr(domains: List[str], key_type: str, key_size: int = None, key_curve: str = None, email: str = None,

              common_name: str = None, country: str = None, state: str = None, locality: str = None,

              organization: str = None, organization_unit: str = None) -> Tuple[bytes, bytes]:
    if key_type.lower() == "rsa":
        private_key = gen_pvt(key_type, key_size)
    else:
        private_key = gen_pvt(key_type, key_curve)
    csr = gen_csr(private_key, domains, email, common_name, country, state, locality, organization, organization_unit)
    return private_key, csr