Skip to content

Commit

Permalink
feat: dsimproc command
Browse files Browse the repository at this point in the history
Simplification procedures that produce definitionally equal results.

WIP
  • Loading branch information
leodemoura committed Mar 5, 2024
1 parent f986f69 commit b24fbf4
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 40 deletions.
98 changes: 79 additions & 19 deletions src/Init/Simproc.lean
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,43 @@ Simplification procedures can be also scoped or local.
-/
syntax (docComment)? attrKind "simproc " (Tactic.simpPre <|> Tactic.simpPost)? ("[" ident,* "]")? ident " (" term ")" " := " term : command

/--
Similar to `simproc`, but resulting expression must be definitionally equal to the input one.
-/
syntax (docComment)? attrKind "dsimproc " (Tactic.simpPre <|> Tactic.simpPost)? ("[" ident,* "]")? ident " (" term ")" " := " term : command

/--
A user-defined simplification procedure declaration. To activate this procedure in `simp` tactic,
we must provide it as an argument, or use the command `attribute` to set its `[simproc]` attribute.
-/
syntax (docComment)? "simproc_decl " ident " (" term ")" " := " term : command

/--
A user-defined defeq simplification procedure declaration. To activate this procedure in `simp` tactic,
we must provide it as an argument, or use the command `attribute` to set its `[simproc]` attribute.
-/
syntax (docComment)? "dsimproc_decl " ident " (" term ")" " := " term : command

/--
A builtin simplification procedure.
-/
syntax (docComment)? attrKind "builtin_simproc " (Tactic.simpPre <|> Tactic.simpPost)? ("[" ident,* "]")? ident " (" term ")" " := " term : command

/--
A builtin defeq simplification procedure.
-/
syntax (docComment)? attrKind "builtin_dsimproc " (Tactic.simpPre <|> Tactic.simpPost)? ("[" ident,* "]")? ident " (" term ")" " := " term : command

/--
A builtin simplification procedure declaration.
-/
syntax (docComment)? "builtin_simproc_decl " ident " (" term ")" " := " term : command

/--
A builtin defeq simplification procedure declaration.
-/
syntax (docComment)? "builtin_dsimproc_decl " ident " (" term ")" " := " term : command

/--
Auxiliary command for associating a pattern with a simplification procedure.
-/
Expand Down Expand Up @@ -86,33 +107,60 @@ macro_rules
`($[$doc?:docComment]? def $n:ident : $(mkIdent simprocType) := $body
simproc_pattern% $pattern => $n)

macro_rules
| `($[$doc?:docComment]? dsimproc_decl $n:ident ($pattern:term) := $body) => do
let simprocType := `Lean.Meta.Simp.DSimproc
`($[$doc?:docComment]? def $n:ident : $(mkIdent simprocType) := $body
simproc_pattern% $pattern => $n)

macro_rules
| `($[$doc?:docComment]? builtin_simproc_decl $n:ident ($pattern:term) := $body) => do
let simprocType := `Lean.Meta.Simp.Simproc
`($[$doc?:docComment]? def $n:ident : $(mkIdent simprocType) := $body
builtin_simproc_pattern% $pattern => $n)

macro_rules
| `($[$doc?:docComment]? builtin_dsimproc_decl $n:ident ($pattern:term) := $body) => do
let simprocType := `Lean.Meta.Simp.DSimproc
`($[$doc?:docComment]? def $n:ident : $(mkIdent simprocType) := $body
builtin_simproc_pattern% $pattern => $n)

