Ozgur Unlu commited on
Commit
62a1dcb
1 Parent(s): 8e69166

made some UI changes, stil in CPU

Browse files
Files changed (3) hide show
  1. analyzer.py +140 -0
  2. app.py +3 -146
  3. interface.py +37 -0
analyzer.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import easyocr
2
+ import torch
3
+ import numpy as np
4
+ from compliance_rules import ComplianceRules
5
+
6
+ # Initialize OCR reader
7
+ reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
8
+
9
+ # Initialize compliance rules
10
+ compliance_rules = ComplianceRules()
11
+
12
+ def extract_text_from_image(image):
13
+ """Extract text from image using EasyOCR"""
14
+ try:
15
+ result = reader.readtext(np.array(image))
16
+ return " ".join([text[1] for text in result])
17
+ except Exception as e:
18
+ print(f"Error in text extraction: {str(e)}")
19
+ return "Error extracting text from image"
20
+
21
+ def check_compliance(text):
22
+ """Check text for compliance across all regions"""
23
+ rules = compliance_rules.get_all_rules()
24
+ report = {
25
+ "compliant": True,
26
+ "violations": [],
27
+ "warnings": [],
28
+ "channel_risks": {
29
+ "email": {"score": 0, "details": []},
30
+ "social": {"score": 0, "details": []},
31
+ "print": {"score": 0, "details": []}
32
+ }
33
+ }
34
+
35
+ for region, region_rules in rules.items():
36
+ # Check prohibited terms
37
+ for term_info in region_rules["prohibited_terms"]:
38
+ term = term_info["term"].lower()
39
+ if term in text.lower() or any(var.lower() in text.lower() for var in term_info["variations"]):
40
+ report["compliant"] = False
41
+ violation = f"{region}: Prohibited term '{term}' found"
42
+ report["violations"].append({
43
+ "region": region,
44
+ "type": "prohibited_term",
45
+ "term": term,
46
+ "severity": term_info["severity"]
47
+ })
48
+
49
+ # Update channel risks
50
+ for channel in report["channel_risks"]:
51
+ risk_score = compliance_rules.calculate_risk_score([violation], [], region)
52
+ report["channel_risks"][channel]["score"] += risk_score
53
+ report["channel_risks"][channel]["details"].append(
54
+ f"Prohibited term '{term}' increases {channel} risk"
55
+ )
56
+
57
+ # Check required disclaimers
58
+ for disclaimer in region_rules["required_disclaimers"]:
59
+ disclaimer_found = any(
60
+ disc_text.lower() in text.lower()
61
+ for disc_text in disclaimer["text"]
62
+ )
63
+ if not disclaimer_found:
64
+ warning = f"{region}: Missing {disclaimer['type']} disclaimer"
65
+ report["warnings"].append({
66
+ "region": region,
67
+ "type": "missing_disclaimer",
68
+ "disclaimer_type": disclaimer["type"],
69
+ "severity": disclaimer["severity"]
70
+ })
71
+
72
+ # Update channel risks
73
+ for channel in report["channel_risks"]:
74
+ risk_score = compliance_rules.calculate_risk_score([], [warning], region)
75
+ report["channel_risks"][channel]["score"] += risk_score
76
+ report["channel_risks"][channel]["details"].append(
77
+ f"Missing {disclaimer['type']} disclaimer affects {channel} risk"
78
+ )
79
+
80
+ return report
81
+
82
+ def format_severity(severity):
83
+ return f'<span class="severity-{severity}">{severity.upper()}</span>'
84
+
85
+ def generate_html_report(compliance_report):
86
+ """Generate formatted HTML report"""
87
+ html = '<div class="report-container">'
88
+
89
+ # Overall Status
90
+ status_class = "compliant" if compliance_report["compliant"] else "non-compliant"
91
+ status_icon = "✅" if compliance_report["compliant"] else "❌"
92
+ html += f'<div class="status {status_class}">{status_icon} Overall Status: {"Compliant" if compliance_report["compliant"] else "Non-Compliant"}</div>'
93
+
94
+ # Violations
95
+ if compliance_report["violations"]:
96
+ html += '<div class="section">'
97
+ html += '<div class="section-title">🚫 Violations Found:</div>'
98
+ for violation in compliance_report["violations"]:
99
+ html += f'<div class="item">• {violation["region"]}: {violation["type"]} - \'{violation["term"]}\' (Severity: {format_severity(violation["severity"])})</div>'
100
+ html += '</div>'
101
+
102
+ # Warnings
103
+ if compliance_report["warnings"]:
104
+ html += '<div class="section">'
105
+ html += '<div class="section-title">⚠️ Warnings:</div>'
106
+ for warning in compliance_report["warnings"]:
107
+ html += f'<div class="item">• {warning["region"]}: {warning["disclaimer_type"]} (Severity: {format_severity(warning["severity"])})</div>'
108
+ html += '</div>'
109
+
110
+ # Channel Risk Assessment
111
+ html += '<div class="section">'
112
+ html += '<div class="section-title">📊 Channel Risk Assessment:</div>'
113
+
114
+ for channel, risk_info in compliance_report["channel_risks"].items():
115
+ score = risk_info["score"]
116
+ risk_level = "low" if score < 3 else "medium" if score < 6 else "high"
117
+
118
+ html += f'<div class="channel risk-{risk_level}">'
119
+ html += f'<strong>{channel.capitalize()}</strong>: {risk_level.upper()} Risk (Score: {score})'
120
+
121
+ if risk_info["details"]:
122
+ html += '<div class="details">'
123
+ for detail in risk_info["details"]:
124
+ html += f'<div>• {detail}</div>'
125
+ html += '</div>'
126
+ html += '</div>'
127
+
128
+ html += '</div></div>'
129
+ return html
130
+
131
+ def analyze_ad_copy(image):
132
+ """Main function to analyze ad copy"""
133
+ # Extract text from image
134
+ text = extract_text_from_image(image)
135
+
136
+ # Check compliance
137
+ compliance_report = check_compliance(text)
138
+
139
+ # Generate HTML report
140
+ return generate_html_report(compliance_report)
app.py CHANGED
@@ -1,149 +1,6 @@
1
- import gradio as gr
2
- import easyocr
3
- import torch
4
- from transformers import pipeline, DistilBertTokenizer, DistilBertForSequenceClassification
5
- import numpy as np
6
- from PIL import Image
7
- import json
8
- from compliance_rules import ComplianceRules
9
 
