@@ -11,6 +11,42 @@ let rangeOfLoc (loc : Location.t) =
1111 let end_ = loc |> Loc. end_ |> mkPosition in
1212 {Protocol. start; end_}
1313
14+ let extractTypeFromExpr expr ~debug ~path ~currentFile ~full ~pos =
15+ match
16+ expr.Parsetree. pexp_loc
17+ |> CompletionFrontEnd. findTypeOfExpressionAtLoc ~debug ~path ~current File
18+ ~pos Cursor:(Pos. ofLexing expr.Parsetree. pexp_loc.loc_start)
19+ with
20+ | Some (completable , scope ) -> (
21+ let env = SharedTypes.QueryEnv. fromFile full.SharedTypes. file in
22+ let completions =
23+ completable
24+ |> CompletionBackEnd. processCompletable ~debug ~full ~pos ~scope ~env
25+ ~for Hover:true
26+ in
27+ let rawOpens = Scope. getRawOpens scope in
28+ match completions with
29+ | {env} :: _ -> (
30+ let opens =
31+ CompletionBackEnd. getOpens ~debug ~raw Opens ~package: full.package ~env
32+ in
33+ match
34+ CompletionBackEnd. completionsGetCompletionType2 ~debug ~full ~raw Opens
35+ ~opens ~pos completions
36+ with
37+ | Some (typ , _env ) ->
38+ let extractedType =
39+ match typ with
40+ | ExtractedType t -> Some t
41+ | TypeExpr t ->
42+ TypeUtils. extractType t ~env ~package: full.package
43+ |> TypeUtils. getExtractedType
44+ in
45+ extractedType
46+ | None -> None )
47+ | _ -> None )
48+ | _ -> None
49+
1450module IfThenElse = struct
1551 (* Convert if-then-else to switch *)
1652
@@ -324,6 +360,196 @@ module AddTypeAnnotation = struct
324360 | _ -> () ))
325361end
326362
363+ module ExpandCatchAllForVariants = struct
364+ let mkIterator ~pos ~result =
365+ let expr (iterator : Ast_iterator.iterator ) (e : Parsetree.expression ) =
366+ (if e.pexp_loc |> Loc. hasPos ~pos then
367+ match e.pexp_desc with
368+ | Pexp_match (switchExpr , cases ) -> (
369+ let catchAllCase =
370+ cases
371+ |> List. find_opt (fun (c : Parsetree.case ) ->
372+ match c with
373+ | {pc_lhs = {ppat_desc = Ppat_any } } -> true
374+ | _ -> false )
375+ in
376+ match catchAllCase with
377+ | None -> ()
378+ | Some catchAllCase ->
379+ result := Some (switchExpr, catchAllCase, cases))
380+ | _ -> () );
381+ Ast_iterator. default_iterator.expr iterator e
382+ in
383+ {Ast_iterator. default_iterator with expr}
384+
385+ let xform ~path ~pos ~full ~structure ~currentFile ~codeActions ~debug =
386+ let result = ref None in
387+ let iterator = mkIterator ~pos ~result in
388+ iterator.structure iterator structure;
389+ match ! result with
390+ | None -> ()
391+ | Some (switchExpr , catchAllCase , cases ) -> (
392+ if Debug. verbose () then
393+ print_endline
394+ " [codeAction - ExpandCatchAllForVariants] Found target switch" ;
395+ let currentConstructorNames =
396+ cases
397+ |> List. filter_map (fun (c : Parsetree.case ) ->
398+ match c with
399+ | {pc_lhs = {ppat_desc = Ppat_construct ({txt} , _ )} } ->
400+ Some (Longident. last txt)
401+ | {pc_lhs = {ppat_desc = Ppat_variant (name , _ )} } -> Some name
402+ | _ -> None )
403+ in
404+ match
405+ switchExpr
406+ |> extractTypeFromExpr ~debug ~path ~current File ~full
407+ ~pos: (Pos. ofLexing switchExpr.pexp_loc.loc_end)
408+ with
409+ | Some (Tvariant {constructors} ) ->
410+ let missingConstructors =
411+ constructors
412+ |> List. filter (fun (c : SharedTypes.Constructor.t ) ->
413+ currentConstructorNames |> List. mem c.cname.txt = false )
414+ in
415+ if List. length missingConstructors > 0 then
416+ let newText =
417+ missingConstructors
418+ |> List. map (fun (c : SharedTypes.Constructor.t ) ->
419+ c.cname.txt
420+ ^
421+ match c.args with
422+ | Args [] -> " "
423+ | Args _ | InlineRecord _ -> " (_)" )
424+ |> String. concat " | "
425+ in
426+ let range = rangeOfLoc catchAllCase.pc_lhs.ppat_loc in
427+ let codeAction =
428+ CodeActions. make ~title: " Expand catch-all" ~kind: RefactorRewrite
429+ ~uri: path ~new Text ~range
430+ in
431+ codeActions := codeAction :: ! codeActions
432+ else ()
433+ | Some (Tpolyvariant {constructors} ) ->
434+ let missingConstructors =
435+ constructors
436+ |> List. filter (fun (c : SharedTypes.polyVariantConstructor ) ->
437+ currentConstructorNames |> List. mem c.name = false )
438+ in
439+ if List. length missingConstructors > 0 then
440+ let newText =
441+ missingConstructors
442+ |> List. map (fun (c : SharedTypes.polyVariantConstructor ) ->
443+ Res_printer. polyVarIdentToString c.name
444+ ^
445+ match c.args with
446+ | [] -> " "
447+ | _ -> " (_)" )
448+ |> String. concat " | "
449+ in
450+ let range = rangeOfLoc catchAllCase.pc_lhs.ppat_loc in
451+ let codeAction =
452+ CodeActions. make ~title: " Expand catch-all" ~kind: RefactorRewrite
453+ ~uri: path ~new Text ~range
454+ in
455+ codeActions := codeAction :: ! codeActions
456+ else ()
457+ | Some (Toption (env , innerType )) -> (
458+ if Debug. verbose () then
459+ print_endline
460+ " [codeAction - ExpandCatchAllForVariants] Found option type" ;
461+ let innerType =
462+ match innerType with
463+ | ExtractedType t -> Some t
464+ | TypeExpr t -> (
465+ match TypeUtils. extractType ~env ~package: full.package t with
466+ | None -> None
467+ | Some (t , _ ) -> Some t)
468+ in
469+ match innerType with
470+ | Some ((Tvariant _ | Tpolyvariant _ ) as variant ) ->
471+ let currentConstructorNames =
472+ cases
473+ |> List. filter_map (fun (c : Parsetree.case ) ->
474+ match c with
475+ | {
476+ pc_lhs =
477+ {
478+ ppat_desc =
479+ Ppat_construct
480+ ( {txt = Lident " Some" },
481+ Some {ppat_desc = Ppat_construct ({txt}, _)} );
482+ };
483+ } ->
484+ Some (Longident. last txt)
485+ | {
486+ pc_lhs =
487+ {
488+ ppat_desc =
489+ Ppat_construct
490+ ( {txt = Lident " Some" },
491+ Some {ppat_desc = Ppat_variant (name, _)} );
492+ };
493+ } ->
494+ Some name
495+ | _ -> None )
496+ in
497+ let hasNoneCase =
498+ cases
499+ |> List. exists (fun (c : Parsetree.case ) ->
500+ match c.pc_lhs.ppat_desc with
501+ | Ppat_construct ({txt = Lident "None" } , _ ) -> true
502+ | _ -> false )
503+ in
504+ let missingConstructors =
505+ match variant with
506+ | Tvariant {constructors} ->
507+ constructors
508+ |> List. filter_map (fun (c : SharedTypes.Constructor.t ) ->
509+ if currentConstructorNames |> List. mem c.cname.txt = false
510+ then
511+ Some
512+ ( c.cname.txt,
513+ match c.args with
514+ | Args [] -> false
515+ | _ -> true )
516+ else None )
517+ | Tpolyvariant {constructors} ->
518+ constructors
519+ |> List. filter_map
520+ (fun (c : SharedTypes.polyVariantConstructor ) ->
521+ if currentConstructorNames |> List. mem c.name = false then
522+ Some
523+ ( Res_printer. polyVarIdentToString c.name,
524+ match c.args with
525+ | [] -> false
526+ | _ -> true )
527+ else None )
528+ | _ -> []
529+ in
530+ if List. length missingConstructors > 0 || not hasNoneCase then
531+ let newText =
532+ " Some("
533+ ^ (missingConstructors
534+ |> List. map (fun (name , hasArgs ) ->
535+ name ^ if hasArgs then " (_)" else " " )
536+ |> String. concat " | " )
537+ ^ " )"
538+ in
539+ let newText =
540+ if hasNoneCase then newText else newText ^ " | None"
541+ in
542+ let range = rangeOfLoc catchAllCase.pc_lhs.ppat_loc in
543+ let codeAction =
544+ CodeActions. make ~title: " Expand catch-all" ~kind: RefactorRewrite
545+ ~uri: path ~new Text ~range
546+ in
547+ codeActions := codeAction :: ! codeActions
548+ else ()
549+ | _ -> () )
550+ | _ -> () )
551+ end
552+
327553module ExhaustiveSwitch = struct
328554 (* Expand expression to be an exhaustive switch of the underlying value *)
329555 type posType = Single of Pos .t | Range of Pos .t * Pos .t
@@ -336,46 +562,6 @@ module ExhaustiveSwitch = struct
336562 }
337563 | Selection of {expr : Parsetree .expression }
338564
339- module C = struct
340- let extractTypeFromExpr expr ~debug ~path ~currentFile ~full ~pos =
341- match
342- expr.Parsetree. pexp_loc
343- |> CompletionFrontEnd. findTypeOfExpressionAtLoc ~debug ~path
344- ~current File
345- ~pos Cursor:(Pos. ofLexing expr.Parsetree. pexp_loc.loc_start)
346- with
347- | Some (completable , scope ) -> (
348- let env = SharedTypes.QueryEnv. fromFile full.SharedTypes. file in
349- let completions =
350- completable
351- |> CompletionBackEnd. processCompletable ~debug ~full ~pos ~scope ~env
352- ~for Hover:true
353- in
354- let rawOpens = Scope. getRawOpens scope in
355- match completions with
356- | {env} :: _ -> (
357- let opens =
358- CompletionBackEnd. getOpens ~debug ~raw Opens ~package: full.package
359- ~env
360- in
361- match
362- CompletionBackEnd. completionsGetCompletionType2 ~debug ~full
363- ~raw Opens ~opens ~pos completions
364- with
365- | Some (typ , _env ) ->
366- let extractedType =
367- match typ with
368- | ExtractedType t -> Some t
369- | TypeExpr t ->
370- TypeUtils. extractType t ~env ~package: full.package
371- |> TypeUtils. getExtractedType
372- in
373- extractedType
374- | None -> None )
375- | _ -> None )
376- | _ -> None
377- end
378-
379565 let mkIteratorSingle ~pos ~result =
380566 let expr (iterator : Ast_iterator.iterator ) (exp : Parsetree.expression ) =
381567 (match exp.pexp_desc with
@@ -434,7 +620,7 @@ module ExhaustiveSwitch = struct
434620 | Some (Selection {expr} ) -> (
435621 match
436622 expr
437- |> C. extractTypeFromExpr ~debug ~path ~current File ~full
623+ |> extractTypeFromExpr ~debug ~path ~current File ~full
438624 ~pos: (Pos. ofLexing expr.pexp_loc.loc_start)
439625 with
440626 | None -> ()
@@ -460,7 +646,7 @@ module ExhaustiveSwitch = struct
460646 | Some (Switch {switchExpr; completionExpr; pos} ) -> (
461647 match
462648 completionExpr
463- |> C. extractTypeFromExpr ~debug ~path ~current File ~full ~pos
649+ |> extractTypeFromExpr ~debug ~path ~current File ~full ~pos
464650 with
465651 | None -> ()
466652 | Some extractedType -> (
@@ -743,6 +929,8 @@ let extractCodeActions ~path ~startPos ~endPos ~currentFile ~debug =
743929 match Cmt. loadFullCmtFromPath ~path with
744930 | Some full ->
745931 AddTypeAnnotation. xform ~path ~pos ~full ~structure ~code Actions ~debug ;
932+ ExpandCatchAllForVariants. xform ~path ~pos ~full ~structure ~code Actions
933+ ~current File ~debug ;
746934 ExhaustiveSwitch. xform ~print Expr ~path
747935 ~pos:
748936 (if startPos = endPos then Single startPos
0 commit comments