private def mkAttributeCmds
(kind : TSyntax `Lean.Parser.Term.attrKind)
(pre? : Option (TSyntax [`Lean.Parser.Tactic.simpPre, `Lean.Parser.Tactic.simpPost]))
(ids? : Option (Syntax.TSepArray `ident ","))
(n : Ident) : MacroM (Array Syntax) := do
let mut cmds := #[]
let pushDefault (cmds : Array (TSyntax `command)) : MacroM (Array (TSyntax `command)) := do
return cmds.push (← `(attribute [$kind simproc $[$pre?]?] $n))
if let some ids := ids? then
for id in ids.getElems do
let idName := id.getId
let (attrName, attrKey) :=
if idName == `simp then
(`simprocAttr, "simproc")
else if idName == `seval then
(`sevalprocAttr, "sevalproc")
else
let idName := idName.appendAfter "_proc"
(`Parser.Attr ++ idName, idName.toString)
let attrStx : TSyntax `attr := ⟨mkNode attrName #[mkAtom attrKey, mkOptionalNode pre?]⟩
cmds := cmds.push (← `(attribute [$kind $attrStx] $n))
else
cmds ← pushDefault cmds
return cmds

macro_rules
| `($[$doc?:docComment]? $kind:attrKind simproc $[$pre?]? $[ [ $ids?:ident,* ] ]? $n:ident ($pattern:term) := $body) => do
let mut cmds := #[(← `($[$doc?:docComment]? simproc_decl $n ($pattern) := $body))]
let pushDefault (cmds : Array (TSyntax `command)) : MacroM (Array (TSyntax `command)) := do
return cmds.push (← `(attribute [$kind simproc $[$pre?]?] $n))
if let some ids := ids? then
for id in ids.getElems do
let idName := id.getId
let (attrName, attrKey) :=
if idName == `simp then
(`simprocAttr, "simproc")
else if idName == `seval then
(`sevalprocAttr, "sevalproc")
else
let idName := idName.appendAfter "_proc"
(`Parser.Attr ++ idName, idName.toString)
let attrStx : TSyntax `attr := ⟨mkNode attrName #[mkAtom attrKey, mkOptionalNode pre?]⟩
cmds := cmds.push (← `(attribute [$kind $attrStx] $n))
else
cmds ← pushDefault cmds
return mkNullNode cmds
return mkNullNode <|
#[(← `($[$doc?:docComment]? simproc_decl $n ($pattern) := $body))]
++ (← mkAttributeCmds kind pre? ids? n)

macro_rules
| `($[$doc?:docComment]? $kind:attrKind dsimproc $[$pre?]? $[ [ $ids?:ident,* ] ]? $n:ident ($pattern:term) := $body) => do
return mkNullNode <|
#[(← `($[$doc?:docComment]? dsimproc_decl $n ($pattern) := $body))]
++ (← mkAttributeCmds kind pre? ids? n)

macro_rules
| `($[$doc?:docComment]? $kind:attrKind builtin_simproc $[$pre?]? $n:ident ($pattern:term) := $body) => do
Expand All @@ -126,4 +174,16 @@ macro_rules
attribute [$kind builtin_simproc $[$pre?]?] $n
attribute [$kind builtin_sevalproc $[$pre?]?] $n)

macro_rules
| `($[$doc?:docComment]? $kind:attrKind builtin_dsimproc $[$pre?]? $n:ident ($pattern:term) := $body) => do
`($[$doc?:docComment]? builtin_dsimproc_decl $n ($pattern) := $body
attribute [$kind builtin_simproc $[$pre?]?] $n)
| `($[$doc?:docComment]? $kind:attrKind builtin_dsimproc $[$pre?]? [seval] $n:ident ($pattern:term) := $body) => do
`($[$doc?:docComment]? builtin_dsimproc_decl $n ($pattern) := $body
attribute [$kind builtin_sevalproc $[$pre?]?] $n)
| `($[$doc?:docComment]? $kind:attrKind builtin_dsimproc $[$pre?]? [simp, seval] $n:ident ($pattern:term) := $body) => do
`($[$doc?:docComment]? builtin_dsimproc_decl $n ($pattern) := $body
attribute [$kind builtin_simproc $[$pre?]?] $n
attribute [$kind builtin_sevalproc $[$pre?]?] $n)

end Lean.Parser
12 changes: 7 additions & 5 deletions src/Lean/Elab/Tactic/Simproc.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ def elabSimprocKeys (stx : Syntax) : MetaM (Array Meta.SimpTheoremKey) := do
let pattern ← elabSimprocPattern stx
DiscrTree.mkPath pattern simpDtConfig

def checkSimprocType (declName : Name) : CoreM Unit := do
def checkSimprocType (declName : Name) : CoreM Bool := do
let decl ← getConstInfo declName
match decl.type with
| .const ``Simproc _ => pure ()
| .const ``Simproc _ => pure false
| .const ``DSimproc _ => pure true
| _ => throwError "unexpected type at '{declName}', 'Simproc' expected"

namespace Command
Expand All @@ -38,17 +39,18 @@ namespace Command
let `(simproc_pattern% $pattern => $declName) := stx | throwUnsupportedSyntax
let declName ← resolveGlobalConstNoOverload declName
liftTermElabM do
checkSimprocType declName
discard <| checkSimprocType declName
let keys ← elabSimprocKeys pattern
registerSimproc declName keys

