Donut - model fine-tuned for US IRS tax documents classification
This donut model has been fine-tuned for IRS (US) tax document classification. It can classify up to 28 different types of IRS documents, targeting common set of documents used for tax returns.
- 1040 U.S. Individual Income Tax Return
- 1040-NR U.S. Nonresident Alien Income Tax Return
- 1040-NR SCHEDULE OI Other Information
- 1040 SCHEDULE 1 Additional Income and Adjustments to Income
- 1040 SCHEDULE 2 Additional Taxes
- 1040 SCHEDULE 3 Additional Credits and Payments
- 1040 SCHEDULE 8812 Credits for Qualifying Children and Other Dependents
- 1040 SCHEDULE A Itemized Deductions
- 1040 SCHEDULE B Interest and Ordinary Dividends
- 1040 SCHEDULE C Profit or Loss From Business
- 1040 SCHEDULE D Capital Gains and Losses
- 1040 SCHEDULE E Supplemental Income and Loss
- 1040 SCHEDULE SE Self-Employment Tax
- Form 1125-A Cost of Goods Sold
- Form 8949 Sales and Other Dispositions of Capital Assets
- Form 8959 Additional Medicare Tax
- Form 8960 Net Investment Income Tax — Individuals, Estates, and Trusts
- Form 8995 Qualified Business Income Deduction Simplified Computation
- Form 8995-A SCHEDULE A Specified Service Trades or Businesses
- Form W-2 Wage and Tax Statement
Model Details & Description
The base model is 'naver-clova-ix/donut-base-finetuned-rvlcdip', the model is finetuned using training data set of over 3000+ documents. The config.json file has assocociated label2id updated to reflect all labels that can be classified via the model.
For inference use image size with width: 1920 px and height: 2560 px
Sample Code for Document Inference
# load dependencies
import torch
from transformers import DonutSwinModel, DonutSwinPreTrainedModel,DonutProcessor
from torch import nn
from PIL import Image
#
class DonutForImageClassification(DonutSwinPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.swin = DonutSwinModel(config)
self.dropout = nn.Dropout(0.5)
self.classifier = nn.Linear(self.swin.num_features, config.num_labels)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
outputs = self.swin(pixel_values)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
sModelName = 'hsarfraz/donut-irs-tax-docs-classifier'
processor = DonutProcessor.from_pretrained(sModelName)
model = DonutForImageClassification.from_pretrained(sModelName)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
# load test image
sTestImagePath ='replace this with document image path' # i.e.
# open image
img = Image.open(sTestImagePath)
# resize image to width 1920 and height 2560 - fine tuned model is trained with this width and height
img_new = img.resize((1920,2560),Image.Resampling.LANCZOS)
# perfoem inference
predicted_label = ''
with torch.no_grad():
pixel_values = processor(img_new.convert("RGB"), return_tensors="pt").pixel_values
print(pixel_values.shape)
pixel_values = pixel_values.to(device)
outputs = model(pixel_values)
logits, predicted = torch.max(outputs.data, 1)
pval = predicted.cpu().numpy()[0]
predicted_label = model.config.id2label[pval]
print('---------------------------------- ')
print('Document Image Classification: ',predicted_label)
- Downloads last month
- 26
Model tree for hsarfraz/donut-irs-tax-docs-classifier
Base model
naver-clova-ix/donut-base-finetuned-rvlcdip