#!/usr/bin/env python

"""
For a given list of topics and directory names, save the image and
point cloud messages for that topic to the specified directories. The
created file names will be in the form: <timestamp>.<extension>, with
the timestamp being the double-precision number of seconds since epoch
read from message header. Can specify a list of timestamps for which
to extract the data.
"""

import argparse, os, re, sys, rosbag, cv2, sensor_msgs, cv_bridge
from sensor_msgs import point_cloud2
import numpy as np

def read_timestamps(filename):
    """Return a map of timestamps."""

    timestamp_map = {}
    with open(filename, 'r') as f:
        lines = f.readlines()

    for line in lines:
        m = re.match("^.*?(\d+\.\d+).*?\n", line)
        if m:
            timestamp = float(m.group(1))
            timestamp_map[timestamp] = line
            
    return timestamp_map

def sanity_checks(args, topic_list, dir_list):
    # Sanity checks
    if args.bag == "":
        print("Must specify the input bag.")
        sys.exit(1)
    if len(topic_list) != len(dir_list):
        print("There must exists as many topics as output directories.")
        sys.exit(1)
    if len(topic_list) == 0:
        print("Must specify at least one topic and output directory.")
        sys.exit(1)
    if args.timestamp_tol is not None and args.timestamp_tol <= 0:
        print("The timestamp tolerance must be positive.")
        sys.exit(1)
    if args.approx_timestamp and args.timestamp_tol is None:
        print("When saving with an approximate timestamp, must specify the tolerance.")
        sys.exit(1)

def find_closest_messages_within_tol(args, topic_list, timestamp_map):
    """
    Given a list of timestamps and a tolerance, find, for each input
    topic, the closest message in the bag to every timestamp in the
    list, within given tolerance.
    """
    print("Doing a first pass through the bag to find the closest messages to "
          "desired ones, with tolerance.")

    # The timestamps close to which need to find data in the bag, Must
    # be sorted.
    req_timestamps = sorted(list(timestamp_map.keys()))
    min_list_time = req_timestamps[0]
    max_list_time = req_timestamps[-1]

    # This will travel forward in time, and we assume the messages travel
    # forward in time too.
    # Must depend on topic!
    req_index = {}
    closest_stamps = {}
    first_good = {}
    for topic in topic_list:
        req_index[topic] = 0
        first_good[topic] = -1
        # Will keep the result here, starts as an array with values smaller
        # than any timestamp
        closest_stamps[topic] = [-1000.0] * len(req_timestamps)
    
    with rosbag.Bag(args.bag, "r") as bag:
        
        # Check image message type
        for topic, msg, t in bag.read_messages(topic_list):
            
            # Read the header timestamp
            try:
                # Note that we search the bag exhaustively. We do not assume
                # timestamps are in increasing order of time. Sometimes
                # that assumption can be violated.
                stamp = msg.header.stamp.to_sec()
                if stamp < min_list_time - args.timestamp_tol:
                    continue # too early
                if stamp > max_list_time + args.timestamp_tol:
                    break # past the desired times

                # See if this timestamp is closer to any of the req_timestamps
                # than the existing candidates. Note that we stop as soon as
                # we can to keep the overall complexity linear and not quadratic.
                first_good[topic] = -1
                for it in range(req_index[topic], len(req_timestamps)):
                    if req_timestamps[it] > stamp + args.timestamp_tol:
                        break # no point in continuing

                    curr_diff = abs(stamp - req_timestamps[it])
                    prev_diff = abs(closest_stamps[topic][it] - req_timestamps[it])
                    if curr_diff < args.timestamp_tol and curr_diff < prev_diff:
                        closest_stamps[topic][it] = stamp

                        if first_good[topic] < 0:
                            first_good[topic] = 1
                            # Found the first good fit index. There won't be good
                            # fits going forward for indices to the left of this. 
                            req_index[topic] = it
            except:
                continue

    # Put in a set and print some stats
    closest_stamps_set = {}
    exact_to_approx = {}
    for topic in topic_list:
        print("Topic is " + str(topic))
        closest_stamps_set[topic] = set()
        exact_to_approx[topic] = {}
        for it in range(len(closest_stamps[topic])):
            diff = abs(closest_stamps[topic][it] - req_timestamps[it])
            msg = ""
            if diff > args.timestamp_tol:
                diff = "-1"
                msg = " (failed)"
            else:
                closest_stamps_set[topic].add(closest_stamps[topic][it])
                exact_to_approx[topic][closest_stamps[topic][it]] = req_timestamps[it] 
            print("For timestamp " + str(req_timestamps[it]) + \
                  ", closest found message is within " + str(diff) + " seconds" + msg)
        
    return (closest_stamps_set, exact_to_approx)
            