@[builtin_command_elab Lean.Parser.simprocPatternBuiltin] def elabSimprocPatternBuiltin : CommandElab := fun stx => do
let `(builtin_simproc_pattern% $pattern => $declName) := stx | throwUnsupportedSyntax
let declName ← resolveGlobalConstNoOverload declName
liftTermElabM do
checkSimprocType declName
let dsimp ← checkSimprocType declName
let keys ← elabSimprocKeys pattern
let val := mkAppN (mkConst ``registerBuiltinSimproc) #[toExpr declName, toExpr keys, mkConst declName]
let registerProcName := if dsimp then ``registerBuiltinDSimproc else ``registerBuiltinSimproc
let val := mkAppN (mkConst registerProcName) #[toExpr declName, toExpr keys, mkConst declName]
let initDeclName ← mkFreshUserName (declName ++ `declare)
declareBuiltin initDeclName val

Expand Down
58 changes: 43 additions & 15 deletions src/Lean/Meta/Tactic/Simp/Simproc.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ It contains:
-/
structure BuiltinSimprocs where
keys : HashMap Name (Array SimpTheoremKey) := {}
procs : HashMap Name Simproc := {}
procs : HashMap Name (Sum Simproc DSimproc) := {}
deriving Inhabited

/--
Expand Down Expand Up @@ -79,14 +79,20 @@ Given a declaration name `declName`, store the discrimination tree keys and the
This method is invoked by the command `builtin_simproc_pattern%` elaborator.
-/
def registerBuiltinSimproc (declName : Name) (key : Array SimpTheoremKey) (proc : Simproc) : IO Unit := do
def registerBuiltinSimprocCore (declName : Name) (key : Array SimpTheoremKey) (proc : Sum Simproc DSimproc) : IO Unit := do
unless (← initializing) do
throw (IO.userError s!"invalid builtin simproc declaration, it can only be registered during initialization")
if (← builtinSimprocDeclsRef.get).keys.contains declName then
throw (IO.userError s!"invalid builtin simproc declaration '{declName}', it has already been declared")
builtinSimprocDeclsRef.modify fun { keys, procs } =>
{ keys := keys.insert declName key, procs := procs.insert declName proc }

def registerBuiltinSimproc (declName : Name) (key : Array SimpTheoremKey) (proc : Simproc) : IO Unit := do
registerBuiltinSimprocCore declName key (.inl proc)

def registerBuiltinDSimproc (declName : Name) (key : Array SimpTheoremKey) (proc : DSimproc) : IO Unit := do
registerBuiltinSimprocCore declName key (.inr proc)

def registerSimproc (declName : Name) (keys : Array SimpTheoremKey) : CoreM Unit := do
let env ← getEnv
unless (env.getModuleIdxFor? declName).isNone do
Expand All @@ -112,14 +118,21 @@ builtin_initialize builtinSEvalprocsRef : IO.Ref Simprocs ← IO.mkRef {}

abbrev SimprocExtension := ScopedEnvExtension SimprocOLeanEntry SimprocEntry Simprocs

