Skip to content

Commit

Permalink
fix(kb): fixed some bugs in file-to-embedding process (#35)
Browse files Browse the repository at this point in the history
Because

1. wronng source table and uid
2. wrong destinaiton
3. fail to save vector type in db

This commit

fixed the bug above
  • Loading branch information
Yougigun committed Jul 9, 2024
1 parent 66307c7 commit 703bb0b
Show file tree
Hide file tree
Showing 10 changed files with 663 additions and 30 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,4 @@ tmp
/config/config_local.yaml
test_.pdf
test_.md
test_pdf_base64.txt
10 changes: 5 additions & 5 deletions pkg/client/grpc/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ package grpcclient
// if err != nil {
// fmt.Println(err)
// }
// fmt.Println("current working diretor:", dir)
// fmt.Println("current working director:", dir)
// pipelinePublicGrpcConn, err := NewGRPCConn("localhost:8081", "", "")
// if err != nil {
// t.Fatalf("failed to create grpc connection: %v", err)
Expand All @@ -123,7 +123,7 @@ package grpcclient
// ctx := metadata.NewOutgoingContext(context.Background(), md)
// pipelinePublicServiceClient := pipelinev1beta.NewPipelinePublicServiceClient(pipelinePublicGrpcConn)

// base64PDF, err := readPDFtoBase64("../../../test_.pdf")
// base64PDF, err := readFileToBase64("../../../test_.pdf")
// if err != nil {
// t.Fatalf("failed to read pdf file: %v", err)
// }
Expand All @@ -139,8 +139,8 @@ package grpcclient
// fmt.Println("convert result\n", res.Outputs[0].GetFields()["convert_result"].GetStringValue()[:100])
// }

// // readPDFtoBase64 read the pdf file and convert it to base64
// func readPDFtoBase64(path string) (string, error) {
// // readFileToBase64 read the pdf file and convert it to base64
// func readFileToBase64(path string) (string, error) {
// // Open the file
// file, err := os.Open(path)
// if err != nil {
Expand Down Expand Up @@ -237,7 +237,7 @@ package grpcclient
// if err != nil {
// fmt.Println(err)
// }
// fmt.Println("current working diretor:", dir)
// fmt.Println("current working director:", dir)
// pipelinePublicGrpcConn, err := NewGRPCConn("localhost:8081", "", "")
// if err != nil {
// t.Fatalf("failed to create grpc connection: %v", err)
Expand Down
1 change: 0 additions & 1 deletion pkg/handler/knowledgebasefiles.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ func checkValidFileType(t artifactpb.FileType) bool {
func (ph *PublicHandler) ListKnowledgeBaseFiles(ctx context.Context, req *artifactpb.ListKnowledgeBaseFilesRequest) (*artifactpb.ListKnowledgeBaseFilesResponse, error) {

log, _ := logger.GetZapLogger(ctx)
fmt.Println("ListKnowledgeBaseFiles>>>", req)
uid, err := getUserUIDFromContext(ctx)
if err != nil {
log.Error("failed to get user id from header", zap.Error(err))
Expand Down
511 changes: 501 additions & 10 deletions pkg/mock/repository_i_mock.gen.go

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pkg/repository/chunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
)

type TextChunkI interface {
TextChunkTableName() string
DeleteAndCreateChunks(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []*TextChunk, externalServiceCall func(chunkUIDs []string) (map[string]any, error)) ([]*TextChunk, error)
DeleteChunksBySource(ctx context.Context, sourceTable string, sourceUID uuid.UUID) error
DeleteChunksByUIDs(ctx context.Context, chunkUIDs []uuid.UUID) error
Expand Down Expand Up @@ -63,7 +64,7 @@ var TextChunkColumn = TextChunkColumns{
}

// TableName returns the table name of the TextChunk
func (TextChunk) TableName() string {
func (r *Repository) TextChunkTableName() string {
return "text_chunk"
}

Expand Down
25 changes: 21 additions & 4 deletions pkg/repository/convertedfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

type ConvertedFileI interface {
ConvertedFileTableName() string
CreateConvertedFile(ctx context.Context, cf ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) error) (*ConvertedFile, error)
CreateConvertedFile(ctx context.Context, cf ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) (map[string]any, error)) (*ConvertedFile, error)
DeleteConvertedFile(ctx context.Context, uid uuid.UUID) error
GetConvertedFileByFileUID(ctx context.Context, fileUID uuid.UUID) (*ConvertedFile, error)
}
Expand Down Expand Up @@ -58,7 +58,7 @@ func (r *Repository) ConvertedFileTableName() string {
return "converted_file"
}

func (r *Repository) CreateConvertedFile(ctx context.Context, cf ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) error) (*ConvertedFile, error) {
func (r *Repository) CreateConvertedFile(ctx context.Context, cf ConvertedFile, callExternalService func(convertedFileUID uuid.UUID) (map[string]any, error)) (*ConvertedFile, error) {
err := r.db.Transaction(func(tx *gorm.DB) error {
// Check if file_uid exists
var existingFile ConvertedFile
Expand Down Expand Up @@ -88,9 +88,17 @@ func (r *Repository) CreateConvertedFile(ctx context.Context, cf ConvertedFile,

if callExternalService != nil {
// Call the external service using the created record's UID
if err := callExternalService(cf.UID); err != nil {
if output, err := callExternalService(cf.UID); err != nil {
// If the external service returns an error, return the error to trigger a rollback
return err
} else {
// get dest from output and update the record
if dest, ok := output[ConvertedFileColumn.Destination].(string); ok {
update := map[string]any{ConvertedFileColumn.Destination: dest}
if err := tx.Model(&cf).Updates(update).Error; err != nil {
return err
}
}
}
}

Expand All @@ -113,7 +121,6 @@ func (r *Repository) GetConvertedFileByFileUID(ctx context.Context, fileUID uuid
return &cf, nil
}


// DeleteConvertedFile deletes the record by UID
func (r *Repository) DeleteConvertedFile(ctx context.Context, uid uuid.UUID) error {
err := r.db.Transaction(func(tx *gorm.DB) error {
Expand All @@ -129,3 +136,13 @@ func (r *Repository) DeleteConvertedFile(ctx context.Context, uid uuid.UUID) err
}
return nil
}

// UpdateConvertedFile updates the record by UID using update map.
func (r *Repository) UpdateConvertedFile(ctx context.Context, uid uuid.UUID, update map[string]any) error {
// Specify the condition to find the record by its UID
where := fmt.Sprintf("%s = ?", ConvertedFileColumn.UID)
if err := r.db.WithContext(ctx).Model(&ConvertedFile{}).Where(where, uid).Updates(update).Error; err != nil {
return err
}
return nil
}
65 changes: 64 additions & 1 deletion pkg/repository/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package repository

import (
"context"
"database/sql/driver"
"encoding/json"
"fmt"
"time"

Expand All @@ -16,17 +18,68 @@ type EmbeddingI interface {
UpsertEmbeddings(ctx context.Context, embeddings []Embedding, externalServiceCall func(embUIDs []string) error) ([]Embedding, error)
DeleteEmbeddingsBySource(ctx context.Context, sourceTable string, sourceUID uuid.UUID) error
DeleteEmbeddingsByUIDs(ctx context.Context, embUIDs []uuid.UUID) error
// GetEmbeddingByUIDs fetches embeddings by their UIDs.
GetEmbeddingByUIDs(ctx context.Context, embUIDs []uuid.UUID) ([]Embedding, error)
}
type Embedding struct {
UID uuid.UUID `gorm:"column:uid;type:uuid;default:gen_random_uuid();primaryKey" json:"uid"`
SourceUID uuid.UUID `gorm:"column:source_uid;type:uuid;not null" json:"source_uid"`
SourceTable string `gorm:"column:source_table;size:255;not null" json:"source_table"`
Vector []float32 `gorm:"column:vector;type:jsonb;not null" json:"vector"`
Vector Vector `gorm:"column:vector;type:jsonb;not null" json:"vector"`
Collection string `gorm:"column:collection;size:255;not null" json:"collection"`
CreateTime *time.Time `gorm:"column:create_time;not null;default:CURRENT_TIMESTAMP" json:"create_time"`
UpdateTime *time.Time `gorm:"column:update_time;not null;default:CURRENT_TIMESTAMP" json:"update_time"`
}

type Vector []float32

func (v Vector) Value() (driver.Value, error) {
if v == nil {
return nil, nil
}
r, err := json.Marshal(v)
if err != nil {
return nil, err
}
return string(r), nil
}

func (v *Vector) Scan(value interface{}) error {
if value == nil {
*v = nil
return nil
}

b, ok := value.([]byte)
if !ok {
return fmt.Errorf("type assertion to []byte failed")
}

return json.Unmarshal(b, v)
}

// MarshalJSON implements the json.Marshaler interface
func (v Vector) MarshalJSON() ([]byte, error) {
if v == nil {
return []byte("null"), nil
}
return json.Marshal([]float32(v))
}

// UnmarshalJSON implements the json.Unmarshaler interface
func (v *Vector) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
*v = nil
return nil
}
var slice []float32
if err := json.Unmarshal(data, &slice); err != nil {
return err
}
*v = Vector(slice)
return nil
}

type EmbeddingColumns struct {
UID string
SourceUID string
Expand Down Expand Up @@ -106,3 +159,13 @@ func (r *Repository) DeleteEmbeddingsByUIDs(ctx context.Context, embUIDs []uuid.
where := fmt.Sprintf("%s IN (?)", EmbeddingColumn.UID)
return r.db.WithContext(ctx).Where(where, embUIDs).Delete(&Embedding{}).Error
}

// GetEmbeddingByUIDs fetches embeddings by their UIDs.
func (r *Repository) GetEmbeddingByUIDs(ctx context.Context, embUIDs []uuid.UUID) ([]Embedding, error) {
var embeddings []Embedding
where := fmt.Sprintf("%s IN (?)", EmbeddingColumn.UID)
if err := r.db.WithContext(ctx).Where(where, embUIDs).Find(&embeddings).Error; err != nil {
return nil, err
}
return embeddings, nil
}
37 changes: 37 additions & 0 deletions pkg/repository/embedding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package repository

// import (
// "context"
// "fmt"
// "os"
// "testing"

// "github.com/google/uuid"
// "github.com/instill-ai/artifact-backend/config"
// "github.com/instill-ai/artifact-backend/pkg/db"
// )

// func TestGetEmbeddingByUIDs(t *testing.T) {
// // set file flag
// // fs := flag.NewFlagSet(os.Args[0], flag.ExitOnError)
// os.Args = []string{"", "-file", "../../config/config_local.yaml"}
// config.Init()
// // get db connection
// db := db.GetConnection()
// // get repository
// repo := NewRepository(db)
// // get embeddings
// uid := "006db525-ad0f-4951-8dd0-d226156b789b"
// // turn uid into uuid
// uidUUID, err := uuid.Parse(uid)
// if err != nil {
// t.Fatalf("Failed to parse uid: %v", err)
// }

// embeddings, err := repo.GetEmbeddingByUIDs(context.TODO(), []uuid.UUID{uidUUID})
// if err != nil {
// t.Fatalf("Failed to get embeddings: %v", err)
// }
// fmt.Println(embeddings)

// }
12 changes: 10 additions & 2 deletions pkg/service/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@ import (
"context"
"errors"


"github.com/google/uuid"
"github.com/instill-ai/artifact-backend/pkg/logger"
pipelinev1beta "github.com/instill-ai/protogen-go/vdp/pipeline/v1beta"
"go.uber.org/zap"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/types/known/structpb"
)

// ConcertPDFToMD using converting pipeline to convert PDF to MD and consume caller's credits
func (s *Service) ConcertPDFToMD(ctx context.Context, caller uuid.UUID, pdfBase64 string) (string, error) {

// ConvertPDFToMD using converting pipeline to convert PDF to MD and consume caller's credits
func (s *Service) ConvertPDFToMD(ctx context.Context, caller uuid.UUID, pdfBase64 string) (string, error) {
logger, _ := logger.GetZapLogger(ctx)
md := metadata.New(map[string]string{"Instill-User-Uid": caller.String(), "Instill-Auth-Type": "user"})
ctx = metadata.NewOutgoingContext(ctx, md)

req := &pipelinev1beta.TriggerOrganizationPipelineReleaseRequest{
Name: "organizations/preset/pipelines/indexing-convert-pdf/releases/v1.0.0",
Inputs: []*structpb.Struct{
Expand All @@ -26,10 +32,12 @@ func (s *Service) ConcertPDFToMD(ctx context.Context, caller uuid.UUID, pdfBase6
}
resp, err := s.PipelinePub.TriggerOrganizationPipelineRelease(ctx, req)
if err != nil {
logger.Error("failed to trigger pipeline", zap.Error(err))
return "", err
}
result, err := getConvertResult(resp)
if err != nil {
logger.Error("failed to get convert result", zap.Error(err))
return "", err
}
return result, nil
Expand Down
28 changes: 22 additions & 6 deletions pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/base64"
"fmt"
"runtime/debug"
"sync"
"time"

Expand Down Expand Up @@ -110,6 +111,19 @@ func (wp *fileToEmbWorkerPool) startWorker(ctx context.Context, workerID int) {
logger, _ := logger.GetZapLogger(ctx)
logger.Info("Worker started", zap.Int("WorkerID", workerID))
defer wp.wg.Done()
// Defer a function to catch panics
defer func() {
if r := recover(); r != nil {
logger.Error("Panic recovered in worker",
zap.Int("WorkerID", workerID),
zap.Any("panic", r),
zap.String("stack", string(debug.Stack())))
// Start a new worker
logger.Info("Restarting worker after panic", zap.Int("WorkerID", workerID))
wp.wg.Add(1)
go wp.startWorker(ctx, workerID)
}
}()
for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -352,7 +366,7 @@ func (wp *fileToEmbWorkerPool) processConvertingFile(ctx context.Context, file r
base64Data := base64.StdEncoding.EncodeToString(data)

// convert the pdf file to md
convertedMD, err := wp.svc.ConcertPDFToMD(ctx, file.CreatorUID, base64Data)
convertedMD, err := wp.svc.ConvertPDFToMD(ctx, file.CreatorUID, base64Data)
if err != nil {
logger.Error("Failed to convert pdf to md.", zap.String("File path", fileInMinIOPath))
return nil, artifactpb.FileProcessStatus_FILE_PROCESS_STATUS_UNSPECIFIED, err
Expand Down Expand Up @@ -565,8 +579,8 @@ func (wp *fileToEmbWorkerPool) processEmbeddingFile(ctx context.Context, file re
embeddings := make([]repository.Embedding, len(vectors))
for i, v := range vectors {
embeddings[i] = repository.Embedding{
SourceUID: sourceUID,
SourceTable: sourceTable,
SourceTable: wp.svc.Repository.TextChunkTableName(),
SourceUID: chunks[i].UID,
Vector: v,
Collection: collection,
}
Expand Down Expand Up @@ -602,13 +616,15 @@ func (wp *fileToEmbWorkerPool) saveConvertedFile(ctx context.Context, kbUID, fil
_, err := wp.svc.Repository.CreateConvertedFile(
ctx,
repository.ConvertedFile{KbUID: kbUID, FileUID: fileUID, Name: name, Type: "text/markdown", Destination: "destination"},
func(convertedFileUID uuid.UUID) error {
func(convertedFileUID uuid.UUID) (map[string]any, error) {
// save the converted file into object storage
err := wp.svc.MinIO.SaveConvertedFile(ctx, kbUID.String(), convertedFileUID.String(), "md", convertedFile)
if err != nil {
return err
return nil, err
}
return nil
output := make(map[string]any)
output[repository.ConvertedFileColumn.Destination] = wp.svc.MinIO.GetConvertedFilePathInKnowledgeBase(kbUID.String(), convertedFileUID.String(), "md")
return output, nil
})
if err != nil {
logger.Error("Failed to save converted file into object storage and metadata into database.", zap.String("FileUID", fileUID.String()))
Expand Down

0 comments on commit 703bb0b

Please sign in to comment.