Search code examples
validationflaskwtforms

WTForms and Flask - Validate file size in another field


I have a structure like this in my Flask app:

class UploadForm(FlaskForm):
        username = StringField('Username',
                       validators=[DataRequired(),
                                   Length(min=2, max=20)])
        email = StringField('Email address',
                    validators=[DataRequired(),
                                Email()])
        domain_name = StringField('Project name',
                          validators=[DataRequired()])
        dir_level = IntegerField(
            'Folder level', validators=[DataRequired()])
        file = FileField('Upload textfile', validators=[
                         FileRequired(), FileAllowed(['txt'])])
        scan_now = BooleanField('Scan website now')
        submit = SubmitField('Submit')

The scan_now field submits the project into a different, faster queue and it is used for priority reports.

Given that scan_now is more computationally expensive, I would like to limit its usage only for .txt files that have less than 500 rows, and raise error if a user checks the box but has uploaded more than 500 rows.

How can I do that?

Thanks you!

EDIT:

Since you have requested my manager views file, here's the logic behind it:

@manager.route("/manage_url", methods=["GET", "POST"])
@auth.login_required
def manage_url():
    form = UploadForm()
    config = []
    if request.method == "POST" and form.validate_on_submit():
        date = datetime.today()
        date = date.strftime('%d/%m/%Y')

        file = request.files['file']
        filename = file.filename
        email = request.form['email']
        username = request.form['username']
        livello_dir = request.form['livello_dir']
        domain_name = request.form['domain_name']
        try:
            scan_now = request.form['scan_now']
        except:
            scan_now = False
        file.save(os.path.join(uploads_dir, secure_filename(filename)))
        # crea file config da leggere
        config_entry = {
            filename:
            {
                "date": date,
                "email": email,
                "user": username,
                "domain_name": domain_name.replace(" ", "_"),
                "livello_dir": livello_dir
            }
        }
        config.append(config_entry)
        if not os.path.exists('./config.json'):
            with open(os.path.join(ROOT_DIR, 'config.json'), 'w', encoding='utf-8') as f:
                json.dump(config, f, ensure_ascii=False, indent=4)
        else:
            with open(os.path.join(ROOT_DIR, 'config.json'), 'r') as f:
                config = json.load(f)
            config.append(config_entry)
            with open(os.path.join(ROOT_DIR, 'config.json'), 'w', encoding='utf-8') as f:
                json.dump(config, f, ensure_ascii=False, indent=4)

        if scan_now:
            def on_the_fly_scan():
                executor = Executor()
                executor.start_process(file=filename)
            thread = Thread(target=on_the_fly_scan)
            thread.daemon = True
            thread.start()
        return redirect(url_for('manager.manage_url'))

    try:
        files = os.listdir(os.path.join(app.instance_path, 'uploads'))
        paths = [os.path.join(app.instance_path, 'uploads') +
                 "/" + file for file in files]
    except:
        files = None
        paths = None

    domini = []
    contact = ""
    operator_name = ""
    livello_dir = ""

    for file in paths:
        with open(file, 'r', encoding='unicode_escape') as csv_file:
            csv_reader = csv.reader(csv_file, delimiter='\t')
            rows = list(csv_reader)

            dominio_rilevato = tldextract.extract(rows[1][0]).registered_domain
            domini.append(dominio_rilevato)

    with open(os.path.join(ROOT_DIR, 'config.json'), 'r') as f:
        data = json.load(f)
    content_package = list(zip(domini, files))

    return render_template("manage_url.html",
                           content_package=content_package,
                           data=data,
                           config_entries=list(list(d.keys()) for d in data),
                           form=form,
                           contact=contact,
                           operator_name=operator_name,
                           livello_dir=livello_dir)

Solution

  • You can Validate file size manually before saving the file. add these lines before file.save(os.path.join(uploads_dir, secure_filename(filename))) :

    if scan_now:
        if count_rows(filename) > 500:
            raise Exception('too many rows')
            # or whatever you need to do
    

    And here's the count_rows() code :

    def count_rows(filename):
        with open(filename, 'r') as f:
            return sum(1 for row in f)
            # if file is in csv format return sum() - 1
            # because first line is not a record
            return sum(1 for row in f) - 1