unsafe def getSimprocFromDeclImpl (declName : Name) : ImportM Simproc := do
unsafe def getSimprocFromDeclImpl (declName : Name) : ImportM (Sum Simproc DSimproc) := do
let ctx ← read
match ctx.env.evalConstCheck Simproc ctx.opts ``Lean.Meta.Simp.Simproc declName with
| .ok proc => return proc
| .error ex => throw (IO.userError ex)
match ctx.env.find? declName with
| none => throw <| IO.userError ("unknown constant '" ++ toString declName ++ "'")
| some info =>
match info.type with
| .const ``Simproc _ =>
return .inl (← IO.ofExcept <| ctx.env.evalConst Simproc ctx.opts declName)
| .const ``DSimproc _ =>
return .inr (← IO.ofExcept <| ctx.env.evalConst DSimproc ctx.opts declName)
| _ => throw <| IO.userError "unexpected type at simproc"


@[implemented_by getSimprocFromDeclImpl]
opaque getSimprocFromDecl (declName: Name) : ImportM Simproc
opaque getSimprocFromDecl (declName: Name) : ImportM (Sum Simproc DSimproc)

def toSimprocEntry (e : SimprocOLeanEntry) : ImportM SimprocEntry := do
return { toSimprocOLeanEntry := e, proc := (← getSimprocFromDecl e.declName) }
Expand All @@ -136,7 +149,7 @@ def addSimprocAttrCore (ext : SimprocExtension) (declName : Name) (kind : Attrib
throwError "invalid [simproc] attribute, '{declName}' is not a simproc"
ext.add { declName, post, keys, proc } kind

def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Name) (post : Bool) (proc : Simproc) : Simprocs :=
def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : Simprocs :=
let s := { s with simprocNames := s.simprocNames.insert declName, erased := s.erased.erase declName }
if post then
{ s with post := s.post.insertCore keys { declName, keys, post, proc } }
Expand All @@ -146,15 +159,21 @@ def Simprocs.addCore (s : Simprocs) (keys : Array SimpTheoremKey) (declName : Na
/--
Implements attributes `builtin_simproc` and `builtin_sevalproc`.
-/
def addSimprocBuiltinAttrCore (ref : IO.Ref Simprocs) (declName : Name) (post : Bool) (proc : Simproc) : IO Unit := do
def addSimprocBuiltinAttrCore (ref : IO.Ref Simprocs) (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit := do
let some keys := (← builtinSimprocDeclsRef.get).keys.find? declName |
throw (IO.userError "invalid [builtin_simproc] attribute, '{declName}' is not a builtin simproc")
ref.modify fun s => s.addCore keys declName post proc

def addSimprocBuiltinAttr (declName : Name) (post : Bool) (proc : Simproc) : IO Unit :=
addSimprocBuiltinAttrCore builtinSimprocsRef declName post proc
addSimprocBuiltinAttrCore builtinSimprocsRef declName post (.inl proc)

def addSEvalprocBuiltinAttr (declName : Name) (post : Bool) (proc : Simproc) : IO Unit :=
addSimprocBuiltinAttrCore builtinSEvalprocsRef declName post (.inl proc)

def addSimprocBuiltinAttrNew (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit :=
addSimprocBuiltinAttrCore builtinSimprocsRef declName post proc

def addSEvalprocBuiltinAttrNew (declName : Name) (post : Bool) (proc : Sum Simproc DSimproc) : IO Unit :=
addSimprocBuiltinAttrCore builtinSEvalprocsRef declName post proc

def Simprocs.add (s : Simprocs) (declName : Name) (post : Bool) : CoreM Simprocs := do
Expand All @@ -179,8 +198,13 @@ def SimprocEntry.try (s : SimprocEntry) (numExtraArgs : Nat) (e : Expr) : SimpM
extraArgs := extraArgs.push e.appArg!
e := e.appFn!
extraArgs := extraArgs.reverse
let s ← s.proc e
s.addExtraArgs extraArgs
match s.proc with
| .inl proc =>
let s ← proc e
s.addExtraArgs extraArgs
| .inr proc =>
let s ← proc e
s.toStep.addExtraArgs extraArgs

def simprocCore (post : Bool) (s : SimprocTree) (erased : PHashSet Name) (e : Expr) : SimpM Step := do
let candidates ← s.getMatchWithExtra e (getDtConfig (← getConfig))
Expand Down Expand Up @@ -315,7 +339,11 @@ builtin_initialize simprocSEvalExtension : SimprocExtension ← registerSimprocA
private def addBuiltin (declName : Name) (stx : Syntax) (addDeclName : Name) : AttrM Unit := do
let go : MetaM Unit := do
let post := if stx[1].isNone then true else stx[1][0].getKind == ``Lean.Parser.Tactic.simpPost
let val := mkAppN (mkConst addDeclName) #[toExpr declName, toExpr post, mkConst declName]
let procExpr ← match (← getConstInfo declName).type with
| .const ``Simproc _ => pure <| mkApp3 (mkConst ``Sum.inl [0, 0]) (mkConst ``Simproc) (mkConst ``DSimproc) (mkConst declName)
| .const ``DSimproc _ => pure <| mkApp3 (mkConst ``Sum.inr [0, 0]) (mkConst ``Simproc) (mkConst ``DSimproc) (mkConst declName)
| _ => throwError "unexpected type at simproc"
let val := mkAppN (mkConst addDeclName) #[toExpr declName, toExpr post, procExpr]
let initDeclName ← mkFreshUserName (declName ++ `declare)
declareBuiltin initDeclName val
go.run' {}
Expand All @@ -327,7 +355,7 @@ builtin_initialize
descr := "Builtin simplification procedure"
applicationTime := AttributeApplicationTime.afterCompilation
erase := fun _ => throwError "Not implemented yet, [-builtin_simproc]"
add := fun declName stx _ => addBuiltin declName stx ``addSimprocBuiltinAttr
add := fun declName stx _ => addBuiltin declName stx ``addSimprocBuiltinAttrNew
}

builtin_initialize
Expand All @@ -337,7 +365,7 @@ builtin_initialize
descr := "Builtin symbolic evaluation procedure"
applicationTime := AttributeApplicationTime.afterCompilation
erase := fun _ => throwError "Not implemented yet, [-builtin_sevalproc]"
add := fun declName stx _ => addBuiltin declName stx ``addSEvalprocBuiltinAttr
add := fun declName stx _ => addBuiltin declName stx ``addSEvalprocBuiltinAttrNew
}

def getSimprocs : CoreM Simprocs :=
Expand Down
14 changes: 13 additions & 1 deletion src/Lean/Meta/Tactic/Simp/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@ See `Step`.
-/
abbrev Simproc := Expr → SimpM Step

/--
Similar to `Simproc`, but resulting expression should be definitionally equal to the input one.
-/
abbrev DSimproc := Expr → SimpM TransformStep

def _root_.Lean.TransformStep.toStep (s : TransformStep) : Step :=
match s with
| .done e => .done { expr := e }
| .visit e => .visit { expr := e }
| .continue (some e) => .continue (some { expr := e })
| .continue none => .continue none

def mkEqTransResultStep (r : Result) (s : Step) : MetaM Step :=
match s with
| .done r' => return .done (← mkEqTransOptProofResult r.proof? r.cache r')
Expand Down Expand Up @@ -189,7 +201,7 @@ structure SimprocEntry extends SimprocOLeanEntry where
Recall that we cannot store `Simproc` into .olean files because it is a closure.
Given `SimprocOLeanEntry.declName`, we convert it into a `Simproc` by using the unsafe function `evalConstCheck`.
-/
proc : Simproc
proc : Sum Simproc DSimproc

abbrev SimprocTree := DiscrTree SimprocEntry

Expand Down

0 comments on commit b24fbf4

Please sign in to comment.