### Deploy any Transformerâ€“based model to DataBricks MLFlow

This notebook shows how to deploy transformer-based model to DataBricks MLFlow and get prediction via batch and online HTTP serving API. 

In [0]:
%pip install sentence-transformers 
%pip install cloudpickle==2.0.0

[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m
[43mNote: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.[0m


In [0]:
import time
import os
import logging

os.environ["DATABRICKS_TOKEN"] = "<your_token_here>"

import cloudpickle
import requests
import pandas as pd
import mlflow
import mlflow.pytorch
from mlflow.models import infer_signature
from mlflow.tracking.client import MlflowClient
from mlflow.entities.model_registry.model_version_status import ModelVersionStatus
from mlflow.utils.environment import _mlflow_conda_env

import sentence_transformers
from sentence_transformers import SentenceTransformer

#### Wrap Transformer model into MLFlow api and log into to MLFlow.

Before deploying the model we need to make it compatible with MLFlow.  
This is done via wrapping a model into `mlflow.pyfunc.PythonModel` subclass, create conda_env and log model to MLFlow.

In [0]:
class TransformerWrapper(mlflow.pyfunc.PythonModel):
    def __init__(self):
        self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L12-v2")

    def predict(self, context, model_input):
        inputs = model_input["text"]
        return self.model.encode(inputs)


with mlflow.start_run() as run:
    pyfunc_model = TransformerWrapper()

    input_schema = mlflow.types.Schema([
        mlflow.types.ColSpec(name="text", type=mlflow.types.DataType.string)
    ])
    output_schema = mlflow.types.Schema([
        mlflow.types.ColSpec(name="embedding", type=mlflow.types.DataType.float)
    ])
    signature = mlflow.models.signature.ModelSignature(
        inputs=input_schema,
        outputs=output_schema,
    )

    conda_env =  _mlflow_conda_env(
        additional_conda_deps=None,
        additional_pip_deps=[
            "cloudpickle=={}".format(cloudpickle.__version__), 
            "torch==1.13.1",
            "sentence-transformers=={}".format(sentence_transformers.__version__),
        ],
        additional_conda_channels=None,
    )

    model_name = "transformer-model"
    log_result = mlflow.pyfunc.log_model(
        model_name,
        python_model=pyfunc_model,
        signature=signature,
        conda_env=conda_env,
    )
    run_id = mlflow.active_run().info.run_id

Uploading artifacts:   0%|          | 0/5 [00:00<?, ?it/s]

#### Register model in MLFlow model registry

In [0]:
model_details = mlflow.register_model(model_uri=log_result.model_uri, name=model_name)

Registered model 'transformer-model' already exists. Creating a new version of this model...
2023/11/14 11:30:31 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation. Model name: transformer-model, version 28
Created version '28' of model 'transformer-model'.


In [0]:
def wait_until_ready(model_name, model_version):
    client = MlflowClient()
    for _ in range(10):
        model_version_details = client.get_model_version(
            name=model_name,
            version=model_version,
        )
        status = ModelVersionStatus.from_string(model_version_details.status)
        print("Model status: %s" % ModelVersionStatus.to_string(status))
        if status == ModelVersionStatus.READY:
            break
        time.sleep(1)

wait_until_ready(model_details.name, model_details.version)

Model status: READY


#### Move model to "Production" stage. 

This step is optional. We can also use non-Production models. For example non-production models can be used via uris like `models:/transformer-model/<model_version>`, whereas production models can be used via `odels:/transformer-model/production`. 

Archiving previous version is also optional. If you don't have previous version or want to deploy multiple version, feel free to comment this step out. 

In [0]:
client = MlflowClient()

client.transition_model_version_stage(
  name=model_details.name,
  version=str(int(model_details.version) - 1),
  stage="Archived"
)

client.transition_model_version_stage(
    name=model_details.name,
    version=model_details.version,
    stage='Production',
)


<ModelVersion: aliases=[], creation_timestamp=1699961431238, current_stage='Production', description='', last_updated_timestamp=1699961443390, name='transformer-model', run_id='730460b252fc449081aed709316d075b', run_link='', source='dbfs:/databricks/mlflow-tracking/150882855439300/730460b252fc449081aed709316d075b/artifacts/transformer-model', status='READY', status_message='', tags={}, user_id='6738268205827713', version='28'>

In [0]:
model_version_details = client.get_model_version(
    name=model_details.name,
    version=model_details.version,
)
model_version_details.current_stage

'Production'

#### Get production model

In [0]:
model = mlflow.pyfunc.load_model(f"models:/{model_name}/production")

Downloading artifacts:   0%|          | 0/5 [00:00<?, ?it/s]

 - torch (current: 2.0.1+cpu, required: torch==1.13.1)
To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.


#### Compute prediction 

To make sure model works as expected

In [0]:
model.predict(pd.DataFrame({"text": ["Hello", "world"]}))

array([[-7.49069974e-02, -1.44330626e-02,  4.86498363e-02,
        -2.71380823e-02, -4.30882052e-02, -1.47763416e-01,
         6.88330382e-02, -1.37038296e-02, -5.35686500e-02,
         2.69752834e-02, -6.28339546e-03, -3.98834497e-02,
         7.65673770e-03, -3.78089435e-02, -1.17557272e-02,
        -3.46409529e-02,  1.29357114e-01, -2.52778418e-02,
        -1.52099021e-02,  7.30297179e-04, -8.06888193e-02,
         2.69378778e-02, -9.87357870e-02, -3.41076069e-02,
        -2.70293076e-02, -7.32003748e-02,  5.08588217e-02,
        -1.72563046e-02,  7.28218630e-02, -1.23288728e-01,
        -1.09712547e-02,  1.54055757e-02,  1.61925405e-01,
        -3.71392109e-02, -1.72857773e-02,  2.61619966e-02,
         7.30954856e-02, -4.66715954e-02, -4.06146385e-02,
         5.91708459e-02,  1.65670700e-02,  7.34634837e-03,
        -4.89753895e-02,  6.89893588e-02,  4.28648442e-02,
        -7.57247433e-02, -1.50943510e-02,  4.19088565e-02,
         5.49149886e-02, -5.63954599e-02, -1.02605252e-0

#### Call http inference endpoint 

This option will only be awailable after enabling serving and waiting for model to be deployed. 

In [0]:
response = requests.post(
    '<your_databricks_workspace_uri>/model/transformer-model/Production/invocations',
    headers={
        'Authorization': f'Bearer {os.environ.get("DATABRICKS_TOKEN")}',
        'Content-Type': 'application/json',
    },
    json={"inputs": {"text": ["hello", "world"]}},
)

In [0]:
response.status_code

200

In [0]:
response.json()

{'predictions': [[-0.07490702718496323,
   -0.01443308126181364,
   0.04864983260631561,
   -0.027138084173202515,
   -0.043088190257549286,
   -0.14776340126991272,
   0.06883306056261063,
   -0.013703802600502968,
   -0.05356861650943756,
   0.026975270360708237,
   -0.006283464841544628,
   -0.039883457124233246,
   0.007656700909137726,
   -0.037808965891599655,
   -0.011755705811083317,
   -0.034640949219465256,
   0.12935706973075867,
   -0.025277869775891304,
   -0.015209854580461979,
   0.000730327214114368,
   -0.08068884164094925,
   0.02693787030875683,
   -0.0987357422709465,
   -0.034107550978660583,
   -0.027029285207390785,
   -0.07320033013820648,
   0.050858885049819946,
   -0.017256317660212517,
   0.07282191514968872,
   -0.12328874319791794,
   -0.010971251875162125,
   0.015405580401420593,
   0.16192536056041718,
   -0.03713918477296829,
   -0.017285751178860664,
   0.026161983609199524,
   0.07309548556804657,
   -0.04667157679796219,
   -0.040614672005176544,
  