10
- # Print GPU information for debugging
11
- print(f"Is CUDA available: {torch.cuda.is_available()}")
12
- if torch.cuda.is_available():
13
- print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
14
- else:
15
- print("Running on CPU")
16
-
17
- # Initialize OCR reader
18
- reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
19
-
20
- # Initialize compliance rules
21
- compliance_rules = ComplianceRules()
22
-
23
- def extract_text_from_image(image):
24
- """Extract text from image using EasyOCR"""
25
- try:
26
- result = reader.readtext(np.array(image))
27
- return " ".join([text[1] for text in result])
28
- except Exception as e:
29
- print(f"Error in text extraction: {str(e)}")
30
- return "Error extracting text from image"
31
-
32
- def check_compliance(text):
33
- """Check text for compliance across all regions"""
34
- rules = compliance_rules.get_all_rules()
35
- report = {
36
- "compliant": True,
37
- "violations": [],
38
- "warnings": [],
39
- "channel_risks": {
40
- "email": {"score": 0, "details": []},
41
- "social": {"score": 0, "details": []},
42
- "print": {"score": 0, "details": []}
43
- }
44
- }
45
-
46
- for region, region_rules in rules.items():
47
- # Check prohibited terms
48
- for term_info in region_rules["prohibited_terms"]:
49
- term = term_info["term"].lower()
50
- if term in text.lower() or any(var.lower() in text.lower() for var in term_info["variations"]):
51
- report["compliant"] = False
52
- violation = f"{region}: Prohibited term '{term}' found"
53
- report["violations"].append({
54
- "region": region,
55
- "type": "prohibited_term",
56
- "term": term,
57
- "severity": term_info["severity"]
58
- })
59
-
60
- # Update channel risks
61
- for channel in report["channel_risks"]:
62
- risk_score = compliance_rules.calculate_risk_score([violation], [], region)
63
- report["channel_risks"][channel]["score"] += risk_score
64
- report["channel_risks"][channel]["details"].append(
65
- f"Prohibited term '{term}' increases {channel} risk"
66
- )
67
-
68
- # Check required disclaimers
69
- for disclaimer in region_rules["required_disclaimers"]:
70
- disclaimer_found = any(
71
- disc_text.lower() in text.lower()
72
- for disc_text in disclaimer["text"]
73
- )
74
- if not disclaimer_found:
75
- warning = f"{region}: Missing {disclaimer['type']} disclaimer"
76
- report["warnings"].append({
77
- "region": region,
78
- "type": "missing_disclaimer",
79
- "disclaimer_type": disclaimer["type"],
80
- "severity": disclaimer["severity"]
81
- })
82
-
83
- # Update channel risks
84
- for channel in report["channel_risks"]:
85
- risk_score = compliance_rules.calculate_risk_score([], [warning], region)
86
- report["channel_risks"][channel]["score"] += risk_score
87
- report["channel_risks"][channel]["details"].append(
88
- f"Missing {disclaimer['type']} disclaimer affects {channel} risk"
89
- )
90
-
91
- return report
92
-
93
- def analyze_ad_copy(image):
94
- """Main function to analyze ad copy"""
95
- # Extract text from image
96
- text = extract_text_from_image(image)
97
-
98
- # Check compliance
99
- compliance_report = check_compliance(text)
100
-
101
- # Generate readable report
102
- report_text = "Compliance Analysis Report\n\n"
103
- report_text += f"Overall Status: {'✅ Compliant' if compliance_report['compliant'] else '❌ Non-Compliant'}\n\n"
104
-
105
- if compliance_report["violations"]:
106
- report_text += "Violations Found:\n"
107
- for violation in compliance_report["violations"]:
108
- report_text += f"• {violation['region']}: {violation['type']} - '{violation['term']}' (Severity: {violation['severity']})\n"
109
- report_text += "\n"
110
-
111
- if compliance_report["warnings"]:
112
- report_text += "Warnings:\n"
113
- for warning in compliance_report["warnings"]:
114
- report_text += f"• {warning['region']}: {warning['disclaimer_type']} (Severity: {warning['severity']})\n"
115
- report_text += "\n"
116
-
117
- report_text += "Channel Risk Assessment:\n"
118
- for channel, risk_info in compliance_report["channel_risks"].items():
119
- score = risk_info["score"]
120
- risk_level = "Low" if score < 3 else "Medium" if score < 6 else "High"
121
- report_text += f"• {channel.capitalize()}: {risk_level} Risk (Score: {score})\n"
122
- if risk_info["details"]:
123
- for detail in risk_info["details"]:
124
- report_text += f" - {detail}\n"
125
-
126
- return report_text
127
-
128
- # Create Gradio interface with updated parameters
129
- demo = gr.Interface(
130
- fn=analyze_ad_copy,
131
- inputs=gr.Image(
132
- type="numpy",
133
- label="Upload Marketing Material",
134
- height=300,
135
- width=400,
136
- ),
137
- outputs=gr.Textbox(
138
- label="Compliance Report",
139
- lines=10,
140
- max_lines=20
141
- ),
142
- title="Marketing Campaign Compliance Checker",
143
- description="Upload marketing material to check compliance with US (SEC), UK (FCA), and EU financial regulations.",
144
- theme=gr.themes.Default()
145
- )
146
-
147
- # Launch the app
148
  if __name__ == "__main__":
 
