Search code examples
pythonflaskdownloadplotly-dash

Button to download a file with Plotly Dash


I've built a Plotly Dash app that allows users to navigate through directories and download files. The files are .log files and are converted into .csv format before download.

The issue I'm facing is with the download functionality. When I first click the download button, it downloads the previously requested file (or first time it will download html page instead). Only when I click the download button for the second time, it downloads the correct file.

Here's the code, where file_path is the path to the log file to be converted and downloaded (note update_download_link callback is the one that does not work correctly):

import datetime
import os
from pathlib import Path
import dash_bootstrap_components as dbc
import pandas as pd
from dash import ALL, Dash, Input, Output, State, callback_context, html, dcc
from dash.exceptions import PreventUpdate
from icons import icons
import io
import time
import uuid


def serve_layout():
    app_layout = html.Div([
        html.Link(
            rel="stylesheet",
            href="https://cdnjs.cloudflare.com/ajax/libs/github-fork-ribbon-css/0.2.3/gh-fork-ribbon.min.css"),
        html.Br(), html.Br(),
        dbc.Row([
            dbc.Col(lg=1, sm=1, md=1),
            dbc.Col([
                dcc.Store(id='stored_cwd', data=os.getcwd()),
                html.H1('File Browser'),
                html.Hr(), html.Br(), html.Br(), html.Br(),
                html.H5(html.B(html.A("⬆️ Parent directory", href='#',
                                    id='parent_dir'))),
                html.H3([html.Code(os.getcwd(), id='cwd')]),
                        html.Br(), html.Br(),
                html.Div(id='cwd_files',
                        style={'height': 500, 'overflow': 'scroll'}),
            ], lg=10, sm=11, md=10)
        ]),
        dcc.Download(id="download"),
        html.A(
            "Download CSV",
            id="download_csv",
            className="btn btn-outline-secondary btn-sm",
            href="",
            download=""
        )
    ] + [html.Br() for _ in range(15)])

    return app_layout



@app.callback(
    Output('cwd', 'children'),
    Input('stored_cwd', 'data'),
    Input('parent_dir', 'n_clicks'),
    Input('cwd', 'children'),
    prevent_initial_call=True)
def get_parent_directory(stored_cwd, n_clicks, currentdir):
    triggered_id = callback_context.triggered_id
    if triggered_id == 'stored_cwd':
        return stored_cwd
    parent = Path(currentdir).parent.as_posix()
    return parent


@app.callback(
    Output('cwd_files', 'children'),
    Input('cwd', 'children'))
def list_cwd_files(cwd):
    path = Path(cwd)
    all_file_details = []
    if path.is_dir():
        files = sorted(os.listdir(path), key=str.lower)
        for i, file in enumerate(files):
            filepath = Path(file)
            full_path=os.path.join(cwd, filepath.as_posix())
            is_dir = Path(full_path).is_dir()
            link = html.A([
                html.Span(
                file, id={'type': 'listed_file', 'index': i},
                title=full_path,
                style={'fontWeight': 'bold', 'fontSize': 18} if is_dir else {}
            )], href='#')
            details = file_info(Path(full_path))
            details['filename'] = link
            if is_dir:
                details['extension'] = html.Img(
                    src=app.get_asset_url('icons/default_folder.svg'),
                    width=25, height=25)
            else:
                details['extension'] = icon_file(details['extension'][1:])
            all_file_details.append(details)

    df = pd.DataFrame(all_file_details)
    df = df.rename(columns={"extension": ''})
    table = dbc.Table.from_dataframe(df, striped=False, bordered=False,
                                    hover=True, size='sm')
    return html.Div(table)



@app.callback(
    Output('stored_cwd', 'data'),  # note the change here
    Input({'type': 'listed_file', 'index': ALL}, 'n_clicks'),
    State({'type': 'listed_file', 'index': ALL}, 'title'))
def store_clicked_file(n_clicks, title):
    if not n_clicks or set(n_clicks) == {None}:
        raise PreventUpdate
    ctx = callback_context
    index = ctx.triggered_id['index']
    file_path = title[index]
    return file_path  # always returning the file path now



@app.callback(
    Output('download_csv', 'href'),
    Output('download_csv', 'download'),
    Input('stored_cwd', 'data'),
    Input('download_csv', 'n_clicks'),
    prevent_initial_call=True
)
def update_download_link(file_path, n_clicks):
    # when there is no click, do not proceed
    if n_clicks is None:
        raise PreventUpdate
    
    if file_path.endswith(".log"):
        with open(file_path, "r") as f:
            log_content = f.read()
        csv_data = import__(log_content)

        temp_filename = save_file(csv_data)  
        
        # delay and then rename the temp file
        time.sleep(10)
        filename = f'{uuid.uuid1()}.csv'
        os.rename(os.path.join('downloads', temp_filename), os.path.join('downloads', filename))
        
        download_link = f'/download_csv?value={filename}'
        return download_link, filename
    else:
        return "#", ""  

I am using temp_filename because without it files bigger than 1mb does not getting downloaded at all for some reason.

helper functions:

def import__(file_content):
    # Convert the file content string to a StringIO object
    file_io = io.StringIO(file_content)

    # Split the file content into lines
    lines = file_content.splitlines()

    # Search for the header row number
    headerline = 0
    for n, line in enumerate(lines):
        if "Header" in line:
            headerline = n
            break

    # Go back to the start of the StringIO object before reading with pandas
    file_io.seek(0)

    # Read the content using pandas
    # Use the StringIO object (file_io) and set the 'skiprows' parameter
    data = pd.read_csv(file_io, sep='|', header = headerline) # header=None, skiprows=headerline)
    data = data.drop(data.index[-1])
    return data


def save_file(df):
    """Save DataFrame to a .csv file and return the file's name."""
    filename = f'{uuid.uuid1()}.csv'
    filepath = os.path.join('downloads', filename)  # assuming the script has permission to write to this location
    print(f"Saving to {filepath}")
    df.to_csv(filepath, index=False)
    return filename

also Flask API is:

@app.server.route('/download_csv')
def download_csv():
    """Provide the DataFrame for csv download."""
    value = request.args.get('value')
    file_path = os.path.join('downloads', value) # Compute the file path
    df = pd.read_csv(file_path) # Read the CSV data
    csv = df.to_csv(index=False, encoding='utf-8') # Convert DataFrame to CSV

    # Create a string response
    return Response(
        csv,
        mimetype="text/csv",
        headers={"Content-disposition": f"attachment; filename={value}"}
    )

Here are screenshots:

1

1

2

2

3

3

4

4

5

5

I'm not sure why the file ready for download is always one step behind. I put some sort of delay time.sleep(10) to ensure the file write operation is completed before the download begins, but it does not work.

Is there any way I can ensure that the correct file is downloaded on the first button click?


