v1.0.0
This commit is contained in:
3
skills/mlops/models/DESCRIPTION.md
Normal file
3
skills/mlops/models/DESCRIPTION.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Specific model architectures and tools — image segmentation (Segment Anything / SAM) and audio generation (AudioCraft / MusicGen). Additional model skills (CLIP, Stable Diffusion, Whisper, LLaVA) are available as optional skills.
|
||||
---
|
||||
568
skills/mlops/models/audiocraft/SKILL.md
Normal file
568
skills/mlops/models/audiocraft/SKILL.md
Normal file
@@ -0,0 +1,568 @@
|
||||
---
|
||||
name: audiocraft-audio-generation
|
||||
description: "AudioCraft: MusicGen text-to-music, AudioGen text-to-sound."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
|
||||
platforms: [linux, macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
|
||||
|
||||
---
|
||||
|
||||
# AudioCraft: Audio Generation
|
||||
|
||||
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
|
||||
|
||||
## When to use AudioCraft
|
||||
|
||||
**Use AudioCraft when:**
|
||||
- Need to generate music from text descriptions
|
||||
- Creating sound effects and environmental audio
|
||||
- Building music generation applications
|
||||
- Need melody-conditioned music generation
|
||||
- Want stereo audio output
|
||||
- Require controllable music generation with style transfer
|
||||
|
||||
**Key features:**
|
||||
- **MusicGen**: Text-to-music generation with melody conditioning
|
||||
- **AudioGen**: Text-to-sound effects generation
|
||||
- **EnCodec**: High-fidelity neural audio codec
|
||||
- **Multiple model sizes**: Small (300M) to Large (3.3B)
|
||||
- **Stereo support**: Full stereo audio generation
|
||||
- **Style conditioning**: MusicGen-Style for reference-based generation
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Stable Audio**: For longer commercial music generation
|
||||
- **Bark**: For text-to-speech with music/sound effects
|
||||
- **Riffusion**: For spectogram-based music generation
|
||||
- **OpenAI Jukebox**: For raw audio generation with lyrics
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# From PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# From GitHub (latest)
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Or use HuggingFace Transformers
|
||||
pip install transformers torch torchaudio
|
||||
```
|
||||
|
||||
### Basic text-to-music (AudioCraft)
|
||||
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Set generation parameters
|
||||
model.set_generation_params(
|
||||
duration=8, # seconds
|
||||
top_k=250,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Generate from text
|
||||
descriptions = ["happy upbeat electronic dance music with synths"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save audio
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Using HuggingFace Transformers
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
import scipy
|
||||
|
||||
# Load model and processor
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
model.to("cuda")
|
||||
|
||||
# Generate music
|
||||
inputs = processor(
|
||||
text=["80s pop track with bassy drums and synth"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
guidance_scale=3,
|
||||
max_new_tokens=256
|
||||
)
|
||||
|
||||
# Save
|
||||
sampling_rate = model.config.audio_encoder.sampling_rate
|
||||
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
|
||||
```
|
||||
|
||||
### Text-to-sound with AudioGen
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
|
||||
# Load AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
model.set_generation_params(duration=5)
|
||||
|
||||
# Generate sound effects
|
||||
descriptions = ["dog barking in a park with birds chirping"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Architecture overview
|
||||
|
||||
```
|
||||
AudioCraft Architecture:
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Text Encoder (T5) │
|
||||
│ │ │
|
||||
│ Text Embeddings │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ Transformer Decoder (LM) │
|
||||
│ Auto-regressively generates audio tokens │
|
||||
│ Using efficient token interleaving patterns │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ EnCodec Audio Decoder │
|
||||
│ Converts tokens back to audio waveform │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Model variants
|
||||
|
||||
| Model | Size | Description | Use Case |
|
||||
|-------|------|-------------|----------|
|
||||
| `musicgen-small` | 300M | Text-to-music | Quick generation |
|
||||
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
|
||||
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
|
||||
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
|
||||
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
|
||||
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
|
||||
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
|
||||
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
|
||||
|
||||
### Generation parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `duration` | 8.0 | Length in seconds (1-120) |
|
||||
| `top_k` | 250 | Top-k sampling |
|
||||
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
|
||||
| `temperature` | 1.0 | Sampling temperature |
|
||||
| `cfg_coef` | 3.0 | Classifier-free guidance |
|
||||
|
||||
## MusicGen usage
|
||||
|
||||
### Text-to-music generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Configure generation
|
||||
model.set_generation_params(
|
||||
duration=30, # Up to 30 seconds
|
||||
top_k=250, # Sampling diversity
|
||||
top_p=0.0, # 0 = use top_k only
|
||||
temperature=1.0, # Creativity (higher = more varied)
|
||||
cfg_coef=3.0 # Text adherence (higher = stricter)
|
||||
)
|
||||
|
||||
# Generate multiple samples
|
||||
descriptions = [
|
||||
"epic orchestral soundtrack with strings and brass",
|
||||
"chill lo-fi hip hop beat with jazzy piano",
|
||||
"energetic rock song with electric guitar"
|
||||
]
|
||||
|
||||
# Generate (returns [batch, channels, samples])
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save each
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Melody-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
# Load melody model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
model.set_generation_params(duration=30)
|
||||
|
||||
# Load melody audio
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Generate with melody conditioning
|
||||
descriptions = ["acoustic guitar folk song"]
|
||||
wav = model.generate_with_chroma(descriptions, melody, sr)
|
||||
|
||||
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Stereo generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load stereo model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
model.set_generation_params(duration=15)
|
||||
|
||||
descriptions = ["ambient electronic music with wide stereo panning"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# wav shape: [batch, 2, samples] for stereo
|
||||
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
|
||||
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Audio continuation
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
# Load audio to continue
|
||||
import torchaudio
|
||||
audio, sr = torchaudio.load("intro.wav")
|
||||
|
||||
# Process with text and audio
|
||||
inputs = processor(
|
||||
audio=audio.squeeze().numpy(),
|
||||
sampling_rate=sr,
|
||||
text=["continue with a epic chorus"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Generate continuation
|
||||
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
|
||||
```
|
||||
|
||||
## MusicGen-Style usage
|
||||
|
||||
### Style-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load style model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-style')
|
||||
|
||||
# Configure generation with style
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=5.0 # Style influence
|
||||
)
|
||||
|
||||
# Configure style conditioner
|
||||
model.set_style_conditioner_params(
|
||||
eval_q=3, # RVQ quantizers (1-6)
|
||||
excerpt_length=3.0 # Style excerpt length
|
||||
)
|
||||
|
||||
# Load style reference
|
||||
style_audio, sr = torchaudio.load("reference_style.wav")
|
||||
|
||||
# Generate with text + style
|
||||
descriptions = ["upbeat dance track"]
|
||||
wav = model.generate_with_style(descriptions, style_audio, sr)
|
||||
```
|
||||
|
||||
### Style-only generation (no text)
|
||||
|
||||
```python
|
||||
# Generate matching style without text prompt
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=None # Disable double CFG for style-only
|
||||
)
|
||||
|
||||
wav = model.generate_with_style([None], style_audio, sr)
|
||||
```
|
||||
|
||||
## AudioGen usage
|
||||
|
||||
### Sound effect generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate various sounds
|
||||
descriptions = [
|
||||
"thunderstorm with heavy rain and lightning",
|
||||
"busy city traffic with car horns",
|
||||
"ocean waves crashing on rocks",
|
||||
"crackling campfire in forest"
|
||||
]
|
||||
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## EnCodec usage
|
||||
|
||||
### Audio compression
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# Load EnCodec
|
||||
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Load audio
|
||||
wav, sr = torchaudio.load("audio.wav")
|
||||
|
||||
# Ensure correct sample rate
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Encode to tokens
|
||||
with torch.no_grad():
|
||||
encoded = model.encode(wav.unsqueeze(0))
|
||||
codes = encoded[0] # Audio codes
|
||||
|
||||
# Decode back to audio
|
||||
with torch.no_grad():
|
||||
decoded = model.decode(codes)
|
||||
|
||||
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Music generation pipeline
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenerator:
|
||||
def __init__(self, model_name="facebook/musicgen-medium"):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.sample_rate = 32000
|
||||
|
||||
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
|
||||
self.model.set_generation_params(
|
||||
duration=duration,
|
||||
top_k=250,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([prompt])
|
||||
|
||||
return wav[0].cpu()
|
||||
|
||||
def generate_batch(self, prompts, duration=30):
|
||||
self.model.set_generation_params(duration=duration)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate(prompts)
|
||||
|
||||
return wav.cpu()
|
||||
|
||||
def save(self, audio, path):
|
||||
torchaudio.save(path, audio, sample_rate=self.sample_rate)
|
||||
|
||||
# Usage
|
||||
generator = MusicGenerator()
|
||||
audio = generator.generate(
|
||||
"epic cinematic orchestral music",
|
||||
duration=30,
|
||||
temperature=1.0
|
||||
)
|
||||
generator.save(audio, "epic_music.wav")
|
||||
```
|
||||
|
||||
### Workflow 2: Sound design batch processing
|
||||
|
||||
```python
|
||||
import json
|
||||
from pathlib import Path
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
def batch_generate_sounds(sound_specs, output_dir):
|
||||
"""
|
||||
Generate multiple sounds from specifications.
|
||||
|
||||
Args:
|
||||
sound_specs: list of {"name": str, "description": str, "duration": float}
|
||||
output_dir: output directory path
|
||||
"""
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
results = []
|
||||
|
||||
for spec in sound_specs:
|
||||
model.set_generation_params(duration=spec.get("duration", 5))
|
||||
|
||||
wav = model.generate([spec["description"]])
|
||||
|
||||
output_path = output_dir / f"{spec['name']}.wav"
|
||||
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
|
||||
|
||||
results.append({
|
||||
"name": spec["name"],
|
||||
"path": str(output_path),
|
||||
"description": spec["description"]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
sounds = [
|
||||
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
|
||||
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
|
||||
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
|
||||
]
|
||||
|
||||
results = batch_generate_sounds(sounds, "sound_effects/")
|
||||
```
|
||||
|
||||
### Workflow 3: Gradio demo
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
def generate_music(prompt, duration, temperature, cfg_coef):
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save to temp file
|
||||
path = "temp_output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate_music,
|
||||
inputs=[
|
||||
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
|
||||
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
|
||||
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Demo"
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Memory optimization
|
||||
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Clear cache between generations
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Generate shorter durations
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Use half precision
|
||||
model = model.half()
|
||||
```
|
||||
|
||||
### Batch processing efficiency
|
||||
|
||||
```python
|
||||
# Process multiple prompts at once (more efficient)
|
||||
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
|
||||
wav = model.generate(descriptions) # Single batch
|
||||
|
||||
# Instead of
|
||||
for desc in descriptions:
|
||||
wav = model.generate([desc]) # Multiple batches (slower)
|
||||
```
|
||||
|
||||
### GPU memory requirements
|
||||
|
||||
| Model | FP32 VRAM | FP16 VRAM |
|
||||
|-------|-----------|-----------|
|
||||
| musicgen-small | ~4GB | ~2GB |
|
||||
| musicgen-medium | ~8GB | ~4GB |
|
||||
| musicgen-large | ~16GB | ~8GB |
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| CUDA OOM | Use smaller model, reduce duration |
|
||||
| Poor quality | Increase cfg_coef, better prompts |
|
||||
| Generation too short | Check max duration setting |
|
||||
| Audio artifacts | Try different temperature |
|
||||
| Stereo not working | Use stereo model variant |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/audiocraft
|
||||
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
|
||||
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
|
||||
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
|
||||
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen
|
||||
666
skills/mlops/models/audiocraft/references/advanced-usage.md
Normal file
666
skills/mlops/models/audiocraft/references/advanced-usage.md
Normal file
@@ -0,0 +1,666 @@
|
||||
# AudioCraft Advanced Usage Guide
|
||||
|
||||
## Fine-tuning MusicGen
|
||||
|
||||
### Custom dataset preparation
|
||||
|
||||
```python
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import torchaudio
|
||||
|
||||
def prepare_dataset(audio_dir, output_dir, metadata_file):
|
||||
"""
|
||||
Prepare dataset for MusicGen fine-tuning.
|
||||
|
||||
Directory structure:
|
||||
output_dir/
|
||||
├── audio/
|
||||
│ ├── 0001.wav
|
||||
│ ├── 0002.wav
|
||||
│ └── ...
|
||||
└── metadata.json
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
audio_output = output_dir / "audio"
|
||||
audio_output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load metadata (format: {"path": "...", "description": "..."})
|
||||
with open(metadata_file) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
processed = []
|
||||
|
||||
for idx, item in enumerate(metadata):
|
||||
audio_path = Path(audio_dir) / item["path"]
|
||||
|
||||
# Load and resample to 32kHz
|
||||
wav, sr = torchaudio.load(str(audio_path))
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Convert to mono if stereo
|
||||
if wav.shape[0] > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
# Save processed audio
|
||||
output_path = audio_output / f"{idx:04d}.wav"
|
||||
torchaudio.save(str(output_path), wav, sample_rate=32000)
|
||||
|
||||
processed.append({
|
||||
"path": str(output_path.relative_to(output_dir)),
|
||||
"description": item["description"],
|
||||
"duration": wav.shape[1] / 32000
|
||||
})
|
||||
|
||||
# Save processed metadata
|
||||
with open(output_dir / "metadata.json", "w") as f:
|
||||
json.dump(processed, f, indent=2)
|
||||
|
||||
print(f"Processed {len(processed)} samples")
|
||||
return processed
|
||||
```
|
||||
|
||||
### Fine-tuning with dora
|
||||
|
||||
```bash
|
||||
# AudioCraft uses dora for experiment management
|
||||
# Install dora
|
||||
pip install dora-search
|
||||
|
||||
# Clone AudioCraft
|
||||
git clone https://github.com/facebookresearch/audiocraft.git
|
||||
cd audiocraft
|
||||
|
||||
# Create config for fine-tuning
|
||||
cat > config/solver/musicgen/finetune.yaml << 'EOF'
|
||||
defaults:
|
||||
- musicgen/musicgen_base
|
||||
- /model: lm/musicgen_lm
|
||||
- /conditioner: cond_base
|
||||
|
||||
solver: musicgen
|
||||
autocast: true
|
||||
autocast_dtype: float16
|
||||
|
||||
optim:
|
||||
epochs: 100
|
||||
batch_size: 4
|
||||
lr: 1e-4
|
||||
ema: 0.999
|
||||
optimizer: adamw
|
||||
|
||||
dataset:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
train:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
valid:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
|
||||
checkpoint:
|
||||
save_every: 10
|
||||
keep_every_states: null
|
||||
EOF
|
||||
|
||||
# Run fine-tuning
|
||||
dora run solver=musicgen/finetune
|
||||
```
|
||||
|
||||
### LoRA fine-tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from audiocraft.models import MusicGen
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Get the language model component
|
||||
lm = model.lm
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none"
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
lm = get_peft_model(lm, lora_config)
|
||||
lm.print_trainable_parameters()
|
||||
```
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
### DataParallel
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Wrap LM with DataParallel
|
||||
if torch.cuda.device_count() > 1:
|
||||
model.lm = nn.DataParallel(model.lm)
|
||||
|
||||
model.to("cuda")
|
||||
```
|
||||
|
||||
### DistributedDataParallel
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def setup(rank, world_size):
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
def train(rank, world_size):
|
||||
setup(rank, world_size)
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.lm = model.lm.to(rank)
|
||||
model.lm = DDP(model.lm, device_ids=[rank])
|
||||
|
||||
# Training loop
|
||||
# ...
|
||||
|
||||
dist.destroy_process_group()
|
||||
```
|
||||
|
||||
## Custom Conditioning
|
||||
|
||||
### Adding new conditioners
|
||||
|
||||
```python
|
||||
from audiocraft.modules.conditioners import BaseConditioner
|
||||
import torch
|
||||
|
||||
class CustomConditioner(BaseConditioner):
|
||||
"""Custom conditioner for additional control signals."""
|
||||
|
||||
def __init__(self, dim, output_dim):
|
||||
super().__init__(dim, output_dim)
|
||||
self.embed = torch.nn.Linear(dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
def tokenize(self, x):
|
||||
# Tokenize input for conditioning
|
||||
return x
|
||||
|
||||
# Use with MusicGen
|
||||
from audiocraft.models.builders import get_lm_model
|
||||
|
||||
# Modify model config to include custom conditioner
|
||||
# This requires editing the model configuration
|
||||
```
|
||||
|
||||
### Melody conditioning internals
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
|
||||
import torch
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Access chroma extractor
|
||||
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
|
||||
|
||||
# Manual chroma extraction
|
||||
def extract_chroma(audio, sr):
|
||||
"""Extract chroma features from audio."""
|
||||
import librosa
|
||||
|
||||
# Compute chroma
|
||||
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
|
||||
|
||||
return torch.from_numpy(chroma).float()
|
||||
|
||||
# Use extracted chroma for conditioning
|
||||
chroma = extract_chroma(melody_audio, sample_rate)
|
||||
```
|
||||
|
||||
## EnCodec Deep Dive
|
||||
|
||||
### Custom compression settings
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
|
||||
# Load EnCodec
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Access codec parameters
|
||||
print(f"Sample rate: {encodec.sample_rate}")
|
||||
print(f"Channels: {encodec.channels}")
|
||||
print(f"Cardinality: {encodec.cardinality}") # Codebook size
|
||||
print(f"Num codebooks: {encodec.num_codebooks}")
|
||||
print(f"Frame rate: {encodec.frame_rate}")
|
||||
|
||||
# Encode with specific bandwidth
|
||||
# Lower bandwidth = more compression, lower quality
|
||||
encodec.set_target_bandwidth(6.0) # 6 kbps
|
||||
|
||||
audio = torch.randn(1, 1, 32000) # 1 second
|
||||
encoded = encodec.encode(audio)
|
||||
decoded = encodec.decode(encoded[0])
|
||||
```
|
||||
|
||||
### Streaming encoding
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.models import CompressionModel
|
||||
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
def encode_streaming(audio_stream, chunk_size=32000):
|
||||
"""Encode audio in streaming fashion."""
|
||||
all_codes = []
|
||||
|
||||
for chunk in audio_stream:
|
||||
# Ensure chunk is right shape
|
||||
if chunk.dim() == 1:
|
||||
chunk = chunk.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
codes = encodec.encode(chunk)[0]
|
||||
all_codes.append(codes)
|
||||
|
||||
return torch.cat(all_codes, dim=-1)
|
||||
|
||||
def decode_streaming(codes_stream, output_stream):
|
||||
"""Decode codes in streaming fashion."""
|
||||
for codes in codes_stream:
|
||||
with torch.no_grad():
|
||||
audio = encodec.decode(codes)
|
||||
output_stream.write(audio.cpu().numpy())
|
||||
```
|
||||
|
||||
## MultiBand Diffusion
|
||||
|
||||
### Using MBD for enhanced quality
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen, MultiBandDiffusion
|
||||
|
||||
# Load MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Load MultiBand Diffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate with standard decoder
|
||||
descriptions = ["epic orchestral music"]
|
||||
wav_standard = model.generate(descriptions)
|
||||
|
||||
# Generate tokens and use MBD decoder
|
||||
with torch.no_grad():
|
||||
# Get tokens
|
||||
gen_tokens = model.generate_tokens(descriptions)
|
||||
|
||||
# Decode with MBD
|
||||
wav_mbd = mbd.tokens_to_wav(gen_tokens)
|
||||
|
||||
# Compare quality
|
||||
print(f"Standard shape: {wav_standard.shape}")
|
||||
print(f"MBD shape: {wav_mbd.shape}")
|
||||
```
|
||||
|
||||
## API Server Deployment
|
||||
|
||||
### FastAPI server
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import io
|
||||
import base64
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model at startup
|
||||
model = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
global model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
duration: float = 10.0
|
||||
temperature: float = 1.0
|
||||
cfg_coef: float = 3.0
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
audio_base64: str
|
||||
sample_rate: int
|
||||
duration: float
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
if model is None:
|
||||
raise HTTPException(status_code=500, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
model.set_generation_params(
|
||||
duration=min(request.duration, 30),
|
||||
temperature=request.temperature,
|
||||
cfg_coef=request.cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([request.prompt])
|
||||
|
||||
# Convert to bytes
|
||||
buffer = io.BytesIO()
|
||||
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
|
||||
buffer.seek(0)
|
||||
|
||||
audio_base64 = base64.b64encode(buffer.read()).decode()
|
||||
|
||||
return GenerateResponse(
|
||||
audio_base64=audio_base64,
|
||||
sample_rate=32000,
|
||||
duration=wav.shape[-1] / 32000
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "model_loaded": model is not None}
|
||||
|
||||
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Batch processing service
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import torch
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenService:
|
||||
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def generate_async(self, prompt, duration=10):
|
||||
"""Async generation with thread pool."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _generate():
|
||||
with torch.no_grad():
|
||||
self.model.set_generation_params(duration=duration)
|
||||
return self.model.generate([prompt])
|
||||
|
||||
# Run in thread pool
|
||||
wav = await loop.run_in_executor(self.executor, _generate)
|
||||
return wav[0].cpu()
|
||||
|
||||
async def generate_batch_async(self, prompts, duration=10):
|
||||
"""Process multiple prompts concurrently."""
|
||||
tasks = [self.generate_async(p, duration) for p in prompts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
# Usage
|
||||
service = MusicGenService()
|
||||
|
||||
async def main():
|
||||
prompts = ["jazz piano", "rock guitar", "electronic beats"]
|
||||
results = await service.generate_batch_async(prompts)
|
||||
return results
|
||||
```
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### LangChain tool
|
||||
|
||||
```python
|
||||
from langchain.tools import BaseTool
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import tempfile
|
||||
|
||||
class MusicGeneratorTool(BaseTool):
|
||||
name = "music_generator"
|
||||
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
self.model.set_generation_params(duration=15)
|
||||
|
||||
def _run(self, description: str) -> str:
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([description])
|
||||
|
||||
# Save to temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
|
||||
return f"Generated music saved to: {f.name}"
|
||||
|
||||
async def _arun(self, description: str) -> str:
|
||||
return self._run(description)
|
||||
```
|
||||
|
||||
### Gradio with advanced controls
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
models = {}
|
||||
|
||||
def load_model(model_size):
|
||||
if model_size not in models:
|
||||
model_name = f"facebook/musicgen-{model_size}"
|
||||
models[model_size] = MusicGen.get_pretrained(model_name)
|
||||
return models[model_size]
|
||||
|
||||
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
|
||||
model = load_model(model_size)
|
||||
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save
|
||||
path = "output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs=[
|
||||
gr.Textbox(label="Prompt", lines=3),
|
||||
gr.Slider(1, 30, value=10, label="Duration (s)"),
|
||||
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
|
||||
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
|
||||
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Advanced",
|
||||
allow_flagging="never"
|
||||
)
|
||||
|
||||
demo.launch(share=True)
|
||||
```
|
||||
|
||||
## Audio Processing Pipeline
|
||||
|
||||
### Post-processing chain
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.transforms as T
|
||||
import numpy as np
|
||||
|
||||
class AudioPostProcessor:
|
||||
def __init__(self, sample_rate=32000):
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
def normalize(self, audio, target_db=-14.0):
|
||||
"""Normalize audio to target loudness."""
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
def fade_in_out(self, audio, fade_duration=0.1):
|
||||
"""Apply fade in/out."""
|
||||
fade_samples = int(fade_duration * self.sample_rate)
|
||||
|
||||
# Create fade curves
|
||||
fade_in = torch.linspace(0, 1, fade_samples)
|
||||
fade_out = torch.linspace(1, 0, fade_samples)
|
||||
|
||||
# Apply fades
|
||||
audio[..., :fade_samples] *= fade_in
|
||||
audio[..., -fade_samples:] *= fade_out
|
||||
|
||||
return audio
|
||||
|
||||
def apply_reverb(self, audio, decay=0.5):
|
||||
"""Apply simple reverb effect."""
|
||||
impulse = torch.zeros(int(self.sample_rate * 0.5))
|
||||
impulse[0] = 1.0
|
||||
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
|
||||
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
|
||||
|
||||
# Convolve
|
||||
audio = torch.nn.functional.conv1d(
|
||||
audio.unsqueeze(0),
|
||||
impulse.unsqueeze(0).unsqueeze(0),
|
||||
padding=len(impulse) // 2
|
||||
).squeeze(0)
|
||||
|
||||
return audio
|
||||
|
||||
def process(self, audio):
|
||||
"""Full processing pipeline."""
|
||||
audio = self.normalize(audio)
|
||||
audio = self.fade_in_out(audio)
|
||||
return audio
|
||||
|
||||
# Usage with MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
wav = model.generate(["chill ambient music"])
|
||||
processor = AudioPostProcessor()
|
||||
wav_processed = processor.process(wav[0].cpu())
|
||||
|
||||
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Audio quality metrics
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.metrics import CLAPTextConsistencyMetric
|
||||
from audiocraft.data.audio import audio_read
|
||||
|
||||
def evaluate_generation(audio_path, text_prompt):
|
||||
"""Evaluate generated audio quality."""
|
||||
# Load audio
|
||||
wav, sr = audio_read(audio_path)
|
||||
|
||||
# CLAP consistency (text-audio alignment)
|
||||
clap_metric = CLAPTextConsistencyMetric()
|
||||
clap_score = clap_metric.compute(wav, [text_prompt])
|
||||
|
||||
return {
|
||||
"clap_score": clap_score,
|
||||
"duration": wav.shape[-1] / sr
|
||||
}
|
||||
|
||||
# Batch evaluation
|
||||
def evaluate_batch(generations):
|
||||
"""Evaluate multiple generations."""
|
||||
results = []
|
||||
for gen in generations:
|
||||
result = evaluate_generation(gen["path"], gen["prompt"])
|
||||
result["prompt"] = gen["prompt"]
|
||||
results.append(result)
|
||||
|
||||
# Aggregate
|
||||
avg_clap = sum(r["clap_score"] for r in results) / len(results)
|
||||
return {
|
||||
"individual": results,
|
||||
"average_clap": avg_clap
|
||||
}
|
||||
```
|
||||
|
||||
## Model Comparison
|
||||
|
||||
### MusicGen variants benchmark
|
||||
|
||||
| Model | CLAP Score | Generation Time (10s) | VRAM |
|
||||
|-------|------------|----------------------|------|
|
||||
| musicgen-small | 0.35 | ~5s | 2GB |
|
||||
| musicgen-medium | 0.42 | ~15s | 4GB |
|
||||
| musicgen-large | 0.48 | ~30s | 8GB |
|
||||
| musicgen-melody | 0.45 | ~15s | 4GB |
|
||||
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
|
||||
|
||||
### Prompt engineering tips
|
||||
|
||||
```python
|
||||
# Good prompts - specific and descriptive
|
||||
good_prompts = [
|
||||
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
|
||||
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
|
||||
"funky disco groove with slap bass, brass section, and rhythmic guitar"
|
||||
]
|
||||
|
||||
# Bad prompts - too vague
|
||||
bad_prompts = [
|
||||
"nice music",
|
||||
"song",
|
||||
"good beat"
|
||||
]
|
||||
|
||||
# Structure: [mood] [genre] with [instruments] at [tempo/style]
|
||||
```
|
||||
504
skills/mlops/models/audiocraft/references/troubleshooting.md
Normal file
504
skills/mlops/models/audiocraft/references/troubleshooting.md
Normal file
@@ -0,0 +1,504 @@
|
||||
# AudioCraft Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Import errors
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'audiocraft'`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install from PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# Or from GitHub
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Verify installation
|
||||
python -c "from audiocraft.models import MusicGen; print('OK')"
|
||||
```
|
||||
|
||||
### FFmpeg not found
|
||||
|
||||
**Error**: `RuntimeError: ffmpeg not found`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install ffmpeg
|
||||
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows (using conda)
|
||||
conda install -c conda-forge ffmpeg
|
||||
|
||||
# Verify
|
||||
ffmpeg -version
|
||||
```
|
||||
|
||||
### PyTorch CUDA mismatch
|
||||
|
||||
**Error**: `RuntimeError: CUDA error: no kernel image is available`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
python -c "import torch; print(torch.version.cuda)"
|
||||
|
||||
# Install matching PyTorch
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# For CUDA 11.8
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
### xformers issues
|
||||
|
||||
**Error**: `ImportError: xformers` related errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install xformers for memory efficiency
|
||||
pip install xformers
|
||||
|
||||
# Or disable xformers
|
||||
export AUDIOCRAFT_USE_XFORMERS=0
|
||||
|
||||
# In Python
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0"
|
||||
from audiocraft.models import MusicGen
|
||||
```
|
||||
|
||||
## Model Loading Issues
|
||||
|
||||
### Out of memory during load
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during model loading
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Force CPU loading first
|
||||
import torch
|
||||
device = "cpu"
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device=device)
|
||||
model = model.to("cuda")
|
||||
|
||||
# Use HuggingFace with device_map
|
||||
from transformers import MusicgenForConditionalGeneration
|
||||
model = MusicgenForConditionalGeneration.from_pretrained(
|
||||
"facebook/musicgen-small",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
### Download failures
|
||||
|
||||
**Error**: Connection errors or incomplete downloads
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Set cache directory
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache"
|
||||
|
||||
# Or for HuggingFace
|
||||
os.environ["HF_HOME"] = "/path/to/hf_cache"
|
||||
|
||||
# Resume download
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("facebook/musicgen-small", resume_download=True)
|
||||
|
||||
# Use local files
|
||||
model = MusicGen.get_pretrained('/local/path/to/model')
|
||||
```
|
||||
|
||||
### Wrong model type
|
||||
|
||||
**Error**: Loading wrong model for task
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# For text-to-music: use MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# For text-to-sound: use AudioGen
|
||||
from audiocraft.models import AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
# For melody conditioning: use melody variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# For stereo: use stereo variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
## Generation Issues
|
||||
|
||||
### Empty or silent output
|
||||
|
||||
**Problem**: Generated audio is silent or very quiet
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check output
|
||||
wav = model.generate(["upbeat music"])
|
||||
print(f"Shape: {wav.shape}")
|
||||
print(f"Max amplitude: {wav.abs().max().item()}")
|
||||
print(f"Mean amplitude: {wav.abs().mean().item()}")
|
||||
|
||||
# If too quiet, normalize
|
||||
def normalize_audio(audio, target_db=-14.0):
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
wav_normalized = normalize_audio(wav)
|
||||
```
|
||||
|
||||
### Poor quality output
|
||||
|
||||
**Problem**: Generated music sounds bad or noisy
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use larger model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-large')
|
||||
|
||||
# Adjust generation parameters
|
||||
model.set_generation_params(
|
||||
duration=15,
|
||||
top_k=250, # Increase for more diversity
|
||||
temperature=0.8, # Lower for more focused output
|
||||
cfg_coef=4.0 # Increase for better text adherence
|
||||
)
|
||||
|
||||
# Use better prompts
|
||||
# Bad: "music"
|
||||
# Good: "upbeat electronic dance music with synthesizers and punchy drums"
|
||||
|
||||
# Try MultiBand Diffusion
|
||||
from audiocraft.models import MultiBandDiffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
tokens = model.generate_tokens(["prompt"])
|
||||
wav = mbd.tokens_to_wav(tokens)
|
||||
```
|
||||
|
||||
### Generation too short
|
||||
|
||||
**Problem**: Audio shorter than expected
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check duration setting
|
||||
model.set_generation_params(duration=30) # Set before generate
|
||||
|
||||
# Verify in generation
|
||||
print(f"Duration setting: {model.generation_params}")
|
||||
|
||||
# Check output shape
|
||||
wav = model.generate(["prompt"])
|
||||
actual_duration = wav.shape[-1] / 32000
|
||||
print(f"Actual duration: {actual_duration}s")
|
||||
|
||||
# Note: max duration is typically 30s
|
||||
```
|
||||
|
||||
### Melody conditioning fails
|
||||
|
||||
**Error**: Issues with melody-conditioned generation
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load melody model (not base model)
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Load and prepare melody
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Resample to model sample rate if needed
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
melody = resampler(melody)
|
||||
|
||||
# Ensure correct shape [batch, channels, samples]
|
||||
if melody.dim() == 1:
|
||||
melody = melody.unsqueeze(0).unsqueeze(0)
|
||||
elif melody.dim() == 2:
|
||||
melody = melody.unsqueeze(0)
|
||||
|
||||
# Convert stereo to mono
|
||||
if melody.shape[1] > 1:
|
||||
melody = melody.mean(dim=1, keepdim=True)
|
||||
|
||||
# Generate with melody
|
||||
model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30))
|
||||
wav = model.generate_with_chroma(["piano cover"], melody, 32000)
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Clear cache before generation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Generate one at a time
|
||||
for prompt in prompts:
|
||||
wav = model.generate([prompt])
|
||||
save_audio(wav)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use CPU for very large generations
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu")
|
||||
```
|
||||
|
||||
### Memory leak during batch processing
|
||||
|
||||
**Problem**: Memory grows over time
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import gc
|
||||
import torch
|
||||
|
||||
def generate_with_cleanup(model, prompts):
|
||||
results = []
|
||||
|
||||
for prompt in prompts:
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
results.append(wav.cpu())
|
||||
|
||||
# Cleanup
|
||||
del wav
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results
|
||||
|
||||
# Use context manager
|
||||
with torch.inference_mode():
|
||||
wav = model.generate(["prompt"])
|
||||
```
|
||||
|
||||
## Audio Format Issues
|
||||
|
||||
### Wrong sample rate
|
||||
|
||||
**Problem**: Audio plays at wrong speed
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
|
||||
# MusicGen outputs at 32kHz
|
||||
sample_rate = 32000
|
||||
|
||||
# AudioGen outputs at 16kHz
|
||||
sample_rate = 16000
|
||||
|
||||
# Always use correct rate when saving
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate)
|
||||
|
||||
# Resample if needed
|
||||
resampler = torchaudio.transforms.Resample(32000, 44100)
|
||||
wav_resampled = resampler(wav)
|
||||
```
|
||||
|
||||
### Stereo/mono mismatch
|
||||
|
||||
**Problem**: Wrong number of channels
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check model type
|
||||
print(f"Audio channels: {wav.shape}")
|
||||
# Mono: [batch, 1, samples]
|
||||
# Stereo: [batch, 2, samples]
|
||||
|
||||
# Convert mono to stereo
|
||||
if wav.shape[1] == 1:
|
||||
wav_stereo = wav.repeat(1, 2, 1)
|
||||
|
||||
# Convert stereo to mono
|
||||
if wav.shape[1] == 2:
|
||||
wav_mono = wav.mean(dim=1, keepdim=True)
|
||||
|
||||
# Use stereo model for stereo output
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
### Clipping and distortion
|
||||
|
||||
**Problem**: Audio has clipping or distortion
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check for clipping
|
||||
max_val = wav.abs().max().item()
|
||||
print(f"Max amplitude: {max_val}")
|
||||
|
||||
# Normalize to prevent clipping
|
||||
if max_val > 1.0:
|
||||
wav = wav / max_val
|
||||
|
||||
# Apply soft clipping
|
||||
def soft_clip(x, threshold=0.9):
|
||||
return torch.tanh(x / threshold) * threshold
|
||||
|
||||
wav_clipped = soft_clip(wav)
|
||||
|
||||
# Lower temperature during generation
|
||||
model.set_generation_params(temperature=0.7) # More controlled
|
||||
```
|
||||
|
||||
## HuggingFace Transformers Issues
|
||||
|
||||
### Processor errors
|
||||
|
||||
**Error**: Issues with MusicgenProcessor
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
# Load matching processor and model
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
|
||||
# Ensure inputs are on same device
|
||||
inputs = processor(
|
||||
text=["prompt"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
# Check processor configuration
|
||||
print(processor.tokenizer)
|
||||
print(processor.feature_extractor)
|
||||
```
|
||||
|
||||
### Generation parameter errors
|
||||
|
||||
**Error**: Invalid generation parameters
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# HuggingFace uses different parameter names
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True, # Enable sampling
|
||||
guidance_scale=3.0, # CFG (not cfg_coef)
|
||||
max_new_tokens=256, # Token limit (not duration)
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Calculate tokens from duration
|
||||
# ~50 tokens per second
|
||||
duration_seconds = 10
|
||||
max_tokens = duration_seconds * 50
|
||||
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
|
||||
```
|
||||
|
||||
## Performance Issues
|
||||
|
||||
### Slow generation
|
||||
|
||||
**Problem**: Generation takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Use GPU
|
||||
model.to("cuda")
|
||||
|
||||
# Enable flash attention if available
|
||||
# (requires compatible hardware)
|
||||
|
||||
# Batch multiple prompts
|
||||
prompts = ["prompt1", "prompt2", "prompt3"]
|
||||
wav = model.generate(prompts) # Single batch is faster than loop
|
||||
|
||||
# Use compile (PyTorch 2.0+)
|
||||
model.lm = torch.compile(model.lm)
|
||||
```
|
||||
|
||||
### CPU fallback
|
||||
|
||||
**Problem**: Generation running on CPU instead of GPU
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check CUDA availability
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# Explicitly move to GPU
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.to("cuda")
|
||||
|
||||
# Verify model device
|
||||
print(f"Model device: {next(model.lm.parameters()).device}")
|
||||
```
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `CUDA out of memory` | Model too large | Use smaller model, reduce duration |
|
||||
| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg |
|
||||
| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` |
|
||||
| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions |
|
||||
| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody |
|
||||
| `Sample rate mismatch` | Wrong audio format | Resample to model rate |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co
|
||||
3. **Paper**: https://arxiv.org/abs/2306.05284
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Python version
|
||||
- PyTorch version
|
||||
- CUDA version
|
||||
- AudioCraft version: `pip show audiocraft`
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
- Hardware (GPU model, VRAM)
|
||||
506
skills/mlops/models/segment-anything/SKILL.md
Normal file
506
skills/mlops/models/segment-anything/SKILL.md
Normal file
@@ -0,0 +1,506 @@
|
||||
---
|
||||
name: segment-anything-model
|
||||
description: "SAM: zero-shot image segmentation via points, boxes, masks."
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [segment-anything, transformers>=4.30.0, torch>=1.7.0]
|
||||
platforms: [linux, macos, windows]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Multimodal, Image Segmentation, Computer Vision, SAM, Zero-Shot]
|
||||
|
||||
---
|
||||
|
||||
# Segment Anything Model (SAM)
|
||||
|
||||
Comprehensive guide to using Meta AI's Segment Anything Model for zero-shot image segmentation.
|
||||
|
||||
## When to use SAM
|
||||
|
||||
**Use SAM when:**
|
||||
- Need to segment any object in images without task-specific training
|
||||
- Building interactive annotation tools with point/box prompts
|
||||
- Generating training data for other vision models
|
||||
- Need zero-shot transfer to new image domains
|
||||
- Building object detection/segmentation pipelines
|
||||
- Processing medical, satellite, or domain-specific images
|
||||
|
||||
**Key features:**
|
||||
- **Zero-shot segmentation**: Works on any image domain without fine-tuning
|
||||
- **Flexible prompts**: Points, bounding boxes, or previous masks
|
||||
- **Automatic segmentation**: Generate all object masks automatically
|
||||
- **High quality**: Trained on 1.1 billion masks from 11 million images
|
||||
- **Multiple model sizes**: ViT-B (fastest), ViT-L, ViT-H (most accurate)
|
||||
- **ONNX export**: Deploy in browsers and edge devices
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **YOLO/Detectron2**: For real-time object detection with classes
|
||||
- **Mask2Former**: For semantic/panoptic segmentation with categories
|
||||
- **GroundingDINO + SAM**: For text-prompted segmentation
|
||||
- **SAM 2**: For video segmentation tasks
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# From GitHub
|
||||
pip install git+https://github.com/facebookresearch/segment-anything.git
|
||||
|
||||
# Optional dependencies
|
||||
pip install opencv-python pycocotools matplotlib
|
||||
|
||||
# Or use HuggingFace transformers
|
||||
pip install transformers
|
||||
```
|
||||
|
||||
### Download checkpoints
|
||||
|
||||
```bash
|
||||
# ViT-H (largest, most accurate) - 2.4GB
|
||||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
||||
|
||||
# ViT-L (medium) - 1.2GB
|
||||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
|
||||
|
||||
# ViT-B (smallest, fastest) - 375MB
|
||||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
|
||||
```
|
||||
|
||||
### Basic usage with SamPredictor
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
# Load model
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam.to(device="cuda")
|
||||
|
||||
# Create predictor
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
# Set image (computes embeddings once)
|
||||
image = cv2.imread("image.jpg")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
predictor.set_image(image)
|
||||
|
||||
# Predict with point prompts
|
||||
input_point = np.array([[500, 375]]) # (x, y) coordinates
|
||||
input_label = np.array([1]) # 1 = foreground, 0 = background
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
multimask_output=True # Returns 3 mask options
|
||||
)
|
||||
|
||||
# Select best mask
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
```
|
||||
|
||||
### HuggingFace Transformers
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import SamModel, SamProcessor
|
||||
|
||||
# Load model and processor
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
model.to("cuda")
|
||||
|
||||
# Process image with point prompt
|
||||
image = Image.open("image.jpg")
|
||||
input_points = [[[450, 600]]] # Batch of points
|
||||
|
||||
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
||||
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
||||
|
||||
# Generate masks
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Post-process masks to original size
|
||||
masks = processor.image_processor.post_process_masks(
|
||||
outputs.pred_masks.cpu(),
|
||||
inputs["original_sizes"].cpu(),
|
||||
inputs["reshaped_input_sizes"].cpu()
|
||||
)
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Model architecture
|
||||
|
||||
<!-- ascii-guard-ignore -->
|
||||
```
|
||||
SAM Architecture:
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ Image Encoder │────▶│ Prompt Encoder │────▶│ Mask Decoder │
|
||||
│ (ViT) │ │ (Points/Boxes) │ │ (Transformer) │
|
||||
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||
│ │ │
|
||||
Image Embeddings Prompt Embeddings Masks + IoU
|
||||
(computed once) (per prompt) predictions
|
||||
```
|
||||
<!-- ascii-guard-ignore-end -->
|
||||
|
||||
### Model variants
|
||||
|
||||
| Model | Checkpoint | Size | Speed | Accuracy |
|
||||
|-------|------------|------|-------|----------|
|
||||
| ViT-H | `vit_h` | 2.4 GB | Slowest | Best |
|
||||
| ViT-L | `vit_l` | 1.2 GB | Medium | Good |
|
||||
| ViT-B | `vit_b` | 375 MB | Fastest | Good |
|
||||
|
||||
### Prompt types
|
||||
|
||||
| Prompt | Description | Use Case |
|
||||
|--------|-------------|----------|
|
||||
| Point (foreground) | Click on object | Single object selection |
|
||||
| Point (background) | Click outside object | Exclude regions |
|
||||
| Bounding box | Rectangle around object | Larger objects |
|
||||
| Previous mask | Low-res mask input | Iterative refinement |
|
||||
|
||||
## Interactive segmentation
|
||||
|
||||
### Point prompts
|
||||
|
||||
```python
|
||||
# Single foreground point
|
||||
input_point = np.array([[500, 375]])
|
||||
input_label = np.array([1])
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
# Multiple points (foreground + background)
|
||||
input_points = np.array([[500, 375], [600, 400], [450, 300]])
|
||||
input_labels = np.array([1, 1, 0]) # 2 foreground, 1 background
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_points,
|
||||
point_labels=input_labels,
|
||||
multimask_output=False # Single mask when prompts are clear
|
||||
)
|
||||
```
|
||||
|
||||
### Box prompts
|
||||
|
||||
```python
|
||||
# Bounding box [x1, y1, x2, y2]
|
||||
input_box = np.array([425, 600, 700, 875])
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
box=input_box,
|
||||
multimask_output=False
|
||||
)
|
||||
```
|
||||
|
||||
### Combined prompts
|
||||
|
||||
```python
|
||||
# Box + points for precise control
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=np.array([[500, 375]]),
|
||||
point_labels=np.array([1]),
|
||||
box=np.array([400, 300, 700, 600]),
|
||||
multimask_output=False
|
||||
)
|
||||
```
|
||||
|
||||
### Iterative refinement
|
||||
|
||||
```python
|
||||
# Initial prediction
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=np.array([[500, 375]]),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
# Refine with additional point using previous mask
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=np.array([[500, 375], [550, 400]]),
|
||||
point_labels=np.array([1, 0]), # Add background point
|
||||
mask_input=logits[np.argmax(scores)][None, :, :], # Use best mask
|
||||
multimask_output=False
|
||||
)
|
||||
```
|
||||
|
||||
## Automatic mask generation
|
||||
|
||||
### Basic automatic segmentation
|
||||
|
||||
```python
|
||||
from segment_anything import SamAutomaticMaskGenerator
|
||||
|
||||
# Create generator
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
|
||||
# Generate all masks
|
||||
masks = mask_generator.generate(image)
|
||||
|
||||
# Each mask contains:
|
||||
# - segmentation: binary mask
|
||||
# - bbox: [x, y, w, h]
|
||||
# - area: pixel count
|
||||
# - predicted_iou: quality score
|
||||
# - stability_score: robustness score
|
||||
# - point_coords: generating point
|
||||
```
|
||||
|
||||
### Customized generation
|
||||
|
||||
```python
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=32, # Grid density (more = more masks)
|
||||
pred_iou_thresh=0.88, # Quality threshold
|
||||
stability_score_thresh=0.95, # Stability threshold
|
||||
crop_n_layers=1, # Multi-scale crops
|
||||
crop_n_points_downscale_factor=2,
|
||||
min_mask_region_area=100, # Remove tiny masks
|
||||
)
|
||||
|
||||
masks = mask_generator.generate(image)
|
||||
```
|
||||
|
||||
### Filtering masks
|
||||
|
||||
```python
|
||||
# Sort by area (largest first)
|
||||
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
|
||||
|
||||
# Filter by predicted IoU
|
||||
high_quality = [m for m in masks if m['predicted_iou'] > 0.9]
|
||||
|
||||
# Filter by stability score
|
||||
stable_masks = [m for m in masks if m['stability_score'] > 0.95]
|
||||
```
|
||||
|
||||
## Batched inference
|
||||
|
||||
### Multiple images
|
||||
|
||||
```python
|
||||
# Process multiple images efficiently
|
||||
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
|
||||
|
||||
all_masks = []
|
||||
for image in images:
|
||||
predictor.set_image(image)
|
||||
masks, _, _ = predictor.predict(
|
||||
point_coords=np.array([[500, 375]]),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
all_masks.append(masks)
|
||||
```
|
||||
|
||||
### Multiple prompts per image
|
||||
|
||||
```python
|
||||
# Process multiple prompts efficiently (one image encoding)
|
||||
predictor.set_image(image)
|
||||
|
||||
# Batch of point prompts
|
||||
points = [
|
||||
np.array([[100, 100]]),
|
||||
np.array([[200, 200]]),
|
||||
np.array([[300, 300]])
|
||||
]
|
||||
|
||||
all_masks = []
|
||||
for point in points:
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=point,
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
all_masks.append(masks[np.argmax(scores)])
|
||||
```
|
||||
|
||||
## ONNX deployment
|
||||
|
||||
### Export model
|
||||
|
||||
```bash
|
||||
python scripts/export_onnx_model.py \
|
||||
--checkpoint sam_vit_h_4b8939.pth \
|
||||
--model-type vit_h \
|
||||
--output sam_onnx.onnx \
|
||||
--return-single-mask
|
||||
```
|
||||
|
||||
### Use ONNX model
|
||||
|
||||
```python
|
||||
import onnxruntime
|
||||
|
||||
# Load ONNX model
|
||||
ort_session = onnxruntime.InferenceSession("sam_onnx.onnx")
|
||||
|
||||
# Run inference (image embeddings computed separately)
|
||||
masks = ort_session.run(
|
||||
None,
|
||||
{
|
||||
"image_embeddings": image_embeddings,
|
||||
"point_coords": point_coords,
|
||||
"point_labels": point_labels,
|
||||
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
|
||||
"has_mask_input": np.array([0], dtype=np.float32),
|
||||
"orig_im_size": np.array([h, w], dtype=np.float32)
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Annotation tool
|
||||
|
||||
```python
|
||||
import cv2
|
||||
|
||||
# Load model
|
||||
predictor = SamPredictor(sam)
|
||||
predictor.set_image(image)
|
||||
|
||||
def on_click(event, x, y, flags, param):
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
# Foreground point
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=np.array([[x, y]]),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
# Display best mask
|
||||
display_mask(masks[np.argmax(scores)])
|
||||
```
|
||||
|
||||
### Workflow 2: Object extraction
|
||||
|
||||
```python
|
||||
def extract_object(image, point):
|
||||
"""Extract object at point with transparent background."""
|
||||
predictor.set_image(image)
|
||||
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=np.array([point]),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
|
||||
# Create RGBA output
|
||||
rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
||||
rgba[:, :, :3] = image
|
||||
rgba[:, :, 3] = best_mask * 255
|
||||
|
||||
return rgba
|
||||
```
|
||||
|
||||
### Workflow 3: Medical image segmentation
|
||||
|
||||
```python
|
||||
# Process medical images (grayscale to RGB)
|
||||
medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE)
|
||||
rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
predictor.set_image(rgb_image)
|
||||
|
||||
# Segment region of interest
|
||||
masks, scores, _ = predictor.predict(
|
||||
box=np.array([x1, y1, x2, y2]), # ROI bounding box
|
||||
multimask_output=True
|
||||
)
|
||||
```
|
||||
|
||||
## Output format
|
||||
|
||||
### Mask data structure
|
||||
|
||||
```python
|
||||
# SamAutomaticMaskGenerator output
|
||||
{
|
||||
"segmentation": np.ndarray, # H×W binary mask
|
||||
"bbox": [x, y, w, h], # Bounding box
|
||||
"area": int, # Pixel count
|
||||
"predicted_iou": float, # 0-1 quality score
|
||||
"stability_score": float, # 0-1 robustness score
|
||||
"crop_box": [x, y, w, h], # Generation crop region
|
||||
"point_coords": [[x, y]], # Input point
|
||||
}
|
||||
```
|
||||
|
||||
### COCO RLE format
|
||||
|
||||
```python
|
||||
from pycocotools import mask as mask_utils
|
||||
|
||||
# Encode mask to RLE
|
||||
rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
|
||||
# Decode RLE to mask
|
||||
decoded_mask = mask_utils.decode(rle)
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### GPU memory
|
||||
|
||||
```python
|
||||
# Use smaller model for limited VRAM
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
|
||||
# Process images in batches
|
||||
# Clear CUDA cache between large batches
|
||||
torch.cuda.empty_cache()
|
||||
```
|
||||
|
||||
### Speed optimization
|
||||
|
||||
```python
|
||||
# Use half precision
|
||||
sam = sam.half()
|
||||
|
||||
# Reduce points for automatic generation
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=16, # Default is 32
|
||||
)
|
||||
|
||||
# Use ONNX for deployment
|
||||
# Export with --return-single-mask for faster inference
|
||||
```
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Out of memory | Use ViT-B model, reduce image size |
|
||||
| Slow inference | Use ViT-B, reduce points_per_side |
|
||||
| Poor mask quality | Try different prompts, use box + points |
|
||||
| Edge artifacts | Use stability_score filtering |
|
||||
| Small objects missed | Increase points_per_side |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Batching, fine-tuning, integration
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/segment-anything
|
||||
- **Paper**: https://arxiv.org/abs/2304.02643
|
||||
- **Demo**: https://segment-anything.com
|
||||
- **SAM 2 (Video)**: https://github.com/facebookresearch/segment-anything-2
|
||||
- **HuggingFace**: https://huggingface.co/facebook/sam-vit-huge
|
||||
@@ -0,0 +1,589 @@
|
||||
# Segment Anything Advanced Usage Guide
|
||||
|
||||
## SAM 2 (Video Segmentation)
|
||||
|
||||
### Overview
|
||||
|
||||
SAM 2 extends SAM to video segmentation with streaming memory architecture:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/facebookresearch/segment-anything-2.git
|
||||
```
|
||||
|
||||
### Video segmentation
|
||||
|
||||
```python
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
predictor = build_sam2_video_predictor("sam2_hiera_l.yaml", "sam2_hiera_large.pt")
|
||||
|
||||
# Initialize with video
|
||||
predictor.init_state(video_path="video.mp4")
|
||||
|
||||
# Add prompt on first frame
|
||||
predictor.add_new_points(
|
||||
frame_idx=0,
|
||||
obj_id=1,
|
||||
points=[[100, 200]],
|
||||
labels=[1]
|
||||
)
|
||||
|
||||
# Propagate through video
|
||||
for frame_idx, masks in predictor.propagate_in_video():
|
||||
# masks contains segmentation for all tracked objects
|
||||
process_frame(frame_idx, masks)
|
||||
```
|
||||
|
||||
### SAM 2 vs SAM comparison
|
||||
|
||||
| Feature | SAM | SAM 2 |
|
||||
|---------|-----|-------|
|
||||
| Input | Images only | Images + Videos |
|
||||
| Architecture | ViT + Decoder | Hiera + Memory |
|
||||
| Memory | Per-image | Streaming memory bank |
|
||||
| Tracking | No | Yes, across frames |
|
||||
| Models | ViT-B/L/H | Hiera-T/S/B+/L |
|
||||
|
||||
## Grounded SAM (Text-Prompted Segmentation)
|
||||
|
||||
### Setup
|
||||
|
||||
```bash
|
||||
pip install groundingdino-py
|
||||
pip install git+https://github.com/facebookresearch/segment-anything.git
|
||||
```
|
||||
|
||||
### Text-to-mask pipeline
|
||||
|
||||
```python
|
||||
from groundingdino.util.inference import load_model, predict
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
import cv2
|
||||
|
||||
# Load Grounding DINO
|
||||
grounding_model = load_model("groundingdino_swint_ogc.pth", "GroundingDINO_SwinT_OGC.py")
|
||||
|
||||
# Load SAM
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
def text_to_mask(image, text_prompt, box_threshold=0.3, text_threshold=0.25):
|
||||
"""Generate masks from text description."""
|
||||
# Get bounding boxes from text
|
||||
boxes, logits, phrases = predict(
|
||||
model=grounding_model,
|
||||
image=image,
|
||||
caption=text_prompt,
|
||||
box_threshold=box_threshold,
|
||||
text_threshold=text_threshold
|
||||
)
|
||||
|
||||
# Generate masks with SAM
|
||||
predictor.set_image(image)
|
||||
|
||||
masks = []
|
||||
for box in boxes:
|
||||
# Convert normalized box to pixel coordinates
|
||||
h, w = image.shape[:2]
|
||||
box_pixels = box * np.array([w, h, w, h])
|
||||
|
||||
mask, score, _ = predictor.predict(
|
||||
box=box_pixels,
|
||||
multimask_output=False
|
||||
)
|
||||
masks.append(mask[0])
|
||||
|
||||
return masks, boxes, phrases
|
||||
|
||||
# Usage
|
||||
image = cv2.imread("image.jpg")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
masks, boxes, phrases = text_to_mask(image, "person . dog . car")
|
||||
```
|
||||
|
||||
## Batched Processing
|
||||
|
||||
### Efficient multi-image processing
|
||||
|
||||
```python
|
||||
import torch
|
||||
from segment_anything import SamPredictor, sam_model_registry
|
||||
|
||||
class BatchedSAM:
|
||||
def __init__(self, checkpoint, model_type="vit_h", device="cuda"):
|
||||
self.sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
||||
self.sam.to(device)
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
self.device = device
|
||||
|
||||
def process_batch(self, images, prompts):
|
||||
"""Process multiple images with corresponding prompts."""
|
||||
results = []
|
||||
|
||||
for image, prompt in zip(images, prompts):
|
||||
self.predictor.set_image(image)
|
||||
|
||||
if "point" in prompt:
|
||||
masks, scores, _ = self.predictor.predict(
|
||||
point_coords=prompt["point"],
|
||||
point_labels=prompt["label"],
|
||||
multimask_output=True
|
||||
)
|
||||
elif "box" in prompt:
|
||||
masks, scores, _ = self.predictor.predict(
|
||||
box=prompt["box"],
|
||||
multimask_output=False
|
||||
)
|
||||
|
||||
results.append({
|
||||
"masks": masks,
|
||||
"scores": scores,
|
||||
"best_mask": masks[np.argmax(scores)]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
batch_sam = BatchedSAM("sam_vit_h_4b8939.pth")
|
||||
|
||||
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
|
||||
prompts = [{"point": np.array([[100, 100]]), "label": np.array([1])} for _ in range(10)]
|
||||
|
||||
results = batch_sam.process_batch(images, prompts)
|
||||
```
|
||||
|
||||
### Parallel automatic mask generation
|
||||
|
||||
```python
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from segment_anything import SamAutomaticMaskGenerator
|
||||
|
||||
def generate_masks_parallel(images, num_workers=4):
|
||||
"""Generate masks for multiple images in parallel."""
|
||||
# Note: Each worker needs its own model instance
|
||||
def worker_init():
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
return SamAutomaticMaskGenerator(sam)
|
||||
|
||||
generators = [worker_init() for _ in range(num_workers)]
|
||||
|
||||
def process_image(args):
|
||||
idx, image = args
|
||||
generator = generators[idx % num_workers]
|
||||
return generator.generate(image)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
results = list(executor.map(process_image, enumerate(images)))
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
## Custom Integration
|
||||
|
||||
### FastAPI service
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from pydantic import BaseModel
|
||||
import numpy as np
|
||||
import cv2
|
||||
import io
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model once
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam.to("cuda")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
class PointPrompt(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
label: int = 1
|
||||
|
||||
@app.post("/segment/point")
|
||||
async def segment_with_point(
|
||||
file: UploadFile = File(...),
|
||||
points: list[PointPrompt] = []
|
||||
):
|
||||
# Read image
|
||||
contents = await file.read()
|
||||
nparr = np.frombuffer(contents, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Set image
|
||||
predictor.set_image(image)
|
||||
|
||||
# Prepare prompts
|
||||
point_coords = np.array([[p.x, p.y] for p in points])
|
||||
point_labels = np.array([p.label for p in points])
|
||||
|
||||
# Generate masks
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=point_coords,
|
||||
point_labels=point_labels,
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
best_idx = np.argmax(scores)
|
||||
|
||||
return {
|
||||
"mask": masks[best_idx].tolist(),
|
||||
"score": float(scores[best_idx]),
|
||||
"all_scores": scores.tolist()
|
||||
}
|
||||
|
||||
@app.post("/segment/auto")
|
||||
async def segment_automatic(file: UploadFile = File(...)):
|
||||
contents = await file.read()
|
||||
nparr = np.frombuffer(contents, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
masks = mask_generator.generate(image)
|
||||
|
||||
return {
|
||||
"num_masks": len(masks),
|
||||
"masks": [
|
||||
{
|
||||
"bbox": m["bbox"],
|
||||
"area": m["area"],
|
||||
"predicted_iou": m["predicted_iou"],
|
||||
"stability_score": m["stability_score"]
|
||||
}
|
||||
for m in masks
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Gradio interface
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
|
||||
# Load model
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
def segment_image(image, evt: gr.SelectData):
|
||||
"""Segment object at clicked point."""
|
||||
predictor.set_image(image)
|
||||
|
||||
point = np.array([[evt.index[0], evt.index[1]]])
|
||||
label = np.array([1])
|
||||
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=point,
|
||||
point_labels=label,
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
|
||||
# Overlay mask on image
|
||||
overlay = image.copy()
|
||||
overlay[best_mask] = overlay[best_mask] * 0.5 + np.array([255, 0, 0]) * 0.5
|
||||
|
||||
return overlay
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("# SAM Interactive Segmentation")
|
||||
gr.Markdown("Click on an object to segment it")
|
||||
|
||||
with gr.Row():
|
||||
input_image = gr.Image(label="Input Image", interactive=True)
|
||||
output_image = gr.Image(label="Segmented Image")
|
||||
|
||||
input_image.select(segment_image, inputs=[input_image], outputs=[output_image])
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Fine-Tuning SAM
|
||||
|
||||
### LoRA fine-tuning (experimental)
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from transformers import SamModel
|
||||
|
||||
# Load model
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["qkv"], # Attention layers
|
||||
lora_dropout=0.1,
|
||||
bias="none",
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
# Training loop (simplified)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
|
||||
for batch in dataloader:
|
||||
outputs = model(
|
||||
pixel_values=batch["pixel_values"],
|
||||
input_points=batch["input_points"],
|
||||
input_labels=batch["input_labels"]
|
||||
)
|
||||
|
||||
# Custom loss (e.g., IoU loss with ground truth)
|
||||
loss = compute_loss(outputs.pred_masks, batch["gt_masks"])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
### MedSAM (Medical imaging)
|
||||
|
||||
```python
|
||||
# MedSAM is a fine-tuned SAM for medical images
|
||||
# https://github.com/bowang-lab/MedSAM
|
||||
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
import torch
|
||||
|
||||
# Load MedSAM checkpoint
|
||||
medsam = sam_model_registry["vit_b"](checkpoint="medsam_vit_b.pth")
|
||||
medsam.to("cuda")
|
||||
|
||||
predictor = SamPredictor(medsam)
|
||||
|
||||
# Process medical image
|
||||
# Convert grayscale to RGB if needed
|
||||
medical_image = cv2.imread("ct_scan.png", cv2.IMREAD_GRAYSCALE)
|
||||
rgb_image = np.stack([medical_image] * 3, axis=-1)
|
||||
|
||||
predictor.set_image(rgb_image)
|
||||
|
||||
# Segment with box prompt (common for medical imaging)
|
||||
masks, scores, _ = predictor.predict(
|
||||
box=np.array([x1, y1, x2, y2]),
|
||||
multimask_output=False
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced Mask Processing
|
||||
|
||||
### Mask refinement
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from scipy import ndimage
|
||||
|
||||
def refine_mask(mask, kernel_size=5, iterations=2):
|
||||
"""Refine mask with morphological operations."""
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
||||
|
||||
# Close small holes
|
||||
closed = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel, iterations=iterations)
|
||||
|
||||
# Remove small noise
|
||||
opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel, iterations=iterations)
|
||||
|
||||
return opened.astype(bool)
|
||||
|
||||
def fill_holes(mask):
|
||||
"""Fill holes in mask."""
|
||||
filled = ndimage.binary_fill_holes(mask)
|
||||
return filled
|
||||
|
||||
def remove_small_regions(mask, min_area=100):
|
||||
"""Remove small disconnected regions."""
|
||||
labeled, num_features = ndimage.label(mask)
|
||||
sizes = ndimage.sum(mask, labeled, range(1, num_features + 1))
|
||||
|
||||
# Keep only regions larger than min_area
|
||||
mask_clean = np.zeros_like(mask)
|
||||
for i, size in enumerate(sizes, 1):
|
||||
if size >= min_area:
|
||||
mask_clean[labeled == i] = True
|
||||
|
||||
return mask_clean
|
||||
```
|
||||
|
||||
### Mask to polygon conversion
|
||||
|
||||
```python
|
||||
import cv2
|
||||
|
||||
def mask_to_polygons(mask, epsilon_factor=0.01):
|
||||
"""Convert binary mask to polygon coordinates."""
|
||||
contours, _ = cv2.findContours(
|
||||
mask.astype(np.uint8),
|
||||
cv2.RETR_EXTERNAL,
|
||||
cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
|
||||
polygons = []
|
||||
for contour in contours:
|
||||
epsilon = epsilon_factor * cv2.arcLength(contour, True)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
polygon = approx.squeeze().tolist()
|
||||
if len(polygon) >= 3: # Valid polygon
|
||||
polygons.append(polygon)
|
||||
|
||||
return polygons
|
||||
|
||||
def polygons_to_mask(polygons, height, width):
|
||||
"""Convert polygons back to binary mask."""
|
||||
mask = np.zeros((height, width), dtype=np.uint8)
|
||||
for polygon in polygons:
|
||||
pts = np.array(polygon, dtype=np.int32)
|
||||
cv2.fillPoly(mask, [pts], 1)
|
||||
return mask.astype(bool)
|
||||
```
|
||||
|
||||
### Multi-scale segmentation
|
||||
|
||||
```python
|
||||
def multiscale_segment(image, predictor, point, scales=[0.5, 1.0, 2.0]):
|
||||
"""Generate masks at multiple scales and combine."""
|
||||
h, w = image.shape[:2]
|
||||
masks_all = []
|
||||
|
||||
for scale in scales:
|
||||
# Resize image
|
||||
new_h, new_w = int(h * scale), int(w * scale)
|
||||
scaled_image = cv2.resize(image, (new_w, new_h))
|
||||
scaled_point = (point * scale).astype(int)
|
||||
|
||||
# Segment
|
||||
predictor.set_image(scaled_image)
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=scaled_point.reshape(1, 2),
|
||||
point_labels=np.array([1]),
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
# Resize mask back
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
original_mask = cv2.resize(best_mask.astype(np.uint8), (w, h)) > 0.5
|
||||
|
||||
masks_all.append(original_mask)
|
||||
|
||||
# Combine masks (majority voting)
|
||||
combined = np.stack(masks_all, axis=0)
|
||||
final_mask = np.sum(combined, axis=0) >= len(scales) // 2 + 1
|
||||
|
||||
return final_mask
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### TensorRT acceleration
|
||||
|
||||
```python
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit
|
||||
|
||||
def export_to_tensorrt(onnx_path, engine_path, fp16=True):
|
||||
"""Convert ONNX model to TensorRT engine."""
|
||||
logger = trt.Logger(trt.Logger.WARNING)
|
||||
builder = trt.Builder(logger)
|
||||
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
|
||||
with open(onnx_path, 'rb') as f:
|
||||
if not parser.parse(f.read()):
|
||||
for error in range(parser.num_errors):
|
||||
print(parser.get_error(error))
|
||||
return None
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = 1 << 30 # 1GB
|
||||
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
engine = builder.build_engine(network, config)
|
||||
|
||||
with open(engine_path, 'wb') as f:
|
||||
f.write(engine.serialize())
|
||||
|
||||
return engine
|
||||
```
|
||||
|
||||
### Memory-efficient inference
|
||||
|
||||
```python
|
||||
class MemoryEfficientSAM:
|
||||
def __init__(self, checkpoint, model_type="vit_b"):
|
||||
self.sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
||||
self.sam.eval()
|
||||
self.predictor = None
|
||||
|
||||
def __enter__(self):
|
||||
self.sam.to("cuda")
|
||||
self.predictor = SamPredictor(self.sam)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.sam.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def segment(self, image, points, labels):
|
||||
self.predictor.set_image(image)
|
||||
masks, scores, _ = self.predictor.predict(
|
||||
point_coords=points,
|
||||
point_labels=labels,
|
||||
multimask_output=True
|
||||
)
|
||||
return masks, scores
|
||||
|
||||
# Usage with context manager (auto-cleanup)
|
||||
with MemoryEfficientSAM("sam_vit_b_01ec64.pth") as sam:
|
||||
masks, scores = sam.segment(image, points, labels)
|
||||
# CUDA memory freed automatically
|
||||
```
|
||||
|
||||
## Dataset Generation
|
||||
|
||||
### Create segmentation dataset
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
def generate_dataset(images_dir, output_dir, mask_generator):
|
||||
"""Generate segmentation dataset from images."""
|
||||
annotations = []
|
||||
|
||||
for img_path in Path(images_dir).glob("*.jpg"):
|
||||
image = cv2.imread(str(img_path))
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Generate masks
|
||||
masks = mask_generator.generate(image)
|
||||
|
||||
# Filter high-quality masks
|
||||
good_masks = [m for m in masks if m["predicted_iou"] > 0.9]
|
||||
|
||||
# Save annotations
|
||||
for i, mask_data in enumerate(good_masks):
|
||||
annotation = {
|
||||
"image_id": img_path.stem,
|
||||
"mask_id": i,
|
||||
"bbox": mask_data["bbox"],
|
||||
"area": mask_data["area"],
|
||||
"segmentation": mask_to_rle(mask_data["segmentation"]),
|
||||
"predicted_iou": mask_data["predicted_iou"],
|
||||
"stability_score": mask_data["stability_score"]
|
||||
}
|
||||
annotations.append(annotation)
|
||||
|
||||
# Save dataset
|
||||
with open(output_dir / "annotations.json", "w") as f:
|
||||
json.dump(annotations, f)
|
||||
|
||||
return annotations
|
||||
```
|
||||
@@ -0,0 +1,484 @@
|
||||
# Segment Anything Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### CUDA not available
|
||||
|
||||
**Error**: `RuntimeError: CUDA not available`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check CUDA availability
|
||||
import torch
|
||||
print(torch.cuda.is_available())
|
||||
print(torch.version.cuda)
|
||||
|
||||
# Install PyTorch with CUDA
|
||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# If CUDA works but SAM doesn't use it
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam.to("cuda") # Explicitly move to GPU
|
||||
```
|
||||
|
||||
### Import errors
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'segment_anything'`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install from GitHub
|
||||
pip install git+https://github.com/facebookresearch/segment-anything.git
|
||||
|
||||
# Or clone and install
|
||||
git clone https://github.com/facebookresearch/segment-anything.git
|
||||
cd segment-anything
|
||||
pip install -e .
|
||||
|
||||
# Verify installation
|
||||
python -c "from segment_anything import sam_model_registry; print('OK')"
|
||||
```
|
||||
|
||||
### Missing dependencies
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'cv2'` or similar
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install all optional dependencies
|
||||
pip install opencv-python pycocotools matplotlib onnxruntime onnx
|
||||
|
||||
# For pycocotools on Windows
|
||||
pip install pycocotools-windows
|
||||
```
|
||||
|
||||
## Model Loading Issues
|
||||
|
||||
### Checkpoint not found
|
||||
|
||||
**Error**: `FileNotFoundError: checkpoint file not found`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Download correct checkpoint
|
||||
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
||||
|
||||
# Verify file integrity
|
||||
md5sum sam_vit_h_4b8939.pth
|
||||
# Expected: a7bf3b02f3ebf1267aba913ff637d9a2
|
||||
|
||||
# Use absolute path
|
||||
sam = sam_model_registry["vit_h"](checkpoint="/full/path/to/sam_vit_h_4b8939.pth")
|
||||
```
|
||||
|
||||
### Model type mismatch
|
||||
|
||||
**Error**: `KeyError: 'unexpected key in state_dict'`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Ensure model type matches checkpoint
|
||||
# vit_h checkpoint → vit_h model
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
|
||||
# vit_l checkpoint → vit_l model
|
||||
sam = sam_model_registry["vit_l"](checkpoint="sam_vit_l_0b3195.pth")
|
||||
|
||||
# vit_b checkpoint → vit_b model
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
```
|
||||
|
||||
### Out of memory during load
|
||||
|
||||
**Error**: `CUDA out of memory` during model loading
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
|
||||
# Load to CPU first, then move
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
sam.to("cuda")
|
||||
|
||||
# Use half precision
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam = sam.half()
|
||||
sam.to("cuda")
|
||||
```
|
||||
|
||||
## Inference Issues
|
||||
|
||||
### Image format errors
|
||||
|
||||
**Error**: `ValueError: expected input to have 3 channels`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import cv2
|
||||
|
||||
# Ensure RGB format
|
||||
image = cv2.imread("image.jpg")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR to RGB
|
||||
|
||||
# Convert grayscale to RGB
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
|
||||
# Handle RGBA
|
||||
if image.shape[2] == 4:
|
||||
image = image[:, :, :3] # Drop alpha channel
|
||||
```
|
||||
|
||||
### Coordinate errors
|
||||
|
||||
**Error**: `IndexError: index out of bounds` or incorrect mask location
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Ensure points are (x, y) not (row, col)
|
||||
# x = column index, y = row index
|
||||
point = np.array([[x, y]]) # Correct
|
||||
|
||||
# Verify coordinates are within image bounds
|
||||
h, w = image.shape[:2]
|
||||
assert 0 <= x < w and 0 <= y < h, "Point outside image"
|
||||
|
||||
# For bounding boxes: [x1, y1, x2, y2]
|
||||
box = np.array([x1, y1, x2, y2])
|
||||
assert x1 < x2 and y1 < y2, "Invalid box coordinates"
|
||||
```
|
||||
|
||||
### Empty or incorrect masks
|
||||
|
||||
**Problem**: Masks don't match expected object
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Try multiple prompts
|
||||
input_points = np.array([[x1, y1], [x2, y2]])
|
||||
input_labels = np.array([1, 1]) # Multiple foreground points
|
||||
|
||||
# Add background points
|
||||
input_points = np.array([[obj_x, obj_y], [bg_x, bg_y]])
|
||||
input_labels = np.array([1, 0]) # 1=foreground, 0=background
|
||||
|
||||
# Use box prompt for large objects
|
||||
box = np.array([x1, y1, x2, y2])
|
||||
masks, scores, _ = predictor.predict(box=box, multimask_output=False)
|
||||
|
||||
# Combine box and point
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=np.array([[center_x, center_y]]),
|
||||
point_labels=np.array([1]),
|
||||
box=np.array([x1, y1, x2, y2]),
|
||||
multimask_output=True
|
||||
)
|
||||
|
||||
# Check scores and select best
|
||||
print(f"Scores: {scores}")
|
||||
best_mask = masks[np.argmax(scores)]
|
||||
```
|
||||
|
||||
### Slow inference
|
||||
|
||||
**Problem**: Prediction takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
|
||||
# Reuse image embeddings
|
||||
predictor.set_image(image) # Compute once
|
||||
for point in points:
|
||||
masks, _, _ = predictor.predict(...) # Fast, reuses embeddings
|
||||
|
||||
# Reduce automatic generation points
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=16, # Default is 32
|
||||
)
|
||||
|
||||
# Use ONNX for deployment
|
||||
# Export: python scripts/export_onnx_model.py --return-single-mask
|
||||
```
|
||||
|
||||
## Automatic Mask Generation Issues
|
||||
|
||||
### Too many masks
|
||||
|
||||
**Problem**: Generating thousands of overlapping masks
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=16, # Reduce from 32
|
||||
pred_iou_thresh=0.92, # Increase from 0.88
|
||||
stability_score_thresh=0.98, # Increase from 0.95
|
||||
box_nms_thresh=0.5, # More aggressive NMS
|
||||
min_mask_region_area=500, # Remove small masks
|
||||
)
|
||||
```
|
||||
|
||||
### Too few masks
|
||||
|
||||
**Problem**: Missing objects in automatic generation
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
points_per_side=64, # Increase density
|
||||
pred_iou_thresh=0.80, # Lower threshold
|
||||
stability_score_thresh=0.85, # Lower threshold
|
||||
crop_n_layers=2, # Add multi-scale
|
||||
min_mask_region_area=0, # Keep all masks
|
||||
)
|
||||
```
|
||||
|
||||
### Small objects missed
|
||||
|
||||
**Problem**: Automatic generation misses small objects
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use crop layers for multi-scale detection
|
||||
mask_generator = SamAutomaticMaskGenerator(
|
||||
model=sam,
|
||||
crop_n_layers=2,
|
||||
crop_n_points_downscale_factor=1, # Don't reduce points in crops
|
||||
min_mask_region_area=10, # Very small minimum
|
||||
)
|
||||
|
||||
# Or process image patches
|
||||
def segment_with_patches(image, patch_size=512, overlap=64):
|
||||
h, w = image.shape[:2]
|
||||
all_masks = []
|
||||
|
||||
for y in range(0, h, patch_size - overlap):
|
||||
for x in range(0, w, patch_size - overlap):
|
||||
patch = image[y:y+patch_size, x:x+patch_size]
|
||||
masks = mask_generator.generate(patch)
|
||||
|
||||
# Offset masks to original coordinates
|
||||
for m in masks:
|
||||
m['bbox'][0] += x
|
||||
m['bbox'][1] += y
|
||||
# Offset segmentation mask too
|
||||
|
||||
all_masks.extend(masks)
|
||||
|
||||
return all_masks
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
|
||||
|
||||
# Clear cache between images
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Process images sequentially, not batched
|
||||
for image in images:
|
||||
predictor.set_image(image)
|
||||
masks, _, _ = predictor.predict(...)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Reduce image size
|
||||
max_size = 1024
|
||||
h, w = image.shape[:2]
|
||||
if max(h, w) > max_size:
|
||||
scale = max_size / max(h, w)
|
||||
image = cv2.resize(image, (int(w*scale), int(h*scale)))
|
||||
|
||||
# Use CPU for large batch processing
|
||||
sam.to("cpu")
|
||||
```
|
||||
|
||||
### RAM out of memory
|
||||
|
||||
**Problem**: System runs out of RAM
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Process images one at a time
|
||||
for img_path in image_paths:
|
||||
image = cv2.imread(img_path)
|
||||
masks = process_image(image)
|
||||
save_results(masks)
|
||||
del image, masks
|
||||
gc.collect()
|
||||
|
||||
# Use generators instead of lists
|
||||
def generate_masks_lazy(image_paths):
|
||||
for path in image_paths:
|
||||
image = cv2.imread(path)
|
||||
masks = mask_generator.generate(image)
|
||||
yield path, masks
|
||||
```
|
||||
|
||||
## ONNX Export Issues
|
||||
|
||||
### Export fails
|
||||
|
||||
**Error**: Various export errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install correct ONNX version
|
||||
pip install onnx==1.14.0 onnxruntime==1.15.0
|
||||
|
||||
# Use correct opset version
|
||||
python scripts/export_onnx_model.py \
|
||||
--checkpoint sam_vit_h_4b8939.pth \
|
||||
--model-type vit_h \
|
||||
--output sam.onnx \
|
||||
--opset 17
|
||||
```
|
||||
|
||||
### ONNX runtime errors
|
||||
|
||||
**Error**: `ONNXRuntimeError` during inference
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import onnxruntime
|
||||
|
||||
# Check available providers
|
||||
print(onnxruntime.get_available_providers())
|
||||
|
||||
# Use CPU provider if GPU fails
|
||||
session = onnxruntime.InferenceSession(
|
||||
"sam.onnx",
|
||||
providers=['CPUExecutionProvider']
|
||||
)
|
||||
|
||||
# Verify input shapes
|
||||
for input in session.get_inputs():
|
||||
print(f"{input.name}: {input.shape}")
|
||||
```
|
||||
|
||||
## HuggingFace Integration Issues
|
||||
|
||||
### Processor errors
|
||||
|
||||
**Error**: Issues with SamProcessor
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
from transformers import SamModel, SamProcessor
|
||||
|
||||
# Use matching processor and model
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-huge")
|
||||
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||
|
||||
# Ensure input format
|
||||
input_points = [[[x, y]]] # Nested list for batch dimension
|
||||
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
||||
|
||||
# Post-process correctly
|
||||
masks = processor.image_processor.post_process_masks(
|
||||
outputs.pred_masks.cpu(),
|
||||
inputs["original_sizes"].cpu(),
|
||||
inputs["reshaped_input_sizes"].cpu()
|
||||
)
|
||||
```
|
||||
|
||||
## Quality Issues
|
||||
|
||||
### Jagged mask edges
|
||||
|
||||
**Problem**: Masks have rough, pixelated edges
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import cv2
|
||||
from scipy import ndimage
|
||||
|
||||
def smooth_mask(mask, sigma=2):
|
||||
"""Smooth mask edges."""
|
||||
# Gaussian blur
|
||||
smooth = ndimage.gaussian_filter(mask.astype(float), sigma=sigma)
|
||||
return smooth > 0.5
|
||||
|
||||
def refine_edges(mask, kernel_size=5):
|
||||
"""Refine mask edges with morphological operations."""
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
|
||||
# Close small gaps
|
||||
closed = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
|
||||
# Open to remove noise
|
||||
opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel)
|
||||
return opened.astype(bool)
|
||||
```
|
||||
|
||||
### Incomplete segmentation
|
||||
|
||||
**Problem**: Mask doesn't cover entire object
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Add multiple points
|
||||
input_points = np.array([
|
||||
[obj_center_x, obj_center_y],
|
||||
[obj_left_x, obj_center_y],
|
||||
[obj_right_x, obj_center_y],
|
||||
[obj_center_x, obj_top_y],
|
||||
[obj_center_x, obj_bottom_y]
|
||||
])
|
||||
input_labels = np.array([1, 1, 1, 1, 1])
|
||||
|
||||
# Use bounding box
|
||||
masks, _, _ = predictor.predict(
|
||||
box=np.array([x1, y1, x2, y2]),
|
||||
multimask_output=False
|
||||
)
|
||||
|
||||
# Iterative refinement
|
||||
mask_input = None
|
||||
for point in points:
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=point.reshape(1, 2),
|
||||
point_labels=np.array([1]),
|
||||
mask_input=mask_input,
|
||||
multimask_output=False
|
||||
)
|
||||
mask_input = logits
|
||||
```
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `CUDA out of memory` | GPU memory full | Use smaller model, clear cache |
|
||||
| `expected 3 channels` | Wrong image format | Convert to RGB |
|
||||
| `index out of bounds` | Invalid coordinates | Check point/box bounds |
|
||||
| `checkpoint not found` | Wrong path | Use absolute path |
|
||||
| `unexpected key` | Model/checkpoint mismatch | Match model type |
|
||||
| `invalid box coordinates` | x1 > x2 or y1 > y2 | Fix box format |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/facebookresearch/segment-anything/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co
|
||||
3. **Paper**: https://arxiv.org/abs/2304.02643
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Python version
|
||||
- PyTorch version: `python -c "import torch; print(torch.__version__)"`
|
||||
- CUDA version: `python -c "import torch; print(torch.version.cuda)"`
|
||||
- SAM model type (vit_b/l/h)
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
Reference in New Issue
Block a user