149
  demo.launch()
 
1
+ from interface import create_interface
 
 
 
 
 
 
 
2
 
3
+ # Create and launch the interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  if __name__ == "__main__":
5
+ demo = create_interface()
6
  demo.launch()
interface.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from analyzer import analyze_ad_copy
3
+
4
+ def create_interface():
5
+ demo = gr.Interface(
6
+ fn=analyze_ad_copy,
7
+ inputs=gr.Image(
8
+ type="numpy",
9
+ label="Upload Marketing Material",
10
+ height=300,
11
+ width=400,
12
+ ),
13
+ outputs=gr.HTML(
14
+ label="Compliance Report",
15
+ ),
16
+ title="Marketing Campaign Compliance Checker",
17
+ description="Upload marketing material to check compliance with US (SEC), UK (FCA), and EU financial regulations.",
18
+ theme=gr.themes.Default(),
19
+ css="""
20
+ .report-container { font-family: 'Arial', sans-serif; padding: 20px; }
21
+ .status { font-size: 1.2em; margin-bottom: 20px; padding: 10px; border-radius: 5px; }
22
+ .compliant { background-color: #e7f5e7; color: #0d5f0d; }
23
+ .non-compliant { background-color: #fce8e8; color: #c41e3a; }
24
+ .section { margin: 15px 0; }
25
+ .section-title { font-weight: bold; color: #2c3e50; margin: 10px 0; }
26
+ .item { margin: 5px 0 5px 20px; }
27
+ .severity-high { color: #c41e3a; }
28
+ .severity-medium { color: #f39c12; }
29
+ .severity-low { color: #27ae60; }
30
+ .risk-high { background-color: #fce8e8; }
31
+ .risk-medium { background-color: #fff3cd; }
32
+ .risk-low { background-color: #e7f5e7; }
33
+ .channel { margin: 10px 0; padding: 10px; border-radius: 5px; }
34
+ .details { margin-left: 20px; font-size: 0.9em; color: #555; }
35
+ """
36
+ )
37
+ return demo