(*
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 *)

open! IStd
module L = Logging
open PulseDomainInterface
open PulseOperationResult.Import

let pulse_transitive_access_verbose = Config.pulse_transitive_access_verbose

module Config : sig
  val fieldname_must_be_monitored : Fieldname.t -> bool

  val procname_must_be_monitored : Tenv.t -> Procname.t -> bool

  type context =
    { initial_caller_class_extends: string list
    ; initial_caller_class_does_not_extend: string list
    ; final_class_only: bool
    ; description: string
    ; tag: string }

  val find_matching_context : Tenv.t -> Procname.t -> context option
end = struct
  type procname_to_monitor_spec =
    { class_names: string list option [@yojson.option]
    ; method_names: string list option [@yojson.option]
    ; class_name_regex: string option [@yojson.option] }
  [@@deriving of_yojson]

  type procname_to_monitor =
    | ClassAndMethodNames of {class_names: string list; method_names: string list}
    | ClassNameRegex of {class_name_regex: Str.regexp}

  let procname_to_monitor_of_yojson json =
    match procname_to_monitor_spec_of_yojson json with
    | {class_names= Some class_names; method_names= Some method_names} ->
        ClassAndMethodNames {class_names; method_names}
    | {class_name_regex= Some class_name_regex} ->
        let class_name_regex = Str.regexp class_name_regex in
        ClassNameRegex {class_name_regex}
    | _ ->
        L.die UserError "parsing of transitive-access config has failed:@\n %a" Yojson.Safe.pp json


  type context =
    { initial_caller_class_extends: string list
    ; initial_caller_class_does_not_extend: string list
    ; final_class_only: bool [@yojson.default false]
    ; description: string
    ; tag: string }
  [@@deriving of_yojson]

  type t =
    { fieldnames_to_monitor: string list
    ; procnames_to_monitor: procname_to_monitor list
    ; contexts: context list }
  [@@deriving of_yojson]

  let empty = {fieldnames_to_monitor= []; procnames_to_monitor= []; contexts= []}

  let get, set =
    let current = ref (None : t option) in
    ((fun () -> !current), fun config -> current := Some config)


  let fieldname_must_be_monitored fieldname =
    match get () with
    | None ->
        false
    | Some {fieldnames_to_monitor} ->
        List.exists fieldnames_to_monitor ~f:(String.equal (Fieldname.get_field_name fieldname))


  let procname_must_be_monitored tenv procname =
    let class_name = Procname.get_class_type_name procname in
    let method_name = Procname.get_method procname in
    let match_class_name names =
      Option.exists class_name ~f:(fun class_name ->
          PatternMatch.supertype_exists tenv
            (fun class_name _ ->
              let class_name_string = Typ.Name.name class_name in
              List.exists names ~f:(fun typ -> String.is_substring ~substring:typ class_name_string)
              )
            class_name )
    in
    let regexp_match regexp name =
      match Str.search_forward regexp name 0 with _ -> true | exception Caml.Not_found -> false
    in
    let match_class_name_regex regexp =
      Option.exists class_name ~f:(fun class_name ->
          PatternMatch.supertype_exists tenv
            (fun class_name _ -> regexp_match regexp (Typ.Name.name class_name))
            class_name )
    in
    match get () with
    | None ->
        false
    | Some {procnames_to_monitor} ->
        List.exists procnames_to_monitor ~f:(function
          | ClassAndMethodNames {class_names; method_names} ->
              match_class_name class_names && List.mem ~equal:String.equal method_names method_name
          | ClassNameRegex {class_name_regex} ->
              match_class_name_regex class_name_regex )


  let is_matching_context tenv procname
      {initial_caller_class_extends; initial_caller_class_does_not_extend; final_class_only} =
    match Procname.get_class_type_name procname with
    | Some type_name ->
        let check_final_status () =
          let is_final () =
            Tenv.lookup tenv type_name
            |> Option.exists ~f:(fun {Struct.annots} -> Annot.Item.is_final annots)
          in
          pulse_transitive_access_verbose || (not final_class_only) || is_final ()
        in
        let check_parents () =
          let has_parents =
            let parents =
              Tenv.fold_supers tenv type_name ~init:String.Set.empty ~f:(fun parent _ acc ->
                  String.Set.add acc (Typ.Name.name parent) )
            in
            fun classes -> List.exists classes ~f:(String.Set.mem parents)
          in
          let check_extends = has_parents initial_caller_class_extends in
          let check_does_not_extend () =
            pulse_transitive_access_verbose
            || not (has_parents initial_caller_class_does_not_extend)
          in
          check_extends && check_does_not_extend ()
        in
        check_final_status () && check_parents ()
    | None ->
        false


  let find_matching_context tenv procname =
    let open IOption.Let_syntax in
    let* {contexts} = get () in
    List.find contexts ~f:(is_matching_context tenv procname)


  let () =
    match Config.pulse_transitive_access_config with
    | [] ->
        ()
    | config_files ->
        let rev_config =
          List.fold config_files ~init:empty ~f:(fun merged_config config_file ->
              let new_config =
                match Utils.read_safe_json_file config_file with
                | Ok (`List []) ->
                    L.die ExternalError "The content of transitive-access JSON config is empty@."
                | Ok json -> (
                  try t_of_yojson json
                  with _ ->
                    L.die ExternalError
                      "Could not read or parse transitive-access JSON config in %s@." config_file )
                | Error msg ->
                    L.die ExternalError
                      "Could not read or parse transitive-access JSON config in %s:@\n%s@."
                      config_file msg
              in
              { fieldnames_to_monitor=
                  List.rev_append new_config.fieldnames_to_monitor
                    merged_config.fieldnames_to_monitor
              ; procnames_to_monitor=
                  List.rev_append new_config.procnames_to_monitor merged_config.procnames_to_monitor
              ; contexts= List.rev_append new_config.contexts merged_config.contexts } )
        in
        { fieldnames_to_monitor= List.rev rev_config.fieldnames_to_monitor
        ; procnames_to_monitor= List.rev rev_config.procnames_to_monitor
        ; contexts= List.rev rev_config.contexts }
        |> set
end

let record_load rhs_exp location astates =
  match rhs_exp with
  | Exp.Lfield (_, fieldname, _) when Config.fieldname_must_be_monitored fieldname ->
      List.map astates ~f:(function
        | ContinueProgram astate ->
            ContinueProgram (AbductiveDomain.record_transitive_access location astate)
        | execstate ->
            execstate )
  | _ ->
      astates


let record_call tenv procname location astate =
  if Option.exists procname ~f:(Config.procname_must_be_monitored tenv) then
    AbductiveDomain.record_transitive_access location astate
  else astate


let report_errors tenv proc_desc err_log {PulseSummary.pre_post_list; non_disj} =
  let procname = Procdesc.get_proc_name proc_desc in
  match Config.find_matching_context tenv procname with
  | Some {tag; description} ->
      let nothing_reported = ref true in
      let report transitive_callees transitive_missed_captures call_trace =
        nothing_reported := false ;
        PulseReport.report ~is_suppressed:false ~latent:false tenv proc_desc err_log
          (Diagnostic.TransitiveAccess
             {tag; description; call_trace; transitive_callees; transitive_missed_captures} )
      in
      List.iter pre_post_list ~f:(function
        | ContinueProgram astate ->
            let {PulseTransitiveInfo.accesses; callees; missed_captures} =
              AbductiveDomain.Summary.get_transitive_info astate
            in
            PulseTrace.Set.iter (report callees missed_captures) accesses
        | _ ->
            () ) ;
      NonDisjDomain.Summary.get_transitive_info_if_not_top non_disj
      |> Option.iter ~f:(fun {PulseTransitiveInfo.accesses; callees; missed_captures} ->
             PulseTrace.Set.iter (report callees missed_captures) accesses ;
             if !nothing_reported && pulse_transitive_access_verbose then
               let call_trace : PulseTrace.t =
                 Immediate {location= Location.dummy; history= PulseValueHistory.epoch}
               in
               let transitive_callees = callees in
               let transitive_missed_captures = missed_captures in
               PulseReport.report ~is_suppressed:false ~latent:false tenv proc_desc err_log
                 (Diagnostic.TransitiveAccess
                    { tag= "NO ACCESS FOUND"
                    ; description= ""
                    ; call_trace
                    ; transitive_callees
                    ; transitive_missed_captures } ) )
  | None ->
      ()