parser = argparse.ArgumentParser(description = __doc__,
                                 formatter_class = argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--bag", dest = "bag", help = "Input bag."),
parser.add_argument("--topics", dest = "topics", default = "",  
                    help = "A list of topics, in quotes, having image, compressed " +
                    "image, or point cloud (depth) data.")
parser.add_argument("--dirs", dest = "dirs", default = "",
                    help = "A list of directories, in quotes, one for each topic, " +
                    "in which to save the data for that topic.")

parser.add_argument("--timestamp_list", dest = "timestamp_list", default = "",  
                    help = "Extract data for the timestamps in this list. " +
                    "If not set, extract all data. Each line in " +
                    "this file must contain a number of the form " +
                    "<digits>.<digits> (the timestamp), and perhaps other " +
                    "text as well, or it will be ignored. So, a " +
                    "filename containing a timestamp as part of " +
                    "its name will be accepted.")

parser.add_argument("--timestamp_tol", dest = "timestamp_tol", type = float,
                    default = None,  
                    help = "If set, extract the data for each of the given topics " +
                    "whose timestamps are closest to the " +
                    "ones in the input list, within this " +
                    "tolerance, in seconds. This should be kept small. It is assumed " +
                    "the bag stores the data for each topic in increasing value of " +
                    "timestamp.")

parser.add_argument('--approx_timestamp', dest='approx_timestamp', action='store_true',
                    help = "If using --timestamp_tol, change the timestamp of " +
                    "the data being saved (which becomes part of the output filename) " +
                    "to be the closest timestamp in the input list.")

args = parser.parse_args()

topic_list = args.topics.split()
dir_list = args.dirs.split()
sanity_checks(args, topic_list, dir_list)

# The map of all required timestamps, if specified
timestamp_map = {}
if args.timestamp_list != "":
    timestamp_map = read_timestamps(args.timestamp_list)
    if len(timestamp_map) == 0:
        print("No timestamps read from " + args.timestamp_list)
        sys.exit(1)

if args.timestamp_tol > 0 and  len(timestamp_map) == 0:
    print("When using a timestamp tolerance, an input list must be provided.")
    sys.exit(1)
    
# TODO(oalexan1): Make this a function.
# Map from topic name to dir name.
topic_to_dir = {}
for it in range(len(topic_list)):
    topic = topic_list[it]
    path  = dir_list[it]
    if topic in topic_to_dir:
        print("Duplicate topic: " + topic)
        sys.exit(1)
    topic_to_dir[topic] = path

# TODO(oalexan1): Make this a function.
# Create directories
for path in dir_list:
    try:
        os.makedirs(path)
    except OSError:
        if os.path.isdir(path):
            pass
        else:
            raise Exception("Could not make directory: " + path)

if args.timestamp_tol > 0:
    (closest_stamps_set, exact_to_approx)\
                         = find_closest_messages_within_tol(args, topic_list, timestamp_map)

if args.approx_timestamp:
    print("Replacing on saving the actual timestamps with the closest from the list.")
    
# TODO(oalexan1): Make this a function
# Read the bag
print("Reading: " + args.bag)
print("Writing the data to: " + args.dirs)

cv_bridge = cv_bridge.CvBridge()

# The timestamp format is consistent with what rig_calibrator uses
fmt_str = "{:10.7f}"

# Write the data
with rosbag.Bag(args.bag, "r") as bag:

    info = bag.get_type_and_topic_info()

    # Check image message type
    for topic, msg, t in bag.read_messages(topic_list):

        # Read the header timestamp
        try:
            # Note that we search the bag exhaustively. We do not assume
            # timestamps are in increasing order of time. Sometimes
            # that assumption can be violated.
            stamp = msg.header.stamp.to_sec()
            if len(timestamp_map) > 0:
                # When extracting only a subset of the timestamps
                if args.timestamp_tol is None:
                    # exact timestamp
                    if stamp not in timestamp_map:
                        continue
                else:
                    # Nearby timestamp
                    if stamp not in closest_stamps_set[topic]:
                        continue
        except:
            continue

        if args.approx_timestamp:
            # Replacing with the apporox one
            approx_stamp = exact_to_approx[topic][stamp]
            print("Saving data for timestamp " + str(stamp) + " as if it were for " + \
                  "timestamp " + str(approx_stamp))
            stamp = approx_stamp
            
        if info.topics[topic].msg_type == "sensor_msgs/Image":
            # Write an image
            try:
                
                filename = topic_to_dir[topic] + "/" + fmt_str.format(stamp) + ".jpg"
                cv_image = cv_bridge.imgmsg_to_cv2(msg, desired_encoding = 'passthrough')
                if len(cv_image.shape) == 2:
                    # Try again, as mono. There should be a better way.
                    cv_image = cv_bridge.imgmsg_to_cv2(msg, desired_encoding = 'mono8')
                print("Writing: " + filename)
                cv2.imwrite(filename, cv_image)
            except:
                print("Failed to write: " + filename)

        if info.topics[topic].msg_type == "sensor_msgs/CompressedImage":
            # Write a compressed image. It is assumed to be in color.
            try:
                filename = topic_to_dir[topic] + "/" + fmt_str.format(stamp) + ".jpg"
                cv_image = cv_bridge.compressed_imgmsg_to_cv2(msg, desired_encoding
                                                              = 'passthrough')
                print("Writing: " + filename)
                cv2.imwrite(filename, cv_image)
            except:
                print("Failed to write: " + filename)

        if info.topics[topic].msg_type == "sensor_msgs/PointCloud2":
            try:
                # Write a point cloud in the format at:
                # https://stereopipeline.readthedocs.io/en/latest/tools/rig_calibrator.html
                # (section on point cloud format). 
                # TODO(oalexan1): Test this!
                filename = topic_to_dir[topic] + "/" + fmt_str.format(stamp) + ".pc"
                print("Writing: " + filename)
                fh = open(filename, "wb")

                vals = np.array([msg.height, msg.width, 3], dtype=np.int32)
                fh.write(vals.tobytes())

                # TODO(oalexan1): This iteration may be slow.
                cloud_points = list(
                    point_cloud2.read_points(msg,
                                             skip_nans = False, # read all
                                             field_names = ("x", "y", "z")))
                
                for point in cloud_points:
                    vals = np.array([point[0], point[1], point[2]], dtype=np.float32)
                    fh.write(vals.tobytes())

            except Exception as e:
                print("Failed to write: ", filename)
                print("Error is ", str(e))
        
            
            
