【SwiftUI】Core ML Stable Diffusionをアプリに実装する サンプルコード

Swift

以下の記事で作成したコードに若干手を加えたものがあるので、せっかくなので公開します。
ざっくり作っただけなので粗がありますが参考までに

Sample1

promptとseedとdisableSafetyを追加、
seed=-1でランダムseedにするコードを追加したコード全文

import SwiftUI
import StableDiffusion

struct ContentView: View {
    
    @State var dispImage:CGImage?
    @State var disableGenerateButton = false
    
    @State var dispStep = 0
    @State var dispStepCount = 0
    
    @State var continueGenerate = false
    
    @State var prompt = "a photo of an astronaut riding a horse on mars"
    @State var seed = "-1"
    @State var disableSafety = false

    @State var status = ""
    
    var body: some View {
        VStack {
            if let image = dispImage {
                Image(image,scale: 1 , label: Text("Generated Image"))
            }else{
                Text("Generated Image")
                    .frame(width: 512,height: 512)
                    .border(.black)
            }
            
            TextField("prompt", text: $prompt)
            TextField("seed", text: $seed)
            
            Toggle("DisableSafety", isOn: $disableSafety)
            
            HStack {
                Button("Generate"){
                    disableGenerateButton = true
                    continueGenerate = true
                    
                    Task{
                        await generateImage()
                        
                        disableGenerateButton = false
                        continueGenerate = false
                    }
                }.disabled(disableGenerateButton)
                
                Button("Stop"){
                    continueGenerate = false
                }.disabled(!continueGenerate)
            }
            

            
            Text("Status:\(status)")
            Text("Step:\(Int(dispStep))/\(Int(dispStepCount))")
        }
        .padding()
    }
    
    func generateImage() async {
        guard let resourceURL = Bundle.main.resourceURL else{
            return
        }
        
        guard var seed = Int(self.seed) else{
            print("seed error")
            return
        }
        
        if seed < 0 {
            seed = Int(UInt32.random(in: UInt32.min...UInt32.max))
        }
        
        do{
            Task{
                await MainActor.run {
                    status = "Model Loding"
                }
            }
            let pipeline = try StableDiffusionPipeline(resourcesAt: resourceURL)

            Task{
                await MainActor.run {
                    status = "Images Generating"
                }
            }
            
            let image = try pipeline.generateImages(prompt: prompt
                                                    , seed: seed
                                                    , disableSafety:disableSafety ){ progress in
                
                Task{
                    await MainActor.run {
                        dispStepCount = progress.stepCount
                        dispStep = progress.step
                    }
                }
                
                return continueGenerate
            }.first ?? nil
            
            Task{
                await MainActor.run {
                    dispImage = image
                }
            }
            
        } catch(let error) {
            print(error.localizedDescription)
        }
    }
}

Sample2

さらにStep設定やProgressView、前回のパラメータの表示を追加。

import SwiftUI
import StableDiffusion

struct CompleteView: View {
    
    @State var dispImage:CGImage?
    @State var disableGenerateButton = false
    
    @State var dispStep = 0.0
    @State var dispStepCount = 0.0
    
    @State var continueGenerate = false
    
    @State var prompt = "a photo of an astronaut riding a horse on mars"
    @State var seed = "-1"
    
    @State var status = ""

    @State var stepCount = "50"
    @State var disableSafety = false
    
    @State var generatedPrompt = ""
    @State var generatedSeed = ""
    @State var generatedStepCount = ""
    @State var generatedDisableSafety = false
    
    
    var body: some View {
        VStack {
            if let image = dispImage {
                Image(image,scale: 1 , label: Text("Generated Image"))
            }else{
                Text("Generated Image")
                    .frame(width: 512,height: 512)
                    .border(.black)
            }
            
            VStack {
                TextField("prompt", text: $prompt)
                TextField("seed", text: $seed)
                TextField("StepCount", text: $stepCount)
                
                HStack {
                    Toggle("DisableSafety", isOn: $disableSafety)
                    Spacer()
                }
            }
            
            HStack {
                Button("Generate"){
                    disableGenerateButton = true
                    continueGenerate = true
                    
                    Task{
                        await generateImage()
                        
                        disableGenerateButton = false
                        continueGenerate = false
                    }
                }.disabled(disableGenerateButton)
                
                Button("Stop"){
                    continueGenerate = false
                }.disabled(!continueGenerate)
            }
            

            Text("Status:\(status)")

            ProgressView(value: dispStep, total: (dispStepCount == 0) ? 1.0 : dispStepCount){
                Text("Step:\(Int(dispStep))/\(Int(dispStepCount))")
            }

            VStack {
                TextField("Prompt", text: .constant(generatedPrompt))
                TextField("Seed", text: .constant(generatedSeed))
                TextField("StepCount", text: .constant(generatedStepCount))
                HStack {
                    Toggle("DisableSafety", isOn: .constant(generatedDisableSafety)).disabled(true)
                    Spacer()
                }
            }

        }
        .padding()
    }
    
    
    func generateImage() async {
        guard let resourceURL = Bundle.main.resourceURL else{
            return
        }
        
        guard var seed = Int(self.seed) else{
            print("seed error")
            return
        }
        
        if seed < 0 {
            seed = Int(UInt32.random(in: UInt32.min...UInt32.max))
        }
        
        guard var stepCount = Int(self.stepCount) else {
            print("StepCount Error")
            return
        }
        
        if stepCount < 0 {
            stepCount = 50
        }
        
        do{
            Task{
                await MainActor.run {
                    status = "Model Loding"
                }
            }
            let pipeline = try StableDiffusionPipeline(resourcesAt: resourceURL)

            
            let generateSeed = seed
            Task{
                await MainActor.run {
                    generatedDisableSafety = disableSafety
                    generatedSeed = String(generateSeed)
                    status = "Images Generating"
                }
            }
            
            let image = try pipeline.generateImages(prompt: prompt
                                                    , stepCount: stepCount
                                                    , seed: seed
                                                    , disableSafety: disableSafety ){ progress in
                
                Task{
                    await MainActor.run {
                        dispStepCount = Double(progress.stepCount)
                        dispStep = Double(progress.step)
                        
                        generatedPrompt = progress.prompt
                        generatedStepCount = String(progress.stepCount)
                    }
                }
                
                return continueGenerate
            }.first ?? nil
            
            Task{
                await MainActor.run {
                    dispImage = image
                }
            }
            
        } catch(let error) {
            print(error.localizedDescription)
        }
    }
}

コメント

  1. […] […]

タイトルとURLをコピーしました