Skip to content

Commit

Permalink
Create curve pkg and organize remaining files (#394)
Browse files Browse the repository at this point in the history
* group files into new pkgs

* fix broken files after merge

* move keystore functionalities to account pkg

* move all stark curve methods to main curve file

* move all curve tests to a single test file

* refactor hash methods on accounts

---------

Co-authored-by: Rian Hughes <[email protected]>
  • Loading branch information
cicr99 and rianhughes committed Oct 6, 2023
1 parent ce25833 commit 61e7f8f
Show file tree
Hide file tree
Showing 23 changed files with 709 additions and 1,001 deletions.
55 changes: 40 additions & 15 deletions account/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import (
"time"

"github.com/NethermindEth/juno/core/felt"
starknetgo "github.com/NethermindEth/starknet.go"
"github.com/NethermindEth/starknet.go/curve"
"github.com/NethermindEth/starknet.go/hash"
"github.com/NethermindEth/starknet.go/rpc"
"github.com/NethermindEth/starknet.go/utils"
)
Expand Down Expand Up @@ -47,10 +48,10 @@ type Account struct {
ChainId *felt.Felt
AccountAddress *felt.Felt
publicKey string
ks starknetgo.Keystore
ks Keystore
}

func NewAccount(provider rpc.RpcProvider, accountAddress *felt.Felt, publicKey string, keystore starknetgo.Keystore) (*Account, error) {
func NewAccount(provider rpc.RpcProvider, accountAddress *felt.Felt, publicKey string, keystore Keystore) (*Account, error) {
account := &Account{
provider: provider,
AccountAddress: accountAddress,
Expand Down Expand Up @@ -134,7 +135,7 @@ func (account *Account) TransactionHashDeployAccount(tx rpc.DeployAccountTxn, co
}
calldata := []*felt.Felt{tx.ClassHash, tx.ContractAddressSalt}
calldata = append(calldata, tx.ConstructorCalldata...)
calldataHash, err := computeHashOnElementsFelt(calldata)
calldataHash, err := hash.ComputeHashOnElementsFelt(calldata)
if err != nil {
return nil, err
}
Expand All @@ -145,7 +146,7 @@ func (account *Account) TransactionHashDeployAccount(tx rpc.DeployAccountTxn, co
}

// https://docs.starknet.io/documentation/architecture_and_concepts/Network_Architecture/transactions/#deploy_account_hash_calculation
return calculateTransactionHashCommon(
return hash.CalculateTransactionHashCommon(
PREFIX_DEPLOY_ACCOUNT,
versionFelt,
contractAddress,
Expand All @@ -166,7 +167,7 @@ func (account *Account) TransactionHashInvoke(tx rpc.InvokeTxnType) (*felt.Felt,
return nil, ErrNotAllParametersSet
}

calldataHash, err := computeHashOnElementsFelt(txn.Calldata)
calldataHash, err := hash.ComputeHashOnElementsFelt(txn.Calldata)
if err != nil {
return nil, err
}
Expand All @@ -175,7 +176,7 @@ func (account *Account) TransactionHashInvoke(tx rpc.InvokeTxnType) (*felt.Felt,
if err != nil {
return nil, err
}
return calculateTransactionHashCommon(
return hash.CalculateTransactionHashCommon(
PREFIX_TRANSACTION,
txnVersionFelt,
txn.ContractAddress,
Expand All @@ -191,15 +192,15 @@ func (account *Account) TransactionHashInvoke(tx rpc.InvokeTxnType) (*felt.Felt,
return nil, ErrNotAllParametersSet
}

calldataHash, err := computeHashOnElementsFelt(txn.Calldata)
calldataHash, err := hash.ComputeHashOnElementsFelt(txn.Calldata)
if err != nil {
return nil, err
}
txnVersionFelt, err := new(felt.Felt).SetString(string(txn.Version))
if err != nil {
return nil, err
}
return calculateTransactionHashCommon(
return hash.CalculateTransactionHashCommon(
PREFIX_TRANSACTION,
txnVersionFelt,
txn.SenderAddress,
Expand All @@ -224,7 +225,7 @@ func (account *Account) TransactionHashDeclare(tx rpc.DeclareTxnType) (*felt.Fel
return nil, ErrNotAllParametersSet
}

calldataHash, err := computeHashOnElementsFelt([]*felt.Felt{txn.ClassHash})
calldataHash, err := hash.ComputeHashOnElementsFelt([]*felt.Felt{txn.ClassHash})
if err != nil {
return nil, err
}
Expand All @@ -233,7 +234,7 @@ func (account *Account) TransactionHashDeclare(tx rpc.DeclareTxnType) (*felt.Fel
if err != nil {
return nil, err
}
return calculateTransactionHashCommon(
return hash.CalculateTransactionHashCommon(
PREFIX_DECLARE,
txnVersionFelt,
txn.SenderAddress,
Expand All @@ -248,7 +249,7 @@ func (account *Account) TransactionHashDeclare(tx rpc.DeclareTxnType) (*felt.Fel
return nil, ErrNotAllParametersSet
}

calldataHash, err := computeHashOnElementsFelt([]*felt.Felt{txn.ClassHash})
calldataHash, err := hash.ComputeHashOnElementsFelt([]*felt.Felt{txn.ClassHash})
if err != nil {
return nil, err
}
Expand All @@ -257,7 +258,7 @@ func (account *Account) TransactionHashDeclare(tx rpc.DeclareTxnType) (*felt.Fel
if err != nil {
return nil, err
}
return calculateTransactionHashCommon(
return hash.CalculateTransactionHashCommon(
PREFIX_DECLARE,
txnVersionFelt,
txn.SenderAddress,
Expand All @@ -284,10 +285,10 @@ func (account *Account) PrecomputeAddress(deployerAddress *felt.Felt, salt *felt
})

constructorCalldataBigIntArr := utils.FeltArrToBigIntArr(constructorCalldata)
constructorCallDataHashInt, _ := starknetgo.Curve.ComputeHashOnElements(constructorCalldataBigIntArr)
constructorCallDataHashInt, _ := curve.Curve.ComputeHashOnElements(constructorCalldataBigIntArr)
bigIntArr = append(bigIntArr, constructorCallDataHashInt)

preBigInt, err := starknetgo.Curve.ComputeHashOnElements(bigIntArr)
preBigInt, err := curve.Curve.ComputeHashOnElements(bigIntArr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -418,3 +419,27 @@ func (account *Account) TransactionByBlockIdAndIndex(ctx context.Context, blockI
func (account *Account) TransactionByHash(ctx context.Context, hash *felt.Felt) (rpc.Transaction, error) {
return account.provider.TransactionByHash(ctx, hash)
}

/*
Formats the multicall transactions in a format which can be signed and verified by the network and OpenZeppelin account contracts
*/
func FmtCalldata(fnCalls []rpc.FunctionCall) []*felt.Felt {
callArray := []*felt.Felt{}
callData := []*felt.Felt{new(felt.Felt).SetUint64(uint64(len(fnCalls)))}

for _, tx := range fnCalls {
callData = append(callData, tx.ContractAddress, tx.EntryPointSelector)

if len(tx.Calldata) == 0 {
callData = append(callData, &felt.Zero, &felt.Zero)
continue
}

callData = append(callData, new(felt.Felt).SetUint64(uint64(len(callArray))), new(felt.Felt).SetUint64(uint64(len(tx.Calldata))+1))
callArray = append(callArray, tx.Calldata...)
}
callData = append(callData, new(felt.Felt).SetUint64(uint64(len(callArray)+1)))
callData = append(callData, callArray...)
callData = append(callData, new(felt.Felt).SetUint64(0))
return callData
}
22 changes: 10 additions & 12 deletions account/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@ import (
"testing"
"time"

"github.com/NethermindEth/juno/core/felt"
starknetgo "github.com/NethermindEth/starknet.go"
"github.com/golang/mock/gomock"
"github.com/joho/godotenv"

"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/starknet.go/account"
"github.com/NethermindEth/starknet.go/contracts"
"github.com/NethermindEth/starknet.go/devnet"
"github.com/NethermindEth/starknet.go/hash"
"github.com/NethermindEth/starknet.go/mocks"
"github.com/NethermindEth/starknet.go/rpc"
"github.com/NethermindEth/starknet.go/utils"
"github.com/golang/mock/gomock"
"github.com/test-go/testify/require"
)

Expand Down Expand Up @@ -134,7 +132,7 @@ func TestTransactionHashInvoke(t *testing.T) {
for _, test := range testSet {

t.Run("Transaction hash", func(t *testing.T) {
ks := starknetgo.NewMemKeystore()
ks := account.NewMemKeystore()
if test.SetKS {
privKeyBI, ok := new(big.Int).SetString(test.PrivKey.String(), 0)
require.True(t, ok)
Expand Down Expand Up @@ -225,7 +223,7 @@ func TestChainIdMOCK(t *testing.T) {

for _, test := range testSet {
mockRpcProvider.EXPECT().ChainID(context.Background()).Return(test.ChainID, nil)
account, err := account.NewAccount(mockRpcProvider, &felt.Zero, "pubkey", starknetgo.NewMemKeystore())
account, err := account.NewAccount(mockRpcProvider, &felt.Zero, "pubkey", account.NewMemKeystore())
require.NoError(t, err)
require.Equal(t, account.ChainId.String(), test.ExpectedID)
}
Expand Down Expand Up @@ -256,7 +254,7 @@ func TestChainId(t *testing.T) {
require.NoError(t, err, "Error in rpc.NewClient")
provider := rpc.NewProvider(client)

account, err := account.NewAccount(provider, &felt.Zero, "pubkey", starknetgo.NewMemKeystore())
account, err := account.NewAccount(provider, &felt.Zero, "pubkey", account.NewMemKeystore())
require.NoError(t, err)
require.Equal(t, account.ChainId.String(), test.ExpectedID)
}
Expand Down Expand Up @@ -297,7 +295,7 @@ func TestSignMOCK(t *testing.T) {
for _, test := range testSet {
privKeyBI, ok := new(big.Int).SetString(test.PrivKey.String(), 0)
require.True(t, ok)
ks := starknetgo.NewMemKeystore()
ks := account.NewMemKeystore()
ks.Put(test.Address.String(), privKeyBI)

mockRpcProvider.EXPECT().ChainID(context.Background()).Return(test.ChainId, nil)
Expand Down Expand Up @@ -388,7 +386,7 @@ func TestAddInvoke(t *testing.T) {
provider := rpc.NewProvider(client)

// Set up ks
ks := starknetgo.NewMemKeystore()
ks := account.NewMemKeystore()
if test.SetKS {
fakePrivKeyBI, ok := new(big.Int).SetString(test.PrivKey.String(), 0)
require.True(t, ok)
Expand Down Expand Up @@ -425,7 +423,7 @@ func TestAddDeployAccountDevnet(t *testing.T) {
fakeUserPub := utils.TestHexToFelt(t, fakeUser.PublicKey)

// Set up ks
ks := starknetgo.NewMemKeystore()
ks := account.NewMemKeystore()
fakePrivKeyBI, ok := new(big.Int).SetString(fakeUser.PrivateKey, 0)
require.True(t, ok)
ks.Put(fakeUser.PublicKey, fakePrivKeyBI)
Expand Down Expand Up @@ -471,7 +469,7 @@ func TestTransactionHashDeployAccountTestnet(t *testing.T) {

ExpectedHash := utils.TestHexToFelt(t, "0x5b6b5927cd70ad7a80efdbe898244525871875c76540b239f6730118598b9cb")
ExpectedPrecomputeAddr := utils.TestHexToFelt(t, "0x88d0038623a89bf853c70ea68b1062ccf32b094d1d7e5f924cda8404dc73e1")
ks := starknetgo.NewMemKeystore()
ks := account.NewMemKeystore()
fakePrivKeyBI, ok := new(big.Int).SetString(PrivKey.String(), 0)
require.True(t, ok)
ks.Put(PubKey.String(), fakePrivKeyBI)
Expand Down Expand Up @@ -515,7 +513,7 @@ func TestTransactionHashDeclare(t *testing.T) {
require.NoError(t, err, "Error in rpc.NewClient")
provider := rpc.NewProvider(client)

acnt, err := account.NewAccount(provider, &felt.Zero, "", starknetgo.NewMemKeystore())
acnt, err := account.NewAccount(provider, &felt.Zero, "", account.NewMemKeystore())
require.NoError(t, err)

tx := rpc.DeclareTxnV2{
Expand Down
67 changes: 0 additions & 67 deletions account/hash.go

This file was deleted.

6 changes: 4 additions & 2 deletions keystore.go → account/keystore.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package starknetgo
package account

import (
"context"
"errors"
"fmt"
"math/big"
"sync"

"github.com/NethermindEth/starknet.go/curve"
)

type Keystore interface {
Expand Down Expand Up @@ -68,7 +70,7 @@ func sign(ctx context.Context, msgHash *big.Int, key *big.Int) (x *big.Int, y *b
err = ctx.Err()

default:
x, y, err = Curve.Sign(msgHash, key)
x, y, err = curve.Curve.Sign(msgHash, key)
}
return x, y, err
}
Loading

0 comments on commit 61e7f8f

Please sign in to comment.