# TensorFlow Lite Model Informative Classes Generation

## Software installation

Our objective is to construct Python classes that accurately represent the data structures defined within TensorFlow Lite Flatbuffer files. Achieving this requires the following dependencies:
 - The `flatc` compiler: Responsible for generating ***Model Informative Classes*** from the text schema describing the model format.
 - The text schema: Defines the data structure of the model format
 - The Flatbuffer Python library: Serves as the runtime dependency for the generated accessor classes.

Notably, the `flatc` compiler is not available as a prebuilt binary and must be compiled from source. To ensure compatibility, the compiler version must align precisely with the Flatbuffer Python library version installed on the system. A mismatch between these versions can result in generated code that fails due to API inconsistencies. For this work, we use the Flatbuffer Python library version 1.12.0. Therefore, we acquire the source code for the flatc compiler by downloading the GitHub snapshot tagged with version 1.12.0, ensuring version consistency across all components. This setup guarantees functional and reproducible results.

Remark: The latest versions of the Flatbuffer Python library and the `flatc` compiler can be used as well, but consistency between the two versions must be ensured to maintain functionality.

### Install Flatbuffer Python Library

In [None]:
pip install flatbuffers==1.12.0
import flatbuffers



### Build the 'flatc' Compiler

The flatc compiler is required to generate ***Model Informative Classes*** for reading and writing serialized files. As prebuilt binaries are not readily available, the source code for the appropriate version is obtained and compiled directly. This process may take a few minutes.

After successfully building the flatc binary, it should be moved to the `/usr/local/bin` directory to ensure it is readily accessible as a system command.

In [None]:
# Build and install the Flatbuffer compiler.
%cd /content/
!rm -rf flatbuffers*
!curl -L "https://github.com/google/flatbuffers/archive/v1.12.0.zip" -o flatbuffers.zip
!unzip -q flatbuffers.zip
!mv flatbuffers-1.12.0 flatbuffers
%cd flatbuffers
!cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release
!make -j 8
!cp flatc /usr/local/bin/

/content
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   124  100   124    0     0    821      0 --:--:-- --:--:-- --:--:--   815
100 1463k    0 1463k    0     0  2631k      0 --:--:-- --:--:-- --:--:-- 2631k
/content/flatbuffers
-- The C compiler identification is GNU 7.5.0
-- The CXX compiler identification is GNU 7.5.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for strtof_l
-- Looking for strtof_l - found
-- L

### Fetch On-device Model Schema

TFLite model schema that defines the data structures of a model file, is located in the TensorFlow source code and can be accessed at [this repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/schema/schema.fbs). To ensure compatibility, the latest version of the schema must be retrieved directly from the GitHub repository.

In [None]:
%cd /content/
!rm -rf tensorflow
!git clone --depth 1 https://github.com/tensorflow/tensorflow

/content
Cloning into 'tensorflow'...
remote: Enumerating objects: 24788, done.[K
remote: Counting objects: 100% (24788/24788), done.[K
remote: Compressing objects: 100% (17969/17969), done.[K
remote: Total 24788 (delta 9056), reused 11254 (delta 6292), pack-reused 0[K
Receiving objects: 100% (24788/24788), 59.68 MiB | 12.75 MiB/s, done.
Resolving deltas: 100% (9056/9056), done.
Checking out files: 100% (24939/24939), done.


### Generate Model Informative Classes

The `flatc` compiler processes the information defined in the schema and generates Model Informative Classes to enable reading and writing of data within serialized Flatbuffer files. The generated classes are stored in the `tflite` folder. These files define classes, such as `ModelT` within `Model.py`, which encapsulate members that facilitate accessing and modifying the data structures described by the schema.

In [None]:
!flatc --python --gen-object-api tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/schema/schema.fbs

### TFLite Model Reading and Writing

The provided wrapper functions illustrate how to load data from a file, convert it into a `ModelT` Python object for modification, and save the updated object to a new file.

In [None]:
import sys
import Model

def load_model_from_file(model_filename):
  with open(model_filename, "rb") as file:
    buffer_data = file.read()
  model_obj = Model.Model.GetRootAsModel(buffer_data, 0)
  model = Model.ModelT.InitFromObj(model_obj)
  return model

def save_model_to_file(model, model_filename):
  builder = flatbuffers.Builder(1024)
  model_offset = model.Pack(builder)
  builder.Finish(model_offset, file_identifier=b'TFL3')
  model_data = builder.Output()
  with open(model_filename, 'wb') as out_file:
    out_file.write(model_data)

In [None]:
import numpy as np

# Load the pre-trained MobileNetV2 TFLite model as a ModelT object.
model = load_model_from_file('MobileNetV2_cifar10.tflite')

# Iterate over all buffer objects containing weights in the model.
for buffer in model.buffers:
  # Skip buffers that are either empty or contain small data arrays, as these are unlikely to represent significant weights.
  if buffer.data is not None and len(buffer.data) > 1024:
    # Read the weights from the model and cast them to 32-bit floats, as this is
    # the known data type for all weights in this specific model. In a real-world DL app,
    # the data type should be validated using the tensor metadata to ensure correctness.
    original_weights = np.frombuffer(buffer.data, dtype=np.float32)

    # Here is where Model Reweighting can be applied
    munged_weights = np.round(original_weights * (1/0.02)) * 0.02

    # Write the modified weights back into the model.
    buffer.data = munged_weights

# Save the modified model to a new TensorFlow Lite file.
save_model_to_file(model, 'MobileNetV2_cifar10_modified.tflite')