Solution

  • Here is a solution extending your app code, based exactly on the helpful insight from @EricLavault:

    1. Remove download_csv.n_clicks from the callback inputs
    2. Update the update_download_link callback to respond to changes in stored_cwd.data
    3. Move the .log to .csv conversion logic and file copying/renaming (i.e., saving to ./downloads) to the download_csv Flask endpoint

    Having implemented those changes, the download link updates as soon as a new file is selected, ensuring that the correct file is downloaded upon first button click. Demo below:

    """Main module for a Dash app providing file browsing and
       CSV download functionality.
    
    Attributes:
        app (Dash): The main Dash app instance.
    """
    import datetime
    import os
    import time
    import uuid
    
    from pathlib import Path
    
    import dash_bootstrap_components as dbc
    import io
    import pandas as pd
    
    from dash import ALL, Dash, Input, Output, State
    from dash import callback_context
    from dash import dcc, html
    from dash.exceptions import PreventUpdate
    
    from flask import Response
    from flask import request
    
    
    # Ensure the 'downloads' directory exists
    if not os.path.exists("downloads"):
        os.makedirs("downloads")
    
    
    app = Dash(external_stylesheets=[dbc.themes.BOOTSTRAP])
    
    
    def generate_log_files(num_files=5):
        """Generate mock log files for testing.
    
        Args:
            num_files (int, optional): Number of log files to generate.
                                       Defaults to 5.
        """
        for i in range(num_files):
            with open(f"test_{time.time_ns()}_{i}.log", "w") as f:
                f.write("Header1|Header2|Header3\n")
                for j in range(10):
                    f.write(
                        f"Value1_{i}_{j}|Value2_{uuid.uuid1()}|{uuid.uuid1()}\n"
                    )
    
    
    def serve_layout():
        """Generate the main layout for the Dash app.
    
        Returns:
            html.Div: The main layout Div containing the app components.
        """
        app_layout = html.Div(
            [
                html.Link(
                    rel="stylesheet",
                    href="https://cdnjs.cloudflare.com/ajax/libs/github-fork-ribbon-css/0.2.3/gh-fork-ribbon.min.css",
                ),
                html.Br(),
                html.Br(),
                dbc.Row(
                    [
                        dbc.Col(lg=1, sm=1, md=1),
                        dbc.Col(
                            [
                                dcc.Store(id="stored_cwd", data=os.getcwd()),
                                html.H1("File Browser"),
                                html.Hr(),
                                html.Br(),
                                html.H5(
                                    html.B(
                                        html.A(
                                            "⬆️ Parent directory",
                                            href="#",
                                            id="parent_dir",
                                        )
                                    )
                                ),
                                html.H3([html.Code(os.getcwd(), id="cwd")]),
                                html.Br(),
                                html.Br(),
                                html.Div(
                                    id="cwd_files",
                                    style={"height": 300, "overflow": "scroll"},
                                ),
                            ],
                            lg=10,
                            sm=11,
                            md=10,
                        ),
                    ]
                ),
                html.Br(),
                dcc.Download(id="download"),
                html.Div(
                    [
                        html.A(
                            "Download CSV",
                            id="download_csv",
                            className="btn btn-outline-secondary btn-sm",
                            href="",
                            download="",
                        )
                    ],
                    style={"textAlign": "center"},
                ),
            ]
            + [html.Br() for _ in range(5)],
        )
    
        return app_layout
    
    
    @app.callback(
        Output("cwd", "children"),
        Input("stored_cwd", "data"),
        Input("parent_dir", "n_clicks"),
        Input("cwd", "children"),
        prevent_initial_call=True,
    )
    def get_parent_directory(stored_cwd, n_clicks, currentdir):
        """Fetch the parent directory based on the current directory or the stored directory.
    
        Args:
            stored_cwd (str): The stored current working directory.
            n_clicks (int): Number of times the parent directory link was clicked.
            currentdir (str): The current directory being displayed.
    
        Returns:
            str: The parent directory path.
        """
        triggered_id = callback_context.triggered_id
        if triggered_id == "stored_cwd":
            return stored_cwd
        parent = Path(currentdir).parent.as_posix()
        return parent
    
    
    @app.callback(Output("cwd_files", "children"), Input("cwd", "children"))
    def list_cwd_files(cwd):
        """List the files in the provided directory and generate a table for display.
    
        Args:
            cwd (str): The current directory whose files are to be listed.
    
        Returns:
            html.Div: A Div containing the table of files in the directory.
        """
        path = Path(cwd)
        all_file_details = []
        if path.is_dir():
            files = sorted(os.listdir(path), key=str.lower)
            for i, file in enumerate(files):
                filepath = Path(file)
                full_path = os.path.join(cwd, filepath.as_posix())
                is_dir = Path(full_path).is_dir()
                link = html.A(
                    [
                        html.Span(
                            file,
                            id={"type": "listed_file", "index": i},
                            title=full_path,
                            style={"fontWeight": "bold", "fontSize": 18}
                            if is_dir
                            else {},
                        )
                    ],
                    href="#",
                )
                details = {"filename": link}
                all_file_details.append(details)
        df = pd.DataFrame(all_file_details)
        table = dbc.Table.from_dataframe(
            df, striped=False, bordered=False, hover=True, size="sm"
        )
        return html.Div(table)
    
    
    @app.callback(
        Output("stored_cwd", "data"),
        Input({"type": "listed_file", "index": ALL}, "n_clicks"),
        State({"type": "listed_file", "index": ALL}, "title"),
    )
    def store_clicked_file(n_clicks, title):
        """Store the path of the clicked file.
    
        Args:
            n_clicks (list[int]): List of click counts for each file.
            title (list[str]): List of file paths.
    
        Returns:
            str: Path of the clicked file.
    
        Raises:
            PreventUpdate: Raised if no file has been clicked.
        """
        if not n_clicks or set(n_clicks) == {None}:
            raise PreventUpdate
        ctx = callback_context
        index = ctx.triggered_id["index"]
        file_path = title[index]
        return file_path
    
    
    @app.callback(
        Output("download_csv", "href"),
        Output("download_csv", "download"),
        Input("stored_cwd", "data"),
        prevent_initial_call=True,
    )
    def update_download_link(file_path):
        """Update the download link for the provided file path.
    
        Args:
            file_path (str): Path of the selected file.
    
        Returns:
            tuple: A tuple containing the download link and filename.
        """
        if file_path.endswith(".log"):
            download_link = f"/download_csv?file_path={file_path}"
            filename = f"{uuid.uuid1()}.csv"
            return download_link, filename
        else:
            return "#", ""
    
    
    @app.server.route("/download_csv")
    def download_csv():
        """Provide the DataFrame for CSV download from a log file.
    
        Returns:
            Response: A Flask response containing the CSV data.
        """
        file_path = request.args.get("file_path")
    
        with open(file_path, "r") as f:
            log_content = f.read()
        csv_data = import__(log_content)
    
        filename = save_file(csv_data)
    
        df = pd.read_csv(os.path.join("downloads", filename))
        csv = df.to_csv(index=False, encoding="utf-8")
    
        return Response(
            csv,
            mimetype="text/csv",
            headers={"Content-disposition": f"attachment; filename={filename}"},
        )
    
    
    def import__(file_content):
        """Convert log file content into a Pandas DataFrame.
    
        Args:
            file_content (str): Content of the log file.
    
        Returns:
            pd.DataFrame: DataFrame containing the parsed content of the log file.
        """
        file_io = io.StringIO(file_content)
        lines = file_content.splitlines()
        headerline = 0
        for n, line in enumerate(lines):
            if "Header" in line:
                headerline = n
                break
        file_io.seek(0)
        data = pd.read_csv(file_io, sep="|", header=headerline)
        data = data.drop(data.index[-1])
        return data
    
    
    def save_file(df):
        """Save DataFrame to a .csv file and return the file's name.
    
        Args:
            df (pd.DataFrame): DataFrame to save as CSV.
    
        Returns:
            str: Filename where the DataFrame was saved.
        """
        filename = f"{uuid.uuid1()}.csv"
        filepath = os.path.join("downloads", filename)
        df.to_csv(filepath, index=False)
        return filename
    
    
    app.layout = serve_layout
    
    if __name__ == "__main__":
        generate_log_files(1)
        app.run_server(debug=True, dev_tools_hot_reload=True)
    
    

    produces, for example, app functionality like:

    Video screenshot recording of demo Dash app functionality showing file browsing and csv download